diff --git a/.gitattributes b/.gitattributes index f544f0f54783f8f245da01807c81600528af66cb..c78f1ab778a6c54558a8c3047cb60791e8c528aa 100644 --- a/.gitattributes +++ b/.gitattributes @@ -414,3 +414,8 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/ .venv/lib/python3.11/site-packages/cv2/cv2.abi3.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_engines_runtime_compiled.so.9 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/tokenizers/tokenizers.abi3.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/blake3/blake3.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/blake3/blake3.cpython-311-x86_64-linux-gnu.so b/.venv/lib/python3.11/site-packages/blake3/blake3.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..f0e7348bcaf7710d2ce7f33c9574424ddc294133 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/blake3/blake3.cpython-311-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:859ec8e9dfc40a1b944a570a9f941f68d3744c38d834da744bdc5e7597b9bde7 +size 964720 diff --git a/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/LICENSE b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ae3828669397e90260062082642f57cb75c6ffed --- /dev/null +++ b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/LICENSE @@ -0,0 +1,23 @@ +Httplib2 Software License + +Copyright (c) 2006 by Joe Gregorio + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of the Software, +and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/METADATA b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..933bfa0bda2cc032d975ce8ab5be357a840cb464 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/METADATA @@ -0,0 +1,75 @@ +Metadata-Version: 2.1 +Name: httplib2 +Version: 0.22.0 +Summary: A comprehensive HTTP client library. +Home-page: https://github.com/httplib2/httplib2 +Author: Joe Gregorio +Author-email: joe@bitworking.org +License: MIT +Classifier: Development Status :: 4 - Beta +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.4 +Classifier: Programming Language :: Python :: 3.5 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Internet :: WWW/HTTP +Classifier: Topic :: Software Development :: Libraries +Requires-Python: >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.* +License-File: LICENSE +Requires-Dist: pyparsing (<3,>=2.4.2) ; python_version < "3.0" +Requires-Dist: pyparsing (!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2) ; python_version > "3.0" + + + +A comprehensive HTTP client library, ``httplib2`` supports many features left out of other HTTP libraries. + +**HTTP and HTTPS** + HTTPS support is only available if the socket module was compiled with SSL support. + + +**Keep-Alive** + Supports HTTP 1.1 Keep-Alive, keeping the socket open and performing multiple requests over the same connection if possible. + + +**Authentication** + The following three types of HTTP Authentication are supported. These can be used over both HTTP and HTTPS. + + * Digest + * Basic + * WSSE + +**Caching** + The module can optionally operate with a private cache that understands the Cache-Control: + header and uses both the ETag and Last-Modified cache validators. Both file system + and memcached based caches are supported. + + +**All Methods** + The module can handle any HTTP request method, not just GET and POST. + + +**Redirects** + Automatically follows 3XX redirects on GETs. + + +**Compression** + Handles both 'deflate' and 'gzip' types of compression. + + +**Lost update support** + Automatically adds back ETags into PUT requests to resources we have already cached. This implements Section 3.2 of Detecting the Lost Update Problem Using Unreserved Checkout + + +**Unit Tested** + A large and growing set of unit tests. diff --git a/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/RECORD b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..e5e3c7517ef4960fd6442d6526a266df177bdc60 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/RECORD @@ -0,0 +1,19 @@ +httplib2-0.22.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +httplib2-0.22.0.dist-info/LICENSE,sha256=WJ7sOPct8r4gNxHTuMvs6bkIxef_ALw8q39juunjZrQ,1086 +httplib2-0.22.0.dist-info/METADATA,sha256=KKy58CVIaYnc6oBjD0upeaA1dgTbB6z7JK5I6FmDSEM,2618 +httplib2-0.22.0.dist-info/RECORD,, +httplib2-0.22.0.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92 +httplib2-0.22.0.dist-info/top_level.txt,sha256=BEY8ChKwagUWmu9x8yN9JObJpZKNeWCr1E-sIECb56I,9 +httplib2/__init__.py,sha256=UOzaxGwGweHiLsxKBc39_Ez0N8aDHwAu--TkTbYLWCw,69396 +httplib2/__pycache__/__init__.cpython-311.pyc,, +httplib2/__pycache__/auth.cpython-311.pyc,, +httplib2/__pycache__/certs.cpython-311.pyc,, +httplib2/__pycache__/error.cpython-311.pyc,, +httplib2/__pycache__/iri2uri.cpython-311.pyc,, +httplib2/__pycache__/socks.cpython-311.pyc,, +httplib2/auth.py,sha256=Fcb7KqrqRCpUaGD-5l84nT5F2aU6ore6ujWLk5idK0o,2158 +httplib2/cacerts.txt,sha256=AbmYP54iGeKRQ1APtfQvHlo9wul2jVmznmbTzy2fTV4,137365 +httplib2/certs.py,sha256=guhfjMNhDdKJEyYBb5ZyLxVO5q1I7Y_P-4BG8MniBk8,971 +httplib2/error.py,sha256=GyqPUvZeKdVLq0f3xg0uX4rjtv7jVGJuPerAdyc-jfk,954 +httplib2/iri2uri.py,sha256=PhIzEzeR6C73l7piwrNAJlVvlWgsqxtJTlFeXgznzQo,4153 +httplib2/socks.py,sha256=oaeEOnT2rkTNm6wnn0CSdhWzVaVshnnkAKiP4kxKzzc,19701 diff --git a/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..1f37c02f2eb2e26b306202feaccb31e522b8b169 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.40.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/top_level.txt b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..fb881ece05bf2f998d84be544449ac976da2897c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/httplib2-0.22.0.dist-info/top_level.txt @@ -0,0 +1 @@ +httplib2 diff --git a/.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc b/.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3ccf8cf8a1301f44f911bbe2556a076f7c1de1e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4921e03ee20ee695d4deb4b8dd20aefdaed193d9e132919db29f632cf56dfcf +size 101525 diff --git a/.venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc b/.venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4a74a990dd1ce024f2726eacd556a9a09bb0374 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:645124158744aa51ea31bcde2a74f677aff393980f2a491fe0f2436733f1fbff +size 163155 diff --git a/.venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so b/.venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..be84e547a28e4d318e94a4a9aa1505208c0cd66b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:227acb07aa26025745fcabc2e60f436ead690b0bf9050835e73f38873ce18794 +size 812104 diff --git a/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/License.txt b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/License.txt new file mode 100644 index 0000000000000000000000000000000000000000..b491c70e0aef319022ded661e111ddbd45b8a17f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/License.txt @@ -0,0 +1,1568 @@ +End User License Agreement +-------------------------- + + +Preface +------- + +The Software License Agreement in Chapter 1 and the Supplement +in Chapter 2 contain license terms and conditions that govern +the use of NVIDIA software. By accepting this agreement, you +agree to comply with all the terms and conditions applicable +to the product(s) included herein. + + +NVIDIA Driver + + +Description + +This package contains the operating system driver and +fundamental system software components for NVIDIA GPUs. + + +NVIDIA CUDA Toolkit + + +Description + +The NVIDIA CUDA Toolkit provides command-line and graphical +tools for building, debugging and optimizing the performance +of applications accelerated by NVIDIA GPUs, runtime and math +libraries, and documentation including programming guides, +user manuals, and API references. + + +Default Install Location of CUDA Toolkit + +Windows platform: + +%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v#.# + +Linux platform: + +/usr/local/cuda-#.# + +Mac platform: + +/Developer/NVIDIA/CUDA-#.# + + +NVIDIA CUDA Samples + + +Description + +This package includes over 100+ CUDA examples that demonstrate +various CUDA programming principles, and efficient CUDA +implementation of algorithms in specific application domains. + + +Default Install Location of CUDA Samples + +Windows platform: + +%ProgramData%\NVIDIA Corporation\CUDA Samples\v#.# + +Linux platform: + +/usr/local/cuda-#.#/samples + +and + +$HOME/NVIDIA_CUDA-#.#_Samples + +Mac platform: + +/Developer/NVIDIA/CUDA-#.#/samples + + +NVIDIA Nsight Visual Studio Edition (Windows only) + + +Description + +NVIDIA Nsight Development Platform, Visual Studio Edition is a +development environment integrated into Microsoft Visual +Studio that provides tools for debugging, profiling, analyzing +and optimizing your GPU computing and graphics applications. + + +Default Install Location of Nsight Visual Studio Edition + +Windows platform: + +%ProgramFiles(x86)%\NVIDIA Corporation\Nsight Visual Studio Edition #.# + + +1. License Agreement for NVIDIA Software Development Kits +--------------------------------------------------------- + + +Release Date: July 26, 2018 +--------------------------- + + +Important NoticeRead before downloading, installing, +copying or using the licensed software: +------------------------------------------------------- + +This license agreement, including exhibits attached +("Agreement”) is a legal agreement between you and NVIDIA +Corporation ("NVIDIA") and governs your use of a NVIDIA +software development kit (“SDK”). + +Each SDK has its own set of software and materials, but here +is a description of the types of items that may be included in +a SDK: source code, header files, APIs, data sets and assets +(examples include images, textures, models, scenes, videos, +native API input/output files), binary software, sample code, +libraries, utility programs, programming code and +documentation. + +This Agreement can be accepted only by an adult of legal age +of majority in the country in which the SDK is used. + +If you are entering into this Agreement on behalf of a company +or other legal entity, you represent that you have the legal +authority to bind the entity to this Agreement, in which case +“you” will mean the entity you represent. + +If you don’t have the required age or authority to accept +this Agreement, or if you don’t accept all the terms and +conditions of this Agreement, do not download, install or use +the SDK. + +You agree to use the SDK only for purposes that are permitted +by (a) this Agreement, and (b) any applicable law, regulation +or generally accepted practices or guidelines in the relevant +jurisdictions. + + +1.1. License + + +1.1.1. License Grant + +Subject to the terms of this Agreement, NVIDIA hereby grants +you a non-exclusive, non-transferable license, without the +right to sublicense (except as expressly provided in this +Agreement) to: + + 1. Install and use the SDK, + + 2. Modify and create derivative works of sample source code + delivered in the SDK, and + + 3. Distribute those portions of the SDK that are identified + in this Agreement as distributable, as incorporated in + object code format into a software application that meets + the distribution requirements indicated in this Agreement. + + +1.1.2. Distribution Requirements + +These are the distribution requirements for you to exercise +the distribution grant: + + 1. Your application must have material additional + functionality, beyond the included portions of the SDK. + + 2. The distributable portions of the SDK shall only be + accessed by your application. + + 3. The following notice shall be included in modifications + and derivative works of sample source code distributed: + “This software contains source code provided by NVIDIA + Corporation.” + + 4. Unless a developer tool is identified in this Agreement + as distributable, it is delivered for your internal use + only. + + 5. The terms under which you distribute your application + must be consistent with the terms of this Agreement, + including (without limitation) terms relating to the + license grant and license restrictions and protection of + NVIDIA’s intellectual property rights. Additionally, you + agree that you will protect the privacy, security and + legal rights of your application users. + + 6. You agree to notify NVIDIA in writing of any known or + suspected distribution or use of the SDK not in compliance + with the requirements of this Agreement, and to enforce + the terms of your agreements with respect to distributed + SDK. + + +1.1.3. Authorized Users + +You may allow employees and contractors of your entity or of +your subsidiary(ies) to access and use the SDK from your +secure network to perform work on your behalf. + +If you are an academic institution you may allow users +enrolled or employed by the academic institution to access and +use the SDK from your secure network. + +You are responsible for the compliance with the terms of this +Agreement by your authorized users. If you become aware that +your authorized users didn’t follow the terms of this +Agreement, you agree to take reasonable steps to resolve the +non-compliance and prevent new occurrences. + + +1.1.4. Pre-Release SDK + +The SDK versions identified as alpha, beta, preview or +otherwise as pre-release, may not be fully functional, may +contain errors or design flaws, and may have reduced or +different security, privacy, accessibility, availability, and +reliability standards relative to commercial versions of +NVIDIA software and materials. Use of a pre-release SDK may +result in unexpected results, loss of data, project delays or +other unpredictable damage or loss. + +You may use a pre-release SDK at your own risk, understanding +that pre-release SDKs are not intended for use in production +or business-critical systems. + +NVIDIA may choose not to make available a commercial version +of any pre-release SDK. NVIDIA may also choose to abandon +development and terminate the availability of a pre-release +SDK at any time without liability. + + +1.1.5. Updates + +NVIDIA may, at its option, make available patches, workarounds +or other updates to this SDK. Unless the updates are provided +with their separate governing terms, they are deemed part of +the SDK licensed to you as provided in this Agreement. You +agree that the form and content of the SDK that NVIDIA +provides may change without prior notice to you. While NVIDIA +generally maintains compatibility between versions, NVIDIA may +in some cases make changes that introduce incompatibilities in +future versions of the SDK. + + +1.1.6. Third Party Licenses + +The SDK may come bundled with, or otherwise include or be +distributed with, third party software licensed by a NVIDIA +supplier and/or open source software provided under an open +source license. Use of third party software is subject to the +third-party license terms, or in the absence of third party +terms, the terms of this Agreement. Copyright to third party +software is held by the copyright holders indicated in the +third-party software or license. + + +1.1.7. Reservation of Rights + +NVIDIA reserves all rights, title, and interest in and to the +SDK, not expressly granted to you under this Agreement. + + +1.2. Limitations + +The following license limitations apply to your use of the +SDK: + + 1. You may not reverse engineer, decompile or disassemble, + or remove copyright or other proprietary notices from any + portion of the SDK or copies of the SDK. + + 2. Except as expressly provided in this Agreement, you may + not copy, sell, rent, sublicense, transfer, distribute, + modify, or create derivative works of any portion of the + SDK. For clarity, you may not distribute or sublicense the + SDK as a stand-alone product. + + 3. Unless you have an agreement with NVIDIA for this + purpose, you may not indicate that an application created + with the SDK is sponsored or endorsed by NVIDIA. + + 4. You may not bypass, disable, or circumvent any + encryption, security, digital rights management or + authentication mechanism in the SDK. + + 5. You may not use the SDK in any manner that would cause it + to become subject to an open source software license. As + examples, licenses that require as a condition of use, + modification, and/or distribution that the SDK be: + + a. Disclosed or distributed in source code form; + + b. Licensed for the purpose of making derivative works; + or + + c. Redistributable at no charge. + + 6. Unless you have an agreement with NVIDIA for this + purpose, you may not use the SDK with any system or + application where the use or failure of the system or + application can reasonably be expected to threaten or + result in personal injury, death, or catastrophic loss. + Examples include use in avionics, navigation, military, + medical, life support or other life critical applications. + NVIDIA does not design, test or manufacture the SDK for + these critical uses and NVIDIA shall not be liable to you + or any third party, in whole or in part, for any claims or + damages arising from such uses. + + 7. You agree to defend, indemnify and hold harmless NVIDIA + and its affiliates, and their respective employees, + contractors, agents, officers and directors, from and + against any and all claims, damages, obligations, losses, + liabilities, costs or debt, fines, restitutions and + expenses (including but not limited to attorney’s fees + and costs incident to establishing the right of + indemnification) arising out of or related to your use of + the SDK outside of the scope of this Agreement, or not in + compliance with its terms. + + +1.3. Ownership + + 1. NVIDIA or its licensors hold all rights, title and + interest in and to the SDK and its modifications and + derivative works, including their respective intellectual + property rights, subject to your rights described in this + section. This SDK may include software and materials from + NVIDIA’s licensors, and these licensors are intended + third party beneficiaries that may enforce this Agreement + with respect to their intellectual property rights. + + 2. You hold all rights, title and interest in and to your + applications and your derivative works of the sample + source code delivered in the SDK, including their + respective intellectual property rights, subject to + NVIDIA’s rights described in this section. + + 3. You may, but don’t have to, provide to NVIDIA + suggestions, feature requests or other feedback regarding + the SDK, including possible enhancements or modifications + to the SDK. For any feedback that you voluntarily provide, + you hereby grant NVIDIA and its affiliates a perpetual, + non-exclusive, worldwide, irrevocable license to use, + reproduce, modify, license, sublicense (through multiple + tiers of sublicensees), and distribute (through multiple + tiers of distributors) it without the payment of any + royalties or fees to you. NVIDIA will use feedback at its + choice. NVIDIA is constantly looking for ways to improve + its products, so you may send feedback to NVIDIA through + the developer portal at https://developer.nvidia.com. + + +1.4. No Warranties + +THE SDK IS PROVIDED BY NVIDIA “AS IS” AND “WITH ALL +FAULTS.” TO THE MAXIMUM EXTENT PERMITTED BY LAW, NVIDIA AND +ITS AFFILIATES EXPRESSLY DISCLAIM ALL WARRANTIES OF ANY KIND +OR NATURE, WHETHER EXPRESS, IMPLIED OR STATUTORY, INCLUDING, +BUT NOT LIMITED TO, ANY WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE, TITLE, NON-INFRINGEMENT, OR THE +ABSENCE OF ANY DEFECTS THEREIN, WHETHER LATENT OR PATENT. NO +WARRANTY IS MADE ON THE BASIS OF TRADE USAGE, COURSE OF +DEALING OR COURSE OF TRADE. + + +1.5. Limitation of Liability + +TO THE MAXIMUM EXTENT PERMITTED BY LAW, NVIDIA AND ITS +AFFILIATES SHALL NOT BE LIABLE FOR ANY SPECIAL, INCIDENTAL, +PUNITIVE OR CONSEQUENTIAL DAMAGES, OR ANY LOST PROFITS, LOSS +OF USE, LOSS OF DATA OR LOSS OF GOODWILL, OR THE COSTS OF +PROCURING SUBSTITUTE PRODUCTS, ARISING OUT OF OR IN CONNECTION +WITH THIS AGREEMENT OR THE USE OR PERFORMANCE OF THE SDK, +WHETHER SUCH LIABILITY ARISES FROM ANY CLAIM BASED UPON BREACH +OF CONTRACT, BREACH OF WARRANTY, TORT (INCLUDING NEGLIGENCE), +PRODUCT LIABILITY OR ANY OTHER CAUSE OF ACTION OR THEORY OF +LIABILITY. IN NO EVENT WILL NVIDIA’S AND ITS AFFILIATES +TOTAL CUMULATIVE LIABILITY UNDER OR ARISING OUT OF THIS +AGREEMENT EXCEED US$10.00. THE NATURE OF THE LIABILITY OR THE +NUMBER OF CLAIMS OR SUITS SHALL NOT ENLARGE OR EXTEND THIS +LIMIT. + +These exclusions and limitations of liability shall apply +regardless if NVIDIA or its affiliates have been advised of +the possibility of such damages, and regardless of whether a +remedy fails its essential purpose. These exclusions and +limitations of liability form an essential basis of the +bargain between the parties, and, absent any of these +exclusions or limitations of liability, the provisions of this +Agreement, including, without limitation, the economic terms, +would be substantially different. + + +1.6. Termination + + 1. This Agreement will continue to apply until terminated by + either you or NVIDIA as described below. + + 2. If you want to terminate this Agreement, you may do so by + stopping to use the SDK. + + 3. NVIDIA may, at any time, terminate this Agreement if: + + a. (i) you fail to comply with any term of this + Agreement and the non-compliance is not fixed within + thirty (30) days following notice from NVIDIA (or + immediately if you violate NVIDIA’s intellectual + property rights); + + b. (ii) you commence or participate in any legal + proceeding against NVIDIA with respect to the SDK; or + + c. (iii) NVIDIA decides to no longer provide the SDK in + a country or, in NVIDIA’s sole discretion, the + continued use of it is no longer commercially viable. + + 4. Upon any termination of this Agreement, you agree to + promptly discontinue use of the SDK and destroy all copies + in your possession or control. Your prior distributions in + accordance with this Agreement are not affected by the + termination of this Agreement. Upon written request, you + will certify in writing that you have complied with your + commitments under this section. Upon any termination of + this Agreement all provisions survive except for the + license grant provisions. + + +1.7. General + +If you wish to assign this Agreement or your rights and +obligations, including by merger, consolidation, dissolution +or operation of law, contact NVIDIA to ask for permission. Any +attempted assignment not approved by NVIDIA in writing shall +be void and of no effect. NVIDIA may assign, delegate or +transfer this Agreement and its rights and obligations, and if +to a non-affiliate you will be notified. + +You agree to cooperate with NVIDIA and provide reasonably +requested information to verify your compliance with this +Agreement. + +This Agreement will be governed in all respects by the laws of +the United States and of the State of Delaware as those laws +are applied to contracts entered into and performed entirely +within Delaware by Delaware residents, without regard to the +conflicts of laws principles. The United Nations Convention on +Contracts for the International Sale of Goods is specifically +disclaimed. You agree to all terms of this Agreement in the +English language. + +The state or federal courts residing in Santa Clara County, +California shall have exclusive jurisdiction over any dispute +or claim arising out of this Agreement. Notwithstanding this, +you agree that NVIDIA shall still be allowed to apply for +injunctive remedies or an equivalent type of urgent legal +relief in any jurisdiction. + +If any court of competent jurisdiction determines that any +provision of this Agreement is illegal, invalid or +unenforceable, such provision will be construed as limited to +the extent necessary to be consistent with and fully +enforceable under the law and the remaining provisions will +remain in full force and effect. Unless otherwise specified, +remedies are cumulative. + +Each party acknowledges and agrees that the other is an +independent contractor in the performance of this Agreement. + +The SDK has been developed entirely at private expense and is +“commercial items” consisting of “commercial computer +software” and “commercial computer software +documentation” provided with RESTRICTED RIGHTS. Use, +duplication or disclosure by the U.S. Government or a U.S. +Government subcontractor is subject to the restrictions in +this Agreement pursuant to DFARS 227.7202-3(a) or as set forth +in subparagraphs (c)(1) and (2) of the Commercial Computer +Software - Restricted Rights clause at FAR 52.227-19, as +applicable. Contractor/manufacturer is NVIDIA, 2788 San Tomas +Expressway, Santa Clara, CA 95051. + +The SDK is subject to United States export laws and +regulations. You agree that you will not ship, transfer or +export the SDK into any country, or use the SDK in any manner, +prohibited by the United States Bureau of Industry and +Security or economic sanctions regulations administered by the +U.S. Department of Treasury’s Office of Foreign Assets +Control (OFAC), or any applicable export laws, restrictions or +regulations. These laws include restrictions on destinations, +end users and end use. By accepting this Agreement, you +confirm that you are not a resident or citizen of any country +currently embargoed by the U.S. and that you are not otherwise +prohibited from receiving the SDK. + +Any notice delivered by NVIDIA to you under this Agreement +will be delivered via mail, email or fax. You agree that any +notices that NVIDIA sends you electronically will satisfy any +legal communication requirements. Please direct your legal +notices or other correspondence to NVIDIA Corporation, 2788 +San Tomas Expressway, Santa Clara, California 95051, United +States of America, Attention: Legal Department. + +This Agreement and any exhibits incorporated into this +Agreement constitute the entire agreement of the parties with +respect to the subject matter of this Agreement and supersede +all prior negotiations or documentation exchanged between the +parties relating to this SDK license. Any additional and/or +conflicting terms on documents issued by you are null, void, +and invalid. Any amendment or waiver under this Agreement +shall be in writing and signed by representatives of both +parties. + + +2. CUDA Toolkit Supplement to Software License Agreement for +NVIDIA Software Development Kits +------------------------------------------------------------ + + +Release date: August 16, 2018 +----------------------------- + +The terms in this supplement govern your use of the NVIDIA +CUDA Toolkit SDK under the terms of your license agreement +(“Agreement”) as modified by this supplement. Capitalized +terms used but not defined below have the meaning assigned to +them in the Agreement. + +This supplement is an exhibit to the Agreement and is +incorporated as an integral part of the Agreement. In the +event of conflict between the terms in this supplement and the +terms in the Agreement, the terms in this supplement govern. + + +2.1. License Scope + +The SDK is licensed for you to develop applications only for +use in systems with NVIDIA GPUs. + + +2.2. Distribution + +The portions of the SDK that are distributable under the +Agreement are listed in Attachment A. + + +2.3. Operating Systems + +Those portions of the SDK designed exclusively for use on the +Linux or FreeBSD operating systems, or other operating systems +derived from the source code to these operating systems, may +be copied and redistributed for use in accordance with this +Agreement, provided that the object code files are not +modified in any way (except for unzipping of compressed +files). + + +2.4. Audio and Video Encoders and Decoders + +You acknowledge and agree that it is your sole responsibility +to obtain any additional third-party licenses required to +make, have made, use, have used, sell, import, and offer for +sale your products or services that include or incorporate any +third-party software and content relating to audio and/or +video encoders and decoders from, including but not limited +to, Microsoft, Thomson, Fraunhofer IIS, Sisvel S.p.A., +MPEG-LA, and Coding Technologies. NVIDIA does not grant to you +under this Agreement any necessary patent or other rights with +respect to any audio and/or video encoders and decoders. + + +2.5. Licensing + +If the distribution terms in this Agreement are not suitable +for your organization, or for any questions regarding this +Agreement, please contact NVIDIA at +nvidia-compute-license-questions@nvidia.com. + + +2.6. Attachment A + +The following portions of the SDK are distributable under the +Agreement: + +Component + +CUDA Runtime + +Windows + +cudart.dll, cudart_static.lib, cudadevrt.lib + +Mac OSX + +libcudart.dylib, libcudart_static.a, libcudadevrt.a + +Linux + +libcudart.so, libcudart_static.a, libcudadevrt.a + +Android + +libcudart.so, libcudart_static.a, libcudadevrt.a + +Component + +CUDA FFT Library + +Windows + +cufft.dll, cufftw.dll, cufft.lib, cufftw.lib + +Mac OSX + +libcufft.dylib, libcufft_static.a, libcufftw.dylib, +libcufftw_static.a + +Linux + +libcufft.so, libcufft_static.a, libcufftw.so, +libcufftw_static.a + +Android + +libcufft.so, libcufft_static.a, libcufftw.so, +libcufftw_static.a + +Component + +CUDA BLAS Library + +Windows + +cublas.dll, cublasLt.dll + +Mac OSX + +libcublas.dylib, libcublasLt.dylib, libcublas_static.a, +libcublasLt_static.a + +Linux + +libcublas.so, libcublasLt.so, libcublas_static.a, +libcublasLt_static.a + +Android + +libcublas.so, libcublasLt.so, libcublas_static.a, +libcublasLt_static.a + +Component + +NVIDIA "Drop-in" BLAS Library + +Windows + +nvblas.dll + +Mac OSX + +libnvblas.dylib + +Linux + +libnvblas.so + +Component + +CUDA Sparse Matrix Library + +Windows + +cusparse.dll, cusparse.lib + +Mac OSX + +libcusparse.dylib, libcusparse_static.a + +Linux + +libcusparse.so, libcusparse_static.a + +Android + +libcusparse.so, libcusparse_static.a + +Component + +CUDA Linear Solver Library + +Windows + +cusolver.dll, cusolver.lib + +Mac OSX + +libcusolver.dylib, libcusolver_static.a + +Linux + +libcusolver.so, libcusolver_static.a + +Android + +libcusolver.so, libcusolver_static.a + +Component + +CUDA Random Number Generation Library + +Windows + +curand.dll, curand.lib + +Mac OSX + +libcurand.dylib, libcurand_static.a + +Linux + +libcurand.so, libcurand_static.a + +Android + +libcurand.so, libcurand_static.a + +Component + +CUDA Accelerated Graph Library + +Component + +NVIDIA Performance Primitives Library + +Windows + +nppc.dll, nppc.lib, nppial.dll, nppial.lib, nppicc.dll, +nppicc.lib, nppicom.dll, nppicom.lib, nppidei.dll, +nppidei.lib, nppif.dll, nppif.lib, nppig.dll, nppig.lib, +nppim.dll, nppim.lib, nppist.dll, nppist.lib, nppisu.dll, +nppisu.lib, nppitc.dll, nppitc.lib, npps.dll, npps.lib + +Mac OSX + +libnppc.dylib, libnppc_static.a, libnppial.dylib, +libnppial_static.a, libnppicc.dylib, libnppicc_static.a, +libnppicom.dylib, libnppicom_static.a, libnppidei.dylib, +libnppidei_static.a, libnppif.dylib, libnppif_static.a, +libnppig.dylib, libnppig_static.a, libnppim.dylib, +libnppisu_static.a, libnppitc.dylib, libnppitc_static.a, +libnpps.dylib, libnpps_static.a + +Linux + +libnppc.so, libnppc_static.a, libnppial.so, +libnppial_static.a, libnppicc.so, libnppicc_static.a, +libnppicom.so, libnppicom_static.a, libnppidei.so, +libnppidei_static.a, libnppif.so, libnppif_static.a +libnppig.so, libnppig_static.a, libnppim.so, +libnppim_static.a, libnppist.so, libnppist_static.a, +libnppisu.so, libnppisu_static.a, libnppitc.so +libnppitc_static.a, libnpps.so, libnpps_static.a + +Android + +libnppc.so, libnppc_static.a, libnppial.so, +libnppial_static.a, libnppicc.so, libnppicc_static.a, +libnppicom.so, libnppicom_static.a, libnppidei.so, +libnppidei_static.a, libnppif.so, libnppif_static.a +libnppig.so, libnppig_static.a, libnppim.so, +libnppim_static.a, libnppist.so, libnppist_static.a, +libnppisu.so, libnppisu_static.a, libnppitc.so +libnppitc_static.a, libnpps.so, libnpps_static.a + +Component + +NVIDIA JPEG Library + +Linux + +libnvjpeg.so, libnvjpeg_static.a + +Component + +Internal common library required for statically linking to +cuBLAS, cuSPARSE, cuFFT, cuRAND, nvJPEG and NPP + +Mac OSX + +libculibos.a + +Linux + +libculibos.a + +Component + +NVIDIA Runtime Compilation Library and Header + +All + +nvrtc.h + +Windows + +nvrtc.dll, nvrtc-builtins.dll + +Mac OSX + +libnvrtc.dylib, libnvrtc-builtins.dylib + +Linux + +libnvrtc.so, libnvrtc-builtins.so + +Component + +NVIDIA Optimizing Compiler Library + +Windows + +nvvm.dll + +Mac OSX + +libnvvm.dylib + +Linux + +libnvvm.so + +Component + +NVIDIA Common Device Math Functions Library + +Windows + +libdevice.10.bc + +Mac OSX + +libdevice.10.bc + +Linux + +libdevice.10.bc + +Component + +CUDA Occupancy Calculation Header Library + +All + +cuda_occupancy.h + +Component + +CUDA Half Precision Headers + +All + +cuda_fp16.h, cuda_fp16.hpp + +Component + +CUDA Profiling Tools Interface (CUPTI) Library + +Windows + +cupti.dll + +Mac OSX + +libcupti.dylib + +Linux + +libcupti.so + +Component + +NVIDIA Tools Extension Library + +Windows + +nvToolsExt.dll, nvToolsExt.lib + +Mac OSX + +libnvToolsExt.dylib + +Linux + +libnvToolsExt.so + +Component + +NVIDIA CUDA Driver Libraries + +Linux + +libcuda.so, libnvidia-fatbinaryloader.so, +libnvidia-ptxjitcompiler.so + +The NVIDIA CUDA Driver Libraries are only distributable in +applications that meet this criteria: + + 1. The application was developed starting from a NVIDIA CUDA + container obtained from Docker Hub or the NVIDIA GPU + Cloud, and + + 2. The resulting application is packaged as a Docker + container and distributed to users on Docker Hub or the + NVIDIA GPU Cloud only. + + +2.7. Attachment B + + +Additional Licensing Obligations + +The following third party components included in the SOFTWARE +are licensed to Licensee pursuant to the following terms and +conditions: + + 1. Licensee's use of the GDB third party component is + subject to the terms and conditions of GNU GPL v3: + + This product includes copyrighted third-party software licensed + under the terms of the GNU General Public License v3 ("GPL v3"). + All third-party software packages are copyright by their respective + authors. GPL v3 terms and conditions are hereby incorporated into + the Agreement by this reference: http://www.gnu.org/licenses/gpl.txt + + Consistent with these licensing requirements, the software + listed below is provided under the terms of the specified + open source software licenses. To obtain source code for + software provided under licenses that require + redistribution of source code, including the GNU General + Public License (GPL) and GNU Lesser General Public License + (LGPL), contact oss-requests@nvidia.com. This offer is + valid for a period of three (3) years from the date of the + distribution of this product by NVIDIA CORPORATION. + + Component License + CUDA-GDB GPL v3 + + 2. Licensee represents and warrants that any and all third + party licensing and/or royalty payment obligations in + connection with Licensee's use of the H.264 video codecs + are solely the responsibility of Licensee. + + 3. Licensee's use of the Thrust library is subject to the + terms and conditions of the Apache License Version 2.0. + All third-party software packages are copyright by their + respective authors. Apache License Version 2.0 terms and + conditions are hereby incorporated into the Agreement by + this reference. + http://www.apache.org/licenses/LICENSE-2.0.html + + In addition, Licensee acknowledges the following notice: + Thrust includes source code from the Boost Iterator, + Tuple, System, and Random Number libraries. + + Boost Software License - Version 1.0 - August 17th, 2003 + . . . . + + Permission is hereby granted, free of charge, to any person or + organization obtaining a copy of the software and accompanying + documentation covered by this license (the "Software") to use, + reproduce, display, distribute, execute, and transmit the Software, + and to prepare derivative works of the Software, and to permit + third-parties to whom the Software is furnished to do so, all + subject to the following: + + The copyright notices in the Software and this entire statement, + including the above license grant, this restriction and the following + disclaimer, must be included in all copies of the Software, in whole + or in part, and all derivative works of the Software, unless such + copies or derivative works are solely in the form of machine-executable + object code generated by a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND + NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR + ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR + OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + 4. Licensee's use of the LLVM third party component is + subject to the following terms and conditions: + + ====================================================== + LLVM Release License + ====================================================== + University of Illinois/NCSA + Open Source License + + Copyright (c) 2003-2010 University of Illinois at Urbana-Champaign. + All rights reserved. + + Developed by: + + LLVM Team + + University of Illinois at Urbana-Champaign + + http://llvm.org + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to + deal with the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimers. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimers in the + documentation and/or other materials provided with the distribution. + + * Neither the names of the LLVM Team, University of Illinois at Urbana- + Champaign, nor the names of its contributors may be used to endorse or + promote products derived from this Software without specific prior + written permission. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS WITH THE SOFTWARE. + + 5. Licensee's use (e.g. nvprof) of the PCRE third party + component is subject to the following terms and + conditions: + + ------------ + PCRE LICENCE + ------------ + PCRE is a library of functions to support regular expressions whose syntax + and semantics are as close as possible to those of the Perl 5 language. + Release 8 of PCRE is distributed under the terms of the "BSD" licence, as + specified below. The documentation for PCRE, supplied in the "doc" + directory, is distributed under the same terms as the software itself. The + basic library functions are written in C and are freestanding. Also + included in the distribution is a set of C++ wrapper functions, and a just- + in-time compiler that can be used to optimize pattern matching. These are + both optional features that can be omitted when the library is built. + + THE BASIC LIBRARY FUNCTIONS + --------------------------- + Written by: Philip Hazel + Email local part: ph10 + Email domain: cam.ac.uk + University of Cambridge Computing Service, + Cambridge, England. + Copyright (c) 1997-2012 University of Cambridge + All rights reserved. + + PCRE JUST-IN-TIME COMPILATION SUPPORT + ------------------------------------- + Written by: Zoltan Herczeg + Email local part: hzmester + Emain domain: freemail.hu + Copyright(c) 2010-2012 Zoltan Herczeg + All rights reserved. + + STACK-LESS JUST-IN-TIME COMPILER + -------------------------------- + Written by: Zoltan Herczeg + Email local part: hzmester + Emain domain: freemail.hu + Copyright(c) 2009-2012 Zoltan Herczeg + All rights reserved. + + THE C++ WRAPPER FUNCTIONS + ------------------------- + Contributed by: Google Inc. + Copyright (c) 2007-2012, Google Inc. + All rights reserved. + + THE "BSD" LICENCE + ----------------- + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the University of Cambridge nor the name of Google + Inc. nor the names of their contributors may be used to endorse or + promote products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + 6. Some of the cuBLAS library routines were written by or + derived from code written by Vasily Volkov and are subject + to the Modified Berkeley Software Distribution License as + follows: + + Copyright (c) 2007-2009, Regents of the University of California + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of the University of California, Berkeley nor + the names of its contributors may be used to endorse or promote + products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING + IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + 7. Some of the cuBLAS library routines were written by or + derived from code written by Davide Barbieri and are + subject to the Modified Berkeley Software Distribution + License as follows: + + Copyright (c) 2008-2009 Davide Barbieri @ University of Rome Tor Vergata. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * The name of the author may not be used to endorse or promote + products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING + IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + 8. Some of the cuBLAS library routines were derived from + code developed by the University of Tennessee and are + subject to the Modified Berkeley Software Distribution + License as follows: + + Copyright (c) 2010 The University of Tennessee. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer listed in this license in the documentation and/or + other materials provided with the distribution. + * Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 9. Some of the cuBLAS library routines were written by or + derived from code written by Jonathan Hogg and are subject + to the Modified Berkeley Software Distribution License as + follows: + + Copyright (c) 2012, The Science and Technology Facilities Council (STFC). + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of the STFC nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE STFC BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN + IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 10. Some of the cuBLAS library routines were written by or + derived from code written by Ahmad M. Abdelfattah, David + Keyes, and Hatem Ltaief, and are subject to the Apache + License, Version 2.0, as follows: + + -- (C) Copyright 2013 King Abdullah University of Science and Technology + Authors: + Ahmad Abdelfattah (ahmad.ahmad@kaust.edu.sa) + David Keyes (david.keyes@kaust.edu.sa) + Hatem Ltaief (hatem.ltaief@kaust.edu.sa) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the King Abdullah University of Science and + Technology nor the names of its contributors may be used to endorse + or promote products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE + + 11. Some of the cuSPARSE library routines were written by or + derived from code written by Li-Wen Chang and are subject + to the NCSA Open Source License as follows: + + Copyright (c) 2012, University of Illinois. + + All rights reserved. + + Developed by: IMPACT Group, University of Illinois, http://impact.crhc.illinois.edu + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal with the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimers in the documentation and/or other materials provided + with the distribution. + * Neither the names of IMPACT Group, University of Illinois, nor + the names of its contributors may be used to endorse or promote + products derived from this Software without specific prior + written permission. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR + IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE + SOFTWARE. + + 12. Some of the cuRAND library routines were written by or + derived from code written by Mutsuo Saito and Makoto + Matsumoto and are subject to the following license: + + Copyright (c) 2009, 2010 Mutsuo Saito, Makoto Matsumoto and Hiroshima + University. All rights reserved. + + Copyright (c) 2011 Mutsuo Saito, Makoto Matsumoto, Hiroshima + University and University of Tokyo. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of the Hiroshima University nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 13. Some of the cuRAND library routines were derived from + code developed by D. E. Shaw Research and are subject to + the following license: + + Copyright 2010-2011, D. E. Shaw Research. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions, and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions, and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of D. E. Shaw Research nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 14. Some of the Math library routines were written by or + derived from code developed by Norbert Juffa and are + subject to the following license: + + Copyright (c) 2015-2017, Norbert Juffa + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 15. Licensee's use of the lz4 third party component is + subject to the following terms and conditions: + + Copyright (C) 2011-2013, Yann Collet. + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 16. The NPP library uses code from the Boost Math Toolkit, + and is subject to the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + . . . . + + Permission is hereby granted, free of charge, to any person or + organization obtaining a copy of the software and accompanying + documentation covered by this license (the "Software") to use, + reproduce, display, distribute, execute, and transmit the Software, + and to prepare derivative works of the Software, and to permit + third-parties to whom the Software is furnished to do so, all + subject to the following: + + The copyright notices in the Software and this entire statement, + including the above license grant, this restriction and the following + disclaimer, must be included in all copies of the Software, in whole + or in part, and all derivative works of the Software, unless such + copies or derivative works are solely in the form of machine-executable + object code generated by a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND + NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR + ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR + OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + 17. Portions of the Nsight Eclipse Edition is subject to the + following license: + + The Eclipse Foundation makes available all content in this plug-in + ("Content"). Unless otherwise indicated below, the Content is provided + to you under the terms and conditions of the Eclipse Public License + Version 1.0 ("EPL"). A copy of the EPL is available at http:// + www.eclipse.org/legal/epl-v10.html. For purposes of the EPL, "Program" + will mean the Content. + + If you did not receive this Content directly from the Eclipse + Foundation, the Content is being redistributed by another party + ("Redistributor") and different terms and conditions may apply to your + use of any object code in the Content. Check the Redistributor's + license that was provided with the Content. If no such license exists, + contact the Redistributor. Unless otherwise indicated below, the terms + and conditions of the EPL still apply to any source code in the + Content and such source code may be obtained at http://www.eclipse.org. + + 18. Some of the cuBLAS library routines uses code from + OpenAI, which is subject to the following license: + + License URL + https://github.com/openai/openai-gemm/blob/master/LICENSE + + License Text + The MIT License + + Copyright (c) 2016 OpenAI (http://openai.com), 2016 Google Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. + + 19. Licensee's use of the Visual Studio Setup Configuration + Samples is subject to the following license: + + The MIT License (MIT) + Copyright (C) Microsoft Corporation. All rights reserved. + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without restriction, + including without limitation the rights to use, copy, modify, merge, + publish, distribute, sublicense, and/or sell copies of the Software, + and to permit persons to whom the Software is furnished to do so, + subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + 20. Licensee's use of linmath.h header for CPU functions for + GL vector/matrix operations from lunarG is subject to the + Apache License Version 2.0. + + 21. The DX12-CUDA sample uses the d3dx12.h header, which is + subject to the MIT license . + +----------------- diff --git a/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/METADATA b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..ee3b960c6d9841aa220d526eb7d9ebe6140dd370 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/METADATA @@ -0,0 +1,35 @@ +Metadata-Version: 2.1 +Name: nvidia-cuda-nvrtc-cu12 +Version: 12.4.127 +Summary: NVRTC native runtime libraries +Home-page: https://developer.nvidia.com/cuda-zone +Author: Nvidia CUDA Installer Team +Author-email: cuda_installer@nvidia.com +License: NVIDIA Proprietary Software +Keywords: cuda,nvidia,runtime,machine learning,deep learning +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: Other/Proprietary License +Classifier: Natural Language :: English +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.5 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Topic :: Scientific/Engineering +Classifier: Topic :: Scientific/Engineering :: Mathematics +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Classifier: Topic :: Software Development +Classifier: Topic :: Software Development :: Libraries +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: POSIX :: Linux +Requires-Python: >=3 +License-File: License.txt + +NVRTC native runtime libraries diff --git a/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/RECORD b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..2e7e4df75417bf35974f58eb5bddd834dd282645 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/RECORD @@ -0,0 +1,17 @@ +nvidia/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/__pycache__/__init__.cpython-311.pyc,, +nvidia/cuda_nvrtc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc,, +nvidia/cuda_nvrtc/include/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc,, +nvidia/cuda_nvrtc/include/nvrtc.h,sha256=K3X9i14crxxUBVAdXyNFIW0BVYUdxTqM6TfOffDXL7U,36119 +nvidia/cuda_nvrtc/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc,, +nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.4,sha256=RCVWGjGdoHu94NwYCpBCLjUbNkw9xPaj6h5pkl1i23I,5343112 +nvidia/cuda_nvrtc/lib/libnvrtc.so.12,sha256=Rm1vFNbP3pmDihMs-noj-Qugis-tqk2eTJ3v4Sd_e_E,60418376 +nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/License.txt,sha256=rW9YU_ugyg0VnQ9Y1JrkmDDC-Mk_epJki5zpCttMbM0,59262 +nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/METADATA,sha256=ZFn72NXguvRFDpzLRoJbcR4F9zUzHvokH8XSt3lduNs,1507 +nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/RECORD,, +nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/WHEEL,sha256=XDTs3wIbcE-BcRO08VJlZpA6z9OaC1mOKPCGGGwuM2g,109 +nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/top_level.txt,sha256=fTkAtiFuL16nUrB9ytDDtpytz2t0B4NvYTnRzwAhO14,7 diff --git a/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..e6c30e957cfb045017a9fef3430bb8ee87c4a074 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.42.0) +Root-Is-Purelib: true +Tag: py3-none-manylinux2014_x86_64 + diff --git a/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/top_level.txt b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..862f7abf232cdfbb928609856247292e81c9decb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia_cuda_nvrtc_cu12-12.4.127.dist-info/top_level.txt @@ -0,0 +1 @@ +nvidia diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e13fce0abaac2f236789e3fcaaa709089169afa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/appdirs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/appdirs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35ead2e2e19425b589c2715158c5b8007149d4c7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/appdirs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/zipp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/zipp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b66275f7ae8a20b59946d3a9c4459ab8ba2b40b2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/__pycache__/zipp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__init__.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34e3a9950cc557879af8d797f9382b18a870fb56 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__init__.py @@ -0,0 +1,36 @@ +"""Read resources contained within a package.""" + +from ._common import ( + as_file, + files, + Package, +) + +from ._legacy import ( + contents, + open_binary, + read_binary, + open_text, + read_text, + is_resource, + path, + Resource, +) + +from .abc import ResourceReader + + +__all__ = [ + 'Package', + 'Resource', + 'ResourceReader', + 'as_file', + 'contents', + 'files', + 'is_resource', + 'open_binary', + 'open_text', + 'path', + 'read_binary', + 'read_text', +] diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41584eaefed8219360c6710084d94244ce3b78df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_adapters.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_adapters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb62ea2a3a7dd7a54e6d6984e5d763d8a0aaf4e4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_adapters.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d989891ae33cb5c065447ac56b4884bd59c46e34 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_compat.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_compat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..750348a49d331c37b920c813e7828fbababb7932 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_compat.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_itertools.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_itertools.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3e8e84762144ee7db977e16f3ee5a1440930b8e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_itertools.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_legacy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_legacy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8b6355e73eb1de64df92e0e83919b8865a6baec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/_legacy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/abc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/abc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e0559cb2a23cd1ec244937703cfe6c60f80124d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/abc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/readers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/readers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64cf69598c756ac82bfb7220d052e9072be4cb7a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/readers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/simple.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/simple.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86f741d5311472915d1b21ea49fd2e8071632efc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/__pycache__/simple.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_adapters.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..ea363d86a564b5450666aa00aecd46353326a75a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_adapters.py @@ -0,0 +1,170 @@ +from contextlib import suppress +from io import TextIOWrapper + +from . import abc + + +class SpecLoaderAdapter: + """ + Adapt a package spec to adapt the underlying loader. + """ + + def __init__(self, spec, adapter=lambda spec: spec.loader): + self.spec = spec + self.loader = adapter(spec) + + def __getattr__(self, name): + return getattr(self.spec, name) + + +class TraversableResourcesLoader: + """ + Adapt a loader to provide TraversableResources. + """ + + def __init__(self, spec): + self.spec = spec + + def get_resource_reader(self, name): + return CompatibilityFiles(self.spec)._native() + + +def _io_wrapper(file, mode='r', *args, **kwargs): + if mode == 'r': + return TextIOWrapper(file, *args, **kwargs) + elif mode == 'rb': + return file + raise ValueError( + "Invalid mode value '{}', only 'r' and 'rb' are supported".format(mode) + ) + + +class CompatibilityFiles: + """ + Adapter for an existing or non-existent resource reader + to provide a compatibility .files(). + """ + + class SpecPath(abc.Traversable): + """ + Path tied to a module spec. + Can be read and exposes the resource reader children. + """ + + def __init__(self, spec, reader): + self._spec = spec + self._reader = reader + + def iterdir(self): + if not self._reader: + return iter(()) + return iter( + CompatibilityFiles.ChildPath(self._reader, path) + for path in self._reader.contents() + ) + + def is_file(self): + return False + + is_dir = is_file + + def joinpath(self, other): + if not self._reader: + return CompatibilityFiles.OrphanPath(other) + return CompatibilityFiles.ChildPath(self._reader, other) + + @property + def name(self): + return self._spec.name + + def open(self, mode='r', *args, **kwargs): + return _io_wrapper(self._reader.open_resource(None), mode, *args, **kwargs) + + class ChildPath(abc.Traversable): + """ + Path tied to a resource reader child. + Can be read but doesn't expose any meaningful children. + """ + + def __init__(self, reader, name): + self._reader = reader + self._name = name + + def iterdir(self): + return iter(()) + + def is_file(self): + return self._reader.is_resource(self.name) + + def is_dir(self): + return not self.is_file() + + def joinpath(self, other): + return CompatibilityFiles.OrphanPath(self.name, other) + + @property + def name(self): + return self._name + + def open(self, mode='r', *args, **kwargs): + return _io_wrapper( + self._reader.open_resource(self.name), mode, *args, **kwargs + ) + + class OrphanPath(abc.Traversable): + """ + Orphan path, not tied to a module spec or resource reader. + Can't be read and doesn't expose any meaningful children. + """ + + def __init__(self, *path_parts): + if len(path_parts) < 1: + raise ValueError('Need at least one path part to construct a path') + self._path = path_parts + + def iterdir(self): + return iter(()) + + def is_file(self): + return False + + is_dir = is_file + + def joinpath(self, other): + return CompatibilityFiles.OrphanPath(*self._path, other) + + @property + def name(self): + return self._path[-1] + + def open(self, mode='r', *args, **kwargs): + raise FileNotFoundError("Can't open orphan path") + + def __init__(self, spec): + self.spec = spec + + @property + def _reader(self): + with suppress(AttributeError): + return self.spec.loader.get_resource_reader(self.spec.name) + + def _native(self): + """ + Return the native reader if it supports files(). + """ + reader = self._reader + return reader if hasattr(reader, 'files') else self + + def __getattr__(self, attr): + return getattr(self._reader, attr) + + def files(self): + return CompatibilityFiles.SpecPath(self.spec, self._reader) + + +def wrap_spec(package): + """ + Construct a package spec with traversable compatibility + on the spec/loader/reader. + """ + return SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_common.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a12e2c75d132c73b556702159d535d15ed9abfd2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_common.py @@ -0,0 +1,104 @@ +import os +import pathlib +import tempfile +import functools +import contextlib +import types +import importlib + +from typing import Union, Optional +from .abc import ResourceReader, Traversable + +from ._compat import wrap_spec + +Package = Union[types.ModuleType, str] + + +def files(package): + # type: (Package) -> Traversable + """ + Get a Traversable resource from a package + """ + return from_package(get_package(package)) + + +def get_resource_reader(package): + # type: (types.ModuleType) -> Optional[ResourceReader] + """ + Return the package's loader if it's a ResourceReader. + """ + # We can't use + # a issubclass() check here because apparently abc.'s __subclasscheck__() + # hook wants to create a weak reference to the object, but + # zipimport.zipimporter does not support weak references, resulting in a + # TypeError. That seems terrible. + spec = package.__spec__ + reader = getattr(spec.loader, 'get_resource_reader', None) # type: ignore + if reader is None: + return None + return reader(spec.name) # type: ignore + + +def resolve(cand): + # type: (Package) -> types.ModuleType + return cand if isinstance(cand, types.ModuleType) else importlib.import_module(cand) + + +def get_package(package): + # type: (Package) -> types.ModuleType + """Take a package name or module object and return the module. + + Raise an exception if the resolved module is not a package. + """ + resolved = resolve(package) + if wrap_spec(resolved).submodule_search_locations is None: + raise TypeError(f'{package!r} is not a package') + return resolved + + +def from_package(package): + """ + Return a Traversable object for the given package. + + """ + spec = wrap_spec(package) + reader = spec.loader.get_resource_reader(spec.name) + return reader.files() + + +@contextlib.contextmanager +def _tempfile(reader, suffix=''): + # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try' + # blocks due to the need to close the temporary file to work on Windows + # properly. + fd, raw_path = tempfile.mkstemp(suffix=suffix) + try: + try: + os.write(fd, reader()) + finally: + os.close(fd) + del reader + yield pathlib.Path(raw_path) + finally: + try: + os.remove(raw_path) + except FileNotFoundError: + pass + + +@functools.singledispatch +def as_file(path): + """ + Given a Traversable object, return that object as a + path on the local file system in a context manager. + """ + return _tempfile(path.read_bytes, suffix=path.name) + + +@as_file.register(pathlib.Path) +@contextlib.contextmanager +def _(path): + """ + Degenerate behavior for pathlib.Path objects. + """ + yield path diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_compat.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9fc820cb352aa6e92705aab4f55cbc2eff96bc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/_compat.py @@ -0,0 +1,98 @@ +# flake8: noqa + +import abc +import sys +import pathlib +from contextlib import suppress + +if sys.version_info >= (3, 10): + from zipfile import Path as ZipPath # type: ignore +else: + from ..zipp import Path as ZipPath # type: ignore + + +try: + from typing import runtime_checkable # type: ignore +except ImportError: + + def runtime_checkable(cls): # type: ignore + return cls + + +try: + from typing import Protocol # type: ignore +except ImportError: + Protocol = abc.ABC # type: ignore + + +class TraversableResourcesLoader: + """ + Adapt loaders to provide TraversableResources and other + compatibility. + + Used primarily for Python 3.9 and earlier where the native + loaders do not yet implement TraversableResources. + """ + + def __init__(self, spec): + self.spec = spec + + @property + def path(self): + return self.spec.origin + + def get_resource_reader(self, name): + from . import readers, _adapters + + def _zip_reader(spec): + with suppress(AttributeError): + return readers.ZipReader(spec.loader, spec.name) + + def _namespace_reader(spec): + with suppress(AttributeError, ValueError): + return readers.NamespaceReader(spec.submodule_search_locations) + + def _available_reader(spec): + with suppress(AttributeError): + return spec.loader.get_resource_reader(spec.name) + + def _native_reader(spec): + reader = _available_reader(spec) + return reader if hasattr(reader, 'files') else None + + def _file_reader(spec): + try: + path = pathlib.Path(self.path) + except TypeError: + return None + if path.exists(): + return readers.FileReader(self) + + return ( + # native reader if it supplies 'files' + _native_reader(self.spec) + or + # local ZipReader if a zip module + _zip_reader(self.spec) + or + # local NamespaceReader if a namespace module + _namespace_reader(self.spec) + or + # local FileReader + _file_reader(self.spec) + # fallback - adapt the spec ResourceReader to TraversableReader + or _adapters.CompatibilityFiles(self.spec) + ) + + +def wrap_spec(package): + """ + Construct a package spec with traversable compatibility + on the spec/loader/reader. + + Supersedes _adapters.wrap_spec to use TraversableResourcesLoader + from above for older Python compatibility (<3.10). + """ + from . import _adapters + + return _adapters.SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/abc.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/abc.py new file mode 100644 index 0000000000000000000000000000000000000000..d39dc1adba0f00d2f7bdf6fa2cd1abcd82475e2e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/abc.py @@ -0,0 +1,137 @@ +import abc +from typing import BinaryIO, Iterable, Text + +from ._compat import runtime_checkable, Protocol + + +class ResourceReader(metaclass=abc.ABCMeta): + """Abstract base class for loaders to provide resource reading support.""" + + @abc.abstractmethod + def open_resource(self, resource: Text) -> BinaryIO: + """Return an opened, file-like object for binary reading. + + The 'resource' argument is expected to represent only a file name. + If the resource cannot be found, FileNotFoundError is raised. + """ + # This deliberately raises FileNotFoundError instead of + # NotImplementedError so that if this method is accidentally called, + # it'll still do the right thing. + raise FileNotFoundError + + @abc.abstractmethod + def resource_path(self, resource: Text) -> Text: + """Return the file system path to the specified resource. + + The 'resource' argument is expected to represent only a file name. + If the resource does not exist on the file system, raise + FileNotFoundError. + """ + # This deliberately raises FileNotFoundError instead of + # NotImplementedError so that if this method is accidentally called, + # it'll still do the right thing. + raise FileNotFoundError + + @abc.abstractmethod + def is_resource(self, path: Text) -> bool: + """Return True if the named 'path' is a resource. + + Files are resources, directories are not. + """ + raise FileNotFoundError + + @abc.abstractmethod + def contents(self) -> Iterable[str]: + """Return an iterable of entries in `package`.""" + raise FileNotFoundError + + +@runtime_checkable +class Traversable(Protocol): + """ + An object with a subset of pathlib.Path methods suitable for + traversing directories and opening files. + """ + + @abc.abstractmethod + def iterdir(self): + """ + Yield Traversable objects in self + """ + + def read_bytes(self): + """ + Read contents of self as bytes + """ + with self.open('rb') as strm: + return strm.read() + + def read_text(self, encoding=None): + """ + Read contents of self as text + """ + with self.open(encoding=encoding) as strm: + return strm.read() + + @abc.abstractmethod + def is_dir(self) -> bool: + """ + Return True if self is a directory + """ + + @abc.abstractmethod + def is_file(self) -> bool: + """ + Return True if self is a file + """ + + @abc.abstractmethod + def joinpath(self, child): + """ + Return Traversable child in self + """ + + def __truediv__(self, child): + """ + Return Traversable child in self + """ + return self.joinpath(child) + + @abc.abstractmethod + def open(self, mode='r', *args, **kwargs): + """ + mode may be 'r' or 'rb' to open as text or binary. Return a handle + suitable for reading (same as pathlib.Path.open). + + When opening as text, accepts encoding parameters such as those + accepted by io.TextIOWrapper. + """ + + @abc.abstractproperty + def name(self) -> str: + """ + The base name of this object without any parent references. + """ + + +class TraversableResources(ResourceReader): + """ + The required interface for providing traversable + resources. + """ + + @abc.abstractmethod + def files(self): + """Return a Traversable object for the loaded package.""" + + def open_resource(self, resource): + return self.files().joinpath(resource).open('rb') + + def resource_path(self, resource): + raise FileNotFoundError(resource) + + def is_resource(self, path): + return self.files().joinpath(path).is_file() + + def contents(self): + return (item.name for item in self.files().iterdir()) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/readers.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/readers.py new file mode 100644 index 0000000000000000000000000000000000000000..f1190ca452a1ce22ee9a1b304991d475281df8ca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/readers.py @@ -0,0 +1,122 @@ +import collections +import pathlib +import operator + +from . import abc + +from ._itertools import unique_everseen +from ._compat import ZipPath + + +def remove_duplicates(items): + return iter(collections.OrderedDict.fromkeys(items)) + + +class FileReader(abc.TraversableResources): + def __init__(self, loader): + self.path = pathlib.Path(loader.path).parent + + def resource_path(self, resource): + """ + Return the file system path to prevent + `resources.path()` from creating a temporary + copy. + """ + return str(self.path.joinpath(resource)) + + def files(self): + return self.path + + +class ZipReader(abc.TraversableResources): + def __init__(self, loader, module): + _, _, name = module.rpartition('.') + self.prefix = loader.prefix.replace('\\', '/') + name + '/' + self.archive = loader.archive + + def open_resource(self, resource): + try: + return super().open_resource(resource) + except KeyError as exc: + raise FileNotFoundError(exc.args[0]) + + def is_resource(self, path): + # workaround for `zipfile.Path.is_file` returning true + # for non-existent paths. + target = self.files().joinpath(path) + return target.is_file() and target.exists() + + def files(self): + return ZipPath(self.archive, self.prefix) + + +class MultiplexedPath(abc.Traversable): + """ + Given a series of Traversable objects, implement a merged + version of the interface across all objects. Useful for + namespace packages which may be multihomed at a single + name. + """ + + def __init__(self, *paths): + self._paths = list(map(pathlib.Path, remove_duplicates(paths))) + if not self._paths: + message = 'MultiplexedPath must contain at least one path' + raise FileNotFoundError(message) + if not all(path.is_dir() for path in self._paths): + raise NotADirectoryError('MultiplexedPath only supports directories') + + def iterdir(self): + files = (file for path in self._paths for file in path.iterdir()) + return unique_everseen(files, key=operator.attrgetter('name')) + + def read_bytes(self): + raise FileNotFoundError(f'{self} is not a file') + + def read_text(self, *args, **kwargs): + raise FileNotFoundError(f'{self} is not a file') + + def is_dir(self): + return True + + def is_file(self): + return False + + def joinpath(self, child): + # first try to find child in current paths + for file in self.iterdir(): + if file.name == child: + return file + # if it does not exist, construct it with the first path + return self._paths[0] / child + + __truediv__ = joinpath + + def open(self, *args, **kwargs): + raise FileNotFoundError(f'{self} is not a file') + + @property + def name(self): + return self._paths[0].name + + def __repr__(self): + paths = ', '.join(f"'{path}'" for path in self._paths) + return f'MultiplexedPath({paths})' + + +class NamespaceReader(abc.TraversableResources): + def __init__(self, namespace_path): + if 'NamespacePath' not in str(namespace_path): + raise ValueError('Invalid path') + self.path = MultiplexedPath(*list(namespace_path)) + + def resource_path(self, resource): + """ + Return the file system path to prevent + `resources.path()` from creating a temporary + copy. + """ + return str(self.path.joinpath(resource)) + + def files(self): + return self.path diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/simple.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..da073cbdb11e6c24c19a2d388c53c8842228595f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/importlib_resources/simple.py @@ -0,0 +1,116 @@ +""" +Interface adapters for low-level readers. +""" + +import abc +import io +import itertools +from typing import BinaryIO, List + +from .abc import Traversable, TraversableResources + + +class SimpleReader(abc.ABC): + """ + The minimum, low-level interface required from a resource + provider. + """ + + @abc.abstractproperty + def package(self): + # type: () -> str + """ + The name of the package for which this reader loads resources. + """ + + @abc.abstractmethod + def children(self): + # type: () -> List['SimpleReader'] + """ + Obtain an iterable of SimpleReader for available + child containers (e.g. directories). + """ + + @abc.abstractmethod + def resources(self): + # type: () -> List[str] + """ + Obtain available named resources for this virtual package. + """ + + @abc.abstractmethod + def open_binary(self, resource): + # type: (str) -> BinaryIO + """ + Obtain a File-like for a named resource. + """ + + @property + def name(self): + return self.package.split('.')[-1] + + +class ResourceHandle(Traversable): + """ + Handle to a named resource in a ResourceReader. + """ + + def __init__(self, parent, name): + # type: (ResourceContainer, str) -> None + self.parent = parent + self.name = name # type: ignore + + def is_file(self): + return True + + def is_dir(self): + return False + + def open(self, mode='r', *args, **kwargs): + stream = self.parent.reader.open_binary(self.name) + if 'b' not in mode: + stream = io.TextIOWrapper(*args, **kwargs) + return stream + + def joinpath(self, name): + raise RuntimeError("Cannot traverse into a resource") + + +class ResourceContainer(Traversable): + """ + Traversable container for a package's resources via its reader. + """ + + def __init__(self, reader): + # type: (SimpleReader) -> None + self.reader = reader + + def is_dir(self): + return True + + def is_file(self): + return False + + def iterdir(self): + files = (ResourceHandle(self, name) for name in self.reader.resources) + dirs = map(ResourceContainer, self.reader.children()) + return itertools.chain(files, dirs) + + def open(self, *args, **kwargs): + raise IsADirectoryError() + + def joinpath(self, name): + return next( + traversable for traversable in self.iterdir() if traversable.name == name + ) + + +class TraversableReader(TraversableResources, SimpleReader): + """ + A TraversableResources based on SimpleReader. Resource providers + may derive from this class to provide the TraversableResources + interface by supplying the SimpleReader interface. + """ + + def files(self): + return ResourceContainer(self) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__init__.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71d6b4cdb818490ac4e5d3c58ebe1320f0ff961c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67da56a60c724041b71a096ff9583c93643c0146 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/functools.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/functools.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59ee2306e0472615f21c9fb37b8ca2815cb53e6b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/__pycache__/functools.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/context.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/context.py new file mode 100644 index 0000000000000000000000000000000000000000..87a4e3dca299c4201ac50f6ef589dc73f1c45576 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/context.py @@ -0,0 +1,213 @@ +import os +import subprocess +import contextlib +import functools +import tempfile +import shutil +import operator + + +@contextlib.contextmanager +def pushd(dir): + orig = os.getcwd() + os.chdir(dir) + try: + yield dir + finally: + os.chdir(orig) + + +@contextlib.contextmanager +def tarball_context(url, target_dir=None, runner=None, pushd=pushd): + """ + Get a tarball, extract it, change to that directory, yield, then + clean up. + `runner` is the function to invoke commands. + `pushd` is a context manager for changing the directory. + """ + if target_dir is None: + target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') + if runner is None: + runner = functools.partial(subprocess.check_call, shell=True) + # In the tar command, use --strip-components=1 to strip the first path and + # then + # use -C to cause the files to be extracted to {target_dir}. This ensures + # that we always know where the files were extracted. + runner('mkdir {target_dir}'.format(**vars())) + try: + getter = 'wget {url} -O -' + extract = 'tar x{compression} --strip-components=1 -C {target_dir}' + cmd = ' | '.join((getter, extract)) + runner(cmd.format(compression=infer_compression(url), **vars())) + with pushd(target_dir): + yield target_dir + finally: + runner('rm -Rf {target_dir}'.format(**vars())) + + +def infer_compression(url): + """ + Given a URL or filename, infer the compression code for tar. + """ + # cheat and just assume it's the last two characters + compression_indicator = url[-2:] + mapping = dict(gz='z', bz='j', xz='J') + # Assume 'z' (gzip) if no match + return mapping.get(compression_indicator, 'z') + + +@contextlib.contextmanager +def temp_dir(remover=shutil.rmtree): + """ + Create a temporary directory context. Pass a custom remover + to override the removal behavior. + """ + temp_dir = tempfile.mkdtemp() + try: + yield temp_dir + finally: + remover(temp_dir) + + +@contextlib.contextmanager +def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir): + """ + Check out the repo indicated by url. + + If dest_ctx is supplied, it should be a context manager + to yield the target directory for the check out. + """ + exe = 'git' if 'git' in url else 'hg' + with dest_ctx() as repo_dir: + cmd = [exe, 'clone', url, repo_dir] + if branch: + cmd.extend(['--branch', branch]) + devnull = open(os.path.devnull, 'w') + stdout = devnull if quiet else None + subprocess.check_call(cmd, stdout=stdout) + yield repo_dir + + +@contextlib.contextmanager +def null(): + yield + + +class ExceptionTrap: + """ + A context manager that will catch certain exceptions and provide an + indication they occurred. + + >>> with ExceptionTrap() as trap: + ... raise Exception() + >>> bool(trap) + True + + >>> with ExceptionTrap() as trap: + ... pass + >>> bool(trap) + False + + >>> with ExceptionTrap(ValueError) as trap: + ... raise ValueError("1 + 1 is not 3") + >>> bool(trap) + True + + >>> with ExceptionTrap(ValueError) as trap: + ... raise Exception() + Traceback (most recent call last): + ... + Exception + + >>> bool(trap) + False + """ + + exc_info = None, None, None + + def __init__(self, exceptions=(Exception,)): + self.exceptions = exceptions + + def __enter__(self): + return self + + @property + def type(self): + return self.exc_info[0] + + @property + def value(self): + return self.exc_info[1] + + @property + def tb(self): + return self.exc_info[2] + + def __exit__(self, *exc_info): + type = exc_info[0] + matches = type and issubclass(type, self.exceptions) + if matches: + self.exc_info = exc_info + return matches + + def __bool__(self): + return bool(self.type) + + def raises(self, func, *, _test=bool): + """ + Wrap func and replace the result with the truth + value of the trap (True if an exception occurred). + + First, give the decorator an alias to support Python 3.8 + Syntax. + + >>> raises = ExceptionTrap(ValueError).raises + + Now decorate a function that always fails. + + >>> @raises + ... def fail(): + ... raise ValueError('failed') + >>> fail() + True + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with ExceptionTrap(self.exceptions) as trap: + func(*args, **kwargs) + return _test(trap) + + return wrapper + + def passes(self, func): + """ + Wrap func and replace the result with the truth + value of the trap (True if no exception). + + First, give the decorator an alias to support Python 3.8 + Syntax. + + >>> passes = ExceptionTrap(ValueError).passes + + Now decorate a function that always fails. + + >>> @passes + ... def fail(): + ... raise ValueError('failed') + + >>> fail() + False + """ + return self.raises(func, _test=operator.not_) + + +class suppress(contextlib.suppress, contextlib.ContextDecorator): + """ + A version of contextlib.suppress with decorator support. + + >>> @suppress(KeyError) + ... def key_error(): + ... {}[''] + >>> key_error() + """ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/functools.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/functools.py new file mode 100644 index 0000000000000000000000000000000000000000..a3fea3a1ae12be660a94c277cd748bd43e67b5dc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/functools.py @@ -0,0 +1,525 @@ +import functools +import time +import inspect +import collections +import types +import itertools + +import pkg_resources.extern.more_itertools + +from typing import Callable, TypeVar + + +CallableT = TypeVar("CallableT", bound=Callable[..., object]) + + +def compose(*funcs): + """ + Compose any number of unary functions into a single unary function. + + >>> import textwrap + >>> expected = str.strip(textwrap.dedent(compose.__doc__)) + >>> strip_and_dedent = compose(str.strip, textwrap.dedent) + >>> strip_and_dedent(compose.__doc__) == expected + True + + Compose also allows the innermost function to take arbitrary arguments. + + >>> round_three = lambda x: round(x, ndigits=3) + >>> f = compose(round_three, int.__truediv__) + >>> [f(3*x, x+1) for x in range(1,10)] + [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] + """ + + def compose_two(f1, f2): + return lambda *args, **kwargs: f1(f2(*args, **kwargs)) + + return functools.reduce(compose_two, funcs) + + +def method_caller(method_name, *args, **kwargs): + """ + Return a function that will call a named method on the + target object with optional positional and keyword + arguments. + + >>> lower = method_caller('lower') + >>> lower('MyString') + 'mystring' + """ + + def call_method(target): + func = getattr(target, method_name) + return func(*args, **kwargs) + + return call_method + + +def once(func): + """ + Decorate func so it's only ever called the first time. + + This decorator can ensure that an expensive or non-idempotent function + will not be expensive on subsequent calls and is idempotent. + + >>> add_three = once(lambda a: a+3) + >>> add_three(3) + 6 + >>> add_three(9) + 6 + >>> add_three('12') + 6 + + To reset the stored value, simply clear the property ``saved_result``. + + >>> del add_three.saved_result + >>> add_three(9) + 12 + >>> add_three(8) + 12 + + Or invoke 'reset()' on it. + + >>> add_three.reset() + >>> add_three(-3) + 0 + >>> add_three(0) + 0 + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not hasattr(wrapper, 'saved_result'): + wrapper.saved_result = func(*args, **kwargs) + return wrapper.saved_result + + wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result') + return wrapper + + +def method_cache( + method: CallableT, + cache_wrapper: Callable[ + [CallableT], CallableT + ] = functools.lru_cache(), # type: ignore[assignment] +) -> CallableT: + """ + Wrap lru_cache to support storing the cache data in the object instances. + + Abstracts the common paradigm where the method explicitly saves an + underscore-prefixed protected property on first call and returns that + subsequently. + + >>> class MyClass: + ... calls = 0 + ... + ... @method_cache + ... def method(self, value): + ... self.calls += 1 + ... return value + + >>> a = MyClass() + >>> a.method(3) + 3 + >>> for x in range(75): + ... res = a.method(x) + >>> a.calls + 75 + + Note that the apparent behavior will be exactly like that of lru_cache + except that the cache is stored on each instance, so values in one + instance will not flush values from another, and when an instance is + deleted, so are the cached values for that instance. + + >>> b = MyClass() + >>> for x in range(35): + ... res = b.method(x) + >>> b.calls + 35 + >>> a.method(0) + 0 + >>> a.calls + 75 + + Note that if method had been decorated with ``functools.lru_cache()``, + a.calls would have been 76 (due to the cached value of 0 having been + flushed by the 'b' instance). + + Clear the cache with ``.cache_clear()`` + + >>> a.method.cache_clear() + + Same for a method that hasn't yet been called. + + >>> c = MyClass() + >>> c.method.cache_clear() + + Another cache wrapper may be supplied: + + >>> cache = functools.lru_cache(maxsize=2) + >>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache) + >>> a = MyClass() + >>> a.method2() + 3 + + Caution - do not subsequently wrap the method with another decorator, such + as ``@property``, which changes the semantics of the function. + + See also + http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ + for another implementation and additional justification. + """ + + def wrapper(self: object, *args: object, **kwargs: object) -> object: + # it's the first call, replace the method with a cached, bound method + bound_method: CallableT = types.MethodType( # type: ignore[assignment] + method, self + ) + cached_method = cache_wrapper(bound_method) + setattr(self, method.__name__, cached_method) + return cached_method(*args, **kwargs) + + # Support cache clear even before cache has been created. + wrapper.cache_clear = lambda: None # type: ignore[attr-defined] + + return ( # type: ignore[return-value] + _special_method_cache(method, cache_wrapper) or wrapper + ) + + +def _special_method_cache(method, cache_wrapper): + """ + Because Python treats special methods differently, it's not + possible to use instance attributes to implement the cached + methods. + + Instead, install the wrapper method under a different name + and return a simple proxy to that wrapper. + + https://github.com/jaraco/jaraco.functools/issues/5 + """ + name = method.__name__ + special_names = '__getattr__', '__getitem__' + if name not in special_names: + return + + wrapper_name = '__cached' + name + + def proxy(self, *args, **kwargs): + if wrapper_name not in vars(self): + bound = types.MethodType(method, self) + cache = cache_wrapper(bound) + setattr(self, wrapper_name, cache) + else: + cache = getattr(self, wrapper_name) + return cache(*args, **kwargs) + + return proxy + + +def apply(transform): + """ + Decorate a function with a transform function that is + invoked on results returned from the decorated function. + + >>> @apply(reversed) + ... def get_numbers(start): + ... "doc for get_numbers" + ... return range(start, start+3) + >>> list(get_numbers(4)) + [6, 5, 4] + >>> get_numbers.__doc__ + 'doc for get_numbers' + """ + + def wrap(func): + return functools.wraps(func)(compose(transform, func)) + + return wrap + + +def result_invoke(action): + r""" + Decorate a function with an action function that is + invoked on the results returned from the decorated + function (for its side-effect), then return the original + result. + + >>> @result_invoke(print) + ... def add_two(a, b): + ... return a + b + >>> x = add_two(2, 3) + 5 + >>> x + 5 + """ + + def wrap(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + action(result) + return result + + return wrapper + + return wrap + + +def call_aside(f, *args, **kwargs): + """ + Call a function for its side effect after initialization. + + >>> @call_aside + ... def func(): print("called") + called + >>> func() + called + + Use functools.partial to pass parameters to the initial call + + >>> @functools.partial(call_aside, name='bingo') + ... def func(name): print("called with", name) + called with bingo + """ + f(*args, **kwargs) + return f + + +class Throttler: + """ + Rate-limit a function (or other callable) + """ + + def __init__(self, func, max_rate=float('Inf')): + if isinstance(func, Throttler): + func = func.func + self.func = func + self.max_rate = max_rate + self.reset() + + def reset(self): + self.last_called = 0 + + def __call__(self, *args, **kwargs): + self._wait() + return self.func(*args, **kwargs) + + def _wait(self): + "ensure at least 1/max_rate seconds from last call" + elapsed = time.time() - self.last_called + must_wait = 1 / self.max_rate - elapsed + time.sleep(max(0, must_wait)) + self.last_called = time.time() + + def __get__(self, obj, type=None): + return first_invoke(self._wait, functools.partial(self.func, obj)) + + +def first_invoke(func1, func2): + """ + Return a function that when invoked will invoke func1 without + any parameters (for its side-effect) and then invoke func2 + with whatever parameters were passed, returning its result. + """ + + def wrapper(*args, **kwargs): + func1() + return func2(*args, **kwargs) + + return wrapper + + +def retry_call(func, cleanup=lambda: None, retries=0, trap=()): + """ + Given a callable func, trap the indicated exceptions + for up to 'retries' times, invoking cleanup on the + exception. On the final attempt, allow any exceptions + to propagate. + """ + attempts = itertools.count() if retries == float('inf') else range(retries) + for attempt in attempts: + try: + return func() + except trap: + cleanup() + + return func() + + +def retry(*r_args, **r_kwargs): + """ + Decorator wrapper for retry_call. Accepts arguments to retry_call + except func and then returns a decorator for the decorated function. + + Ex: + + >>> @retry(retries=3) + ... def my_func(a, b): + ... "this is my funk" + ... print(a, b) + >>> my_func.__doc__ + 'this is my funk' + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(*f_args, **f_kwargs): + bound = functools.partial(func, *f_args, **f_kwargs) + return retry_call(bound, *r_args, **r_kwargs) + + return wrapper + + return decorate + + +def print_yielded(func): + """ + Convert a generator into a function that prints all yielded elements + + >>> @print_yielded + ... def x(): + ... yield 3; yield None + >>> x() + 3 + None + """ + print_all = functools.partial(map, print) + print_results = compose(more_itertools.consume, print_all, func) + return functools.wraps(func)(print_results) + + +def pass_none(func): + """ + Wrap func so it's not called if its first param is None + + >>> print_text = pass_none(print) + >>> print_text('text') + text + >>> print_text(None) + """ + + @functools.wraps(func) + def wrapper(param, *args, **kwargs): + if param is not None: + return func(param, *args, **kwargs) + + return wrapper + + +def assign_params(func, namespace): + """ + Assign parameters from namespace where func solicits. + + >>> def func(x, y=3): + ... print(x, y) + >>> assigned = assign_params(func, dict(x=2, z=4)) + >>> assigned() + 2 3 + + The usual errors are raised if a function doesn't receive + its required parameters: + + >>> assigned = assign_params(func, dict(y=3, z=4)) + >>> assigned() + Traceback (most recent call last): + TypeError: func() ...argument... + + It even works on methods: + + >>> class Handler: + ... def meth(self, arg): + ... print(arg) + >>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))() + crystal + """ + sig = inspect.signature(func) + params = sig.parameters.keys() + call_ns = {k: namespace[k] for k in params if k in namespace} + return functools.partial(func, **call_ns) + + +def save_method_args(method): + """ + Wrap a method such that when it is called, the args and kwargs are + saved on the method. + + >>> class MyClass: + ... @save_method_args + ... def method(self, a, b): + ... print(a, b) + >>> my_ob = MyClass() + >>> my_ob.method(1, 2) + 1 2 + >>> my_ob._saved_method.args + (1, 2) + >>> my_ob._saved_method.kwargs + {} + >>> my_ob.method(a=3, b='foo') + 3 foo + >>> my_ob._saved_method.args + () + >>> my_ob._saved_method.kwargs == dict(a=3, b='foo') + True + + The arguments are stored on the instance, allowing for + different instance to save different args. + + >>> your_ob = MyClass() + >>> your_ob.method({str('x'): 3}, b=[4]) + {'x': 3} [4] + >>> your_ob._saved_method.args + ({'x': 3},) + >>> my_ob._saved_method.args + () + """ + args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + attr_name = '_saved_' + method.__name__ + attr = args_and_kwargs(args, kwargs) + setattr(self, attr_name, attr) + return method(self, *args, **kwargs) + + return wrapper + + +def except_(*exceptions, replace=None, use=None): + """ + Replace the indicated exceptions, if raised, with the indicated + literal replacement or evaluated expression (if present). + + >>> safe_int = except_(ValueError)(int) + >>> safe_int('five') + >>> safe_int('5') + 5 + + Specify a literal replacement with ``replace``. + + >>> safe_int_r = except_(ValueError, replace=0)(int) + >>> safe_int_r('five') + 0 + + Provide an expression to ``use`` to pass through particular parameters. + + >>> safe_int_pt = except_(ValueError, use='args[0]')(int) + >>> safe_int_pt('five') + 'five' + + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except exceptions: + try: + return eval(use) + except TypeError: + return replace + + return wrapper + + return decorate diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/text/__init__.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c466378ceba69a335d2beb4d3af92703d52b3831 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/text/__init__.py @@ -0,0 +1,599 @@ +import re +import itertools +import textwrap +import functools + +try: + from importlib.resources import files # type: ignore +except ImportError: # pragma: nocover + from pkg_resources.extern.importlib_resources import files # type: ignore + +from pkg_resources.extern.jaraco.functools import compose, method_cache +from pkg_resources.extern.jaraco.context import ExceptionTrap + + +def substitution(old, new): + """ + Return a function that will perform a substitution on a string + """ + return lambda s: s.replace(old, new) + + +def multi_substitution(*substitutions): + """ + Take a sequence of pairs specifying substitutions, and create + a function that performs those substitutions. + + >>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo') + 'baz' + """ + substitutions = itertools.starmap(substitution, substitutions) + # compose function applies last function first, so reverse the + # substitutions to get the expected order. + substitutions = reversed(tuple(substitutions)) + return compose(*substitutions) + + +class FoldedCase(str): + """ + A case insensitive string class; behaves just like str + except compares equal when the only variation is case. + + >>> s = FoldedCase('hello world') + + >>> s == 'Hello World' + True + + >>> 'Hello World' == s + True + + >>> s != 'Hello World' + False + + >>> s.index('O') + 4 + + >>> s.split('O') + ['hell', ' w', 'rld'] + + >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) + ['alpha', 'Beta', 'GAMMA'] + + Sequence membership is straightforward. + + >>> "Hello World" in [s] + True + >>> s in ["Hello World"] + True + + You may test for set inclusion, but candidate and elements + must both be folded. + + >>> FoldedCase("Hello World") in {s} + True + >>> s in {FoldedCase("Hello World")} + True + + String inclusion works as long as the FoldedCase object + is on the right. + + >>> "hello" in FoldedCase("Hello World") + True + + But not if the FoldedCase object is on the left: + + >>> FoldedCase('hello') in 'Hello World' + False + + In that case, use ``in_``: + + >>> FoldedCase('hello').in_('Hello World') + True + + >>> FoldedCase('hello') > FoldedCase('Hello') + False + """ + + def __lt__(self, other): + return self.lower() < other.lower() + + def __gt__(self, other): + return self.lower() > other.lower() + + def __eq__(self, other): + return self.lower() == other.lower() + + def __ne__(self, other): + return self.lower() != other.lower() + + def __hash__(self): + return hash(self.lower()) + + def __contains__(self, other): + return super().lower().__contains__(other.lower()) + + def in_(self, other): + "Does self appear in other?" + return self in FoldedCase(other) + + # cache lower since it's likely to be called frequently. + @method_cache + def lower(self): + return super().lower() + + def index(self, sub): + return self.lower().index(sub.lower()) + + def split(self, splitter=' ', maxsplit=0): + pattern = re.compile(re.escape(splitter), re.I) + return pattern.split(self, maxsplit) + + +# Python 3.8 compatibility +_unicode_trap = ExceptionTrap(UnicodeDecodeError) + + +@_unicode_trap.passes +def is_decodable(value): + r""" + Return True if the supplied value is decodable (using the default + encoding). + + >>> is_decodable(b'\xff') + False + >>> is_decodable(b'\x32') + True + """ + value.decode() + + +def is_binary(value): + r""" + Return True if the value appears to be binary (that is, it's a byte + string and isn't decodable). + + >>> is_binary(b'\xff') + True + >>> is_binary('\xff') + False + """ + return isinstance(value, bytes) and not is_decodable(value) + + +def trim(s): + r""" + Trim something like a docstring to remove the whitespace that + is common due to indentation and formatting. + + >>> trim("\n\tfoo = bar\n\t\tbar = baz\n") + 'foo = bar\n\tbar = baz' + """ + return textwrap.dedent(s).strip() + + +def wrap(s): + """ + Wrap lines of text, retaining existing newlines as + paragraph markers. + + >>> print(wrap(lorem_ipsum)) + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad + minim veniam, quis nostrud exercitation ullamco laboris nisi ut + aliquip ex ea commodo consequat. Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur. Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum. + + Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam + varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus + magna felis sollicitudin mauris. Integer in mauris eu nibh euismod + gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis + risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, + eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas + fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla + a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, + neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing + sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque + nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus + quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, + molestie eu, feugiat in, orci. In hac habitasse platea dictumst. + """ + paragraphs = s.splitlines() + wrapped = ('\n'.join(textwrap.wrap(para)) for para in paragraphs) + return '\n\n'.join(wrapped) + + +def unwrap(s): + r""" + Given a multi-line string, return an unwrapped version. + + >>> wrapped = wrap(lorem_ipsum) + >>> wrapped.count('\n') + 20 + >>> unwrapped = unwrap(wrapped) + >>> unwrapped.count('\n') + 1 + >>> print(unwrapped) + Lorem ipsum dolor sit amet, consectetur adipiscing ... + Curabitur pretium tincidunt lacus. Nulla gravida orci ... + + """ + paragraphs = re.split(r'\n\n+', s) + cleaned = (para.replace('\n', ' ') for para in paragraphs) + return '\n'.join(cleaned) + + + + +class Splitter(object): + """object that will split a string with the given arguments for each call + + >>> s = Splitter(',') + >>> s('hello, world, this is your, master calling') + ['hello', ' world', ' this is your', ' master calling'] + """ + + def __init__(self, *args): + self.args = args + + def __call__(self, s): + return s.split(*self.args) + + +def indent(string, prefix=' ' * 4): + """ + >>> indent('foo') + ' foo' + """ + return prefix + string + + +class WordSet(tuple): + """ + Given an identifier, return the words that identifier represents, + whether in camel case, underscore-separated, etc. + + >>> WordSet.parse("camelCase") + ('camel', 'Case') + + >>> WordSet.parse("under_sep") + ('under', 'sep') + + Acronyms should be retained + + >>> WordSet.parse("firstSNL") + ('first', 'SNL') + + >>> WordSet.parse("you_and_I") + ('you', 'and', 'I') + + >>> WordSet.parse("A simple test") + ('A', 'simple', 'test') + + Multiple caps should not interfere with the first cap of another word. + + >>> WordSet.parse("myABCClass") + ('my', 'ABC', 'Class') + + The result is a WordSet, so you can get the form you need. + + >>> WordSet.parse("myABCClass").underscore_separated() + 'my_ABC_Class' + + >>> WordSet.parse('a-command').camel_case() + 'ACommand' + + >>> WordSet.parse('someIdentifier').lowered().space_separated() + 'some identifier' + + Slices of the result should return another WordSet. + + >>> WordSet.parse('taken-out-of-context')[1:].underscore_separated() + 'out_of_context' + + >>> WordSet.from_class_name(WordSet()).lowered().space_separated() + 'word set' + + >>> example = WordSet.parse('figured it out') + >>> example.headless_camel_case() + 'figuredItOut' + >>> example.dash_separated() + 'figured-it-out' + + """ + + _pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))') + + def capitalized(self): + return WordSet(word.capitalize() for word in self) + + def lowered(self): + return WordSet(word.lower() for word in self) + + def camel_case(self): + return ''.join(self.capitalized()) + + def headless_camel_case(self): + words = iter(self) + first = next(words).lower() + new_words = itertools.chain((first,), WordSet(words).camel_case()) + return ''.join(new_words) + + def underscore_separated(self): + return '_'.join(self) + + def dash_separated(self): + return '-'.join(self) + + def space_separated(self): + return ' '.join(self) + + def trim_right(self, item): + """ + Remove the item from the end of the set. + + >>> WordSet.parse('foo bar').trim_right('foo') + ('foo', 'bar') + >>> WordSet.parse('foo bar').trim_right('bar') + ('foo',) + >>> WordSet.parse('').trim_right('bar') + () + """ + return self[:-1] if self and self[-1] == item else self + + def trim_left(self, item): + """ + Remove the item from the beginning of the set. + + >>> WordSet.parse('foo bar').trim_left('foo') + ('bar',) + >>> WordSet.parse('foo bar').trim_left('bar') + ('foo', 'bar') + >>> WordSet.parse('').trim_left('bar') + () + """ + return self[1:] if self and self[0] == item else self + + def trim(self, item): + """ + >>> WordSet.parse('foo bar').trim('foo') + ('bar',) + """ + return self.trim_left(item).trim_right(item) + + def __getitem__(self, item): + result = super(WordSet, self).__getitem__(item) + if isinstance(item, slice): + result = WordSet(result) + return result + + @classmethod + def parse(cls, identifier): + matches = cls._pattern.finditer(identifier) + return WordSet(match.group(0) for match in matches) + + @classmethod + def from_class_name(cls, subject): + return cls.parse(subject.__class__.__name__) + + +# for backward compatibility +words = WordSet.parse + + +def simple_html_strip(s): + r""" + Remove HTML from the string `s`. + + >>> str(simple_html_strip('')) + '' + + >>> print(simple_html_strip('A stormy day in paradise')) + A stormy day in paradise + + >>> print(simple_html_strip('Somebody tell the truth.')) + Somebody tell the truth. + + >>> print(simple_html_strip('What about
\nmultiple lines?')) + What about + multiple lines? + """ + html_stripper = re.compile('()|(<[^>]*>)|([^<]+)', re.DOTALL) + texts = (match.group(3) or '' for match in html_stripper.finditer(s)) + return ''.join(texts) + + +class SeparatedValues(str): + """ + A string separated by a separator. Overrides __iter__ for getting + the values. + + >>> list(SeparatedValues('a,b,c')) + ['a', 'b', 'c'] + + Whitespace is stripped and empty values are discarded. + + >>> list(SeparatedValues(' a, b , c, ')) + ['a', 'b', 'c'] + """ + + separator = ',' + + def __iter__(self): + parts = self.split(self.separator) + return filter(None, (part.strip() for part in parts)) + + +class Stripper: + r""" + Given a series of lines, find the common prefix and strip it from them. + + >>> lines = [ + ... 'abcdefg\n', + ... 'abc\n', + ... 'abcde\n', + ... ] + >>> res = Stripper.strip_prefix(lines) + >>> res.prefix + 'abc' + >>> list(res.lines) + ['defg\n', '\n', 'de\n'] + + If no prefix is common, nothing should be stripped. + + >>> lines = [ + ... 'abcd\n', + ... '1234\n', + ... ] + >>> res = Stripper.strip_prefix(lines) + >>> res.prefix = '' + >>> list(res.lines) + ['abcd\n', '1234\n'] + """ + + def __init__(self, prefix, lines): + self.prefix = prefix + self.lines = map(self, lines) + + @classmethod + def strip_prefix(cls, lines): + prefix_lines, lines = itertools.tee(lines) + prefix = functools.reduce(cls.common_prefix, prefix_lines) + return cls(prefix, lines) + + def __call__(self, line): + if not self.prefix: + return line + null, prefix, rest = line.partition(self.prefix) + return rest + + @staticmethod + def common_prefix(s1, s2): + """ + Return the common prefix of two lines. + """ + index = min(len(s1), len(s2)) + while s1[:index] != s2[:index]: + index -= 1 + return s1[:index] + + +def remove_prefix(text, prefix): + """ + Remove the prefix from the text if it exists. + + >>> remove_prefix('underwhelming performance', 'underwhelming ') + 'performance' + + >>> remove_prefix('something special', 'sample') + 'something special' + """ + null, prefix, rest = text.rpartition(prefix) + return rest + + +def remove_suffix(text, suffix): + """ + Remove the suffix from the text if it exists. + + >>> remove_suffix('name.git', '.git') + 'name' + + >>> remove_suffix('something special', 'sample') + 'something special' + """ + rest, suffix, null = text.partition(suffix) + return rest + + +def normalize_newlines(text): + r""" + Replace alternate newlines with the canonical newline. + + >>> normalize_newlines('Lorem Ipsum\u2029') + 'Lorem Ipsum\n' + >>> normalize_newlines('Lorem Ipsum\r\n') + 'Lorem Ipsum\n' + >>> normalize_newlines('Lorem Ipsum\x85') + 'Lorem Ipsum\n' + """ + newlines = ['\r\n', '\r', '\n', '\u0085', '\u2028', '\u2029'] + pattern = '|'.join(newlines) + return re.sub(pattern, '\n', text) + + +def _nonblank(str): + return str and not str.startswith('#') + + +@functools.singledispatch +def yield_lines(iterable): + r""" + Yield valid lines of a string or iterable. + + >>> list(yield_lines('')) + [] + >>> list(yield_lines(['foo', 'bar'])) + ['foo', 'bar'] + >>> list(yield_lines('foo\nbar')) + ['foo', 'bar'] + >>> list(yield_lines('\nfoo\n#bar\nbaz #comment')) + ['foo', 'baz #comment'] + >>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n'])) + ['foo', 'bar', 'baz', 'bing'] + """ + return itertools.chain.from_iterable(map(yield_lines, iterable)) + + +@yield_lines.register(str) +def _(text): + return filter(_nonblank, map(str.strip, text.splitlines())) + + +def drop_comment(line): + """ + Drop comments. + + >>> drop_comment('foo # bar') + 'foo' + + A hash without a space may be in a URL. + + >>> drop_comment('http://example.com/foo#bar') + 'http://example.com/foo#bar' + """ + return line.partition(' #')[0] + + +def join_continuation(lines): + r""" + Join lines continued by a trailing backslash. + + >>> list(join_continuation(['foo \\', 'bar', 'baz'])) + ['foobar', 'baz'] + >>> list(join_continuation(['foo \\', 'bar', 'baz'])) + ['foobar', 'baz'] + >>> list(join_continuation(['foo \\', 'bar \\', 'baz'])) + ['foobarbaz'] + + Not sure why, but... + The character preceeding the backslash is also elided. + + >>> list(join_continuation(['goo\\', 'dly'])) + ['godly'] + + A terrible idea, but... + If no line is available to continue, suppress the lines. + + >>> list(join_continuation(['foo', 'bar\\', 'baz\\'])) + ['foo'] + """ + lines = iter(lines) + for item in lines: + while item.endswith('\\'): + try: + item = item[:-2].strip() + next(lines) + except StopIteration: + return + yield item diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/text/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/text/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d7e56c5753f09b1fed8b1d87d2aa399ee1cabb3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/jaraco/text/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__init__.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea38bef1f661e62d577b3c2207386d901d851c72 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__init__.py @@ -0,0 +1,4 @@ +from .more import * # noqa +from .recipes import * # noqa + +__version__ = '8.12.0' diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48cbe0c0d444d1cc0859a71ea2ed72e15baafc64 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/recipes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/recipes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b3bd43d70a51e444a56125004c22dca9fbbe704 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/recipes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/more.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/more.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6a5cab25ad87ec414c3180611f33575308d54f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/more.py @@ -0,0 +1,4316 @@ +import warnings + +from collections import Counter, defaultdict, deque, abc +from collections.abc import Sequence +from functools import partial, reduce, wraps +from heapq import merge, heapify, heapreplace, heappop +from itertools import ( + chain, + compress, + count, + cycle, + dropwhile, + groupby, + islice, + repeat, + starmap, + takewhile, + tee, + zip_longest, +) +from math import exp, factorial, floor, log +from queue import Empty, Queue +from random import random, randrange, uniform +from operator import itemgetter, mul, sub, gt, lt, ge, le +from sys import hexversion, maxsize +from time import monotonic + +from .recipes import ( + consume, + flatten, + pairwise, + powerset, + take, + unique_everseen, +) + +__all__ = [ + 'AbortThread', + 'SequenceView', + 'UnequalIterablesError', + 'adjacent', + 'all_unique', + 'always_iterable', + 'always_reversible', + 'bucket', + 'callback_iter', + 'chunked', + 'chunked_even', + 'circular_shifts', + 'collapse', + 'collate', + 'combination_index', + 'consecutive_groups', + 'consumer', + 'count_cycle', + 'countable', + 'difference', + 'distinct_combinations', + 'distinct_permutations', + 'distribute', + 'divide', + 'duplicates_everseen', + 'duplicates_justseen', + 'exactly_n', + 'filter_except', + 'first', + 'groupby_transform', + 'ichunked', + 'ilen', + 'interleave', + 'interleave_evenly', + 'interleave_longest', + 'intersperse', + 'is_sorted', + 'islice_extended', + 'iterate', + 'last', + 'locate', + 'lstrip', + 'make_decorator', + 'map_except', + 'map_if', + 'map_reduce', + 'mark_ends', + 'minmax', + 'nth_or_last', + 'nth_permutation', + 'nth_product', + 'numeric_range', + 'one', + 'only', + 'padded', + 'partitions', + 'peekable', + 'permutation_index', + 'product_index', + 'raise_', + 'repeat_each', + 'repeat_last', + 'replace', + 'rlocate', + 'rstrip', + 'run_length', + 'sample', + 'seekable', + 'set_partitions', + 'side_effect', + 'sliced', + 'sort_together', + 'split_after', + 'split_at', + 'split_before', + 'split_into', + 'split_when', + 'spy', + 'stagger', + 'strip', + 'strictly_n', + 'substrings', + 'substrings_indexes', + 'time_limited', + 'unique_in_window', + 'unique_to_each', + 'unzip', + 'value_chain', + 'windowed', + 'windowed_complete', + 'with_iter', + 'zip_broadcast', + 'zip_equal', + 'zip_offset', +] + + +_marker = object() + + +def chunked(iterable, n, strict=False): + """Break *iterable* into lists of length *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) + [[1, 2, 3], [4, 5, 6]] + + By the default, the last yielded list will have fewer than *n* elements + if the length of *iterable* is not divisible by *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) + [[1, 2, 3], [4, 5, 6], [7, 8]] + + To use a fill-in value instead, see the :func:`grouper` recipe. + + If the length of *iterable* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + list is yielded. + + """ + iterator = iter(partial(take, n, iter(iterable)), []) + if strict: + if n is None: + raise ValueError('n must not be None when using strict mode.') + + def ret(): + for chunk in iterator: + if len(chunk) != n: + raise ValueError('iterable is not divisible by n.') + yield chunk + + return iter(ret()) + else: + return iterator + + +def first(iterable, default=_marker): + """Return the first item of *iterable*, or *default* if *iterable* is + empty. + + >>> first([0, 1, 2, 3]) + 0 + >>> first([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + + :func:`first` is useful when you have a generator of expensive-to-retrieve + values and want any arbitrary one. It is marginally shorter than + ``next(iter(iterable), default)``. + + """ + try: + return next(iter(iterable)) + except StopIteration as e: + if default is _marker: + raise ValueError( + 'first() was called on an empty iterable, and no ' + 'default value was provided.' + ) from e + return default + + +def last(iterable, default=_marker): + """Return the last item of *iterable*, or *default* if *iterable* is + empty. + + >>> last([0, 1, 2, 3]) + 3 + >>> last([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + try: + if isinstance(iterable, Sequence): + return iterable[-1] + # Work around https://bugs.python.org/issue38525 + elif hasattr(iterable, '__reversed__') and (hexversion != 0x030800F0): + return next(reversed(iterable)) + else: + return deque(iterable, maxlen=1)[-1] + except (IndexError, TypeError, StopIteration): + if default is _marker: + raise ValueError( + 'last() was called on an empty iterable, and no default was ' + 'provided.' + ) + return default + + +def nth_or_last(iterable, n, default=_marker): + """Return the nth or the last item of *iterable*, + or *default* if *iterable* is empty. + + >>> nth_or_last([0, 1, 2, 3], 2) + 2 + >>> nth_or_last([0, 1], 2) + 1 + >>> nth_or_last([], 0, 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + return last(islice(iterable, n + 1), default=default) + + +class peekable: + """Wrap an iterator to allow lookahead and prepending elements. + + Call :meth:`peek` on the result to get the value that will be returned + by :func:`next`. This won't advance the iterator: + + >>> p = peekable(['a', 'b']) + >>> p.peek() + 'a' + >>> next(p) + 'a' + + Pass :meth:`peek` a default value to return that instead of raising + ``StopIteration`` when the iterator is exhausted. + + >>> p = peekable([]) + >>> p.peek('hi') + 'hi' + + peekables also offer a :meth:`prepend` method, which "inserts" items + at the head of the iterable: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> p.peek() + 11 + >>> list(p) + [11, 12, 1, 2, 3] + + peekables can be indexed. Index 0 is the item that will be returned by + :func:`next`, index 1 is the item after that, and so on: + The values up to the given index will be cached. + + >>> p = peekable(['a', 'b', 'c', 'd']) + >>> p[0] + 'a' + >>> p[1] + 'b' + >>> next(p) + 'a' + + Negative indexes are supported, but be aware that they will cache the + remaining items in the source iterator, which may require significant + storage. + + To check whether a peekable is exhausted, check its truth value: + + >>> p = peekable(['a', 'b']) + >>> if p: # peekable has items + ... list(p) + ['a', 'b'] + >>> if not p: # peekable is exhausted + ... list(p) + [] + + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self._cache = deque() + + def __iter__(self): + return self + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + """Return the item that will be next returned from ``next()``. + + Return ``default`` if there are no items left. If ``default`` is not + provided, raise ``StopIteration``. + + """ + if not self._cache: + try: + self._cache.append(next(self._it)) + except StopIteration: + if default is _marker: + raise + return default + return self._cache[0] + + def prepend(self, *items): + """Stack up items to be the next ones returned from ``next()`` or + ``self.peek()``. The items will be returned in + first in, first out order:: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> list(p) + [11, 12, 1, 2, 3] + + It is possible, by prepending items, to "resurrect" a peekable that + previously raised ``StopIteration``. + + >>> p = peekable([]) + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + >>> p.prepend(1) + >>> next(p) + 1 + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + + """ + self._cache.extendleft(reversed(items)) + + def __next__(self): + if self._cache: + return self._cache.popleft() + + return next(self._it) + + def _get_slice(self, index): + # Normalize the slice's arguments + step = 1 if (index.step is None) else index.step + if step > 0: + start = 0 if (index.start is None) else index.start + stop = maxsize if (index.stop is None) else index.stop + elif step < 0: + start = -1 if (index.start is None) else index.start + stop = (-maxsize - 1) if (index.stop is None) else index.stop + else: + raise ValueError('slice step cannot be zero') + + # If either the start or stop index is negative, we'll need to cache + # the rest of the iterable in order to slice from the right side. + if (start < 0) or (stop < 0): + self._cache.extend(self._it) + # Otherwise we'll need to find the rightmost index and cache to that + # point. + else: + n = min(max(start, stop) + 1, maxsize) + cache_len = len(self._cache) + if n >= cache_len: + self._cache.extend(islice(self._it, n - cache_len)) + + return list(self._cache)[index] + + def __getitem__(self, index): + if isinstance(index, slice): + return self._get_slice(index) + + cache_len = len(self._cache) + if index < 0: + self._cache.extend(self._it) + elif index >= cache_len: + self._cache.extend(islice(self._it, index + 1 - cache_len)) + + return self._cache[index] + + +def collate(*iterables, **kwargs): + """Return a sorted merge of the items from each of several already-sorted + *iterables*. + + >>> list(collate('ACDZ', 'AZ', 'JKL')) + ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] + + Works lazily, keeping only the next value from each iterable in memory. Use + :func:`collate` to, for example, perform a n-way mergesort of items that + don't fit in memory. + + If a *key* function is specified, the iterables will be sorted according + to its result: + + >>> key = lambda s: int(s) # Sort by numeric value, not by string + >>> list(collate(['1', '10'], ['2', '11'], key=key)) + ['1', '2', '10', '11'] + + + If the *iterables* are sorted in descending order, set *reverse* to + ``True``: + + >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True)) + [5, 4, 3, 2, 1, 0] + + If the elements of the passed-in iterables are out of order, you might get + unexpected results. + + On Python 3.5+, this function is an alias for :func:`heapq.merge`. + + """ + warnings.warn( + "collate is no longer part of more_itertools, use heapq.merge", + DeprecationWarning, + ) + return merge(*iterables, **kwargs) + + +def consumer(func): + """Decorator that automatically advances a PEP-342-style "reverse iterator" + to its first yield point so you don't have to call ``next()`` on it + manually. + + >>> @consumer + ... def tally(): + ... i = 0 + ... while True: + ... print('Thing number %s is %s.' % (i, (yield))) + ... i += 1 + ... + >>> t = tally() + >>> t.send('red') + Thing number 0 is red. + >>> t.send('fish') + Thing number 1 is fish. + + Without the decorator, you would have to call ``next(t)`` before + ``t.send()`` could be used. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + gen = func(*args, **kwargs) + next(gen) + return gen + + return wrapper + + +def ilen(iterable): + """Return the number of items in *iterable*. + + >>> ilen(x for x in range(1000000) if x % 3 == 0) + 333334 + + This consumes the iterable, so handle with care. + + """ + # This approach was selected because benchmarks showed it's likely the + # fastest of the known implementations at the time of writing. + # See GitHub tracker: #236, #230. + counter = count() + deque(zip(iterable, counter), maxlen=0) + return next(counter) + + +def iterate(func, start): + """Return ``start``, ``func(start)``, ``func(func(start))``, ... + + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + """ + while True: + yield start + start = func(start) + + +def with_iter(context_manager): + """Wrap an iterable in a ``with`` statement, so it closes once exhausted. + + For example, this will close the file when the iterator is exhausted:: + + upper_lines = (line.upper() for line in with_iter(open('foo'))) + + Any context manager which returns an iterable is a candidate for + ``with_iter``. + + """ + with context_manager as iterable: + yield from iterable + + +def one(iterable, too_short=None, too_long=None): + """Return the first item from *iterable*, which is expected to contain only + that item. Raise an exception if *iterable* is empty or has more than one + item. + + :func:`one` is useful for ensuring that an iterable contains only one item. + For example, it can be used to retrieve the result of a database query + that is expected to return a single row. + + If *iterable* is empty, ``ValueError`` will be raised. You may specify a + different exception with the *too_short* keyword: + + >>> it = [] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_short = IndexError('too few items') + >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + IndexError: too few items + + Similarly, if *iterable* contains more than one item, ``ValueError`` will + be raised. You may specify a different exception with the *too_long* + keyword: + + >>> it = ['too', 'many'] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 'too', + 'many', and perhaps more. + >>> too_long = RuntimeError + >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + Note that :func:`one` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check iterable + contents less destructively. + + """ + it = iter(iterable) + + try: + first_value = next(it) + except StopIteration as e: + raise ( + too_short or ValueError('too few items in iterable (expected 1)') + ) from e + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def raise_(exception, *args): + raise exception(*args) + + +def strictly_n(iterable, n, too_short=None, too_long=None): + """Validate that *iterable* has exactly *n* items and return them if + it does. If it has fewer than *n* items, call function *too_short* + with those items. If it has more than *n* items, call function + *too_long* with the first ``n + 1`` items. + + >>> iterable = ['a', 'b', 'c', 'd'] + >>> n = 4 + >>> list(strictly_n(iterable, n)) + ['a', 'b', 'c', 'd'] + + By default, *too_short* and *too_long* are functions that raise + ``ValueError``. + + >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too few items in iterable (got 2) + + >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (got at least 3) + + You can instead supply functions that do something else. + *too_short* will be called with the number of items in *iterable*. + *too_long* will be called with `n + 1`. + + >>> def too_short(item_count): + ... raise RuntimeError + >>> it = strictly_n('abcd', 6, too_short=too_short) + >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + >>> def too_long(item_count): + ... print('The boss is going to hear about this') + >>> it = strictly_n('abcdef', 4, too_long=too_long) + >>> list(it) + The boss is going to hear about this + ['a', 'b', 'c', 'd'] + + """ + if too_short is None: + too_short = lambda item_count: raise_( + ValueError, + 'Too few items in iterable (got {})'.format(item_count), + ) + + if too_long is None: + too_long = lambda item_count: raise_( + ValueError, + 'Too many items in iterable (got at least {})'.format(item_count), + ) + + it = iter(iterable) + for i in range(n): + try: + item = next(it) + except StopIteration: + too_short(i) + return + else: + yield item + + try: + next(it) + except StopIteration: + pass + else: + too_long(n + 1) + + +def distinct_permutations(iterable, r=None): + """Yield successive distinct permutations of the elements in *iterable*. + + >>> sorted(distinct_permutations([1, 0, 1])) + [(0, 1, 1), (1, 0, 1), (1, 1, 0)] + + Equivalent to ``set(permutations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + Duplicate permutations arise when there are duplicated elements in the + input iterable. The number of items returned is + `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of + items input, and each `x_i` is the count of a distinct item in the input + sequence. + + If *r* is given, only the *r*-length permutations are yielded. + + >>> sorted(distinct_permutations([1, 0, 1], r=2)) + [(0, 1), (1, 0), (1, 1)] + >>> sorted(distinct_permutations(range(3), r=2)) + [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + + """ + # Algorithm: https://w.wiki/Qai + def _full(A): + while True: + # Yield the permutation we have + yield tuple(A) + + # Find the largest index i such that A[i] < A[i + 1] + for i in range(size - 2, -1, -1): + if A[i] < A[i + 1]: + break + # If no such index exists, this permutation is the last one + else: + return + + # Find the largest index j greater than j such that A[i] < A[j] + for j in range(size - 1, i, -1): + if A[i] < A[j]: + break + + # Swap the value of A[i] with that of A[j], then reverse the + # sequence from A[i + 1] to form the new permutation + A[i], A[j] = A[j], A[i] + A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1] + + # Algorithm: modified from the above + def _partial(A, r): + # Split A into the first r items and the last r items + head, tail = A[:r], A[r:] + right_head_indexes = range(r - 1, -1, -1) + left_tail_indexes = range(len(tail)) + + while True: + # Yield the permutation we have + yield tuple(head) + + # Starting from the right, find the first index of the head with + # value smaller than the maximum value of the tail - call it i. + pivot = tail[-1] + for i in right_head_indexes: + if head[i] < pivot: + break + pivot = head[i] + else: + return + + # Starting from the left, find the first value of the tail + # with a value greater than head[i] and swap. + for j in left_tail_indexes: + if tail[j] > head[i]: + head[i], tail[j] = tail[j], head[i] + break + # If we didn't find one, start from the right and find the first + # index of the head with a value greater than head[i] and swap. + else: + for j in right_head_indexes: + if head[j] > head[i]: + head[i], head[j] = head[j], head[i] + break + + # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] + tail += head[: i - r : -1] # head[i + 1:][::-1] + i += 1 + head[i:], tail[:] = tail[: r - i], tail[r - i :] + + items = sorted(iterable) + + size = len(items) + if r is None: + r = size + + if 0 < r <= size: + return _full(items) if (r == size) else _partial(items, r) + + return iter(() if r else ((),)) + + +def intersperse(e, iterable, n=1): + """Intersperse filler element *e* among the items in *iterable*, leaving + *n* items between each filler element. + + >>> list(intersperse('!', [1, 2, 3, 4, 5])) + [1, '!', 2, '!', 3, '!', 4, '!', 5] + + >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2)) + [1, 2, None, 3, 4, None, 5] + + """ + if n == 0: + raise ValueError('n must be > 0') + elif n == 1: + # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, x_1, e, x_2... + return islice(interleave(repeat(e), iterable), 1, None) + else: + # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... + # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]... + # flatten(...) -> x_0, x_1, e, x_2, x_3... + filler = repeat([e]) + chunks = chunked(iterable, n) + return flatten(islice(interleave(filler, chunks), 1, None)) + + +def unique_to_each(*iterables): + """Return the elements from each of the input iterables that aren't in the + other input iterables. + + For example, suppose you have a set of packages, each with a set of + dependencies:: + + {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}} + + If you remove one package, which dependencies can also be removed? + + If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not + associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for + ``pkg_2``, and ``D`` is only needed for ``pkg_3``:: + + >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'}) + [['A'], ['C'], ['D']] + + If there are duplicates in one input iterable that aren't in the others + they will be duplicated in the output. Input order is preserved:: + + >>> unique_to_each("mississippi", "missouri") + [['p', 'p'], ['o', 'u', 'r']] + + It is assumed that the elements of each iterable are hashable. + + """ + pool = [list(it) for it in iterables] + counts = Counter(chain.from_iterable(map(set, pool))) + uniques = {element for element in counts if counts[element] == 1} + return [list(filter(uniques.__contains__, it)) for it in pool] + + +def windowed(seq, n, fillvalue=None, step=1): + """Return a sliding window of width *n* over the given iterable. + + >>> all_windows = windowed([1, 2, 3, 4, 5], 3) + >>> list(all_windows) + [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + + When the window is larger than the iterable, *fillvalue* is used in place + of missing values: + + >>> list(windowed([1, 2, 3], 4)) + [(1, 2, 3, None)] + + Each window will advance in increments of *step*: + + >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) + [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + + To slide into the iterable's items, use :func:`chain` to add filler items + to the left: + + >>> iterable = [1, 2, 3, 4] + >>> n = 3 + >>> padding = [None] * (n - 1) + >>> list(windowed(chain(padding, iterable), 3)) + [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] + """ + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + yield tuple() + return + if step < 1: + raise ValueError('step must be >= 1') + + window = deque(maxlen=n) + i = n + for _ in map(window.append, seq): + i -= 1 + if not i: + i = step + yield tuple(window) + + size = len(window) + if size < n: + yield tuple(chain(window, repeat(fillvalue, n - size))) + elif 0 < i < min(step, n): + window += (fillvalue,) * i + yield tuple(window) + + +def substrings(iterable): + """Yield all of the substrings of *iterable*. + + >>> [''.join(s) for s in substrings('more')] + ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more'] + + Note that non-string iterables can also be subdivided. + + >>> list(substrings([0, 1, 2])) + [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)] + + """ + # The length-1 substrings + seq = [] + for item in iter(iterable): + seq.append(item) + yield (item,) + seq = tuple(seq) + item_count = len(seq) + + # And the rest + for n in range(2, item_count + 1): + for i in range(item_count - n + 1): + yield seq[i : i + n] + + +def substrings_indexes(seq, reverse=False): + """Yield all substrings and their positions in *seq* + + The items yielded will be a tuple of the form ``(substr, i, j)``, where + ``substr == seq[i:j]``. + + This function only works for iterables that support slicing, such as + ``str`` objects. + + >>> for item in substrings_indexes('more'): + ... print(item) + ('m', 0, 1) + ('o', 1, 2) + ('r', 2, 3) + ('e', 3, 4) + ('mo', 0, 2) + ('or', 1, 3) + ('re', 2, 4) + ('mor', 0, 3) + ('ore', 1, 4) + ('more', 0, 4) + + Set *reverse* to ``True`` to yield the same items in the opposite order. + + + """ + r = range(1, len(seq) + 1) + if reverse: + r = reversed(r) + return ( + (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) + ) + + +class bucket: + """Wrap *iterable* and return an object that buckets it iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + yield from self._cache.keys() + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) + + +def spy(iterable, n=1): + """Return a 2-tuple with a list containing the first *n* elements of + *iterable*, and an iterator with the same items as *iterable*. + This allows you to "look ahead" at the items in the iterable without + advancing it. + + There is one item in the list by default: + + >>> iterable = 'abcdefg' + >>> head, iterable = spy(iterable) + >>> head + ['a'] + >>> list(iterable) + ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + + You may use unpacking to retrieve items instead of lists: + + >>> (head,), iterable = spy('abcdefg') + >>> head + 'a' + >>> (first, second), iterable = spy('abcdefg', 2) + >>> first + 'a' + >>> second + 'b' + + The number of items requested can be larger than the number of items in + the iterable: + + >>> iterable = [1, 2, 3, 4, 5] + >>> head, iterable = spy(iterable, 10) + >>> head + [1, 2, 3, 4, 5] + >>> list(iterable) + [1, 2, 3, 4, 5] + + """ + it = iter(iterable) + head = take(n, it) + + return head.copy(), chain(head, it) + + +def interleave(*iterables): + """Return a new iterable yielding from each iterable in turn, + until the shortest is exhausted. + + >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7] + + For a version that doesn't terminate after the shortest iterable is + exhausted, see :func:`interleave_longest`. + + """ + return chain.from_iterable(zip(*iterables)) + + +def interleave_longest(*iterables): + """Return a new iterable yielding from each iterable in turn, + skipping any that are exhausted. + + >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7, 3, 8] + + This function produces the same output as :func:`roundrobin`, but may + perform better for some inputs (in particular when the number of iterables + is large). + + """ + i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker)) + return (x for x in i if x is not _marker) + + +def interleave_evenly(iterables, lengths=None): + """ + Interleave multiple iterables so that their elements are evenly distributed + throughout the output sequence. + + >>> iterables = [1, 2, 3, 4, 5], ['a', 'b'] + >>> list(interleave_evenly(iterables)) + [1, 2, 'a', 3, 4, 'b', 5] + + >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]] + >>> list(interleave_evenly(iterables)) + [1, 6, 4, 2, 7, 3, 8, 5] + + This function requires iterables of known length. Iterables without + ``__len__()`` can be used by manually specifying lengths with *lengths*: + + >>> from itertools import combinations, repeat + >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']] + >>> lengths = [4 * (4 - 1) // 2, 3] + >>> list(interleave_evenly(iterables, lengths=lengths)) + [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c'] + + Based on Bresenham's algorithm. + """ + if lengths is None: + try: + lengths = [len(it) for it in iterables] + except TypeError: + raise ValueError( + 'Iterable lengths could not be determined automatically. ' + 'Specify them with the lengths keyword.' + ) + elif len(iterables) != len(lengths): + raise ValueError('Mismatching number of iterables and lengths.') + + dims = len(lengths) + + # sort iterables by length, descending + lengths_permute = sorted( + range(dims), key=lambda i: lengths[i], reverse=True + ) + lengths_desc = [lengths[i] for i in lengths_permute] + iters_desc = [iter(iterables[i]) for i in lengths_permute] + + # the longest iterable is the primary one (Bresenham: the longest + # distance along an axis) + delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:] + iter_primary, iters_secondary = iters_desc[0], iters_desc[1:] + errors = [delta_primary // dims] * len(deltas_secondary) + + to_yield = sum(lengths) + while to_yield: + yield next(iter_primary) + to_yield -= 1 + # update errors for each secondary iterable + errors = [e - delta for e, delta in zip(errors, deltas_secondary)] + + # those iterables for which the error is negative are yielded + # ("diagonal step" in Bresenham) + for i, e in enumerate(errors): + if e < 0: + yield next(iters_secondary[i]) + to_yield -= 1 + errors[i] += delta_primary + + +def collapse(iterable, base_type=None, levels=None): + """Flatten an iterable with multiple levels of nesting (e.g., a list of + lists of tuples) into non-iterable types. + + >>> iterable = [(1, 2), ([3, 4], [[5], [6]])] + >>> list(collapse(iterable)) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and + will not be collapsed. + + To avoid collapsing other types, specify *base_type*: + + >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] + >>> list(collapse(iterable, base_type=tuple)) + ['ab', ('cd', 'ef'), 'gh', 'ij'] + + Specify *levels* to stop flattening after a certain level: + + >>> iterable = [('a', ['b']), ('c', ['d'])] + >>> list(collapse(iterable)) # Fully flattened + ['a', 'b', 'c', 'd'] + >>> list(collapse(iterable, levels=1)) # Only one level flattened + ['a', ['b'], 'c', ['d']] + + """ + + def walk(node, level): + if ( + ((levels is not None) and (level > levels)) + or isinstance(node, (str, bytes)) + or ((base_type is not None) and isinstance(node, base_type)) + ): + yield node + return + + try: + tree = iter(node) + except TypeError: + yield node + return + else: + for child in tree: + yield from walk(child, level + 1) + + yield from walk(iterable, 0) + + +def side_effect(func, iterable, chunk_size=None, before=None, after=None): + """Invoke *func* on each item in *iterable* (or on each *chunk_size* group + of items) before yielding the item. + + `func` must be a function that takes a single argument. Its return value + will be discarded. + + *before* and *after* are optional functions that take no arguments. They + will be executed before iteration starts and after it ends, respectively. + + `side_effect` can be used for logging, updating progress bars, or anything + that is not functionally "pure." + + Emitting a status message: + + >>> from more_itertools import consume + >>> func = lambda item: print('Received {}'.format(item)) + >>> consume(side_effect(func, range(2))) + Received 0 + Received 1 + + Operating on chunks of items: + + >>> pair_sums = [] + >>> func = lambda chunk: pair_sums.append(sum(chunk)) + >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2)) + [0, 1, 2, 3, 4, 5] + >>> list(pair_sums) + [1, 5, 9] + + Writing to a file-like object: + + >>> from io import StringIO + >>> from more_itertools import consume + >>> f = StringIO() + >>> func = lambda x: print(x, file=f) + >>> before = lambda: print(u'HEADER', file=f) + >>> after = f.close + >>> it = [u'a', u'b', u'c'] + >>> consume(side_effect(func, it, before=before, after=after)) + >>> f.closed + True + + """ + try: + if before is not None: + before() + + if chunk_size is None: + for item in iterable: + func(item) + yield item + else: + for chunk in chunked(iterable, chunk_size): + func(chunk) + yield from chunk + finally: + if after is not None: + after() + + +def sliced(seq, n, strict=False): + """Yield slices of length *n* from the sequence *seq*. + + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] + + By the default, the last yielded slice will have fewer than *n* elements + if the length of *seq* is not divisible by *n*: + + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + If the length of *seq* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + slice is yielded. + + This function will only work for iterables that support slicing. + For non-sliceable iterables, see :func:`chunked`. + + """ + iterator = takewhile(len, (seq[i : i + n] for i in count(0, n))) + if strict: + + def ret(): + for _slice in iterator: + if len(_slice) != n: + raise ValueError("seq is not divisible by n.") + yield _slice + + return iter(ret()) + else: + return iterator + + +def split_at(iterable, pred, maxsplit=-1, keep_separator=False): + """Yield lists of items from *iterable*, where each list is delimited by + an item where callable *pred* returns ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b')) + [['a'], ['c', 'd', 'c'], ['a']] + + >>> list(split_at(range(10), lambda n: n % 2 == 1)) + [[0], [2], [4], [6], [8], []] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) + [[0], [2], [4, 5, 6, 7, 8, 9]] + + By default, the delimiting items are not included in the output. + The include them, set *keep_separator* to ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) + [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item): + yield buf + if keep_separator: + yield [item] + if maxsplit == 1: + yield list(it) + return + buf = [] + maxsplit -= 1 + else: + buf.append(item) + yield buf + + +def split_before(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends just before + an item for which callable *pred* returns ``True``: + + >>> list(split_before('OneTwo', lambda s: s.isupper())) + [['O', 'n', 'e'], ['T', 'w', 'o']] + + >>> list(split_before(range(10), lambda n: n % 3 == 0)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item) and buf: + yield buf + if maxsplit == 1: + yield [item] + list(it) + return + buf = [] + maxsplit -= 1 + buf.append(item) + if buf: + yield buf + + +def split_after(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends with an + item where callable *pred* returns ``True``: + + >>> list(split_after('one1two2', lambda s: s.isdigit())) + [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']] + + >>> list(split_after(range(10), lambda n: n % 3 == 0)) + [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + buf.append(item) + if pred(item) and buf: + yield buf + if maxsplit == 1: + yield list(it) + return + buf = [] + maxsplit -= 1 + if buf: + yield buf + + +def split_when(iterable, pred, maxsplit=-1): + """Split *iterable* into pieces based on the output of *pred*. + *pred* should be a function that takes successive pairs of items and + returns ``True`` if the iterable should be split in between them. + + For example, to find runs of increasing numbers, split the iterable when + element ``i`` is larger than element ``i + 1``: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) + [[1, 2, 3, 3], [2, 5], [2, 4], [2]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], + ... lambda x, y: x > y, maxsplit=2)) + [[1, 2, 3, 3], [2, 5], [2, 4, 2]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + it = iter(iterable) + try: + cur_item = next(it) + except StopIteration: + return + + buf = [cur_item] + for next_item in it: + if pred(cur_item, next_item): + yield buf + if maxsplit == 1: + yield [next_item] + list(it) + return + buf = [] + maxsplit -= 1 + + buf.append(next_item) + cur_item = next_item + + yield buf + + +def split_into(iterable, sizes): + """Yield a list of sequential items from *iterable* of length 'n' for each + integer 'n' in *sizes*. + + >>> list(split_into([1,2,3,4,5,6], [1,2,3])) + [[1], [2, 3], [4, 5, 6]] + + If the sum of *sizes* is smaller than the length of *iterable*, then the + remaining items of *iterable* will not be returned. + + >>> list(split_into([1,2,3,4,5,6], [2,3])) + [[1, 2], [3, 4, 5]] + + If the sum of *sizes* is larger than the length of *iterable*, fewer items + will be returned in the iteration that overruns *iterable* and further + lists will be empty: + + >>> list(split_into([1,2,3,4], [1,2,3,4])) + [[1], [2, 3], [4], []] + + When a ``None`` object is encountered in *sizes*, the returned list will + contain items up to the end of *iterable* the same way that itertools.slice + does: + + >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None])) + [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]] + + :func:`split_into` can be useful for grouping a series of items where the + sizes of the groups are not uniform. An example would be where in a row + from a table, multiple columns represent elements of the same feature + (e.g. a point represented by x,y,z) but, the format is not the same for + all columns. + """ + # convert the iterable argument into an iterator so its contents can + # be consumed by islice in case it is a generator + it = iter(iterable) + + for size in sizes: + if size is None: + yield list(it) + return + else: + yield list(islice(it, size)) + + +def padded(iterable, fillvalue=None, n=None, next_multiple=False): + """Yield the elements from *iterable*, followed by *fillvalue*, such that + at least *n* items are emitted. + + >>> list(padded([1, 2, 3], '?', 5)) + [1, 2, 3, '?', '?'] + + If *next_multiple* is ``True``, *fillvalue* will be emitted until the + number of items emitted is a multiple of *n*:: + + >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True)) + [1, 2, 3, 4, None, None] + + If *n* is ``None``, *fillvalue* will be emitted indefinitely. + + """ + it = iter(iterable) + if n is None: + yield from chain(it, repeat(fillvalue)) + elif n < 1: + raise ValueError('n must be at least 1') + else: + item_count = 0 + for item in it: + yield item + item_count += 1 + + remaining = (n - item_count) % n if next_multiple else n - item_count + for _ in range(remaining): + yield fillvalue + + +def repeat_each(iterable, n=2): + """Repeat each element in *iterable* *n* times. + + >>> list(repeat_each('ABC', 3)) + ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'] + """ + return chain.from_iterable(map(repeat, iterable, repeat(n))) + + +def repeat_last(iterable, default=None): + """After the *iterable* is exhausted, keep yielding its last element. + + >>> list(islice(repeat_last(range(3)), 5)) + [0, 1, 2, 2, 2] + + If the iterable is empty, yield *default* forever:: + + >>> list(islice(repeat_last(range(0), 42), 5)) + [42, 42, 42, 42, 42] + + """ + item = _marker + for item in iterable: + yield item + final = default if item is _marker else item + yield from repeat(final) + + +def distribute(n, iterable): + """Distribute the items from *iterable* among *n* smaller iterables. + + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + + If the length of *iterable* is smaller than *n*, then the last returned + iterables will be empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function uses :func:`itertools.tee` and may require significant + storage. If you need the order items in the smaller iterables to match the + original iterable, see :func:`divide`. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + children = tee(iterable, n) + return [islice(it, index, None, n) for index, it in enumerate(children)] + + +def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): + """Yield tuples whose elements are offset from *iterable*. + The amount by which the `i`-th item in each tuple is offset is given by + the `i`-th item in *offsets*. + + >>> list(stagger([0, 1, 2, 3])) + [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + >>> list(stagger(range(8), offsets=(0, 2, 4))) + [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)] + + By default, the sequence will end when the final element of a tuple is the + last item in the iterable. To continue until the first element of a tuple + is the last item in the iterable, set *longest* to ``True``:: + + >>> list(stagger([0, 1, 2, 3], longest=True)) + [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + children = tee(iterable, len(offsets)) + + return zip_offset( + *children, offsets=offsets, longest=longest, fillvalue=fillvalue + ) + + +class UnequalIterablesError(ValueError): + def __init__(self, details=None): + msg = 'Iterables have different lengths' + if details is not None: + msg += (': index 0 has length {}; index {} has length {}').format( + *details + ) + + super().__init__(msg) + + +def _zip_equal_generator(iterables): + for combo in zip_longest(*iterables, fillvalue=_marker): + for val in combo: + if val is _marker: + raise UnequalIterablesError() + yield combo + + +def _zip_equal(*iterables): + # Check whether the iterables are all the same size. + try: + first_size = len(iterables[0]) + for i, it in enumerate(iterables[1:], 1): + size = len(it) + if size != first_size: + break + else: + # If we didn't break out, we can use the built-in zip. + return zip(*iterables) + + # If we did break out, there was a mismatch. + raise UnequalIterablesError(details=(first_size, i, size)) + # If any one of the iterables didn't have a length, start reading + # them until one runs out. + except TypeError: + return _zip_equal_generator(iterables) + + +def zip_equal(*iterables): + """``zip`` the input *iterables* together, but raise + ``UnequalIterablesError`` if they aren't all the same length. + + >>> it_1 = range(3) + >>> it_2 = iter('abc') + >>> list(zip_equal(it_1, it_2)) + [(0, 'a'), (1, 'b'), (2, 'c')] + + >>> it_1 = range(3) + >>> it_2 = iter('abcd') + >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + more_itertools.more.UnequalIterablesError: Iterables have different + lengths + + """ + if hexversion >= 0x30A00A6: + warnings.warn( + ( + 'zip_equal will be removed in a future version of ' + 'more-itertools. Use the builtin zip function with ' + 'strict=True instead.' + ), + DeprecationWarning, + ) + + return _zip_equal(*iterables) + + +def zip_offset(*iterables, offsets, longest=False, fillvalue=None): + """``zip`` the input *iterables* together, but offset the `i`-th iterable + by the `i`-th item in *offsets*. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1))) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] + + This can be used as a lightweight alternative to SciPy or pandas to analyze + data sets in which some series have a lead or lag relationship. + + By default, the sequence will end when the shortest iterable is exhausted. + To continue until the longest iterable is exhausted, set *longest* to + ``True``. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True)) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + if len(iterables) != len(offsets): + raise ValueError("Number of iterables and offsets didn't match") + + staggered = [] + for it, n in zip(iterables, offsets): + if n < 0: + staggered.append(chain(repeat(fillvalue, -n), it)) + elif n > 0: + staggered.append(islice(it, n, None)) + else: + staggered.append(it) + + if longest: + return zip_longest(*staggered, fillvalue=fillvalue) + + return zip(*staggered) + + +def sort_together(iterables, key_list=(0,), key=None, reverse=False): + """Return the input iterables sorted together, with *key_list* as the + priority for sorting. All iterables are trimmed to the length of the + shortest one. + + This can be used like the sorting function in a spreadsheet. If each + iterable represents a column of data, the key list determines which + columns are used for sorting. + + By default, all iterables are sorted using the ``0``-th iterable:: + + >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')] + >>> sort_together(iterables) + [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] + + Set a different key list to sort according to another iterable. + Specifying multiple keys dictates how ties are broken:: + + >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] + >>> sort_together(iterables, key_list=(1, 2)) + [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + + To sort by a function of the elements of the iterable, pass a *key* + function. Its arguments are the elements of the iterables corresponding to + the key list:: + + >>> names = ('a', 'b', 'c') + >>> lengths = (1, 2, 3) + >>> widths = (5, 2, 1) + >>> def area(length, width): + ... return length * width + >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area) + [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)] + + Set *reverse* to ``True`` to sort in descending order. + + >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) + [(3, 2, 1), ('a', 'b', 'c')] + + """ + if key is None: + # if there is no key function, the key argument to sorted is an + # itemgetter + key_argument = itemgetter(*key_list) + else: + # if there is a key function, call it with the items at the offsets + # specified by the key function as arguments + key_list = list(key_list) + if len(key_list) == 1: + # if key_list contains a single item, pass the item at that offset + # as the only argument to the key function + key_offset = key_list[0] + key_argument = lambda zipped_items: key(zipped_items[key_offset]) + else: + # if key_list contains multiple items, use itemgetter to return a + # tuple of items, which we pass as *args to the key function + get_key_items = itemgetter(*key_list) + key_argument = lambda zipped_items: key( + *get_key_items(zipped_items) + ) + + return list( + zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse)) + ) + + +def unzip(iterable): + """The inverse of :func:`zip`, this function disaggregates the elements + of the zipped *iterable*. + + The ``i``-th iterable contains the ``i``-th element from each element + of the zipped iterable. The first element is used to to determine the + length of the remaining elements. + + >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> letters, numbers = unzip(iterable) + >>> list(letters) + ['a', 'b', 'c', 'd'] + >>> list(numbers) + [1, 2, 3, 4] + + This is similar to using ``zip(*iterable)``, but it avoids reading + *iterable* into memory. Note, however, that this function uses + :func:`itertools.tee` and thus may require significant storage. + + """ + head, iterable = spy(iter(iterable)) + if not head: + # empty iterable, e.g. zip([], [], []) + return () + # spy returns a one-length iterable as head + head = head[0] + iterables = tee(iterable, len(head)) + + def itemgetter(i): + def getter(obj): + try: + return obj[i] + except IndexError: + # basically if we have an iterable like + # iter([(1, 2, 3), (4, 5), (6,)]) + # the second unzipped iterable would fail at the third tuple + # since it would try to access tup[1] + # same with the third unzipped iterable and the second tuple + # to support these "improperly zipped" iterables, + # we create a custom itemgetter + # which just stops the unzipped iterables + # at first length mismatch + raise StopIteration + + return getter + + return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables)) + + +def divide(n, iterable): + """Divide the elements from *iterable* into *n* parts, maintaining + order. + + >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 2, 3] + >>> list(group_2) + [4, 5, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 2, 3], [4, 5], [6, 7]] + + If the length of the iterable is smaller than n, then the last returned + iterables will be empty: + + >>> children = divide(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function will exhaust the iterable before returning and may require + significant storage. If order is not important, see :func:`distribute`, + which does not first pull the iterable into memory. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + try: + iterable[:0] + except TypeError: + seq = tuple(iterable) + else: + seq = iterable + + q, r = divmod(len(seq), n) + + ret = [] + stop = 0 + for i in range(1, n + 1): + start = stop + stop += q + 1 if i <= r else q + ret.append(iter(seq[start:stop])) + + return ret + + +def always_iterable(obj, base_type=(str, bytes)): + """If *obj* is iterable, return an iterator over its items:: + + >>> obj = (1, 2, 3) + >>> list(always_iterable(obj)) + [1, 2, 3] + + If *obj* is not iterable, return a one-item iterable containing *obj*:: + + >>> obj = 1 + >>> list(always_iterable(obj)) + [1] + + If *obj* is ``None``, return an empty iterable: + + >>> obj = None + >>> list(always_iterable(None)) + [] + + By default, binary and text strings are not considered iterable:: + + >>> obj = 'foo' + >>> list(always_iterable(obj)) + ['foo'] + + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + >>> obj = {'a': 1} + >>> list(always_iterable(obj)) # Iterate over the dict's keys + ['a'] + >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit + [{'a': 1}] + + Set *base_type* to ``None`` to avoid any special handling and treat objects + Python considers iterable as iterable: + + >>> obj = 'foo' + >>> list(always_iterable(obj, base_type=None)) + ['f', 'o', 'o'] + """ + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + +def adjacent(predicate, iterable, distance=1): + """Return an iterable over `(bool, item)` tuples where the `item` is + drawn from *iterable* and the `bool` indicates whether + that item satisfies the *predicate* or is adjacent to an item that does. + + For example, to find whether items are adjacent to a ``3``:: + + >>> list(adjacent(lambda x: x == 3, range(6))) + [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)] + + Set *distance* to change what counts as adjacent. For example, to find + whether items are two places away from a ``3``: + + >>> list(adjacent(lambda x: x == 3, range(6), distance=2)) + [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)] + + This is useful for contextualizing the results of a search function. + For example, a code comparison tool might want to identify lines that + have changed, but also surrounding lines to give the viewer of the diff + context. + + The predicate function will only be called once for each item in the + iterable. + + See also :func:`groupby_transform`, which can be used with this function + to group ranges of items with the same `bool` value. + + """ + # Allow distance=0 mainly for testing that it reproduces results with map() + if distance < 0: + raise ValueError('distance must be at least 0') + + i1, i2 = tee(iterable) + padding = [False] * distance + selected = chain(padding, map(predicate, i1), padding) + adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1)) + return zip(adjacent_to_selected, i2) + + +def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): + """An extension of :func:`itertools.groupby` that can apply transformations + to the grouped data. + + * *keyfunc* is a function computing a key value for each item in *iterable* + * *valuefunc* is a function that transforms the individual items from + *iterable* after grouping + * *reducefunc* is a function that transforms each group of items + + >>> iterable = 'aAAbBBcCC' + >>> keyfunc = lambda k: k.upper() + >>> valuefunc = lambda v: v.lower() + >>> reducefunc = lambda g: ''.join(g) + >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc)) + [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')] + + Each optional argument defaults to an identity function if not specified. + + :func:`groupby_transform` is useful when grouping elements of an iterable + using a separate iterable as the key. To do this, :func:`zip` the iterables + and pass a *keyfunc* that extracts the first element and a *valuefunc* + that extracts the second element:: + + >>> from operator import itemgetter + >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3] + >>> values = 'abcdefghi' + >>> iterable = zip(keys, values) + >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1)) + >>> [(k, ''.join(g)) for k, g in grouper] + [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')] + + Note that the order of items in the iterable is significant. + Only adjacent items are grouped together, so if you don't want any + duplicate groups, you should sort the iterable by the key function. + + """ + ret = groupby(iterable, keyfunc) + if valuefunc: + ret = ((k, map(valuefunc, g)) for k, g in ret) + if reducefunc: + ret = ((k, reducefunc(g)) for k, g in ret) + + return ret + + +class numeric_range(abc.Sequence, abc.Hashable): + """An extension of the built-in ``range()`` function whose arguments can + be any orderable numeric type. + + With only *stop* specified, *start* defaults to ``0`` and *step* + defaults to ``1``. The output items will match the type of *stop*: + + >>> list(numeric_range(3.5)) + [0.0, 1.0, 2.0, 3.0] + + With only *start* and *stop* specified, *step* defaults to ``1``. The + output items will match the type of *start*: + + >>> from decimal import Decimal + >>> start = Decimal('2.1') + >>> stop = Decimal('5.1') + >>> list(numeric_range(start, stop)) + [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')] + + With *start*, *stop*, and *step* specified the output items will match + the type of ``start + step``: + + >>> from fractions import Fraction + >>> start = Fraction(1, 2) # Start at 1/2 + >>> stop = Fraction(5, 2) # End at 5/2 + >>> step = Fraction(1, 2) # Count by 1/2 + >>> list(numeric_range(start, stop, step)) + [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)] + + If *step* is zero, ``ValueError`` is raised. Negative steps are supported: + + >>> list(numeric_range(3, -1, -1.0)) + [3.0, 2.0, 1.0, 0.0] + + Be aware of the limitations of floating point numbers; the representation + of the yielded numbers may be surprising. + + ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* + is a ``datetime.timedelta`` object: + + >>> import datetime + >>> start = datetime.datetime(2019, 1, 1) + >>> stop = datetime.datetime(2019, 1, 3) + >>> step = datetime.timedelta(days=1) + >>> items = iter(numeric_range(start, stop, step)) + >>> next(items) + datetime.datetime(2019, 1, 1, 0, 0) + >>> next(items) + datetime.datetime(2019, 1, 2, 0, 0) + + """ + + _EMPTY_HASH = hash(range(0, 0)) + + def __init__(self, *args): + argc = len(args) + if argc == 1: + (self._stop,) = args + self._start = type(self._stop)(0) + self._step = type(self._stop - self._start)(1) + elif argc == 2: + self._start, self._stop = args + self._step = type(self._stop - self._start)(1) + elif argc == 3: + self._start, self._stop, self._step = args + elif argc == 0: + raise TypeError( + 'numeric_range expected at least ' + '1 argument, got {}'.format(argc) + ) + else: + raise TypeError( + 'numeric_range expected at most ' + '3 arguments, got {}'.format(argc) + ) + + self._zero = type(self._step)(0) + if self._step == self._zero: + raise ValueError('numeric_range() arg 3 must not be zero') + self._growing = self._step > self._zero + self._init_len() + + def __bool__(self): + if self._growing: + return self._start < self._stop + else: + return self._start > self._stop + + def __contains__(self, elem): + if self._growing: + if self._start <= elem < self._stop: + return (elem - self._start) % self._step == self._zero + else: + if self._start >= elem > self._stop: + return (self._start - elem) % (-self._step) == self._zero + + return False + + def __eq__(self, other): + if isinstance(other, numeric_range): + empty_self = not bool(self) + empty_other = not bool(other) + if empty_self or empty_other: + return empty_self and empty_other # True if both empty + else: + return ( + self._start == other._start + and self._step == other._step + and self._get_by_index(-1) == other._get_by_index(-1) + ) + else: + return False + + def __getitem__(self, key): + if isinstance(key, int): + return self._get_by_index(key) + elif isinstance(key, slice): + step = self._step if key.step is None else key.step * self._step + + if key.start is None or key.start <= -self._len: + start = self._start + elif key.start >= self._len: + start = self._stop + else: # -self._len < key.start < self._len + start = self._get_by_index(key.start) + + if key.stop is None or key.stop >= self._len: + stop = self._stop + elif key.stop <= -self._len: + stop = self._start + else: # -self._len < key.stop < self._len + stop = self._get_by_index(key.stop) + + return numeric_range(start, stop, step) + else: + raise TypeError( + 'numeric range indices must be ' + 'integers or slices, not {}'.format(type(key).__name__) + ) + + def __hash__(self): + if self: + return hash((self._start, self._get_by_index(-1), self._step)) + else: + return self._EMPTY_HASH + + def __iter__(self): + values = (self._start + (n * self._step) for n in count()) + if self._growing: + return takewhile(partial(gt, self._stop), values) + else: + return takewhile(partial(lt, self._stop), values) + + def __len__(self): + return self._len + + def _init_len(self): + if self._growing: + start = self._start + stop = self._stop + step = self._step + else: + start = self._stop + stop = self._start + step = -self._step + distance = stop - start + if distance <= self._zero: + self._len = 0 + else: # distance > 0 and step > 0: regular euclidean division + q, r = divmod(distance, step) + self._len = int(q) + int(r != self._zero) + + def __reduce__(self): + return numeric_range, (self._start, self._stop, self._step) + + def __repr__(self): + if self._step == 1: + return "numeric_range({}, {})".format( + repr(self._start), repr(self._stop) + ) + else: + return "numeric_range({}, {}, {})".format( + repr(self._start), repr(self._stop), repr(self._step) + ) + + def __reversed__(self): + return iter( + numeric_range( + self._get_by_index(-1), self._start - self._step, -self._step + ) + ) + + def count(self, value): + return int(value in self) + + def index(self, value): + if self._growing: + if self._start <= value < self._stop: + q, r = divmod(value - self._start, self._step) + if r == self._zero: + return int(q) + else: + if self._start >= value > self._stop: + q, r = divmod(self._start - value, -self._step) + if r == self._zero: + return int(q) + + raise ValueError("{} is not in numeric range".format(value)) + + def _get_by_index(self, i): + if i < 0: + i += self._len + if i < 0 or i >= self._len: + raise IndexError("numeric range object index out of range") + return self._start + i * self._step + + +def count_cycle(iterable, n=None): + """Cycle through the items from *iterable* up to *n* times, yielding + the number of completed cycles along with each item. If *n* is omitted the + process repeats indefinitely. + + >>> list(count_cycle('AB', 3)) + [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')] + + """ + iterable = tuple(iterable) + if not iterable: + return iter(()) + counter = count() if n is None else range(n) + return ((i, item) for i in counter for item in iterable) + + +def mark_ends(iterable): + """Yield 3-tuples of the form ``(is_first, is_last, item)``. + + >>> list(mark_ends('ABC')) + [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')] + + Use this when looping over an iterable to take special action on its first + and/or last items: + + >>> iterable = ['Header', 100, 200, 'Footer'] + >>> total = 0 + >>> for is_first, is_last, item in mark_ends(iterable): + ... if is_first: + ... continue # Skip the header + ... if is_last: + ... continue # Skip the footer + ... total += item + >>> print(total) + 300 + """ + it = iter(iterable) + + try: + b = next(it) + except StopIteration: + return + + try: + for i in count(): + a = b + b = next(it) + yield i == 0, False, a + + except StopIteration: + yield i == 0, True, a + + +def locate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(locate([0, 1, 1, 0, 1, 0, 0])) + [1, 2, 4] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item. + + >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b')) + [1, 3] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(locate(iterable, pred=pred, window_size=3)) + [1, 5, 9] + + Use with :func:`seekable` to find indexes and then retrieve the associated + items: + + >>> from itertools import count + >>> from more_itertools import seekable + >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count()) + >>> it = seekable(source) + >>> pred = lambda x: x > 100 + >>> indexes = locate(it, pred=pred) + >>> i = next(indexes) + >>> it.seek(i) + >>> next(it) + 106 + + """ + if window_size is None: + return compress(count(), map(pred, iterable)) + + if window_size < 1: + raise ValueError('window size must be at least 1') + + it = windowed(iterable, window_size, fillvalue=_marker) + return compress(count(), starmap(pred, it)) + + +def lstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the beginning + for which *pred* returns ``True``. + + For example, to remove a set of items from the start of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(lstrip(iterable, pred)) + [1, 2, None, 3, False, None] + + This function is analogous to to :func:`str.lstrip`, and is essentially + an wrapper for :func:`itertools.dropwhile`. + + """ + return dropwhile(pred, iterable) + + +def rstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the end + for which *pred* returns ``True``. + + For example, to remove a set of items from the end of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(rstrip(iterable, pred)) + [None, False, None, 1, 2, None, 3] + + This function is analogous to :func:`str.rstrip`. + + """ + cache = [] + cache_append = cache.append + cache_clear = cache.clear + for x in iterable: + if pred(x): + cache_append(x) + else: + yield from cache + cache_clear() + yield x + + +def strip(iterable, pred): + """Yield the items from *iterable*, but strip any from the + beginning and end for which *pred* returns ``True``. + + For example, to remove a set of items from both ends of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(strip(iterable, pred)) + [1, 2, None, 3] + + This function is analogous to :func:`str.strip`. + + """ + return rstrip(lstrip(iterable, pred), pred) + + +class islice_extended: + """An extension of :func:`itertools.islice` that supports negative values + for *stop*, *start*, and *step*. + + >>> iterable = iter('abcdefgh') + >>> list(islice_extended(iterable, -4, -1)) + ['e', 'f', 'g'] + + Slices with negative values require some caching of *iterable*, but this + function takes care to minimize the amount of memory required. + + For example, you can use a negative step with an infinite iterator: + + >>> from itertools import count + >>> list(islice_extended(count(), 110, 99, -2)) + [110, 108, 106, 104, 102, 100] + + You can also use slice notation directly: + + >>> iterable = map(str, count()) + >>> it = islice_extended(iterable)[10:20:2] + >>> list(it) + ['10', '12', '14', '16', '18'] + + """ + + def __init__(self, iterable, *args): + it = iter(iterable) + if args: + self._iterable = _islice_helper(it, slice(*args)) + else: + self._iterable = it + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterable) + + def __getitem__(self, key): + if isinstance(key, slice): + return islice_extended(_islice_helper(self._iterable, key)) + + raise TypeError('islice_extended.__getitem__ argument must be a slice') + + +def _islice_helper(it, s): + start = s.start + stop = s.stop + if s.step == 0: + raise ValueError('step argument must be a non-zero integer or None.') + step = s.step or 1 + + if step > 0: + start = 0 if (start is None) else start + + if start < 0: + # Consume all but the last -start items + cache = deque(enumerate(it, 1), maxlen=-start) + len_iter = cache[-1][0] if cache else 0 + + # Adjust start to be positive + i = max(len_iter + start, 0) + + # Adjust stop to be positive + if stop is None: + j = len_iter + elif stop >= 0: + j = min(stop, len_iter) + else: + j = max(len_iter + stop, 0) + + # Slice the cache + n = j - i + if n <= 0: + return + + for index, item in islice(cache, 0, n, step): + yield item + elif (stop is not None) and (stop < 0): + # Advance to the start position + next(islice(it, start, start), None) + + # When stop is negative, we have to carry -stop items while + # iterating + cache = deque(islice(it, -stop), maxlen=-stop) + + for index, item in enumerate(it): + cached_item = cache.popleft() + if index % step == 0: + yield cached_item + cache.append(item) + else: + # When both start and stop are positive we have the normal case + yield from islice(it, start, stop, step) + else: + start = -1 if (start is None) else start + + if (stop is not None) and (stop < 0): + # Consume all but the last items + n = -stop - 1 + cache = deque(enumerate(it, 1), maxlen=n) + len_iter = cache[-1][0] if cache else 0 + + # If start and stop are both negative they are comparable and + # we can just slice. Otherwise we can adjust start to be negative + # and then slice. + if start < 0: + i, j = start, stop + else: + i, j = min(start - len_iter, -1), None + + for index, item in list(cache)[i:j:step]: + yield item + else: + # Advance to the stop position + if stop is not None: + m = stop + 1 + next(islice(it, m, m), None) + + # stop is positive, so if start is negative they are not comparable + # and we need the rest of the items. + if start < 0: + i = start + n = None + # stop is None and start is positive, so we just need items up to + # the start index. + elif stop is None: + i = None + n = start + 1 + # Both stop and start are positive, so they are comparable. + else: + i = None + n = start - stop + if n <= 0: + return + + cache = list(islice(it, n)) + + yield from cache[i::step] + + +def always_reversible(iterable): + """An extension of :func:`reversed` that supports all iterables, not + just those which implement the ``Reversible`` or ``Sequence`` protocols. + + >>> print(*always_reversible(x for x in range(3))) + 2 1 0 + + If the iterable is already reversible, this function returns the + result of :func:`reversed()`. If the iterable is not reversible, + this function will cache the remaining items in the iterable and + yield them in reverse order, which may require significant storage. + """ + try: + return reversed(iterable) + except TypeError: + return reversed(list(iterable)) + + +def consecutive_groups(iterable, ordering=lambda x: x): + """Yield groups of consecutive items using :func:`itertools.groupby`. + The *ordering* function determines whether two items are adjacent by + returning their position. + + By default, the ordering function is the identity function. This is + suitable for finding runs of numbers: + + >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40] + >>> for group in consecutive_groups(iterable): + ... print(list(group)) + [1] + [10, 11, 12] + [20] + [30, 31, 32, 33] + [40] + + For finding runs of adjacent letters, try using the :meth:`index` method + of a string of letters: + + >>> from string import ascii_lowercase + >>> iterable = 'abcdfgilmnop' + >>> ordering = ascii_lowercase.index + >>> for group in consecutive_groups(iterable, ordering): + ... print(list(group)) + ['a', 'b', 'c', 'd'] + ['f', 'g'] + ['i'] + ['l', 'm', 'n', 'o', 'p'] + + Each group of consecutive items is an iterator that shares it source with + *iterable*. When an an output group is advanced, the previous group is + no longer available unless its elements are copied (e.g., into a ``list``). + + >>> iterable = [1, 2, 11, 12, 21, 22] + >>> saved_groups = [] + >>> for group in consecutive_groups(iterable): + ... saved_groups.append(list(group)) # Copy group elements + >>> saved_groups + [[1, 2], [11, 12], [21, 22]] + + """ + for k, g in groupby( + enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) + ): + yield map(itemgetter(1), g) + + +def difference(iterable, func=sub, *, initial=None): + """This function is the inverse of :func:`itertools.accumulate`. By default + it will compute the first difference of *iterable* using + :func:`operator.sub`: + + >>> from itertools import accumulate + >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10 + >>> list(difference(iterable)) + [0, 1, 2, 3, 4] + + *func* defaults to :func:`operator.sub`, but other functions can be + specified. They will be applied as follows:: + + A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... + + For example, to do progressive division: + + >>> iterable = [1, 2, 6, 24, 120] + >>> func = lambda x, y: x // y + >>> list(difference(iterable, func)) + [1, 2, 3, 4, 5] + + If the *initial* keyword is set, the first element will be skipped when + computing successive differences. + + >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10) + >>> list(difference(it, initial=10)) + [1, 2, 3] + + """ + a, b = tee(iterable) + try: + first = [next(b)] + except StopIteration: + return iter([]) + + if initial is not None: + first = [] + + return chain(first, starmap(func, zip(b, a))) + + +class SequenceView(Sequence): + """Return a read-only view of the sequence object *target*. + + :class:`SequenceView` objects are analogous to Python's built-in + "dictionary view" types. They provide a dynamic view of a sequence's items, + meaning that when the sequence updates, so does the view. + + >>> seq = ['0', '1', '2'] + >>> view = SequenceView(seq) + >>> view + SequenceView(['0', '1', '2']) + >>> seq.append('3') + >>> view + SequenceView(['0', '1', '2', '3']) + + Sequence views support indexing, slicing, and length queries. They act + like the underlying sequence, except they don't allow assignment: + + >>> view[1] + '1' + >>> view[1:-1] + ['1', '2'] + >>> len(view) + 4 + + Sequence views are useful as an alternative to copying, as they don't + require (much) extra storage. + + """ + + def __init__(self, target): + if not isinstance(target, Sequence): + raise TypeError + self._target = target + + def __getitem__(self, index): + return self._target[index] + + def __len__(self): + return len(self._target) + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, repr(self._target)) + + +class seekable: + """Wrap an iterator to allow for seeking backward and forward. This + progressively caches the items in the source iterable so they can be + re-visited. + + Call :meth:`seek` with an index to seek to that position in the source + iterable. + + To "reset" an iterator, seek to ``0``: + + >>> from itertools import count + >>> it = seekable((str(n) for n in count())) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> it.seek(0) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> next(it) + '3' + + You can also seek forward: + + >>> it = seekable((str(n) for n in range(20))) + >>> it.seek(10) + >>> next(it) + '10' + >>> it.seek(20) # Seeking past the end of the source isn't a problem + >>> list(it) + [] + >>> it.seek(0) # Resetting works even after hitting the end + >>> next(it), next(it), next(it) + ('0', '1', '2') + + Call :meth:`peek` to look ahead one item without advancing the iterator: + + >>> it = seekable('1234') + >>> it.peek() + '1' + >>> list(it) + ['1', '2', '3', '4'] + >>> it.peek(default='empty') + 'empty' + + Before the iterator is at its end, calling :func:`bool` on it will return + ``True``. After it will return ``False``: + + >>> it = seekable('5678') + >>> bool(it) + True + >>> list(it) + ['5', '6', '7', '8'] + >>> bool(it) + False + + You may view the contents of the cache with the :meth:`elements` method. + That returns a :class:`SequenceView`, a view that updates automatically: + + >>> it = seekable((str(n) for n in range(10))) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> elements = it.elements() + >>> elements + SequenceView(['0', '1', '2']) + >>> next(it) + '3' + >>> elements + SequenceView(['0', '1', '2', '3']) + + By default, the cache grows as the source iterable progresses, so beware of + wrapping very large or infinite iterables. Supply *maxlen* to limit the + size of the cache (this of course limits how far back you can seek). + + >>> from itertools import count + >>> it = seekable((str(n) for n in count()), maxlen=2) + >>> next(it), next(it), next(it), next(it) + ('0', '1', '2', '3') + >>> list(it.elements()) + ['2', '3'] + >>> it.seek(0) + >>> next(it), next(it), next(it), next(it) + ('2', '3', '4', '5') + >>> next(it) + '6' + + """ + + def __init__(self, iterable, maxlen=None): + self._source = iter(iterable) + if maxlen is None: + self._cache = [] + else: + self._cache = deque([], maxlen) + self._index = None + + def __iter__(self): + return self + + def __next__(self): + if self._index is not None: + try: + item = self._cache[self._index] + except IndexError: + self._index = None + else: + self._index += 1 + return item + + item = next(self._source) + self._cache.append(item) + return item + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + try: + peeked = next(self) + except StopIteration: + if default is _marker: + raise + return default + if self._index is None: + self._index = len(self._cache) + self._index -= 1 + return peeked + + def elements(self): + return SequenceView(self._cache) + + def seek(self, index): + self._index = index + remainder = index - len(self._cache) + if remainder > 0: + consume(self, remainder) + + +class run_length: + """ + :func:`run_length.encode` compresses an iterable with run-length encoding. + It yields groups of repeated items with the count of how many times they + were repeated: + + >>> uncompressed = 'abbcccdddd' + >>> list(run_length.encode(uncompressed)) + [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + + :func:`run_length.decode` decompresses an iterable that was previously + compressed with run-length encoding. It yields the items of the + decompressed iterable: + + >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> list(run_length.decode(compressed)) + ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd'] + + """ + + @staticmethod + def encode(iterable): + return ((k, ilen(g)) for k, g in groupby(iterable)) + + @staticmethod + def decode(iterable): + return chain.from_iterable(repeat(k, n) for k, n in iterable) + + +def exactly_n(iterable, n, predicate=bool): + """Return ``True`` if exactly ``n`` items in the iterable are ``True`` + according to the *predicate* function. + + >>> exactly_n([True, True, False], 2) + True + >>> exactly_n([True, True, False], 1) + False + >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3) + True + + The iterable will be advanced until ``n + 1`` truthy items are encountered, + so avoid calling it on infinite iterables. + + """ + return len(take(n + 1, filter(predicate, iterable))) == n + + +def circular_shifts(iterable): + """Return a list of circular shifts of *iterable*. + + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + """ + lst = list(iterable) + return take(len(lst), windowed(cycle(lst), len(lst))) + + +def make_decorator(wrapping_func, result_index=0): + """Return a decorator version of *wrapping_func*, which is a function that + modifies an iterable. *result_index* is the position in that function's + signature where the iterable goes. + + This lets you use itertools on the "production end," i.e. at function + definition. This can augment what the function returns without changing the + function's code. + + For example, to produce a decorator version of :func:`chunked`: + + >>> from more_itertools import chunked + >>> chunker = make_decorator(chunked, result_index=0) + >>> @chunker(3) + ... def iter_range(n): + ... return iter(range(n)) + ... + >>> list(iter_range(9)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + To only allow truthy items to be returned: + + >>> truth_serum = make_decorator(filter, result_index=1) + >>> @truth_serum(bool) + ... def boolean_test(): + ... return [0, 1, '', ' ', False, True] + ... + >>> list(boolean_test()) + [1, ' ', True] + + The :func:`peekable` and :func:`seekable` wrappers make for practical + decorators: + + >>> from more_itertools import peekable + >>> peekable_function = make_decorator(peekable) + >>> @peekable_function() + ... def str_range(*args): + ... return (str(x) for x in range(*args)) + ... + >>> it = str_range(1, 20, 2) + >>> next(it), next(it), next(it) + ('1', '3', '5') + >>> it.peek() + '7' + >>> next(it) + '7' + + """ + # See https://sites.google.com/site/bbayles/index/decorator_factory for + # notes on how this works. + def decorator(*wrapping_args, **wrapping_kwargs): + def outer_wrapper(f): + def inner_wrapper(*args, **kwargs): + result = f(*args, **kwargs) + wrapping_args_ = list(wrapping_args) + wrapping_args_.insert(result_index, result) + return wrapping_func(*wrapping_args_, **wrapping_kwargs) + + return inner_wrapper + + return outer_wrapper + + return decorator + + +def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None): + """Return a dictionary that maps the items in *iterable* to categories + defined by *keyfunc*, transforms them with *valuefunc*, and + then summarizes them by category with *reducefunc*. + + *valuefunc* defaults to the identity function if it is unspecified. + If *reducefunc* is unspecified, no summarization takes place: + + >>> keyfunc = lambda x: x.upper() + >>> result = map_reduce('abbccc', keyfunc) + >>> sorted(result.items()) + [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])] + + Specifying *valuefunc* transforms the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> result = map_reduce('abbccc', keyfunc, valuefunc) + >>> sorted(result.items()) + [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])] + + Specifying *reducefunc* summarizes the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> reducefunc = sum + >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc) + >>> sorted(result.items()) + [('A', 1), ('B', 2), ('C', 3)] + + You may want to filter the input iterable before applying the map/reduce + procedure: + + >>> all_items = range(30) + >>> items = [x for x in all_items if 10 <= x <= 20] # Filter + >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1 + >>> categories = map_reduce(items, keyfunc=keyfunc) + >>> sorted(categories.items()) + [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])] + >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum) + >>> sorted(summaries.items()) + [(0, 90), (1, 75)] + + Note that all items in the iterable are gathered into a list before the + summarization step, which may require significant storage. + + The returned object is a :obj:`collections.defaultdict` with the + ``default_factory`` set to ``None``, such that it behaves like a normal + dictionary. + + """ + valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc + + ret = defaultdict(list) + for item in iterable: + key = keyfunc(item) + value = valuefunc(item) + ret[key].append(value) + + if reducefunc is not None: + for key, value_list in ret.items(): + ret[key] = reducefunc(value_list) + + ret.default_factory = None + return ret + + +def rlocate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``, starting from the right and moving left. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4 + [4, 2, 1] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item: + + >>> iterable = iter('abcb') + >>> pred = lambda x: x == 'b' + >>> list(rlocate(iterable, pred)) + [3, 1] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(rlocate(iterable, pred=pred, window_size=3)) + [9, 5, 1] + + Beware, this function won't return anything for infinite iterables. + If *iterable* is reversible, ``rlocate`` will reverse it and search from + the right. Otherwise, it will search from the left and return the results + in reverse order. + + See :func:`locate` to for other example applications. + + """ + if window_size is None: + try: + len_iter = len(iterable) + return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) + except TypeError: + pass + + return reversed(list(locate(iterable, pred, window_size))) + + +def replace(iterable, pred, substitutes, count=None, window_size=1): + """Yield the items from *iterable*, replacing the items for which *pred* + returns ``True`` with the items from the iterable *substitutes*. + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1] + >>> pred = lambda x: x == 0 + >>> substitutes = (2, 3) + >>> list(replace(iterable, pred, substitutes)) + [1, 1, 2, 3, 1, 1, 2, 3, 1, 1] + + If *count* is given, the number of replacements will be limited: + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0] + >>> pred = lambda x: x == 0 + >>> substitutes = [None] + >>> list(replace(iterable, pred, substitutes, count=2)) + [1, 1, None, 1, 1, None, 1, 1, 0] + + Use *window_size* to control the number of items passed as arguments to + *pred*. This allows for locating and replacing subsequences. + + >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5] + >>> window_size = 3 + >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred + >>> substitutes = [3, 4] # Splice in these items + >>> list(replace(iterable, pred, substitutes, window_size=window_size)) + [3, 4, 5, 3, 4, 5] + + """ + if window_size < 1: + raise ValueError('window_size must be at least 1') + + # Save the substitutes iterable, since it's used more than once + substitutes = tuple(substitutes) + + # Add padding such that the number of windows matches the length of the + # iterable + it = chain(iterable, [_marker] * (window_size - 1)) + windows = windowed(it, window_size) + + n = 0 + for w in windows: + # If the current window matches our predicate (and we haven't hit + # our maximum number of replacements), splice in the substitutes + # and then consume the following windows that overlap with this one. + # For example, if the iterable is (0, 1, 2, 3, 4...) + # and the window size is 2, we have (0, 1), (1, 2), (2, 3)... + # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2) + if pred(*w): + if (count is None) or (n < count): + n += 1 + yield from substitutes + consume(windows, window_size - 1) + continue + + # If there was no match (or we've reached the replacement limit), + # yield the first item from the window. + if w and (w[0] is not _marker): + yield w[0] + + +def partitions(iterable): + """Yield all possible order-preserving partitions of *iterable*. + + >>> iterable = 'abc' + >>> for part in partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['a', 'b', 'c'] + + This is unrelated to :func:`partition`. + + """ + sequence = list(iterable) + n = len(sequence) + for i in powerset(range(1, n)): + yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] + + +def set_partitions(iterable, k=None): + """ + Yield the set partitions of *iterable* into *k* parts. Set partitions are + not order-preserving. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, 2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + + + If *k* is not given, every set partition is generated. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + + """ + L = list(iterable) + n = len(L) + if k is not None: + if k < 1: + raise ValueError( + "Can't partition in a negative or zero number of groups" + ) + elif k > n: + return + + def set_partitions_helper(L, k): + n = len(L) + if k == 1: + yield [L] + elif n == k: + yield [[s] for s in L] + else: + e, *M = L + for p in set_partitions_helper(M, k - 1): + yield [[e], *p] + for p in set_partitions_helper(M, k): + for i in range(len(p)): + yield p[:i] + [[e] + p[i]] + p[i + 1 :] + + if k is None: + for k in range(1, n + 1): + yield from set_partitions_helper(L, k) + else: + yield from set_partitions_helper(L, k) + + +class time_limited: + """ + Yield items from *iterable* until *limit_seconds* have passed. + If the time limit expires before all items have been yielded, the + ``timed_out`` parameter will be set to ``True``. + + >>> from time import sleep + >>> def generator(): + ... yield 1 + ... yield 2 + ... sleep(0.2) + ... yield 3 + >>> iterable = time_limited(0.1, generator()) + >>> list(iterable) + [1, 2] + >>> iterable.timed_out + True + + Note that the time is checked before each item is yielded, and iteration + stops if the time elapsed is greater than *limit_seconds*. If your time + limit is 1 second, but it takes 2 seconds to generate the first item from + the iterable, the function will run for 2 seconds and not yield anything. + + """ + + def __init__(self, limit_seconds, iterable): + if limit_seconds < 0: + raise ValueError('limit_seconds must be positive') + self.limit_seconds = limit_seconds + self._iterable = iter(iterable) + self._start_time = monotonic() + self.timed_out = False + + def __iter__(self): + return self + + def __next__(self): + item = next(self._iterable) + if monotonic() - self._start_time > self.limit_seconds: + self.timed_out = True + raise StopIteration + + return item + + +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + first_value = next(it, default) + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def ichunked(iterable, n): + """Break *iterable* into sub-iterables with *n* elements each. + :func:`ichunked` is like :func:`chunked`, but it yields iterables + instead of lists. + + If the sub-iterables are read in order, the elements of *iterable* + won't be stored in memory. + If they are read out of order, :func:`itertools.tee` is used to cache + elements as necessary. + + >>> from itertools import count + >>> all_chunks = ichunked(count(), 4) + >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) + >>> list(c_2) # c_1's elements have been cached; c_3's haven't been + [4, 5, 6, 7] + >>> list(c_1) + [0, 1, 2, 3] + >>> list(c_3) + [8, 9, 10, 11] + + """ + source = iter(iterable) + + while True: + # Check to see whether we're at the end of the source iterable + item = next(source, _marker) + if item is _marker: + return + + # Clone the source and yield an n-length slice + source, it = tee(chain([item], source)) + yield islice(it, n) + + # Advance the source iterable + consume(source, n) + + +def distinct_combinations(iterable, r): + """Yield the distinct combinations of *r* items taken from *iterable*. + + >>> list(distinct_combinations([0, 0, 1], 2)) + [(0, 0), (0, 1)] + + Equivalent to ``set(combinations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + """ + if r < 0: + raise ValueError('r must be non-negative') + elif r == 0: + yield () + return + pool = tuple(iterable) + generators = [unique_everseen(enumerate(pool), key=itemgetter(1))] + current_combo = [None] * r + level = 0 + while generators: + try: + cur_idx, p = next(generators[-1]) + except StopIteration: + generators.pop() + level -= 1 + continue + current_combo[level] = p + if level + 1 == r: + yield tuple(current_combo) + else: + generators.append( + unique_everseen( + enumerate(pool[cur_idx + 1 :], cur_idx + 1), + key=itemgetter(1), + ) + ) + level += 1 + + +def filter_except(validator, iterable, *exceptions): + """Yield the items from *iterable* for which the *validator* function does + not raise one of the specified *exceptions*. + + *validator* is called for each item in *iterable*. + It should be a function that accepts one argument and raises an exception + if that item is not valid. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(filter_except(int, iterable, ValueError, TypeError)) + ['1', '2', '4'] + + If an exception other than one given by *exceptions* is raised by + *validator*, it is raised like normal. + """ + for item in iterable: + try: + validator(item) + except exceptions: + pass + else: + yield item + + +def map_except(function, iterable, *exceptions): + """Transform each item from *iterable* with *function* and yield the + result, unless *function* raises one of the specified *exceptions*. + + *function* is called to transform each item in *iterable*. + It should accept one argument. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(map_except(int, iterable, ValueError, TypeError)) + [1, 2, 4] + + If an exception other than one given by *exceptions* is raised by + *function*, it is raised like normal. + """ + for item in iterable: + try: + yield function(item) + except exceptions: + pass + + +def map_if(iterable, pred, func, func_else=lambda x: x): + """Evaluate each item from *iterable* using *pred*. If the result is + equivalent to ``True``, transform the item with *func* and yield it. + Otherwise, transform the item with *func_else* and yield it. + + *pred*, *func*, and *func_else* should each be functions that accept + one argument. By default, *func_else* is the identity function. + + >>> from math import sqrt + >>> iterable = list(range(-5, 5)) + >>> iterable + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] + >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig')) + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig'] + >>> list(map_if(iterable, lambda x: x >= 0, + ... lambda x: f'{sqrt(x):.2f}', lambda x: None)) + [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00'] + """ + for item in iterable: + yield func(item) if pred(item) else func_else(item) + + +def _sample_unweighted(iterable, k): + # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: + # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". + + # Fill up the reservoir (collection of samples) with the first `k` samples + reservoir = take(k, iterable) + + # Generate random number that's the largest in a sample of k U(0,1) numbers + # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic + W = exp(log(random()) / k) + + # The number of elements to skip before changing the reservoir is a random + # number with a geometric distribution. Sample it using random() and logs. + next_index = k + floor(log(random()) / log(1 - W)) + + for index, element in enumerate(iterable, k): + + if index == next_index: + reservoir[randrange(k)] = element + # The new W is the largest in a sample of k U(0, `old_W`) numbers + W *= exp(log(random()) / k) + next_index += floor(log(random()) / log(1 - W)) + 1 + + return reservoir + + +def _sample_weighted(iterable, k, weights): + # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : + # "Weighted random sampling with a reservoir". + + # Log-transform for numerical stability for weights that are small/large + weight_keys = (log(random()) / weight for weight in weights) + + # Fill up the reservoir (collection of samples) with the first `k` + # weight-keys and elements, then heapify the list. + reservoir = take(k, zip(weight_keys, iterable)) + heapify(reservoir) + + # The number of jumps before changing the reservoir is a random variable + # with an exponential distribution. Sample it using random() and logs. + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + + for weight, element in zip(weights, iterable): + if weight >= weights_to_skip: + # The notation here is consistent with the paper, but we store + # the weight-keys in log-space for better numerical stability. + smallest_weight_key, _ = reservoir[0] + t_w = exp(weight * smallest_weight_key) + r_2 = uniform(t_w, 1) # generate U(t_w, 1) + weight_key = log(r_2) / weight + heapreplace(reservoir, (weight_key, element)) + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + else: + weights_to_skip -= weight + + # Equivalent to [element for weight_key, element in sorted(reservoir)] + return [heappop(reservoir)[1] for _ in range(k)] + + +def sample(iterable, k, weights=None): + """Return a *k*-length list of elements chosen (without replacement) + from the *iterable*. Like :func:`random.sample`, but works on iterables + of unknown length. + + >>> iterable = range(100) + >>> sample(iterable, 5) # doctest: +SKIP + [81, 60, 96, 16, 4] + + An iterable with *weights* may also be given: + + >>> iterable = range(100) + >>> weights = (i * i + 1 for i in range(100)) + >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP + [79, 67, 74, 66, 78] + + The algorithm can also be used to generate weighted random permutations. + The relative weight of each item determines the probability that it + appears late in the permutation. + + >>> data = "abcdefgh" + >>> weights = range(1, len(data) + 1) + >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP + ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f'] + """ + if k == 0: + return [] + + iterable = iter(iterable) + if weights is None: + return _sample_unweighted(iterable, k) + else: + weights = iter(weights) + return _sample_weighted(iterable, k, weights) + + +def is_sorted(iterable, key=None, reverse=False, strict=False): + """Returns ``True`` if the items of iterable are in sorted order, and + ``False`` otherwise. *key* and *reverse* have the same meaning that they do + in the built-in :func:`sorted` function. + + >>> is_sorted(['1', '2', '3', '4', '5'], key=int) + True + >>> is_sorted([5, 4, 3, 1, 2], reverse=True) + False + + If *strict*, tests for strict sorting, that is, returns ``False`` if equal + elements are found: + + >>> is_sorted([1, 2, 2]) + True + >>> is_sorted([1, 2, 2], strict=True) + False + + The function returns ``False`` after encountering the first out-of-order + item. If there are no out-of-order items, the iterable is exhausted. + """ + + compare = (le if reverse else ge) if strict else (lt if reverse else gt) + it = iterable if key is None else map(key, iterable) + return not any(starmap(compare, pairwise(it))) + + +class AbortThread(BaseException): + pass + + +class callback_iter: + """Convert a function that uses callbacks to an iterator. + + Let *func* be a function that takes a `callback` keyword argument. + For example: + + >>> def func(callback=None): + ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]: + ... if callback: + ... callback(i, c) + ... return 4 + + + Use ``with callback_iter(func)`` to get an iterator over the parameters + that are delivered to the callback. + + >>> with callback_iter(func) as it: + ... for args, kwargs in it: + ... print(args) + (1, 'a') + (2, 'b') + (3, 'c') + + The function will be called in a background thread. The ``done`` property + indicates whether it has completed execution. + + >>> it.done + True + + If it completes successfully, its return value will be available + in the ``result`` property. + + >>> it.result + 4 + + Notes: + + * If the function uses some keyword argument besides ``callback``, supply + *callback_kwd*. + * If it finished executing, but raised an exception, accessing the + ``result`` property will raise the same exception. + * If it hasn't finished executing, accessing the ``result`` + property from within the ``with`` block will raise ``RuntimeError``. + * If it hasn't finished executing, accessing the ``result`` property from + outside the ``with`` block will raise a + ``more_itertools.AbortThread`` exception. + * Provide *wait_seconds* to adjust how frequently the it is polled for + output. + + """ + + def __init__(self, func, callback_kwd='callback', wait_seconds=0.1): + self._func = func + self._callback_kwd = callback_kwd + self._aborted = False + self._future = None + self._wait_seconds = wait_seconds + self._executor = __import__("concurrent.futures").futures.ThreadPoolExecutor(max_workers=1) + self._iterator = self._reader() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._aborted = True + self._executor.shutdown() + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterator) + + @property + def done(self): + if self._future is None: + return False + return self._future.done() + + @property + def result(self): + if not self.done: + raise RuntimeError('Function has not yet completed') + + return self._future.result() + + def _reader(self): + q = Queue() + + def callback(*args, **kwargs): + if self._aborted: + raise AbortThread('canceled by user') + + q.put((args, kwargs)) + + self._future = self._executor.submit( + self._func, **{self._callback_kwd: callback} + ) + + while True: + try: + item = q.get(timeout=self._wait_seconds) + except Empty: + pass + else: + q.task_done() + yield item + + if self._future.done(): + break + + remaining = [] + while True: + try: + item = q.get_nowait() + except Empty: + break + else: + q.task_done() + remaining.append(item) + q.join() + yield from remaining + + +def windowed_complete(iterable, n): + """ + Yield ``(beginning, middle, end)`` tuples, where: + + * Each ``middle`` has *n* items from *iterable* + * Each ``beginning`` has the items before the ones in ``middle`` + * Each ``end`` has the items after the ones in ``middle`` + + >>> iterable = range(7) + >>> n = 3 + >>> for beginning, middle, end in windowed_complete(iterable, n): + ... print(beginning, middle, end) + () (0, 1, 2) (3, 4, 5, 6) + (0,) (1, 2, 3) (4, 5, 6) + (0, 1) (2, 3, 4) (5, 6) + (0, 1, 2) (3, 4, 5) (6,) + (0, 1, 2, 3) (4, 5, 6) () + + Note that *n* must be at least 0 and most equal to the length of + *iterable*. + + This function will exhaust the iterable and may require significant + storage. + """ + if n < 0: + raise ValueError('n must be >= 0') + + seq = tuple(iterable) + size = len(seq) + + if n > size: + raise ValueError('n must be <= len(seq)') + + for i in range(size - n + 1): + beginning = seq[:i] + middle = seq[i : i + n] + end = seq[i + n :] + yield beginning, middle, end + + +def all_unique(iterable, key=None): + """ + Returns ``True`` if all the elements of *iterable* are unique (no two + elements are equal). + + >>> all_unique('ABCB') + False + + If a *key* function is specified, it will be used to make comparisons. + + >>> all_unique('ABCb') + True + >>> all_unique('ABCb', str.lower) + False + + The function returns as soon as the first non-unique element is + encountered. Iterables with a mix of hashable and unhashable items can + be used, but the function will be slower for unhashable items. + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + for element in map(key, iterable) if key else iterable: + try: + if element in seenset: + return False + seenset_add(element) + except TypeError: + if element in seenlist: + return False + seenlist_add(element) + return True + + +def nth_product(index, *args): + """Equivalent to ``list(product(*args))[index]``. + + The products of *args* can be ordered lexicographically. + :func:`nth_product` computes the product at sort position *index* without + computing the previous products. + + >>> nth_product(8, range(2), range(2), range(2), range(2)) + (1, 0, 0, 0) + + ``IndexError`` will be raised if the given *index* is invalid. + """ + pools = list(map(tuple, reversed(args))) + ns = list(map(len, pools)) + + c = reduce(mul, ns) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + result = [] + for pool, n in zip(pools, ns): + result.append(pool[index % n]) + index //= n + + return tuple(reversed(result)) + + +def nth_permutation(iterable, r, index): + """Equivalent to ``list(permutations(iterable, r))[index]``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`nth_permutation` + computes the subsequence at sort position *index* directly, without + computing the previous subsequences. + + >>> nth_permutation('ghijk', 2, 5) + ('h', 'i') + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = list(iterable) + n = len(pool) + + if r is None or r == n: + r, c = n, factorial(n) + elif not 0 <= r < n: + raise ValueError + else: + c = factorial(n) // factorial(n - r) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + if c == 0: + return tuple() + + result = [0] * r + q = index * factorial(n) // c if r < n else index + for d in range(1, n + 1): + q, i = divmod(q, d) + if 0 <= n - d < r: + result[n - d] = i + if q == 0: + break + + return tuple(map(pool.pop, result)) + + +def value_chain(*args): + """Yield all arguments passed to the function in the same order in which + they were passed. If an argument itself is iterable then iterate over its + values. + + >>> list(value_chain(1, 2, 3, [4, 5, 6])) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and are emitted + as-is: + + >>> list(value_chain('12', '34', ['56', '78'])) + ['12', '34', '56', '78'] + + + Multiple levels of nesting are not flattened. + + """ + for value in args: + if isinstance(value, (str, bytes)): + yield value + continue + try: + yield from value + except TypeError: + yield value + + +def product_index(element, *args): + """Equivalent to ``list(product(*args)).index(element)`` + + The products of *args* can be ordered lexicographically. + :func:`product_index` computes the first index of *element* without + computing the previous products. + + >>> product_index([8, 2], range(10), range(5)) + 42 + + ``ValueError`` will be raised if the given *element* isn't in the product + of *args*. + """ + index = 0 + + for x, pool in zip_longest(element, args, fillvalue=_marker): + if x is _marker or pool is _marker: + raise ValueError('element is not a product of args') + + pool = tuple(pool) + index = index * len(pool) + pool.index(x) + + return index + + +def combination_index(element, iterable): + """Equivalent to ``list(combinations(iterable, r)).index(element)`` + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`combination_index` computes the index of the + first *element*, without computing the previous combinations. + + >>> combination_index('adf', 'abcdefg') + 10 + + ``ValueError`` will be raised if the given *element* isn't one of the + combinations of *iterable*. + """ + element = enumerate(element) + k, y = next(element, (None, None)) + if k is None: + return 0 + + indexes = [] + pool = enumerate(iterable) + for n, x in pool: + if x == y: + indexes.append(n) + tmp, y = next(element, (None, None)) + if tmp is None: + break + else: + k = tmp + else: + raise ValueError('element is not a combination of iterable') + + n, _ = last(pool, default=(n, None)) + + # Python versiosn below 3.8 don't have math.comb + index = 1 + for i, j in enumerate(reversed(indexes), start=1): + j = n - j + if i <= j: + index += factorial(j) // (factorial(i) * factorial(j - i)) + + return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index + + +def permutation_index(element, iterable): + """Equivalent to ``list(permutations(iterable, r)).index(element)``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`permutation_index` + computes the index of the first *element* directly, without computing + the previous permutations. + + >>> permutation_index([1, 3, 2], range(5)) + 19 + + ``ValueError`` will be raised if the given *element* isn't one of the + permutations of *iterable*. + """ + index = 0 + pool = list(iterable) + for i, x in zip(range(len(pool), -1, -1), element): + r = pool.index(x) + index = index * i + r + del pool[r] + + return index + + +class countable: + """Wrap *iterable* and keep a count of how many items have been consumed. + + The ``items_seen`` attribute starts at ``0`` and increments as the iterable + is consumed: + + >>> iterable = map(str, range(10)) + >>> it = countable(iterable) + >>> it.items_seen + 0 + >>> next(it), next(it) + ('0', '1') + >>> list(it) + ['2', '3', '4', '5', '6', '7', '8', '9'] + >>> it.items_seen + 10 + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self.items_seen = 0 + + def __iter__(self): + return self + + def __next__(self): + item = next(self._it) + self.items_seen += 1 + + return item + + +def chunked_even(iterable, n): + """Break *iterable* into lists of approximately length *n*. + Items are distributed such the lengths of the lists differ by at most + 1 item. + + >>> iterable = [1, 2, 3, 4, 5, 6, 7] + >>> n = 3 + >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2 + [[1, 2, 3], [4, 5], [6, 7]] + >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1 + [[1, 2, 3], [4, 5, 6], [7]] + + """ + + len_method = getattr(iterable, '__len__', None) + + if len_method is None: + return _chunked_even_online(iterable, n) + else: + return _chunked_even_finite(iterable, len_method(), n) + + +def _chunked_even_online(iterable, n): + buffer = [] + maxbuf = n + (n - 2) * (n - 1) + for x in iterable: + buffer.append(x) + if len(buffer) == maxbuf: + yield buffer[:n] + buffer = buffer[n:] + yield from _chunked_even_finite(buffer, len(buffer), n) + + +def _chunked_even_finite(iterable, N, n): + if N < 1: + return + + # Lists are either size `full_size <= n` or `partial_size = full_size - 1` + q, r = divmod(N, n) + num_lists = q + (1 if r > 0 else 0) + q, r = divmod(N, num_lists) + full_size = q + (1 if r > 0 else 0) + partial_size = full_size - 1 + num_full = N - partial_size * num_lists + num_partial = num_lists - num_full + + buffer = [] + iterator = iter(iterable) + + # Yield num_full lists of full_size + for x in iterator: + buffer.append(x) + if len(buffer) == full_size: + yield buffer + buffer = [] + num_full -= 1 + if num_full <= 0: + break + + # Yield num_partial lists of partial_size + for x in iterator: + buffer.append(x) + if len(buffer) == partial_size: + yield buffer + buffer = [] + num_partial -= 1 + + +def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False): + """A version of :func:`zip` that "broadcasts" any scalar + (i.e., non-iterable) items into output tuples. + + >>> iterable_1 = [1, 2, 3] + >>> iterable_2 = ['a', 'b', 'c'] + >>> scalar = '_' + >>> list(zip_broadcast(iterable_1, iterable_2, scalar)) + [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')] + + The *scalar_types* keyword argument determines what types are considered + scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to + treat strings and byte strings as iterable: + + >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None)) + [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')] + + If the *strict* keyword argument is ``True``, then + ``UnequalIterablesError`` will be raised if any of the iterables have + different lengthss. + """ + + def is_scalar(obj): + if scalar_types and isinstance(obj, scalar_types): + return True + try: + iter(obj) + except TypeError: + return True + else: + return False + + size = len(objects) + if not size: + return + + iterables, iterable_positions = [], [] + scalars, scalar_positions = [], [] + for i, obj in enumerate(objects): + if is_scalar(obj): + scalars.append(obj) + scalar_positions.append(i) + else: + iterables.append(iter(obj)) + iterable_positions.append(i) + + if len(scalars) == size: + yield tuple(objects) + return + + zipper = _zip_equal if strict else zip + for item in zipper(*iterables): + new_item = [None] * size + + for i, elem in zip(iterable_positions, item): + new_item[i] = elem + + for i, elem in zip(scalar_positions, scalars): + new_item[i] = elem + + yield tuple(new_item) + + +def unique_in_window(iterable, n, key=None): + """Yield the items from *iterable* that haven't been seen recently. + *n* is the size of the lookback window. + + >>> iterable = [0, 1, 0, 2, 3, 0] + >>> n = 3 + >>> list(unique_in_window(iterable, n)) + [0, 1, 2, 3, 0] + + The *key* function, if provided, will be used to determine uniqueness: + + >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower())) + ['a', 'b', 'c', 'd', 'a'] + + The items in *iterable* must be hashable. + + """ + if n <= 0: + raise ValueError('n must be greater than 0') + + window = deque(maxlen=n) + uniques = set() + use_key = key is not None + + for item in iterable: + k = key(item) if use_key else item + if k in uniques: + continue + + if len(uniques) == n: + uniques.discard(window[0]) + + uniques.add(k) + window.append(k) + + yield item + + +def duplicates_everseen(iterable, key=None): + """Yield duplicate elements after their first appearance. + + >>> list(duplicates_everseen('mississippi')) + ['s', 'i', 's', 's', 'i', 'p', 'i'] + >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) + ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] + + This function is analagous to :func:`unique_everseen` and is subject to + the same performance considerations. + + """ + seen_set = set() + seen_list = [] + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seen_set: + seen_set.add(k) + else: + yield element + except TypeError: + if k not in seen_list: + seen_list.append(k) + else: + yield element + + +def duplicates_justseen(iterable, key=None): + """Yields serially-duplicate elements after their first appearance. + + >>> list(duplicates_justseen('mississippi')) + ['s', 's', 'p'] + >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) + ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] + + This function is analagous to :func:`unique_justseen`. + + """ + return flatten( + map( + lambda group_tuple: islice_extended(group_tuple[1])[1:], + groupby(iterable, key), + ) + ) + + +def minmax(iterable_or_value, *others, key=None, default=_marker): + """Returns both the smallest and largest items in an iterable + or the largest of two or more arguments. + + >>> minmax([3, 1, 5]) + (1, 5) + + >>> minmax(4, 2, 6) + (2, 6) + + If a *key* function is provided, it will be used to transform the input + items for comparison. + + >>> minmax([5, 30], key=str) # '30' sorts before '5' + (30, 5) + + If a *default* value is provided, it will be returned if there are no + input items. + + >>> minmax([], default=(0, 0)) + (0, 0) + + Otherwise ``ValueError`` is raised. + + This function is based on the + `recipe `__ by + Raymond Hettinger and takes care to minimize the number of comparisons + performed. + """ + iterable = (iterable_or_value, *others) if others else iterable_or_value + + it = iter(iterable) + + try: + lo = hi = next(it) + except StopIteration as e: + if default is _marker: + raise ValueError( + '`minmax()` argument is an empty iterable. ' + 'Provide a `default` value to suppress this error.' + ) from e + return default + + # Different branches depending on the presence of key. This saves a lot + # of unimportant copies which would slow the "key=None" branch + # significantly down. + if key is None: + for x, y in zip_longest(it, it, fillvalue=lo): + if y < x: + x, y = y, x + if x < lo: + lo = x + if hi < y: + hi = y + + else: + lo_key = hi_key = key(lo) + + for x, y in zip_longest(it, it, fillvalue=lo): + + x_key, y_key = key(x), key(y) + + if y_key < x_key: + x, y, x_key, y_key = y, x, y_key, x_key + if x_key < lo_key: + lo, lo_key = x, x_key + if hi_key < y_key: + hi, hi_key = y, y_key + + return lo, hi diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/recipes.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/recipes.py new file mode 100644 index 0000000000000000000000000000000000000000..a2596423a4c3dbd15a357241477a0af0a531f9ec --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/recipes.py @@ -0,0 +1,698 @@ +"""Imported from the recipes section of the itertools documentation. + +All functions taken from the recipes section of the itertools library docs +[1]_. +Some backward-compatible usability improvements have been made. + +.. [1] http://docs.python.org/library/itertools.html#recipes + +""" +import warnings +from collections import deque +from itertools import ( + chain, + combinations, + count, + cycle, + groupby, + islice, + repeat, + starmap, + tee, + zip_longest, +) +import operator +from random import randrange, sample, choice + +__all__ = [ + 'all_equal', + 'before_and_after', + 'consume', + 'convolve', + 'dotproduct', + 'first_true', + 'flatten', + 'grouper', + 'iter_except', + 'ncycles', + 'nth', + 'nth_combination', + 'padnone', + 'pad_none', + 'pairwise', + 'partition', + 'powerset', + 'prepend', + 'quantify', + 'random_combination_with_replacement', + 'random_combination', + 'random_permutation', + 'random_product', + 'repeatfunc', + 'roundrobin', + 'sliding_window', + 'tabulate', + 'tail', + 'take', + 'triplewise', + 'unique_everseen', + 'unique_justseen', +] + + +def take(n, iterable): + """Return first *n* items of the iterable as a list. + + >>> take(3, range(10)) + [0, 1, 2] + + If there are fewer than *n* items in the iterable, all of them are + returned. + + >>> take(10, range(3)) + [0, 1, 2] + + """ + return list(islice(iterable, n)) + + +def tabulate(function, start=0): + """Return an iterator over the results of ``func(start)``, + ``func(start + 1)``, ``func(start + 2)``... + + *func* should be a function that accepts one integer argument. + + If *start* is not specified it defaults to 0. It will be incremented each + time the iterator is advanced. + + >>> square = lambda x: x ** 2 + >>> iterator = tabulate(square, -3) + >>> take(4, iterator) + [9, 4, 1, 0] + + """ + return map(function, count(start)) + + +def tail(n, iterable): + """Return an iterator over the last *n* items of *iterable*. + + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] + + """ + return iter(deque(iterable, maxlen=n)) + + +def consume(iterator, n=None): + """Advance *iterable* by *n* steps. If *n* is ``None``, consume it + entirely. + + Efficiently exhausts an iterator without returning values. Defaults to + consuming the whole iterator, but an optional second argument may be + provided to limit consumption. + + >>> i = (x for x in range(10)) + >>> next(i) + 0 + >>> consume(i, 3) + >>> next(i) + 4 + >>> consume(i) + >>> next(i) + Traceback (most recent call last): + File "", line 1, in + StopIteration + + If the iterator has fewer items remaining than the provided limit, the + whole iterator will be consumed. + + >>> i = (x for x in range(3)) + >>> consume(i, 5) + >>> next(i) + Traceback (most recent call last): + File "", line 1, in + StopIteration + + """ + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + + +def nth(iterable, n, default=None): + """Returns the nth item or a default value. + + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' + + """ + return next(islice(iterable, n, None), default) + + +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + + >>> all_equal('aaaa') + True + >>> all_equal('aaab') + False + + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +def quantify(iterable, pred=bool): + """Return the how many times the predicate is true. + + >>> quantify([True, False, True]) + 2 + + """ + return sum(map(pred, iterable)) + + +def pad_none(iterable): + """Returns the sequence of elements and then returns ``None`` indefinitely. + + >>> take(5, pad_none(range(3))) + [0, 1, 2, None, None] + + Useful for emulating the behavior of the built-in :func:`map` function. + + See also :func:`padded`. + + """ + return chain(iterable, repeat(None)) + + +padnone = pad_none + + +def ncycles(iterable, n): + """Returns the sequence elements *n* times + + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] + + """ + return chain.from_iterable(repeat(tuple(iterable), n)) + + +def dotproduct(vec1, vec2): + """Returns the dot product of the two iterables. + + >>> dotproduct([10, 10], [20, 20]) + 400 + + """ + return sum(map(operator.mul, vec1, vec2)) + + +def flatten(listOfLists): + """Return an iterator flattening one level of nesting in a list of lists. + + >>> list(flatten([[0, 1], [2, 3]])) + [0, 1, 2, 3] + + See also :func:`collapse`, which can flatten multiple levels of nesting. + + """ + return chain.from_iterable(listOfLists) + + +def repeatfunc(func, times=None, *args): + """Call *func* with *args* repeatedly, returning an iterable over the + results. + + If *times* is specified, the iterable will terminate after that many + repetitions: + + >>> from operator import add + >>> times = 4 + >>> args = 3, 5 + >>> list(repeatfunc(add, times, *args)) + [8, 8, 8, 8] + + If *times* is ``None`` the iterable will not terminate: + + >>> from random import randrange + >>> times = None + >>> args = 1, 11 + >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP + [2, 4, 8, 1, 8, 4] + + """ + if times is None: + return starmap(func, repeat(args)) + return starmap(func, repeat(args, times)) + + +def _pairwise(iterable): + """Returns an iterator of paired items, overlapping, from the original + + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`. + + """ + a, b = tee(iterable) + next(b, None) + yield from zip(a, b) + + +try: + from itertools import pairwise as itertools_pairwise +except ImportError: + pairwise = _pairwise +else: + + def pairwise(iterable): + yield from itertools_pairwise(iterable) + + pairwise.__doc__ = _pairwise.__doc__ + + +def grouper(iterable, n, fillvalue=None): + """Collect data into fixed-length chunks or blocks. + + >>> list(grouper('ABCDEFG', 3, 'x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + + """ + if isinstance(iterable, int): + warnings.warn( + "grouper expects iterable as first parameter", DeprecationWarning + ) + n, iterable = iterable, n + args = [iter(iterable)] * n + return zip_longest(fillvalue=fillvalue, *args) + + +def roundrobin(*iterables): + """Yields an item from each iterable, alternating between them. + + >>> list(roundrobin('ABC', 'D', 'EF')) + ['A', 'D', 'E', 'B', 'F', 'C'] + + This function produces the same output as :func:`interleave_longest`, but + may perform better for some inputs (in particular when the number of + iterables is small). + + """ + # Recipe credited to George Sakkis + pending = len(iterables) + nexts = cycle(iter(it).__next__ for it in iterables) + while pending: + try: + for next in nexts: + yield next() + except StopIteration: + pending -= 1 + nexts = cycle(islice(nexts, pending)) + + +def partition(pred, iterable): + """ + Returns a 2-tuple of iterables derived from the input iterable. + The first yields the items that have ``pred(item) == False``. + The second yields the items that have ``pred(item) == True``. + + >>> is_odd = lambda x: x % 2 != 0 + >>> iterable = range(10) + >>> even_items, odd_items = partition(is_odd, iterable) + >>> list(even_items), list(odd_items) + ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + + If *pred* is None, :func:`bool` is used. + + >>> iterable = [0, 1, False, True, '', ' '] + >>> false_items, true_items = partition(None, iterable) + >>> list(false_items), list(true_items) + ([0, False, ''], [1, True, ' ']) + + """ + if pred is None: + pred = bool + + evaluations = ((pred(x), x) for x in iterable) + t1, t2 = tee(evaluations) + return ( + (x for (cond, x) in t1 if not cond), + (x for (cond, x) in t2 if cond), + ) + + +def powerset(iterable): + """Yields all possible subsets of the iterable. + + >>> list(powerset([1, 2, 3])) + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + + :func:`powerset` will operate on iterables that aren't :class:`set` + instances, so repeated elements in the input will produce repeated elements + in the output. Use :func:`unique_everseen` on the input to avoid generating + duplicates: + + >>> seq = [1, 1, 0] + >>> list(powerset(seq)) + [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)] + >>> from more_itertools import unique_everseen + >>> list(powerset(unique_everseen(seq))) + [(), (1,), (0,), (1, 0)] + + """ + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def unique_everseen(iterable, key=None): + """ + Yield unique elements, preserving order. + + >>> list(unique_everseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D'] + >>> list(unique_everseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'D'] + + Sequences with a mix of hashable and unhashable items can be used. + The function will be slower (i.e., `O(n^2)`) for unhashable items. + + Remember that ``list`` objects are unhashable - you can use the *key* + parameter to transform the list to a tuple (which is hashable) to + avoid a slowdown. + + >>> iterable = ([1, 2], [2, 3], [1, 2]) + >>> list(unique_everseen(iterable)) # Slow + [[1, 2], [2, 3]] + >>> list(unique_everseen(iterable, key=tuple)) # Faster + [[1, 2], [2, 3]] + + Similary, you may want to convert unhashable ``set`` objects with + ``key=frozenset``. For ``dict`` objects, + ``key=lambda x: frozenset(x.items())`` can be used. + + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element + + +def unique_justseen(iterable, key=None): + """Yields elements in order, ignoring serial duplicates + + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] + + """ + return map(next, map(operator.itemgetter(1), groupby(iterable, key))) + + +def iter_except(func, exception, first=None): + """Yields results from a function repeatedly until an exception is raised. + + Converts a call-until-exception interface to an iterator interface. + Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel + to end the loop. + + >>> l = [0, 1, 2] + >>> list(iter_except(l.pop, IndexError)) + [2, 1, 0] + + Multiple exceptions can be specified as a stopping condition: + + >>> l = [1, 2, 3, '...', 4, 5, 6] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [7, 6, 5] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [4, 3, 2] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [] + + """ + try: + if first is not None: + yield first() + while 1: + yield func() + except exception: + pass + + +def first_true(iterable, default=None, pred=None): + """ + Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item for which + ``pred(item) == True`` . + + >>> first_true(range(10)) + 1 + >>> first_true(range(10), pred=lambda x: x > 5) + 6 + >>> first_true(range(10), default='missing', pred=lambda x: x > 9) + 'missing' + + """ + return next(filter(pred, iterable), default) + + +def random_product(*args, repeat=1): + """Draw an item at random from each of the input iterables. + + >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP + ('c', 3, 'Z') + + If *repeat* is provided as a keyword argument, that many items will be + drawn from each iterable. + + >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP + ('a', 2, 'd', 3) + + This equivalent to taking a random selection from + ``itertools.product(*args, **kwarg)``. + + """ + pools = [tuple(pool) for pool in args] * repeat + return tuple(choice(pool) for pool in pools) + + +def random_permutation(iterable, r=None): + """Return a random *r* length permutation of the elements in *iterable*. + + If *r* is not specified or is ``None``, then *r* defaults to the length of + *iterable*. + + >>> random_permutation(range(5)) # doctest:+SKIP + (3, 4, 0, 1, 2) + + This equivalent to taking a random selection from + ``itertools.permutations(iterable, r)``. + + """ + pool = tuple(iterable) + r = len(pool) if r is None else r + return tuple(sample(pool, r)) + + +def random_combination(iterable, r): + """Return a random *r* length subsequence of the elements in *iterable*. + + >>> random_combination(range(5), 3) # doctest:+SKIP + (2, 3, 4) + + This equivalent to taking a random selection from + ``itertools.combinations(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(sample(range(n), r)) + return tuple(pool[i] for i in indices) + + +def random_combination_with_replacement(iterable, r): + """Return a random *r* length subsequence of elements in *iterable*, + allowing individual elements to be repeated. + + >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP + (0, 0, 1, 2, 2) + + This equivalent to taking a random selection from + ``itertools.combinations_with_replacement(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(randrange(n) for i in range(r)) + return tuple(pool[i] for i in indices) + + +def nth_combination(iterable, r, index): + """Equivalent to ``list(combinations(iterable, r))[index]``. + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`nth_combination` computes the subsequence at + sort position *index* directly, without computing the previous + subsequences. + + >>> nth_combination(range(5), 3, 5) + (0, 3, 4) + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = tuple(iterable) + n = len(pool) + if (r < 0) or (r > n): + raise ValueError + + c = 1 + k = min(r, n - r) + for i in range(1, k + 1): + c = c * (n - k + i) // i + + if index < 0: + index += c + + if (index < 0) or (index >= c): + raise IndexError + + result = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + + return tuple(result) + + +def prepend(value, iterator): + """Yield *value*, followed by the elements in *iterator*. + + >>> value = '0' + >>> iterator = ['1', '2', '3'] + >>> list(prepend(value, iterator)) + ['0', '1', '2', '3'] + + To prepend multiple values, see :func:`itertools.chain` + or :func:`value_chain`. + + """ + return chain([value], iterator) + + +def convolve(signal, kernel): + """Convolve the iterable *signal* with the iterable *kernel*. + + >>> signal = (1, 2, 3, 4, 5) + >>> kernel = [3, 2, 1] + >>> list(convolve(signal, kernel)) + [3, 8, 14, 20, 26, 14, 5] + + Note: the input arguments are not interchangeable, as the *kernel* + is immediately consumed and stored. + + """ + kernel = tuple(kernel)[::-1] + n = len(kernel) + window = deque([0], maxlen=n) * n + for x in chain(signal, repeat(0, n - 1)): + window.append(x) + yield sum(map(operator.mul, kernel, window)) + + +def before_and_after(predicate, it): + """A variant of :func:`takewhile` that allows complete access to the + remainder of the iterator. + + >>> it = iter('ABCdEfGhI') + >>> all_upper, remainder = before_and_after(str.isupper, it) + >>> ''.join(all_upper) + 'ABC' + >>> ''.join(remainder) # takewhile() would lose the 'd' + 'dEfGhI' + + Note that the first iterator must be fully consumed before the second + iterator can generate valid results. + """ + it = iter(it) + transition = [] + + def true_iterator(): + for elem in it: + if predicate(elem): + yield elem + else: + transition.append(elem) + return + + def remainder_iterator(): + yield from transition + yield from it + + return true_iterator(), remainder_iterator() + + +def triplewise(iterable): + """Return overlapping triplets from *iterable*. + + >>> list(triplewise('ABCDE')) + [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')] + + """ + for (a, _), (b, c) in pairwise(pairwise(iterable)): + yield a, b, c + + +def sliding_window(iterable, n): + """Return a sliding window of width *n* over *iterable*. + + >>> list(sliding_window(range(6), 4)) + [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)] + + If *iterable* has fewer than *n* items, then nothing is yielded: + + >>> list(sliding_window(range(3), 4)) + [] + + For a variant with more features, see :func:`windowed`. + """ + it = iter(iterable) + window = deque(islice(it, n), maxlen=n) + if len(window) == n: + yield tuple(window) + for x in it: + window.append(x) + yield tuple(window) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__about__.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__about__.py new file mode 100644 index 0000000000000000000000000000000000000000..3551bc2d29846441299cf57b397b02fc164c99b9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__about__.py @@ -0,0 +1,26 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +__all__ = [ + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", +] + +__title__ = "packaging" +__summary__ = "Core utilities for Python packages" +__uri__ = "https://github.com/pypa/packaging" + +__version__ = "21.3" + +__author__ = "Donald Stufft and individual contributors" +__email__ = "donald@stufft.io" + +__license__ = "BSD-2-Clause or Apache-2.0" +__copyright__ = "2014-2019 %s" % __author__ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__init__.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c50c5dcfeeda2efed282200a5c5cc8c5f7542f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__init__.py @@ -0,0 +1,25 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from .__about__ import ( + __author__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, +) + +__all__ = [ + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", +] diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/__about__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/__about__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e427a574e55c03409bf69b6232685d2dfeed7deb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/__about__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4824430f843f42e1a327500face3fcdd3c0208a6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_manylinux.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_manylinux.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab9ac789e05f151182d6971d66be992230f982e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_manylinux.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_musllinux.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_musllinux.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41e04ad83793535d3e43023fd2e10a37945dd532 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_musllinux.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_structures.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_structures.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81574b20b52f12aca9155c7bf99623e4cd1e9c57 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/_structures.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/markers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/markers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab55cb3b7ba81776768c34bd950f6b9924383b3c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/markers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/requirements.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/requirements.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4a7f43d60be2f27a673c29fe459cc19bf3741b0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/requirements.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/specifiers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/specifiers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba84e4ed76367921b612f1ae9de91930dfbf5883 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/specifiers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/tags.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/tags.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..182abfd7d0d42074b805ff6cbf7b3ce1e3f86eca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/tags.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a94eb65965dcc180877df9a111c71a5722ae63 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/version.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a39d6d0eb117741240cf882c01a59433384ad15 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/__pycache__/version.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_manylinux.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_manylinux.py new file mode 100644 index 0000000000000000000000000000000000000000..4c379aa6f69ff56c8f19612002c6e3e939ea6012 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_manylinux.py @@ -0,0 +1,301 @@ +import collections +import functools +import os +import re +import struct +import sys +import warnings +from typing import IO, Dict, Iterator, NamedTuple, Optional, Tuple + + +# Python does not provide platform information at sufficient granularity to +# identify the architecture of the running executable in some cases, so we +# determine it dynamically by reading the information from the running +# process. This only applies on Linux, which uses the ELF format. +class _ELFFileHeader: + # https://en.wikipedia.org/wiki/Executable_and_Linkable_Format#File_header + class _InvalidELFFileHeader(ValueError): + """ + An invalid ELF file header was found. + """ + + ELF_MAGIC_NUMBER = 0x7F454C46 + ELFCLASS32 = 1 + ELFCLASS64 = 2 + ELFDATA2LSB = 1 + ELFDATA2MSB = 2 + EM_386 = 3 + EM_S390 = 22 + EM_ARM = 40 + EM_X86_64 = 62 + EF_ARM_ABIMASK = 0xFF000000 + EF_ARM_ABI_VER5 = 0x05000000 + EF_ARM_ABI_FLOAT_HARD = 0x00000400 + + def __init__(self, file: IO[bytes]) -> None: + def unpack(fmt: str) -> int: + try: + data = file.read(struct.calcsize(fmt)) + result: Tuple[int, ...] = struct.unpack(fmt, data) + except struct.error: + raise _ELFFileHeader._InvalidELFFileHeader() + return result[0] + + self.e_ident_magic = unpack(">I") + if self.e_ident_magic != self.ELF_MAGIC_NUMBER: + raise _ELFFileHeader._InvalidELFFileHeader() + self.e_ident_class = unpack("B") + if self.e_ident_class not in {self.ELFCLASS32, self.ELFCLASS64}: + raise _ELFFileHeader._InvalidELFFileHeader() + self.e_ident_data = unpack("B") + if self.e_ident_data not in {self.ELFDATA2LSB, self.ELFDATA2MSB}: + raise _ELFFileHeader._InvalidELFFileHeader() + self.e_ident_version = unpack("B") + self.e_ident_osabi = unpack("B") + self.e_ident_abiversion = unpack("B") + self.e_ident_pad = file.read(7) + format_h = "H" + format_i = "I" + format_q = "Q" + format_p = format_i if self.e_ident_class == self.ELFCLASS32 else format_q + self.e_type = unpack(format_h) + self.e_machine = unpack(format_h) + self.e_version = unpack(format_i) + self.e_entry = unpack(format_p) + self.e_phoff = unpack(format_p) + self.e_shoff = unpack(format_p) + self.e_flags = unpack(format_i) + self.e_ehsize = unpack(format_h) + self.e_phentsize = unpack(format_h) + self.e_phnum = unpack(format_h) + self.e_shentsize = unpack(format_h) + self.e_shnum = unpack(format_h) + self.e_shstrndx = unpack(format_h) + + +def _get_elf_header() -> Optional[_ELFFileHeader]: + try: + with open(sys.executable, "rb") as f: + elf_header = _ELFFileHeader(f) + except (OSError, TypeError, _ELFFileHeader._InvalidELFFileHeader): + return None + return elf_header + + +def _is_linux_armhf() -> bool: + # hard-float ABI can be detected from the ELF header of the running + # process + # https://static.docs.arm.com/ihi0044/g/aaelf32.pdf + elf_header = _get_elf_header() + if elf_header is None: + return False + result = elf_header.e_ident_class == elf_header.ELFCLASS32 + result &= elf_header.e_ident_data == elf_header.ELFDATA2LSB + result &= elf_header.e_machine == elf_header.EM_ARM + result &= ( + elf_header.e_flags & elf_header.EF_ARM_ABIMASK + ) == elf_header.EF_ARM_ABI_VER5 + result &= ( + elf_header.e_flags & elf_header.EF_ARM_ABI_FLOAT_HARD + ) == elf_header.EF_ARM_ABI_FLOAT_HARD + return result + + +def _is_linux_i686() -> bool: + elf_header = _get_elf_header() + if elf_header is None: + return False + result = elf_header.e_ident_class == elf_header.ELFCLASS32 + result &= elf_header.e_ident_data == elf_header.ELFDATA2LSB + result &= elf_header.e_machine == elf_header.EM_386 + return result + + +def _have_compatible_abi(arch: str) -> bool: + if arch == "armv7l": + return _is_linux_armhf() + if arch == "i686": + return _is_linux_i686() + return arch in {"x86_64", "aarch64", "ppc64", "ppc64le", "s390x"} + + +# If glibc ever changes its major version, we need to know what the last +# minor version was, so we can build the complete list of all versions. +# For now, guess what the highest minor version might be, assume it will +# be 50 for testing. Once this actually happens, update the dictionary +# with the actual value. +_LAST_GLIBC_MINOR: Dict[int, int] = collections.defaultdict(lambda: 50) + + +class _GLibCVersion(NamedTuple): + major: int + minor: int + + +def _glibc_version_string_confstr() -> Optional[str]: + """ + Primary implementation of glibc_version_string using os.confstr. + """ + # os.confstr is quite a bit faster than ctypes.DLL. It's also less likely + # to be broken or missing. This strategy is used in the standard library + # platform module. + # https://github.com/python/cpython/blob/fcf1d003bf4f0100c/Lib/platform.py#L175-L183 + try: + # os.confstr("CS_GNU_LIBC_VERSION") returns a string like "glibc 2.17". + version_string = os.confstr("CS_GNU_LIBC_VERSION") + assert version_string is not None + _, version = version_string.split() + except (AssertionError, AttributeError, OSError, ValueError): + # os.confstr() or CS_GNU_LIBC_VERSION not available (or a bad value)... + return None + return version + + +def _glibc_version_string_ctypes() -> Optional[str]: + """ + Fallback implementation of glibc_version_string using ctypes. + """ + try: + import ctypes + except ImportError: + return None + + # ctypes.CDLL(None) internally calls dlopen(NULL), and as the dlopen + # manpage says, "If filename is NULL, then the returned handle is for the + # main program". This way we can let the linker do the work to figure out + # which libc our process is actually using. + # + # We must also handle the special case where the executable is not a + # dynamically linked executable. This can occur when using musl libc, + # for example. In this situation, dlopen() will error, leading to an + # OSError. Interestingly, at least in the case of musl, there is no + # errno set on the OSError. The single string argument used to construct + # OSError comes from libc itself and is therefore not portable to + # hard code here. In any case, failure to call dlopen() means we + # can proceed, so we bail on our attempt. + try: + process_namespace = ctypes.CDLL(None) + except OSError: + return None + + try: + gnu_get_libc_version = process_namespace.gnu_get_libc_version + except AttributeError: + # Symbol doesn't exist -> therefore, we are not linked to + # glibc. + return None + + # Call gnu_get_libc_version, which returns a string like "2.5" + gnu_get_libc_version.restype = ctypes.c_char_p + version_str: str = gnu_get_libc_version() + # py2 / py3 compatibility: + if not isinstance(version_str, str): + version_str = version_str.decode("ascii") + + return version_str + + +def _glibc_version_string() -> Optional[str]: + """Returns glibc version string, or None if not using glibc.""" + return _glibc_version_string_confstr() or _glibc_version_string_ctypes() + + +def _parse_glibc_version(version_str: str) -> Tuple[int, int]: + """Parse glibc version. + + We use a regexp instead of str.split because we want to discard any + random junk that might come after the minor version -- this might happen + in patched/forked versions of glibc (e.g. Linaro's version of glibc + uses version strings like "2.20-2014.11"). See gh-3588. + """ + m = re.match(r"(?P[0-9]+)\.(?P[0-9]+)", version_str) + if not m: + warnings.warn( + "Expected glibc version with 2 components major.minor," + " got: %s" % version_str, + RuntimeWarning, + ) + return -1, -1 + return int(m.group("major")), int(m.group("minor")) + + +@functools.lru_cache() +def _get_glibc_version() -> Tuple[int, int]: + version_str = _glibc_version_string() + if version_str is None: + return (-1, -1) + return _parse_glibc_version(version_str) + + +# From PEP 513, PEP 600 +def _is_compatible(name: str, arch: str, version: _GLibCVersion) -> bool: + sys_glibc = _get_glibc_version() + if sys_glibc < version: + return False + # Check for presence of _manylinux module. + try: + import _manylinux # noqa + except ImportError: + return True + if hasattr(_manylinux, "manylinux_compatible"): + result = _manylinux.manylinux_compatible(version[0], version[1], arch) + if result is not None: + return bool(result) + return True + if version == _GLibCVersion(2, 5): + if hasattr(_manylinux, "manylinux1_compatible"): + return bool(_manylinux.manylinux1_compatible) + if version == _GLibCVersion(2, 12): + if hasattr(_manylinux, "manylinux2010_compatible"): + return bool(_manylinux.manylinux2010_compatible) + if version == _GLibCVersion(2, 17): + if hasattr(_manylinux, "manylinux2014_compatible"): + return bool(_manylinux.manylinux2014_compatible) + return True + + +_LEGACY_MANYLINUX_MAP = { + # CentOS 7 w/ glibc 2.17 (PEP 599) + (2, 17): "manylinux2014", + # CentOS 6 w/ glibc 2.12 (PEP 571) + (2, 12): "manylinux2010", + # CentOS 5 w/ glibc 2.5 (PEP 513) + (2, 5): "manylinux1", +} + + +def platform_tags(linux: str, arch: str) -> Iterator[str]: + if not _have_compatible_abi(arch): + return + # Oldest glibc to be supported regardless of architecture is (2, 17). + too_old_glibc2 = _GLibCVersion(2, 16) + if arch in {"x86_64", "i686"}: + # On x86/i686 also oldest glibc to be supported is (2, 5). + too_old_glibc2 = _GLibCVersion(2, 4) + current_glibc = _GLibCVersion(*_get_glibc_version()) + glibc_max_list = [current_glibc] + # We can assume compatibility across glibc major versions. + # https://sourceware.org/bugzilla/show_bug.cgi?id=24636 + # + # Build a list of maximum glibc versions so that we can + # output the canonical list of all glibc from current_glibc + # down to too_old_glibc2, including all intermediary versions. + for glibc_major in range(current_glibc.major - 1, 1, -1): + glibc_minor = _LAST_GLIBC_MINOR[glibc_major] + glibc_max_list.append(_GLibCVersion(glibc_major, glibc_minor)) + for glibc_max in glibc_max_list: + if glibc_max.major == too_old_glibc2.major: + min_minor = too_old_glibc2.minor + else: + # For other glibc major versions oldest supported is (x, 0). + min_minor = -1 + for glibc_minor in range(glibc_max.minor, min_minor, -1): + glibc_version = _GLibCVersion(glibc_max.major, glibc_minor) + tag = "manylinux_{}_{}".format(*glibc_version) + if _is_compatible(tag, arch, glibc_version): + yield linux.replace("linux", tag) + # Handle the legacy manylinux1, manylinux2010, manylinux2014 tags. + if glibc_version in _LEGACY_MANYLINUX_MAP: + legacy_tag = _LEGACY_MANYLINUX_MAP[glibc_version] + if _is_compatible(legacy_tag, arch, glibc_version): + yield linux.replace("linux", legacy_tag) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_musllinux.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_musllinux.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac3059ba3c246b9a5a6fb8d14936bb07777191e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_musllinux.py @@ -0,0 +1,136 @@ +"""PEP 656 support. + +This module implements logic to detect if the currently running Python is +linked against musl, and what musl version is used. +""" + +import contextlib +import functools +import operator +import os +import re +import struct +import subprocess +import sys +from typing import IO, Iterator, NamedTuple, Optional, Tuple + + +def _read_unpacked(f: IO[bytes], fmt: str) -> Tuple[int, ...]: + return struct.unpack(fmt, f.read(struct.calcsize(fmt))) + + +def _parse_ld_musl_from_elf(f: IO[bytes]) -> Optional[str]: + """Detect musl libc location by parsing the Python executable. + + Based on: https://gist.github.com/lyssdod/f51579ae8d93c8657a5564aefc2ffbca + ELF header: https://refspecs.linuxfoundation.org/elf/gabi4+/ch4.eheader.html + """ + f.seek(0) + try: + ident = _read_unpacked(f, "16B") + except struct.error: + return None + if ident[:4] != tuple(b"\x7fELF"): # Invalid magic, not ELF. + return None + f.seek(struct.calcsize("HHI"), 1) # Skip file type, machine, and version. + + try: + # e_fmt: Format for program header. + # p_fmt: Format for section header. + # p_idx: Indexes to find p_type, p_offset, and p_filesz. + e_fmt, p_fmt, p_idx = { + 1: ("IIIIHHH", "IIIIIIII", (0, 1, 4)), # 32-bit. + 2: ("QQQIHHH", "IIQQQQQQ", (0, 2, 5)), # 64-bit. + }[ident[4]] + except KeyError: + return None + else: + p_get = operator.itemgetter(*p_idx) + + # Find the interpreter section and return its content. + try: + _, e_phoff, _, _, _, e_phentsize, e_phnum = _read_unpacked(f, e_fmt) + except struct.error: + return None + for i in range(e_phnum + 1): + f.seek(e_phoff + e_phentsize * i) + try: + p_type, p_offset, p_filesz = p_get(_read_unpacked(f, p_fmt)) + except struct.error: + return None + if p_type != 3: # Not PT_INTERP. + continue + f.seek(p_offset) + interpreter = os.fsdecode(f.read(p_filesz)).strip("\0") + if "musl" not in interpreter: + return None + return interpreter + return None + + +class _MuslVersion(NamedTuple): + major: int + minor: int + + +def _parse_musl_version(output: str) -> Optional[_MuslVersion]: + lines = [n for n in (n.strip() for n in output.splitlines()) if n] + if len(lines) < 2 or lines[0][:4] != "musl": + return None + m = re.match(r"Version (\d+)\.(\d+)", lines[1]) + if not m: + return None + return _MuslVersion(major=int(m.group(1)), minor=int(m.group(2))) + + +@functools.lru_cache() +def _get_musl_version(executable: str) -> Optional[_MuslVersion]: + """Detect currently-running musl runtime version. + + This is done by checking the specified executable's dynamic linking + information, and invoking the loader to parse its output for a version + string. If the loader is musl, the output would be something like:: + + musl libc (x86_64) + Version 1.2.2 + Dynamic Program Loader + """ + with contextlib.ExitStack() as stack: + try: + f = stack.enter_context(open(executable, "rb")) + except OSError: + return None + ld = _parse_ld_musl_from_elf(f) + if not ld: + return None + proc = subprocess.run([ld], stderr=subprocess.PIPE, universal_newlines=True) + return _parse_musl_version(proc.stderr) + + +def platform_tags(arch: str) -> Iterator[str]: + """Generate musllinux tags compatible to the current platform. + + :param arch: Should be the part of platform tag after the ``linux_`` + prefix, e.g. ``x86_64``. The ``linux_`` prefix is assumed as a + prerequisite for the current platform to be musllinux-compatible. + + :returns: An iterator of compatible musllinux tags. + """ + sys_musl = _get_musl_version(sys.executable) + if sys_musl is None: # Python not dynamically linked against musl. + return + for minor in range(sys_musl.minor, -1, -1): + yield f"musllinux_{sys_musl.major}_{minor}_{arch}" + + +if __name__ == "__main__": # pragma: no cover + import sysconfig + + plat = sysconfig.get_platform() + assert plat.startswith("linux-"), "not linux" + + print("plat:", plat) + print("musl:", _get_musl_version(sys.executable)) + print("tags:", end=" ") + for t in platform_tags(re.sub(r"[.-]", "_", plat.split("-", 1)[-1])): + print(t, end="\n ") diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_structures.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..90a6465f9682c886363eea5327dac64bf623a6ff --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/_structures.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> "NegativeInfinityType": + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/markers.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/markers.py new file mode 100644 index 0000000000000000000000000000000000000000..18769b09a8a34f1e7d63cc61e62cd128ff5f9484 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/markers.py @@ -0,0 +1,304 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import operator +import os +import platform +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from pkg_resources.extern.pyparsing import ( # noqa: N817 + Forward, + Group, + Literal as L, + ParseException, + ParseResults, + QuotedString, + ZeroOrMore, + stringEnd, + stringStart, +) + +from .specifiers import InvalidSpecifier, Specifier + +__all__ = [ + "InvalidMarker", + "UndefinedComparison", + "UndefinedEnvironmentName", + "Marker", + "default_environment", +] + +Operator = Callable[[str, str], bool] + + +class InvalidMarker(ValueError): + """ + An invalid marker was found, users should refer to PEP 508. + """ + + +class UndefinedComparison(ValueError): + """ + An invalid operation was attempted on a value that doesn't support it. + """ + + +class UndefinedEnvironmentName(ValueError): + """ + A name was attempted to be used that does not exist inside of the + environment. + """ + + +class Node: + def __init__(self, value: Any) -> None: + self.value = value + + def __str__(self) -> str: + return str(self.value) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}('{self}')>" + + def serialize(self) -> str: + raise NotImplementedError + + +class Variable(Node): + def serialize(self) -> str: + return str(self) + + +class Value(Node): + def serialize(self) -> str: + return f'"{self}"' + + +class Op(Node): + def serialize(self) -> str: + return str(self) + + +VARIABLE = ( + L("implementation_version") + | L("platform_python_implementation") + | L("implementation_name") + | L("python_full_version") + | L("platform_release") + | L("platform_version") + | L("platform_machine") + | L("platform_system") + | L("python_version") + | L("sys_platform") + | L("os_name") + | L("os.name") # PEP-345 + | L("sys.platform") # PEP-345 + | L("platform.version") # PEP-345 + | L("platform.machine") # PEP-345 + | L("platform.python_implementation") # PEP-345 + | L("python_implementation") # undocumented setuptools legacy + | L("extra") # PEP-508 +) +ALIASES = { + "os.name": "os_name", + "sys.platform": "sys_platform", + "platform.version": "platform_version", + "platform.machine": "platform_machine", + "platform.python_implementation": "platform_python_implementation", + "python_implementation": "platform_python_implementation", +} +VARIABLE.setParseAction(lambda s, l, t: Variable(ALIASES.get(t[0], t[0]))) + +VERSION_CMP = ( + L("===") | L("==") | L(">=") | L("<=") | L("!=") | L("~=") | L(">") | L("<") +) + +MARKER_OP = VERSION_CMP | L("not in") | L("in") +MARKER_OP.setParseAction(lambda s, l, t: Op(t[0])) + +MARKER_VALUE = QuotedString("'") | QuotedString('"') +MARKER_VALUE.setParseAction(lambda s, l, t: Value(t[0])) + +BOOLOP = L("and") | L("or") + +MARKER_VAR = VARIABLE | MARKER_VALUE + +MARKER_ITEM = Group(MARKER_VAR + MARKER_OP + MARKER_VAR) +MARKER_ITEM.setParseAction(lambda s, l, t: tuple(t[0])) + +LPAREN = L("(").suppress() +RPAREN = L(")").suppress() + +MARKER_EXPR = Forward() +MARKER_ATOM = MARKER_ITEM | Group(LPAREN + MARKER_EXPR + RPAREN) +MARKER_EXPR << MARKER_ATOM + ZeroOrMore(BOOLOP + MARKER_EXPR) + +MARKER = stringStart + MARKER_EXPR + stringEnd + + +def _coerce_parse_result(results: Union[ParseResults, List[Any]]) -> List[Any]: + if isinstance(results, ParseResults): + return [_coerce_parse_result(i) for i in results] + else: + return results + + +def _format_marker( + marker: Union[List[str], Tuple[Node, ...], str], first: Optional[bool] = True +) -> str: + + assert isinstance(marker, (list, tuple, str)) + + # Sometimes we have a structure like [[...]] which is a single item list + # where the single item is itself it's own list. In that case we want skip + # the rest of this function so that we don't get extraneous () on the + # outside. + if ( + isinstance(marker, list) + and len(marker) == 1 + and isinstance(marker[0], (list, tuple)) + ): + return _format_marker(marker[0]) + + if isinstance(marker, list): + inner = (_format_marker(m, first=False) for m in marker) + if first: + return " ".join(inner) + else: + return "(" + " ".join(inner) + ")" + elif isinstance(marker, tuple): + return " ".join([m.serialize() for m in marker]) + else: + return marker + + +_operators: Dict[str, Operator] = { + "in": lambda lhs, rhs: lhs in rhs, + "not in": lambda lhs, rhs: lhs not in rhs, + "<": operator.lt, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + ">=": operator.ge, + ">": operator.gt, +} + + +def _eval_op(lhs: str, op: Op, rhs: str) -> bool: + try: + spec = Specifier("".join([op.serialize(), rhs])) + except InvalidSpecifier: + pass + else: + return spec.contains(lhs) + + oper: Optional[Operator] = _operators.get(op.serialize()) + if oper is None: + raise UndefinedComparison(f"Undefined {op!r} on {lhs!r} and {rhs!r}.") + + return oper(lhs, rhs) + + +class Undefined: + pass + + +_undefined = Undefined() + + +def _get_env(environment: Dict[str, str], name: str) -> str: + value: Union[str, Undefined] = environment.get(name, _undefined) + + if isinstance(value, Undefined): + raise UndefinedEnvironmentName( + f"{name!r} does not exist in evaluation environment." + ) + + return value + + +def _evaluate_markers(markers: List[Any], environment: Dict[str, str]) -> bool: + groups: List[List[bool]] = [[]] + + for marker in markers: + assert isinstance(marker, (list, tuple, str)) + + if isinstance(marker, list): + groups[-1].append(_evaluate_markers(marker, environment)) + elif isinstance(marker, tuple): + lhs, op, rhs = marker + + if isinstance(lhs, Variable): + lhs_value = _get_env(environment, lhs.value) + rhs_value = rhs.value + else: + lhs_value = lhs.value + rhs_value = _get_env(environment, rhs.value) + + groups[-1].append(_eval_op(lhs_value, op, rhs_value)) + else: + assert marker in ["and", "or"] + if marker == "or": + groups.append([]) + + return any(all(item) for item in groups) + + +def format_full_version(info: "sys._version_info") -> str: + version = "{0.major}.{0.minor}.{0.micro}".format(info) + kind = info.releaselevel + if kind != "final": + version += kind[0] + str(info.serial) + return version + + +def default_environment() -> Dict[str, str]: + iver = format_full_version(sys.implementation.version) + implementation_name = sys.implementation.name + return { + "implementation_name": implementation_name, + "implementation_version": iver, + "os_name": os.name, + "platform_machine": platform.machine(), + "platform_release": platform.release(), + "platform_system": platform.system(), + "platform_version": platform.version(), + "python_full_version": platform.python_version(), + "platform_python_implementation": platform.python_implementation(), + "python_version": ".".join(platform.python_version_tuple()[:2]), + "sys_platform": sys.platform, + } + + +class Marker: + def __init__(self, marker: str) -> None: + try: + self._markers = _coerce_parse_result(MARKER.parseString(marker)) + except ParseException as e: + raise InvalidMarker( + f"Invalid marker: {marker!r}, parse error at " + f"{marker[e.loc : e.loc + 8]!r}" + ) + + def __str__(self) -> str: + return _format_marker(self._markers) + + def __repr__(self) -> str: + return f"" + + def evaluate(self, environment: Optional[Dict[str, str]] = None) -> bool: + """Evaluate a marker. + + Return the boolean from evaluating the given marker against the + environment. environment is an optional argument to override all or + part of the determined environment. + + The environment is determined from the current Python process. + """ + current_environment = default_environment() + if environment is not None: + current_environment.update(environment) + + return _evaluate_markers(self._markers, current_environment) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/requirements.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/requirements.py new file mode 100644 index 0000000000000000000000000000000000000000..6af14ec4ce49e633d030611c26f0bd9beaf13e6a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/requirements.py @@ -0,0 +1,146 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import re +import string +import urllib.parse +from typing import List, Optional as TOptional, Set + +from pkg_resources.extern.pyparsing import ( # noqa + Combine, + Literal as L, + Optional, + ParseException, + Regex, + Word, + ZeroOrMore, + originalTextFor, + stringEnd, + stringStart, +) + +from .markers import MARKER_EXPR, Marker +from .specifiers import LegacySpecifier, Specifier, SpecifierSet + + +class InvalidRequirement(ValueError): + """ + An invalid requirement was found, users should refer to PEP 508. + """ + + +ALPHANUM = Word(string.ascii_letters + string.digits) + +LBRACKET = L("[").suppress() +RBRACKET = L("]").suppress() +LPAREN = L("(").suppress() +RPAREN = L(")").suppress() +COMMA = L(",").suppress() +SEMICOLON = L(";").suppress() +AT = L("@").suppress() + +PUNCTUATION = Word("-_.") +IDENTIFIER_END = ALPHANUM | (ZeroOrMore(PUNCTUATION) + ALPHANUM) +IDENTIFIER = Combine(ALPHANUM + ZeroOrMore(IDENTIFIER_END)) + +NAME = IDENTIFIER("name") +EXTRA = IDENTIFIER + +URI = Regex(r"[^ ]+")("url") +URL = AT + URI + +EXTRAS_LIST = EXTRA + ZeroOrMore(COMMA + EXTRA) +EXTRAS = (LBRACKET + Optional(EXTRAS_LIST) + RBRACKET)("extras") + +VERSION_PEP440 = Regex(Specifier._regex_str, re.VERBOSE | re.IGNORECASE) +VERSION_LEGACY = Regex(LegacySpecifier._regex_str, re.VERBOSE | re.IGNORECASE) + +VERSION_ONE = VERSION_PEP440 ^ VERSION_LEGACY +VERSION_MANY = Combine( + VERSION_ONE + ZeroOrMore(COMMA + VERSION_ONE), joinString=",", adjacent=False +)("_raw_spec") +_VERSION_SPEC = Optional((LPAREN + VERSION_MANY + RPAREN) | VERSION_MANY) +_VERSION_SPEC.setParseAction(lambda s, l, t: t._raw_spec or "") + +VERSION_SPEC = originalTextFor(_VERSION_SPEC)("specifier") +VERSION_SPEC.setParseAction(lambda s, l, t: t[1]) + +MARKER_EXPR = originalTextFor(MARKER_EXPR())("marker") +MARKER_EXPR.setParseAction( + lambda s, l, t: Marker(s[t._original_start : t._original_end]) +) +MARKER_SEPARATOR = SEMICOLON +MARKER = MARKER_SEPARATOR + MARKER_EXPR + +VERSION_AND_MARKER = VERSION_SPEC + Optional(MARKER) +URL_AND_MARKER = URL + Optional(MARKER) + +NAMED_REQUIREMENT = NAME + Optional(EXTRAS) + (URL_AND_MARKER | VERSION_AND_MARKER) + +REQUIREMENT = stringStart + NAMED_REQUIREMENT + stringEnd +# pkg_resources.extern.pyparsing isn't thread safe during initialization, so we do it eagerly, see +# issue #104 +REQUIREMENT.parseString("x[]") + + +class Requirement: + """Parse a requirement. + + Parse a given requirement string into its parts, such as name, specifier, + URL, and extras. Raises InvalidRequirement on a badly-formed requirement + string. + """ + + # TODO: Can we test whether something is contained within a requirement? + # If so how do we do that? Do we need to test against the _name_ of + # the thing as well as the version? What about the markers? + # TODO: Can we normalize the name and extra name? + + def __init__(self, requirement_string: str) -> None: + try: + req = REQUIREMENT.parseString(requirement_string) + except ParseException as e: + raise InvalidRequirement( + f'Parse error at "{ requirement_string[e.loc : e.loc + 8]!r}": {e.msg}' + ) + + self.name: str = req.name + if req.url: + parsed_url = urllib.parse.urlparse(req.url) + if parsed_url.scheme == "file": + if urllib.parse.urlunparse(parsed_url) != req.url: + raise InvalidRequirement("Invalid URL given") + elif not (parsed_url.scheme and parsed_url.netloc) or ( + not parsed_url.scheme and not parsed_url.netloc + ): + raise InvalidRequirement(f"Invalid URL: {req.url}") + self.url: TOptional[str] = req.url + else: + self.url = None + self.extras: Set[str] = set(req.extras.asList() if req.extras else []) + self.specifier: SpecifierSet = SpecifierSet(req.specifier) + self.marker: TOptional[Marker] = req.marker if req.marker else None + + def __str__(self) -> str: + parts: List[str] = [self.name] + + if self.extras: + formatted_extras = ",".join(sorted(self.extras)) + parts.append(f"[{formatted_extras}]") + + if self.specifier: + parts.append(str(self.specifier)) + + if self.url: + parts.append(f"@ {self.url}") + if self.marker: + parts.append(" ") + + if self.marker: + parts.append(f"; {self.marker}") + + return "".join(parts) + + def __repr__(self) -> str: + return f"" diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/specifiers.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/specifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..0e218a6f9f75ea2060a8b08d1f1a043fdad68df8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/specifiers.py @@ -0,0 +1,802 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import abc +import functools +import itertools +import re +import warnings +from typing import ( + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Pattern, + Set, + Tuple, + TypeVar, + Union, +) + +from .utils import canonicalize_version +from .version import LegacyVersion, Version, parse + +ParsedVersion = Union[Version, LegacyVersion] +UnparsedVersion = Union[Version, LegacyVersion, str] +VersionTypeVar = TypeVar("VersionTypeVar", bound=UnparsedVersion) +CallableOperator = Callable[[ParsedVersion, str], bool] + + +class InvalidSpecifier(ValueError): + """ + An invalid specifier was found, users should refer to PEP 440. + """ + + +class BaseSpecifier(metaclass=abc.ABCMeta): + @abc.abstractmethod + def __str__(self) -> str: + """ + Returns the str representation of this Specifier like object. This + should be representative of the Specifier itself. + """ + + @abc.abstractmethod + def __hash__(self) -> int: + """ + Returns a hash value for this Specifier like object. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Returns a boolean representing whether or not the two Specifier like + objects are equal. + """ + + @abc.abstractproperty + def prereleases(self) -> Optional[bool]: + """ + Returns whether or not pre-releases as a whole are allowed by this + specifier. + """ + + @prereleases.setter + def prereleases(self, value: bool) -> None: + """ + Sets whether or not pre-releases as a whole are allowed by this + specifier. + """ + + @abc.abstractmethod + def contains(self, item: str, prereleases: Optional[bool] = None) -> bool: + """ + Determines if the given item is contained within this specifier. + """ + + @abc.abstractmethod + def filter( + self, iterable: Iterable[VersionTypeVar], prereleases: Optional[bool] = None + ) -> Iterable[VersionTypeVar]: + """ + Takes an iterable of items and filters them so that only items which + are contained within this specifier are allowed in it. + """ + + +class _IndividualSpecifier(BaseSpecifier): + + _operators: Dict[str, str] = {} + _regex: Pattern[str] + + def __init__(self, spec: str = "", prereleases: Optional[bool] = None) -> None: + match = self._regex.search(spec) + if not match: + raise InvalidSpecifier(f"Invalid specifier: '{spec}'") + + self._spec: Tuple[str, str] = ( + match.group("operator").strip(), + match.group("version").strip(), + ) + + # Store whether or not this Specifier should accept prereleases + self._prereleases = prereleases + + def __repr__(self) -> str: + pre = ( + f", prereleases={self.prereleases!r}" + if self._prereleases is not None + else "" + ) + + return f"<{self.__class__.__name__}({str(self)!r}{pre})>" + + def __str__(self) -> str: + return "{}{}".format(*self._spec) + + @property + def _canonical_spec(self) -> Tuple[str, str]: + return self._spec[0], canonicalize_version(self._spec[1]) + + def __hash__(self) -> int: + return hash(self._canonical_spec) + + def __eq__(self, other: object) -> bool: + if isinstance(other, str): + try: + other = self.__class__(str(other)) + except InvalidSpecifier: + return NotImplemented + elif not isinstance(other, self.__class__): + return NotImplemented + + return self._canonical_spec == other._canonical_spec + + def _get_operator(self, op: str) -> CallableOperator: + operator_callable: CallableOperator = getattr( + self, f"_compare_{self._operators[op]}" + ) + return operator_callable + + def _coerce_version(self, version: UnparsedVersion) -> ParsedVersion: + if not isinstance(version, (LegacyVersion, Version)): + version = parse(version) + return version + + @property + def operator(self) -> str: + return self._spec[0] + + @property + def version(self) -> str: + return self._spec[1] + + @property + def prereleases(self) -> Optional[bool]: + return self._prereleases + + @prereleases.setter + def prereleases(self, value: bool) -> None: + self._prereleases = value + + def __contains__(self, item: str) -> bool: + return self.contains(item) + + def contains( + self, item: UnparsedVersion, prereleases: Optional[bool] = None + ) -> bool: + + # Determine if prereleases are to be allowed or not. + if prereleases is None: + prereleases = self.prereleases + + # Normalize item to a Version or LegacyVersion, this allows us to have + # a shortcut for ``"2.0" in Specifier(">=2") + normalized_item = self._coerce_version(item) + + # Determine if we should be supporting prereleases in this specifier + # or not, if we do not support prereleases than we can short circuit + # logic if this version is a prereleases. + if normalized_item.is_prerelease and not prereleases: + return False + + # Actually do the comparison to determine if this item is contained + # within this Specifier or not. + operator_callable: CallableOperator = self._get_operator(self.operator) + return operator_callable(normalized_item, self.version) + + def filter( + self, iterable: Iterable[VersionTypeVar], prereleases: Optional[bool] = None + ) -> Iterable[VersionTypeVar]: + + yielded = False + found_prereleases = [] + + kw = {"prereleases": prereleases if prereleases is not None else True} + + # Attempt to iterate over all the values in the iterable and if any of + # them match, yield them. + for version in iterable: + parsed_version = self._coerce_version(version) + + if self.contains(parsed_version, **kw): + # If our version is a prerelease, and we were not set to allow + # prereleases, then we'll store it for later in case nothing + # else matches this specifier. + if parsed_version.is_prerelease and not ( + prereleases or self.prereleases + ): + found_prereleases.append(version) + # Either this is not a prerelease, or we should have been + # accepting prereleases from the beginning. + else: + yielded = True + yield version + + # Now that we've iterated over everything, determine if we've yielded + # any values, and if we have not and we have any prereleases stored up + # then we will go ahead and yield the prereleases. + if not yielded and found_prereleases: + for version in found_prereleases: + yield version + + +class LegacySpecifier(_IndividualSpecifier): + + _regex_str = r""" + (?P(==|!=|<=|>=|<|>)) + \s* + (?P + [^,;\s)]* # Since this is a "legacy" specifier, and the version + # string can be just about anything, we match everything + # except for whitespace, a semi-colon for marker support, + # a closing paren since versions can be enclosed in + # them, and a comma since it's a version separator. + ) + """ + + _regex = re.compile(r"^\s*" + _regex_str + r"\s*$", re.VERBOSE | re.IGNORECASE) + + _operators = { + "==": "equal", + "!=": "not_equal", + "<=": "less_than_equal", + ">=": "greater_than_equal", + "<": "less_than", + ">": "greater_than", + } + + def __init__(self, spec: str = "", prereleases: Optional[bool] = None) -> None: + super().__init__(spec, prereleases) + + warnings.warn( + "Creating a LegacyVersion has been deprecated and will be " + "removed in the next major release", + DeprecationWarning, + ) + + def _coerce_version(self, version: UnparsedVersion) -> LegacyVersion: + if not isinstance(version, LegacyVersion): + version = LegacyVersion(str(version)) + return version + + def _compare_equal(self, prospective: LegacyVersion, spec: str) -> bool: + return prospective == self._coerce_version(spec) + + def _compare_not_equal(self, prospective: LegacyVersion, spec: str) -> bool: + return prospective != self._coerce_version(spec) + + def _compare_less_than_equal(self, prospective: LegacyVersion, spec: str) -> bool: + return prospective <= self._coerce_version(spec) + + def _compare_greater_than_equal( + self, prospective: LegacyVersion, spec: str + ) -> bool: + return prospective >= self._coerce_version(spec) + + def _compare_less_than(self, prospective: LegacyVersion, spec: str) -> bool: + return prospective < self._coerce_version(spec) + + def _compare_greater_than(self, prospective: LegacyVersion, spec: str) -> bool: + return prospective > self._coerce_version(spec) + + +def _require_version_compare( + fn: Callable[["Specifier", ParsedVersion, str], bool] +) -> Callable[["Specifier", ParsedVersion, str], bool]: + @functools.wraps(fn) + def wrapped(self: "Specifier", prospective: ParsedVersion, spec: str) -> bool: + if not isinstance(prospective, Version): + return False + return fn(self, prospective, spec) + + return wrapped + + +class Specifier(_IndividualSpecifier): + + _regex_str = r""" + (?P(~=|==|!=|<=|>=|<|>|===)) + (?P + (?: + # The identity operators allow for an escape hatch that will + # do an exact string match of the version you wish to install. + # This will not be parsed by PEP 440 and we cannot determine + # any semantic meaning from it. This operator is discouraged + # but included entirely as an escape hatch. + (?<====) # Only match for the identity operator + \s* + [^\s]* # We just match everything, except for whitespace + # since we are only testing for strict identity. + ) + | + (?: + # The (non)equality operators allow for wild card and local + # versions to be specified so we have to define these two + # operators separately to enable that. + (?<===|!=) # Only match for equals and not equals + + \s* + v? + (?:[0-9]+!)? # epoch + [0-9]+(?:\.[0-9]+)* # release + (?: # pre release + [-_\.]? + (a|b|c|rc|alpha|beta|pre|preview) + [-_\.]? + [0-9]* + )? + (?: # post release + (?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*) + )? + + # You cannot use a wild card and a dev or local version + # together so group them with a | and make them optional. + (?: + (?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release + (?:\+[a-z0-9]+(?:[-_\.][a-z0-9]+)*)? # local + | + \.\* # Wild card syntax of .* + )? + ) + | + (?: + # The compatible operator requires at least two digits in the + # release segment. + (?<=~=) # Only match for the compatible operator + + \s* + v? + (?:[0-9]+!)? # epoch + [0-9]+(?:\.[0-9]+)+ # release (We have a + instead of a *) + (?: # pre release + [-_\.]? + (a|b|c|rc|alpha|beta|pre|preview) + [-_\.]? + [0-9]* + )? + (?: # post release + (?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*) + )? + (?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release + ) + | + (?: + # All other operators only allow a sub set of what the + # (non)equality operators do. Specifically they do not allow + # local versions to be specified nor do they allow the prefix + # matching wild cards. + (?=": "greater_than_equal", + "<": "less_than", + ">": "greater_than", + "===": "arbitrary", + } + + @_require_version_compare + def _compare_compatible(self, prospective: ParsedVersion, spec: str) -> bool: + + # Compatible releases have an equivalent combination of >= and ==. That + # is that ~=2.2 is equivalent to >=2.2,==2.*. This allows us to + # implement this in terms of the other specifiers instead of + # implementing it ourselves. The only thing we need to do is construct + # the other specifiers. + + # We want everything but the last item in the version, but we want to + # ignore suffix segments. + prefix = ".".join( + list(itertools.takewhile(_is_not_suffix, _version_split(spec)))[:-1] + ) + + # Add the prefix notation to the end of our string + prefix += ".*" + + return self._get_operator(">=")(prospective, spec) and self._get_operator("==")( + prospective, prefix + ) + + @_require_version_compare + def _compare_equal(self, prospective: ParsedVersion, spec: str) -> bool: + + # We need special logic to handle prefix matching + if spec.endswith(".*"): + # In the case of prefix matching we want to ignore local segment. + prospective = Version(prospective.public) + # Split the spec out by dots, and pretend that there is an implicit + # dot in between a release segment and a pre-release segment. + split_spec = _version_split(spec[:-2]) # Remove the trailing .* + + # Split the prospective version out by dots, and pretend that there + # is an implicit dot in between a release segment and a pre-release + # segment. + split_prospective = _version_split(str(prospective)) + + # Shorten the prospective version to be the same length as the spec + # so that we can determine if the specifier is a prefix of the + # prospective version or not. + shortened_prospective = split_prospective[: len(split_spec)] + + # Pad out our two sides with zeros so that they both equal the same + # length. + padded_spec, padded_prospective = _pad_version( + split_spec, shortened_prospective + ) + + return padded_prospective == padded_spec + else: + # Convert our spec string into a Version + spec_version = Version(spec) + + # If the specifier does not have a local segment, then we want to + # act as if the prospective version also does not have a local + # segment. + if not spec_version.local: + prospective = Version(prospective.public) + + return prospective == spec_version + + @_require_version_compare + def _compare_not_equal(self, prospective: ParsedVersion, spec: str) -> bool: + return not self._compare_equal(prospective, spec) + + @_require_version_compare + def _compare_less_than_equal(self, prospective: ParsedVersion, spec: str) -> bool: + + # NB: Local version identifiers are NOT permitted in the version + # specifier, so local version labels can be universally removed from + # the prospective version. + return Version(prospective.public) <= Version(spec) + + @_require_version_compare + def _compare_greater_than_equal( + self, prospective: ParsedVersion, spec: str + ) -> bool: + + # NB: Local version identifiers are NOT permitted in the version + # specifier, so local version labels can be universally removed from + # the prospective version. + return Version(prospective.public) >= Version(spec) + + @_require_version_compare + def _compare_less_than(self, prospective: ParsedVersion, spec_str: str) -> bool: + + # Convert our spec to a Version instance, since we'll want to work with + # it as a version. + spec = Version(spec_str) + + # Check to see if the prospective version is less than the spec + # version. If it's not we can short circuit and just return False now + # instead of doing extra unneeded work. + if not prospective < spec: + return False + + # This special case is here so that, unless the specifier itself + # includes is a pre-release version, that we do not accept pre-release + # versions for the version mentioned in the specifier (e.g. <3.1 should + # not match 3.1.dev0, but should match 3.0.dev0). + if not spec.is_prerelease and prospective.is_prerelease: + if Version(prospective.base_version) == Version(spec.base_version): + return False + + # If we've gotten to here, it means that prospective version is both + # less than the spec version *and* it's not a pre-release of the same + # version in the spec. + return True + + @_require_version_compare + def _compare_greater_than(self, prospective: ParsedVersion, spec_str: str) -> bool: + + # Convert our spec to a Version instance, since we'll want to work with + # it as a version. + spec = Version(spec_str) + + # Check to see if the prospective version is greater than the spec + # version. If it's not we can short circuit and just return False now + # instead of doing extra unneeded work. + if not prospective > spec: + return False + + # This special case is here so that, unless the specifier itself + # includes is a post-release version, that we do not accept + # post-release versions for the version mentioned in the specifier + # (e.g. >3.1 should not match 3.0.post0, but should match 3.2.post0). + if not spec.is_postrelease and prospective.is_postrelease: + if Version(prospective.base_version) == Version(spec.base_version): + return False + + # Ensure that we do not allow a local version of the version mentioned + # in the specifier, which is technically greater than, to match. + if prospective.local is not None: + if Version(prospective.base_version) == Version(spec.base_version): + return False + + # If we've gotten to here, it means that prospective version is both + # greater than the spec version *and* it's not a pre-release of the + # same version in the spec. + return True + + def _compare_arbitrary(self, prospective: Version, spec: str) -> bool: + return str(prospective).lower() == str(spec).lower() + + @property + def prereleases(self) -> bool: + + # If there is an explicit prereleases set for this, then we'll just + # blindly use that. + if self._prereleases is not None: + return self._prereleases + + # Look at all of our specifiers and determine if they are inclusive + # operators, and if they are if they are including an explicit + # prerelease. + operator, version = self._spec + if operator in ["==", ">=", "<=", "~=", "==="]: + # The == specifier can include a trailing .*, if it does we + # want to remove before parsing. + if operator == "==" and version.endswith(".*"): + version = version[:-2] + + # Parse the version, and if it is a pre-release than this + # specifier allows pre-releases. + if parse(version).is_prerelease: + return True + + return False + + @prereleases.setter + def prereleases(self, value: bool) -> None: + self._prereleases = value + + +_prefix_regex = re.compile(r"^([0-9]+)((?:a|b|c|rc)[0-9]+)$") + + +def _version_split(version: str) -> List[str]: + result: List[str] = [] + for item in version.split("."): + match = _prefix_regex.search(item) + if match: + result.extend(match.groups()) + else: + result.append(item) + return result + + +def _is_not_suffix(segment: str) -> bool: + return not any( + segment.startswith(prefix) for prefix in ("dev", "a", "b", "rc", "post") + ) + + +def _pad_version(left: List[str], right: List[str]) -> Tuple[List[str], List[str]]: + left_split, right_split = [], [] + + # Get the release segment of our versions + left_split.append(list(itertools.takewhile(lambda x: x.isdigit(), left))) + right_split.append(list(itertools.takewhile(lambda x: x.isdigit(), right))) + + # Get the rest of our versions + left_split.append(left[len(left_split[0]) :]) + right_split.append(right[len(right_split[0]) :]) + + # Insert our padding + left_split.insert(1, ["0"] * max(0, len(right_split[0]) - len(left_split[0]))) + right_split.insert(1, ["0"] * max(0, len(left_split[0]) - len(right_split[0]))) + + return (list(itertools.chain(*left_split)), list(itertools.chain(*right_split))) + + +class SpecifierSet(BaseSpecifier): + def __init__( + self, specifiers: str = "", prereleases: Optional[bool] = None + ) -> None: + + # Split on , to break each individual specifier into it's own item, and + # strip each item to remove leading/trailing whitespace. + split_specifiers = [s.strip() for s in specifiers.split(",") if s.strip()] + + # Parsed each individual specifier, attempting first to make it a + # Specifier and falling back to a LegacySpecifier. + parsed: Set[_IndividualSpecifier] = set() + for specifier in split_specifiers: + try: + parsed.add(Specifier(specifier)) + except InvalidSpecifier: + parsed.add(LegacySpecifier(specifier)) + + # Turn our parsed specifiers into a frozen set and save them for later. + self._specs = frozenset(parsed) + + # Store our prereleases value so we can use it later to determine if + # we accept prereleases or not. + self._prereleases = prereleases + + def __repr__(self) -> str: + pre = ( + f", prereleases={self.prereleases!r}" + if self._prereleases is not None + else "" + ) + + return f"" + + def __str__(self) -> str: + return ",".join(sorted(str(s) for s in self._specs)) + + def __hash__(self) -> int: + return hash(self._specs) + + def __and__(self, other: Union["SpecifierSet", str]) -> "SpecifierSet": + if isinstance(other, str): + other = SpecifierSet(other) + elif not isinstance(other, SpecifierSet): + return NotImplemented + + specifier = SpecifierSet() + specifier._specs = frozenset(self._specs | other._specs) + + if self._prereleases is None and other._prereleases is not None: + specifier._prereleases = other._prereleases + elif self._prereleases is not None and other._prereleases is None: + specifier._prereleases = self._prereleases + elif self._prereleases == other._prereleases: + specifier._prereleases = self._prereleases + else: + raise ValueError( + "Cannot combine SpecifierSets with True and False prerelease " + "overrides." + ) + + return specifier + + def __eq__(self, other: object) -> bool: + if isinstance(other, (str, _IndividualSpecifier)): + other = SpecifierSet(str(other)) + elif not isinstance(other, SpecifierSet): + return NotImplemented + + return self._specs == other._specs + + def __len__(self) -> int: + return len(self._specs) + + def __iter__(self) -> Iterator[_IndividualSpecifier]: + return iter(self._specs) + + @property + def prereleases(self) -> Optional[bool]: + + # If we have been given an explicit prerelease modifier, then we'll + # pass that through here. + if self._prereleases is not None: + return self._prereleases + + # If we don't have any specifiers, and we don't have a forced value, + # then we'll just return None since we don't know if this should have + # pre-releases or not. + if not self._specs: + return None + + # Otherwise we'll see if any of the given specifiers accept + # prereleases, if any of them do we'll return True, otherwise False. + return any(s.prereleases for s in self._specs) + + @prereleases.setter + def prereleases(self, value: bool) -> None: + self._prereleases = value + + def __contains__(self, item: UnparsedVersion) -> bool: + return self.contains(item) + + def contains( + self, item: UnparsedVersion, prereleases: Optional[bool] = None + ) -> bool: + + # Ensure that our item is a Version or LegacyVersion instance. + if not isinstance(item, (LegacyVersion, Version)): + item = parse(item) + + # Determine if we're forcing a prerelease or not, if we're not forcing + # one for this particular filter call, then we'll use whatever the + # SpecifierSet thinks for whether or not we should support prereleases. + if prereleases is None: + prereleases = self.prereleases + + # We can determine if we're going to allow pre-releases by looking to + # see if any of the underlying items supports them. If none of them do + # and this item is a pre-release then we do not allow it and we can + # short circuit that here. + # Note: This means that 1.0.dev1 would not be contained in something + # like >=1.0.devabc however it would be in >=1.0.debabc,>0.0.dev0 + if not prereleases and item.is_prerelease: + return False + + # We simply dispatch to the underlying specs here to make sure that the + # given version is contained within all of them. + # Note: This use of all() here means that an empty set of specifiers + # will always return True, this is an explicit design decision. + return all(s.contains(item, prereleases=prereleases) for s in self._specs) + + def filter( + self, iterable: Iterable[VersionTypeVar], prereleases: Optional[bool] = None + ) -> Iterable[VersionTypeVar]: + + # Determine if we're forcing a prerelease or not, if we're not forcing + # one for this particular filter call, then we'll use whatever the + # SpecifierSet thinks for whether or not we should support prereleases. + if prereleases is None: + prereleases = self.prereleases + + # If we have any specifiers, then we want to wrap our iterable in the + # filter method for each one, this will act as a logical AND amongst + # each specifier. + if self._specs: + for spec in self._specs: + iterable = spec.filter(iterable, prereleases=bool(prereleases)) + return iterable + # If we do not have any specifiers, then we need to have a rough filter + # which will filter out any pre-releases, unless there are no final + # releases, and which will filter out LegacyVersion in general. + else: + filtered: List[VersionTypeVar] = [] + found_prereleases: List[VersionTypeVar] = [] + + item: UnparsedVersion + parsed_version: Union[Version, LegacyVersion] + + for item in iterable: + # Ensure that we some kind of Version class for this item. + if not isinstance(item, (LegacyVersion, Version)): + parsed_version = parse(item) + else: + parsed_version = item + + # Filter out any item which is parsed as a LegacyVersion + if isinstance(parsed_version, LegacyVersion): + continue + + # Store any item which is a pre-release for later unless we've + # already found a final version or we are accepting prereleases + if parsed_version.is_prerelease and not prereleases: + if not filtered: + found_prereleases.append(item) + else: + filtered.append(item) + + # If we've found no items except for pre-releases, then we'll go + # ahead and use the pre-releases + if not filtered and found_prereleases and prereleases is None: + return found_prereleases + + return filtered diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/tags.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/tags.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3d25a71c75c975291cf987001ecd6882d6417d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/tags.py @@ -0,0 +1,487 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import logging +import platform +import sys +import sysconfig +from importlib.machinery import EXTENSION_SUFFIXES +from typing import ( + Dict, + FrozenSet, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from . import _manylinux, _musllinux + +logger = logging.getLogger(__name__) + +PythonVersion = Sequence[int] +MacVersion = Tuple[int, int] + +INTERPRETER_SHORT_NAMES: Dict[str, str] = { + "python": "py", # Generic. + "cpython": "cp", + "pypy": "pp", + "ironpython": "ip", + "jython": "jy", +} + + +_32_BIT_INTERPRETER = sys.maxsize <= 2 ** 32 + + +class Tag: + """ + A representation of the tag triple for a wheel. + + Instances are considered immutable and thus are hashable. Equality checking + is also supported. + """ + + __slots__ = ["_interpreter", "_abi", "_platform", "_hash"] + + def __init__(self, interpreter: str, abi: str, platform: str) -> None: + self._interpreter = interpreter.lower() + self._abi = abi.lower() + self._platform = platform.lower() + # The __hash__ of every single element in a Set[Tag] will be evaluated each time + # that a set calls its `.disjoint()` method, which may be called hundreds of + # times when scanning a page of links for packages with tags matching that + # Set[Tag]. Pre-computing the value here produces significant speedups for + # downstream consumers. + self._hash = hash((self._interpreter, self._abi, self._platform)) + + @property + def interpreter(self) -> str: + return self._interpreter + + @property + def abi(self) -> str: + return self._abi + + @property + def platform(self) -> str: + return self._platform + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Tag): + return NotImplemented + + return ( + (self._hash == other._hash) # Short-circuit ASAP for perf reasons. + and (self._platform == other._platform) + and (self._abi == other._abi) + and (self._interpreter == other._interpreter) + ) + + def __hash__(self) -> int: + return self._hash + + def __str__(self) -> str: + return f"{self._interpreter}-{self._abi}-{self._platform}" + + def __repr__(self) -> str: + return f"<{self} @ {id(self)}>" + + +def parse_tag(tag: str) -> FrozenSet[Tag]: + """ + Parses the provided tag (e.g. `py3-none-any`) into a frozenset of Tag instances. + + Returning a set is required due to the possibility that the tag is a + compressed tag set. + """ + tags = set() + interpreters, abis, platforms = tag.split("-") + for interpreter in interpreters.split("."): + for abi in abis.split("."): + for platform_ in platforms.split("."): + tags.add(Tag(interpreter, abi, platform_)) + return frozenset(tags) + + +def _get_config_var(name: str, warn: bool = False) -> Union[int, str, None]: + value = sysconfig.get_config_var(name) + if value is None and warn: + logger.debug( + "Config variable '%s' is unset, Python ABI tag may be incorrect", name + ) + return value + + +def _normalize_string(string: str) -> str: + return string.replace(".", "_").replace("-", "_") + + +def _abi3_applies(python_version: PythonVersion) -> bool: + """ + Determine if the Python version supports abi3. + + PEP 384 was first implemented in Python 3.2. + """ + return len(python_version) > 1 and tuple(python_version) >= (3, 2) + + +def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> List[str]: + py_version = tuple(py_version) # To allow for version comparison. + abis = [] + version = _version_nodot(py_version[:2]) + debug = pymalloc = ucs4 = "" + with_debug = _get_config_var("Py_DEBUG", warn) + has_refcount = hasattr(sys, "gettotalrefcount") + # Windows doesn't set Py_DEBUG, so checking for support of debug-compiled + # extension modules is the best option. + # https://github.com/pypa/pip/issues/3383#issuecomment-173267692 + has_ext = "_d.pyd" in EXTENSION_SUFFIXES + if with_debug or (with_debug is None and (has_refcount or has_ext)): + debug = "d" + if py_version < (3, 8): + with_pymalloc = _get_config_var("WITH_PYMALLOC", warn) + if with_pymalloc or with_pymalloc is None: + pymalloc = "m" + if py_version < (3, 3): + unicode_size = _get_config_var("Py_UNICODE_SIZE", warn) + if unicode_size == 4 or ( + unicode_size is None and sys.maxunicode == 0x10FFFF + ): + ucs4 = "u" + elif debug: + # Debug builds can also load "normal" extension modules. + # We can also assume no UCS-4 or pymalloc requirement. + abis.append(f"cp{version}") + abis.insert( + 0, + "cp{version}{debug}{pymalloc}{ucs4}".format( + version=version, debug=debug, pymalloc=pymalloc, ucs4=ucs4 + ), + ) + return abis + + +def cpython_tags( + python_version: Optional[PythonVersion] = None, + abis: Optional[Iterable[str]] = None, + platforms: Optional[Iterable[str]] = None, + *, + warn: bool = False, +) -> Iterator[Tag]: + """ + Yields the tags for a CPython interpreter. + + The tags consist of: + - cp-- + - cp-abi3- + - cp-none- + - cp-abi3- # Older Python versions down to 3.2. + + If python_version only specifies a major version then user-provided ABIs and + the 'none' ABItag will be used. + + If 'abi3' or 'none' are specified in 'abis' then they will be yielded at + their normal position and not at the beginning. + """ + if not python_version: + python_version = sys.version_info[:2] + + interpreter = f"cp{_version_nodot(python_version[:2])}" + + if abis is None: + if len(python_version) > 1: + abis = _cpython_abis(python_version, warn) + else: + abis = [] + abis = list(abis) + # 'abi3' and 'none' are explicitly handled later. + for explicit_abi in ("abi3", "none"): + try: + abis.remove(explicit_abi) + except ValueError: + pass + + platforms = list(platforms or platform_tags()) + for abi in abis: + for platform_ in platforms: + yield Tag(interpreter, abi, platform_) + if _abi3_applies(python_version): + yield from (Tag(interpreter, "abi3", platform_) for platform_ in platforms) + yield from (Tag(interpreter, "none", platform_) for platform_ in platforms) + + if _abi3_applies(python_version): + for minor_version in range(python_version[1] - 1, 1, -1): + for platform_ in platforms: + interpreter = "cp{version}".format( + version=_version_nodot((python_version[0], minor_version)) + ) + yield Tag(interpreter, "abi3", platform_) + + +def _generic_abi() -> Iterator[str]: + abi = sysconfig.get_config_var("SOABI") + if abi: + yield _normalize_string(abi) + + +def generic_tags( + interpreter: Optional[str] = None, + abis: Optional[Iterable[str]] = None, + platforms: Optional[Iterable[str]] = None, + *, + warn: bool = False, +) -> Iterator[Tag]: + """ + Yields the tags for a generic interpreter. + + The tags consist of: + - -- + + The "none" ABI will be added if it was not explicitly provided. + """ + if not interpreter: + interp_name = interpreter_name() + interp_version = interpreter_version(warn=warn) + interpreter = "".join([interp_name, interp_version]) + if abis is None: + abis = _generic_abi() + platforms = list(platforms or platform_tags()) + abis = list(abis) + if "none" not in abis: + abis.append("none") + for abi in abis: + for platform_ in platforms: + yield Tag(interpreter, abi, platform_) + + +def _py_interpreter_range(py_version: PythonVersion) -> Iterator[str]: + """ + Yields Python versions in descending order. + + After the latest version, the major-only version will be yielded, and then + all previous versions of that major version. + """ + if len(py_version) > 1: + yield f"py{_version_nodot(py_version[:2])}" + yield f"py{py_version[0]}" + if len(py_version) > 1: + for minor in range(py_version[1] - 1, -1, -1): + yield f"py{_version_nodot((py_version[0], minor))}" + + +def compatible_tags( + python_version: Optional[PythonVersion] = None, + interpreter: Optional[str] = None, + platforms: Optional[Iterable[str]] = None, +) -> Iterator[Tag]: + """ + Yields the sequence of tags that are compatible with a specific version of Python. + + The tags consist of: + - py*-none- + - -none-any # ... if `interpreter` is provided. + - py*-none-any + """ + if not python_version: + python_version = sys.version_info[:2] + platforms = list(platforms or platform_tags()) + for version in _py_interpreter_range(python_version): + for platform_ in platforms: + yield Tag(version, "none", platform_) + if interpreter: + yield Tag(interpreter, "none", "any") + for version in _py_interpreter_range(python_version): + yield Tag(version, "none", "any") + + +def _mac_arch(arch: str, is_32bit: bool = _32_BIT_INTERPRETER) -> str: + if not is_32bit: + return arch + + if arch.startswith("ppc"): + return "ppc" + + return "i386" + + +def _mac_binary_formats(version: MacVersion, cpu_arch: str) -> List[str]: + formats = [cpu_arch] + if cpu_arch == "x86_64": + if version < (10, 4): + return [] + formats.extend(["intel", "fat64", "fat32"]) + + elif cpu_arch == "i386": + if version < (10, 4): + return [] + formats.extend(["intel", "fat32", "fat"]) + + elif cpu_arch == "ppc64": + # TODO: Need to care about 32-bit PPC for ppc64 through 10.2? + if version > (10, 5) or version < (10, 4): + return [] + formats.append("fat64") + + elif cpu_arch == "ppc": + if version > (10, 6): + return [] + formats.extend(["fat32", "fat"]) + + if cpu_arch in {"arm64", "x86_64"}: + formats.append("universal2") + + if cpu_arch in {"x86_64", "i386", "ppc64", "ppc", "intel"}: + formats.append("universal") + + return formats + + +def mac_platforms( + version: Optional[MacVersion] = None, arch: Optional[str] = None +) -> Iterator[str]: + """ + Yields the platform tags for a macOS system. + + The `version` parameter is a two-item tuple specifying the macOS version to + generate platform tags for. The `arch` parameter is the CPU architecture to + generate platform tags for. Both parameters default to the appropriate value + for the current system. + """ + version_str, _, cpu_arch = platform.mac_ver() + if version is None: + version = cast("MacVersion", tuple(map(int, version_str.split(".")[:2]))) + else: + version = version + if arch is None: + arch = _mac_arch(cpu_arch) + else: + arch = arch + + if (10, 0) <= version and version < (11, 0): + # Prior to Mac OS 11, each yearly release of Mac OS bumped the + # "minor" version number. The major version was always 10. + for minor_version in range(version[1], -1, -1): + compat_version = 10, minor_version + binary_formats = _mac_binary_formats(compat_version, arch) + for binary_format in binary_formats: + yield "macosx_{major}_{minor}_{binary_format}".format( + major=10, minor=minor_version, binary_format=binary_format + ) + + if version >= (11, 0): + # Starting with Mac OS 11, each yearly release bumps the major version + # number. The minor versions are now the midyear updates. + for major_version in range(version[0], 10, -1): + compat_version = major_version, 0 + binary_formats = _mac_binary_formats(compat_version, arch) + for binary_format in binary_formats: + yield "macosx_{major}_{minor}_{binary_format}".format( + major=major_version, minor=0, binary_format=binary_format + ) + + if version >= (11, 0): + # Mac OS 11 on x86_64 is compatible with binaries from previous releases. + # Arm64 support was introduced in 11.0, so no Arm binaries from previous + # releases exist. + # + # However, the "universal2" binary format can have a + # macOS version earlier than 11.0 when the x86_64 part of the binary supports + # that version of macOS. + if arch == "x86_64": + for minor_version in range(16, 3, -1): + compat_version = 10, minor_version + binary_formats = _mac_binary_formats(compat_version, arch) + for binary_format in binary_formats: + yield "macosx_{major}_{minor}_{binary_format}".format( + major=compat_version[0], + minor=compat_version[1], + binary_format=binary_format, + ) + else: + for minor_version in range(16, 3, -1): + compat_version = 10, minor_version + binary_format = "universal2" + yield "macosx_{major}_{minor}_{binary_format}".format( + major=compat_version[0], + minor=compat_version[1], + binary_format=binary_format, + ) + + +def _linux_platforms(is_32bit: bool = _32_BIT_INTERPRETER) -> Iterator[str]: + linux = _normalize_string(sysconfig.get_platform()) + if is_32bit: + if linux == "linux_x86_64": + linux = "linux_i686" + elif linux == "linux_aarch64": + linux = "linux_armv7l" + _, arch = linux.split("_", 1) + yield from _manylinux.platform_tags(linux, arch) + yield from _musllinux.platform_tags(arch) + yield linux + + +def _generic_platforms() -> Iterator[str]: + yield _normalize_string(sysconfig.get_platform()) + + +def platform_tags() -> Iterator[str]: + """ + Provides the platform tags for this installation. + """ + if platform.system() == "Darwin": + return mac_platforms() + elif platform.system() == "Linux": + return _linux_platforms() + else: + return _generic_platforms() + + +def interpreter_name() -> str: + """ + Returns the name of the running interpreter. + """ + name = sys.implementation.name + return INTERPRETER_SHORT_NAMES.get(name) or name + + +def interpreter_version(*, warn: bool = False) -> str: + """ + Returns the version of the running interpreter. + """ + version = _get_config_var("py_version_nodot", warn=warn) + if version: + version = str(version) + else: + version = _version_nodot(sys.version_info[:2]) + return version + + +def _version_nodot(version: PythonVersion) -> str: + return "".join(map(str, version)) + + +def sys_tags(*, warn: bool = False) -> Iterator[Tag]: + """ + Returns the sequence of tag triples for the running interpreter. + + The order of the sequence corresponds to priority order for the + interpreter, from most to least important. + """ + + interp_name = interpreter_name() + if interp_name == "cp": + yield from cpython_tags(warn=warn) + else: + yield from generic_tags() + + if interp_name == "pp": + yield from compatible_tags(interpreter="pp3") + else: + yield from compatible_tags() diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/utils.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bab11b80c60f10a4f3bccb12eb5b17c48a449767 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/utils.py @@ -0,0 +1,136 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import re +from typing import FrozenSet, NewType, Tuple, Union, cast + +from .tags import Tag, parse_tag +from .version import InvalidVersion, Version + +BuildTag = Union[Tuple[()], Tuple[int, str]] +NormalizedName = NewType("NormalizedName", str) + + +class InvalidWheelFilename(ValueError): + """ + An invalid wheel filename was found, users should refer to PEP 427. + """ + + +class InvalidSdistFilename(ValueError): + """ + An invalid sdist filename was found, users should refer to the packaging user guide. + """ + + +_canonicalize_regex = re.compile(r"[-_.]+") +# PEP 427: The build number must start with a digit. +_build_tag_regex = re.compile(r"(\d+)(.*)") + + +def canonicalize_name(name: str) -> NormalizedName: + # This is taken from PEP 503. + value = _canonicalize_regex.sub("-", name).lower() + return cast(NormalizedName, value) + + +def canonicalize_version(version: Union[Version, str]) -> str: + """ + This is very similar to Version.__str__, but has one subtle difference + with the way it handles the release segment. + """ + if isinstance(version, str): + try: + parsed = Version(version) + except InvalidVersion: + # Legacy versions cannot be normalized + return version + else: + parsed = version + + parts = [] + + # Epoch + if parsed.epoch != 0: + parts.append(f"{parsed.epoch}!") + + # Release segment + # NB: This strips trailing '.0's to normalize + parts.append(re.sub(r"(\.0)+$", "", ".".join(str(x) for x in parsed.release))) + + # Pre-release + if parsed.pre is not None: + parts.append("".join(str(x) for x in parsed.pre)) + + # Post-release + if parsed.post is not None: + parts.append(f".post{parsed.post}") + + # Development release + if parsed.dev is not None: + parts.append(f".dev{parsed.dev}") + + # Local version segment + if parsed.local is not None: + parts.append(f"+{parsed.local}") + + return "".join(parts) + + +def parse_wheel_filename( + filename: str, +) -> Tuple[NormalizedName, Version, BuildTag, FrozenSet[Tag]]: + if not filename.endswith(".whl"): + raise InvalidWheelFilename( + f"Invalid wheel filename (extension must be '.whl'): {filename}" + ) + + filename = filename[:-4] + dashes = filename.count("-") + if dashes not in (4, 5): + raise InvalidWheelFilename( + f"Invalid wheel filename (wrong number of parts): {filename}" + ) + + parts = filename.split("-", dashes - 2) + name_part = parts[0] + # See PEP 427 for the rules on escaping the project name + if "__" in name_part or re.match(r"^[\w\d._]*$", name_part, re.UNICODE) is None: + raise InvalidWheelFilename(f"Invalid project name: {filename}") + name = canonicalize_name(name_part) + version = Version(parts[1]) + if dashes == 5: + build_part = parts[2] + build_match = _build_tag_regex.match(build_part) + if build_match is None: + raise InvalidWheelFilename( + f"Invalid build number: {build_part} in '{filename}'" + ) + build = cast(BuildTag, (int(build_match.group(1)), build_match.group(2))) + else: + build = () + tags = parse_tag(parts[-1]) + return (name, version, build, tags) + + +def parse_sdist_filename(filename: str) -> Tuple[NormalizedName, Version]: + if filename.endswith(".tar.gz"): + file_stem = filename[: -len(".tar.gz")] + elif filename.endswith(".zip"): + file_stem = filename[: -len(".zip")] + else: + raise InvalidSdistFilename( + f"Invalid sdist filename (extension must be '.tar.gz' or '.zip'):" + f" {filename}" + ) + + # We are requiring a PEP 440 version, which cannot contain dashes, + # so we split on the last dash. + name_part, sep, version_part = file_stem.rpartition("-") + if not sep: + raise InvalidSdistFilename(f"Invalid sdist filename: {filename}") + + name = canonicalize_name(name_part) + version = Version(version_part) + return (name, version) diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/version.py b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/version.py new file mode 100644 index 0000000000000000000000000000000000000000..de9a09a4ed3b078b37e7490a6686f660ae935aca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/packaging/version.py @@ -0,0 +1,504 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import collections +import itertools +import re +import warnings +from typing import Callable, Iterator, List, Optional, SupportsInt, Tuple, Union + +from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType + +__all__ = ["parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN"] + +InfiniteTypes = Union[InfinityType, NegativeInfinityType] +PrePostDevType = Union[InfiniteTypes, Tuple[str, int]] +SubLocalType = Union[InfiniteTypes, int, str] +LocalType = Union[ + NegativeInfinityType, + Tuple[ + Union[ + SubLocalType, + Tuple[SubLocalType, str], + Tuple[NegativeInfinityType, SubLocalType], + ], + ..., + ], +] +CmpKey = Tuple[ + int, Tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType +] +LegacyCmpKey = Tuple[int, Tuple[str, ...]] +VersionComparisonMethod = Callable[ + [Union[CmpKey, LegacyCmpKey], Union[CmpKey, LegacyCmpKey]], bool +] + +_Version = collections.namedtuple( + "_Version", ["epoch", "release", "dev", "pre", "post", "local"] +) + + +def parse(version: str) -> Union["LegacyVersion", "Version"]: + """ + Parse the given version string and return either a :class:`Version` object + or a :class:`LegacyVersion` object depending on if the given version is + a valid PEP 440 version or a legacy version. + """ + try: + return Version(version) + except InvalidVersion: + return LegacyVersion(version) + + +class InvalidVersion(ValueError): + """ + An invalid version was found, users should refer to PEP 440. + """ + + +class _BaseVersion: + _key: Union[CmpKey, LegacyCmpKey] + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +class LegacyVersion(_BaseVersion): + def __init__(self, version: str) -> None: + self._version = str(version) + self._key = _legacy_cmpkey(self._version) + + warnings.warn( + "Creating a LegacyVersion has been deprecated and will be " + "removed in the next major release", + DeprecationWarning, + ) + + def __str__(self) -> str: + return self._version + + def __repr__(self) -> str: + return f"" + + @property + def public(self) -> str: + return self._version + + @property + def base_version(self) -> str: + return self._version + + @property + def epoch(self) -> int: + return -1 + + @property + def release(self) -> None: + return None + + @property + def pre(self) -> None: + return None + + @property + def post(self) -> None: + return None + + @property + def dev(self) -> None: + return None + + @property + def local(self) -> None: + return None + + @property + def is_prerelease(self) -> bool: + return False + + @property + def is_postrelease(self) -> bool: + return False + + @property + def is_devrelease(self) -> bool: + return False + + +_legacy_version_component_re = re.compile(r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE) + +_legacy_version_replacement_map = { + "pre": "c", + "preview": "c", + "-": "final-", + "rc": "c", + "dev": "@", +} + + +def _parse_version_parts(s: str) -> Iterator[str]: + for part in _legacy_version_component_re.split(s): + part = _legacy_version_replacement_map.get(part, part) + + if not part or part == ".": + continue + + if part[:1] in "0123456789": + # pad for numeric comparison + yield part.zfill(8) + else: + yield "*" + part + + # ensure that alpha/beta/candidate are before final + yield "*final" + + +def _legacy_cmpkey(version: str) -> LegacyCmpKey: + + # We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch + # greater than or equal to 0. This will effectively put the LegacyVersion, + # which uses the defacto standard originally implemented by setuptools, + # as before all PEP 440 versions. + epoch = -1 + + # This scheme is taken from pkg_resources.parse_version setuptools prior to + # it's adoption of the packaging library. + parts: List[str] = [] + for part in _parse_version_parts(version.lower()): + if part.startswith("*"): + # remove "-" before a prerelease tag + if part < "*final": + while parts and parts[-1] == "*final-": + parts.pop() + + # remove trailing zeros from each series of numeric parts + while parts and parts[-1] == "00000000": + parts.pop() + + parts.append(part) + + return epoch, tuple(parts) + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?P(a|b|c|rc|alpha|beta|pre|preview))
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+
+class Version(_BaseVersion):
+
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+
+    def __init__(self, version: str) -> None:
+
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        return f""
+
+    def __str__(self) -> str:
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join(str(x) for x in self.pre))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        _epoch: int = self._version.epoch
+        return _epoch
+
+    @property
+    def release(self) -> Tuple[int, ...]:
+        _release: Tuple[int, ...] = self._version.release
+        return _release
+
+    @property
+    def pre(self) -> Optional[Tuple[str, int]]:
+        _pre: Optional[Tuple[str, int]] = self._version.pre
+        return _pre
+
+    @property
+    def post(self) -> Optional[int]:
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> Optional[int]:
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> Optional[str]:
+        if self._version.local:
+            return ".".join(str(x) for x in self._version.local)
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: str, number: Union[str, bytes, SupportsInt]
+) -> Optional[Tuple[str, int]]:
+
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: str) -> Optional[LocalType]:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: Tuple[int, ...],
+    pre: Optional[Tuple[str, int]],
+    post: Optional[Tuple[str, int]],
+    dev: Optional[Tuple[str, int]],
+    local: Optional[Tuple[SubLocalType]],
+) -> CmpKey:
+
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: PrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: PrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: PrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: LocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/INSTALLER
new file mode 100644
index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/METADATA b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..d641d4933580f4ed59351026af5ed09624ef6356
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/METADATA
@@ -0,0 +1,99 @@
+Metadata-Version: 2.4
+Name: rpds-py
+Version: 0.22.3
+Classifier: Development Status :: 3 - Alpha
+Classifier: Intended Audience :: Developers
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Operating System :: OS Independent
+Classifier: Programming Language :: Rust
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: 3.11
+Classifier: Programming Language :: Python :: 3.12
+Classifier: Programming Language :: Python :: 3.13
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: Implementation :: CPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+License-File: LICENSE
+Summary: Python bindings to Rust's persistent data structures (rpds)
+Keywords: data structures,rust,persistent
+Author-email: Julian Berman 
+Requires-Python: >=3.9
+Description-Content-Type: text/x-rst; charset=UTF-8
+Project-URL: Documentation, https://rpds.readthedocs.io/
+Project-URL: Homepage, https://github.com/crate-py/rpds
+Project-URL: Issues, https://github.com/crate-py/rpds/issues/
+Project-URL: Funding, https://github.com/sponsors/Julian
+Project-URL: Tidelift, https://tidelift.com/subscription/pkg/pypi-rpds-py?utm_source=pypi-rpds-py&utm_medium=referral&utm_campaign=pypi-link
+Project-URL: Source, https://github.com/crate-py/rpds
+Project-URL: Upstream, https://github.com/orium/rpds
+
+===========
+``rpds.py``
+===========
+
+|PyPI| |Pythons| |CI|
+
+.. |PyPI| image:: https://img.shields.io/pypi/v/rpds-py.svg
+  :alt: PyPI version
+  :target: https://pypi.org/project/rpds-py/
+
+.. |Pythons| image:: https://img.shields.io/pypi/pyversions/rpds-py.svg
+  :alt: Supported Python versions
+  :target: https://pypi.org/project/rpds-py/
+
+.. |CI| image:: https://github.com/crate-py/rpds/workflows/CI/badge.svg
+  :alt: Build status
+  :target: https://github.com/crate-py/rpds/actions?query=workflow%3ACI
+
+.. |ReadTheDocs| image:: https://readthedocs.org/projects/referencing/badge/?version=stable&style=flat
+   :alt: ReadTheDocs status
+   :target: https://referencing.readthedocs.io/en/stable/
+
+
+Python bindings to the `Rust rpds crate `_ for persistent data structures.
+
+What's here is quite minimal (in transparency, it was written initially to support replacing ``pyrsistent`` in the `referencing library `_).
+If you see something missing (which is very likely), a PR is definitely welcome to add it.
+
+Installation
+------------
+
+The distribution on PyPI is named ``rpds.py`` (equivalently ``rpds-py``), and thus can be installed via e.g.:
+
+.. code:: sh
+
+    $ pip install rpds-py
+
+Note that if you install ``rpds-py`` from source, you will need a Rust toolchain installed, as it is a build-time dependency.
+An example of how to do so in a ``Dockerfile`` can be found `here `_.
+
+If you believe you are on a common platform which should have wheels built (i.e. and not need to compile from source), feel free to file an issue or pull request modifying the GitHub action used here to build wheels via ``maturin``.
+
+Usage
+-----
+
+Methods in general are named similarly to their ``rpds`` counterparts (rather than ``pyrsistent``\ 's conventions, though probably a full drop-in ``pyrsistent``\ -compatible wrapper module is a good addition at some point).
+
+.. code:: python
+
+    >>> from rpds import HashTrieMap, HashTrieSet, List
+
+    >>> m = HashTrieMap({"foo": "bar", "baz": "quux"})
+    >>> m.insert("spam", 37) == HashTrieMap({"foo": "bar", "baz": "quux", "spam": 37})
+    True
+    >>> m.remove("foo") == HashTrieMap({"baz": "quux"})
+    True
+
+    >>> s = HashTrieSet({"foo", "bar", "baz", "quux"})
+    >>> s.insert("spam") == HashTrieSet({"foo", "bar", "baz", "quux", "spam"})
+    True
+    >>> s.remove("foo") == HashTrieSet({"bar", "baz", "quux"})
+    True
+
+    >>> L = List([1, 3, 5])
+    >>> L.push_front(-1) == List([-1, 1, 3, 5])
+    True
+    >>> L.rest == List([3, 5])
+    True
+
diff --git a/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/RECORD b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/RECORD
new file mode 100644
index 0000000000000000000000000000000000000000..afdc27e055c6328e11247d68564629ca1371e3b8
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/RECORD
@@ -0,0 +1,10 @@
+rpds/__init__.py,sha256=w3MgXW7lpTCICw0KXbw20QX573_kbsEnWIeMsCAugvM,99
+rpds/__init__.pyi,sha256=-hd1A1d1BqGxL3XZGI9SpW1-xZe2iJcDMU6CpiejgV4,2570
+rpds/__pycache__/__init__.cpython-311.pyc,,
+rpds/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+rpds/rpds.cpython-311-x86_64-linux-gnu.so,sha256=re1u5b2IEJZWXL1UoGyctDKx7J5ptP3CnzfzLEVzvhY,1015312
+rpds_py-0.22.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+rpds_py-0.22.3.dist-info/METADATA,sha256=lxb6Yaex6SsszDmw10ziC7NxAbP_ykXJs8VYhdfiLvQ,4170
+rpds_py-0.22.3.dist-info/RECORD,,
+rpds_py-0.22.3.dist-info/WHEEL,sha256=SdZDzfXLqer-3JXZz2VXekP5HOiSU9GP0IKnKzv-Y64,129
+rpds_py-0.22.3.dist-info/licenses/LICENSE,sha256=MU5Okb47qpPA-0vMyeTpfNZD64ObBlr5IXgsIXX-mQk,1057
diff --git a/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/WHEEL
new file mode 100644
index 0000000000000000000000000000000000000000..483132cbc21b4c1c5466460d7de4b57ebf1045dd
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/WHEEL
@@ -0,0 +1,4 @@
+Wheel-Version: 1.0
+Generator: maturin (1.7.8)
+Root-Is-Purelib: false
+Tag: cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64
diff --git a/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/licenses/LICENSE b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/licenses/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..119a1f205aa85f584e0dbc04d7b34deaebe9199d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/rpds_py-0.22.3.dist-info/licenses/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2023 Julian Berman
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/.venv/lib/python3.11/site-packages/tokenizers/tokenizers.abi3.so b/.venv/lib/python3.11/site-packages/tokenizers/tokenizers.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..2455c3764f202f36f1e14db2328c91cec2db0ed0
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/tokenizers/tokenizers.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8f7b59958ef21415c5e43f2ca293495c80e8cf4f72ca953b40393b8385900dd
+size 8942016
diff --git a/.venv/lib/python3.11/site-packages/transformers/feature_extraction_utils.py b/.venv/lib/python3.11/site-packages/transformers/feature_extraction_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e8007edbc0b782970887e83a08ba518de0f264a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/feature_extraction_utils.py
@@ -0,0 +1,702 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extraction saving/loading class for common feature extractors.
+"""
+
+import copy
+import json
+import os
+import warnings
+from collections import UserDict
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+
+from .dynamic_module_utils import custom_object_save
+from .utils import (
+    FEATURE_EXTRACTOR_NAME,
+    PushToHubMixin,
+    TensorType,
+    add_model_info_to_auto_map,
+    add_model_info_to_custom_pipelines,
+    cached_file,
+    copy_func,
+    download_url,
+    is_flax_available,
+    is_jax_tensor,
+    is_numpy_array,
+    is_offline_mode,
+    is_remote_url,
+    is_tf_available,
+    is_torch_available,
+    is_torch_device,
+    is_torch_dtype,
+    logging,
+    requires_backends,
+)
+
+
+if TYPE_CHECKING:
+    if is_torch_available():
+        import torch  # noqa
+
+
+logger = logging.get_logger(__name__)
+
+PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"]  # noqa: F821
+
+
+class BatchFeature(UserDict):
+    r"""
+    Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods.
+
+    This class is derived from a python dictionary and can be used as a dictionary.
+
+    Args:
+        data (`dict`, *optional*):
+            Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
+            etc.).
+        tensor_type (`Union[None, str, TensorType]`, *optional*):
+            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+            initialization.
+    """
+
+    def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
+        super().__init__(data)
+        self.convert_to_tensors(tensor_type=tensor_type)
+
+    def __getitem__(self, item: str) -> Union[Any]:
+        """
+        If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
+        etc.).
+        """
+        if isinstance(item, str):
+            return self.data[item]
+        else:
+            raise KeyError("Indexing with integers is not available when using Python based feature extractors")
+
+    def __getattr__(self, item: str):
+        try:
+            return self.data[item]
+        except KeyError:
+            raise AttributeError
+
+    def __getstate__(self):
+        return {"data": self.data}
+
+    def __setstate__(self, state):
+        if "data" in state:
+            self.data = state["data"]
+
+    # Copied from transformers.tokenization_utils_base.BatchEncoding.keys
+    def keys(self):
+        return self.data.keys()
+
+    # Copied from transformers.tokenization_utils_base.BatchEncoding.values
+    def values(self):
+        return self.data.values()
+
+    # Copied from transformers.tokenization_utils_base.BatchEncoding.items
+    def items(self):
+        return self.data.items()
+
+    def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
+        if tensor_type is None:
+            return None, None
+
+        # Convert to TensorType
+        if not isinstance(tensor_type, TensorType):
+            tensor_type = TensorType(tensor_type)
+
+        # Get a function reference for the correct framework
+        if tensor_type == TensorType.TENSORFLOW:
+            if not is_tf_available():
+                raise ImportError(
+                    "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
+                )
+            import tensorflow as tf
+
+            as_tensor = tf.constant
+            is_tensor = tf.is_tensor
+        elif tensor_type == TensorType.PYTORCH:
+            if not is_torch_available():
+                raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
+            import torch  # noqa
+
+            def as_tensor(value):
+                if isinstance(value, (list, tuple)) and len(value) > 0:
+                    if isinstance(value[0], np.ndarray):
+                        value = np.array(value)
+                    elif (
+                        isinstance(value[0], (list, tuple))
+                        and len(value[0]) > 0
+                        and isinstance(value[0][0], np.ndarray)
+                    ):
+                        value = np.array(value)
+                if isinstance(value, np.ndarray):
+                    return torch.from_numpy(value)
+                else:
+                    return torch.tensor(value)
+
+            is_tensor = torch.is_tensor
+        elif tensor_type == TensorType.JAX:
+            if not is_flax_available():
+                raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
+            import jax.numpy as jnp  # noqa: F811
+
+            as_tensor = jnp.array
+            is_tensor = is_jax_tensor
+        else:
+
+            def as_tensor(value, dtype=None):
+                if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
+                    value_lens = [len(val) for val in value]
+                    if len(set(value_lens)) > 1 and dtype is None:
+                        # we have a ragged list so handle explicitly
+                        value = as_tensor([np.asarray(val) for val in value], dtype=object)
+                return np.asarray(value, dtype=dtype)
+
+            is_tensor = is_numpy_array
+        return is_tensor, as_tensor
+
+    def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
+        """
+        Convert the inner content to tensors.
+
+        Args:
+            tensor_type (`str` or [`~utils.TensorType`], *optional*):
+                The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+                `None`, no modification is done.
+        """
+        if tensor_type is None:
+            return self
+
+        is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
+
+        # Do the tensor conversion in batch
+        for key, value in self.items():
+            try:
+                if not is_tensor(value):
+                    tensor = as_tensor(value)
+
+                    self[key] = tensor
+            except:  # noqa E722
+                if key == "overflowing_values":
+                    raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
+                raise ValueError(
+                    "Unable to create tensor, you should probably activate padding "
+                    "with 'padding=True' to have batched tensors with the same length."
+                )
+
+        return self
+
+    def to(self, *args, **kwargs) -> "BatchFeature":
+        """
+        Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
+        different `dtypes` and sending the `BatchFeature` to a different `device`.
+
+        Args:
+            args (`Tuple`):
+                Will be passed to the `to(...)` function of the tensors.
+            kwargs (`Dict`, *optional*):
+                Will be passed to the `to(...)` function of the tensors.
+                To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`).
+
+        Returns:
+            [`BatchFeature`]: The same instance after modification.
+        """
+        requires_backends(self, ["torch"])
+        import torch  # noqa
+
+        new_data = {}
+        device = kwargs.get("device")
+        non_blocking = kwargs.get("non_blocking", False)
+        # Check if the args are a device or a dtype
+        if device is None and len(args) > 0:
+            # device should be always the first argument
+            arg = args[0]
+            if is_torch_dtype(arg):
+                # The first argument is a dtype
+                pass
+            elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
+                device = arg
+            else:
+                # it's something else
+                raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
+        # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
+        for k, v in self.items():
+            # check if v is a floating point
+            if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
+                # cast and send to device
+                new_data[k] = v.to(*args, **kwargs)
+            elif isinstance(v, torch.Tensor) and device is not None:
+                new_data[k] = v.to(device=device, non_blocking=non_blocking)
+            else:
+                new_data[k] = v
+        self.data = new_data
+        return self
+
+
+class FeatureExtractionMixin(PushToHubMixin):
+    """
+    This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
+    extractors.
+    """
+
+    _auto_class = None
+
+    def __init__(self, **kwargs):
+        """Set elements of `kwargs` as attributes."""
+        # Pop "processor_class" as it should be saved as private attribute
+        self._processor_class = kwargs.pop("processor_class", None)
+        # Additional attributes without default values
+        for key, value in kwargs.items():
+            try:
+                setattr(self, key, value)
+            except AttributeError as err:
+                logger.error(f"Can't set {key} with value {value} for {self}")
+                raise err
+
+    def _set_processor_class(self, processor_class: str):
+        """Sets processor class as an attribute."""
+        self._processor_class = processor_class
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        pretrained_model_name_or_path: Union[str, os.PathLike],
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        **kwargs,
+    ):
+        r"""
+        Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
+        derived class of [`SequenceFeatureExtractor`].
+
+        Args:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                This can be either:
+
+                - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+                  huggingface.co.
+                - a path to a *directory* containing a feature extractor file saved using the
+                  [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
+                  `./my_model_directory/`.
+                - a path or url to a saved feature extractor JSON *file*, e.g.,
+                  `./my_model_directory/preprocessor_config.json`.
+            cache_dir (`str` or `os.PathLike`, *optional*):
+                Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
+                standard cache should not be used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force to (re-)download the feature extractor files and override the cached versions
+                if they exist.
+            resume_download:
+                Deprecated and ignored. All downloads are now resumed by default when possible.
+                Will be removed in v5 of Transformers.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+
+
+                
+
+                To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+                
+
+            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+                If `False`, then this function returns just the final feature extractor object. If `True`, then this
+                functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
+                consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
+                `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
+            kwargs (`Dict[str, Any]`, *optional*):
+                The values in kwargs of any keys which are feature extractor attributes will be used to override the
+                loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
+                controlled by the `return_unused_kwargs` keyword parameter.
+
+        Returns:
+            A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`].
+
+        Examples:
+
+        ```python
+        # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a
+        # derived class: *Wav2Vec2FeatureExtractor*
+        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+            "facebook/wav2vec2-base-960h"
+        )  # Download feature_extraction_config from huggingface.co and cache.
+        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+            "./test/saved_model/"
+        )  # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')*
+        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json")
+        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+            "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False
+        )
+        assert feature_extractor.return_attention_mask is False
+        feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained(
+            "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True
+        )
+        assert feature_extractor.return_attention_mask is False
+        assert unused_kwargs == {"foo": False}
+        ```"""
+        kwargs["cache_dir"] = cache_dir
+        kwargs["force_download"] = force_download
+        kwargs["local_files_only"] = local_files_only
+        kwargs["revision"] = revision
+
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None:
+            kwargs["token"] = token
+
+        feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
+
+        return cls.from_dict(feature_extractor_dict, **kwargs)
+
+    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+        """
+        Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
+        [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
+
+        Args:
+            save_directory (`str` or `os.PathLike`):
+                Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            kwargs (`Dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if kwargs.get("token", None) is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            kwargs["token"] = use_auth_token
+
+        if os.path.isfile(save_directory):
+            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+        # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            custom_object_save(self, save_directory, config=self)
+
+        # If we save using the predefined names, we can load using `from_pretrained`
+        output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
+
+        self.to_json_file(output_feature_extractor_file)
+        logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
+
+        if push_to_hub:
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=kwargs.get("token"),
+            )
+
+        return [output_feature_extractor_file]
+
+    @classmethod
+    def get_feature_extractor_dict(
+        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+        """
+        From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+        feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+
+        Returns:
+            `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
+        """
+        cache_dir = kwargs.pop("cache_dir", None)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", None)
+        proxies = kwargs.pop("proxies", None)
+        subfolder = kwargs.pop("subfolder", None)
+        token = kwargs.pop("token", None)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        local_files_only = kwargs.pop("local_files_only", False)
+        revision = kwargs.pop("revision", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+
+        user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+        is_local = os.path.isdir(pretrained_model_name_or_path)
+        if os.path.isdir(pretrained_model_name_or_path):
+            feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
+        if os.path.isfile(pretrained_model_name_or_path):
+            resolved_feature_extractor_file = pretrained_model_name_or_path
+            is_local = True
+        elif is_remote_url(pretrained_model_name_or_path):
+            feature_extractor_file = pretrained_model_name_or_path
+            resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
+        else:
+            feature_extractor_file = FEATURE_EXTRACTOR_NAME
+            try:
+                # Load from local folder or from cache or download from model Hub and cache
+                resolved_feature_extractor_file = cached_file(
+                    pretrained_model_name_or_path,
+                    feature_extractor_file,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    local_files_only=local_files_only,
+                    subfolder=subfolder,
+                    token=token,
+                    user_agent=user_agent,
+                    revision=revision,
+                )
+            except EnvironmentError:
+                # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+                # the original exception.
+                raise
+            except Exception:
+                # For any other exception, we throw a generic error.
+                raise EnvironmentError(
+                    f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
+                    " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+                    f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+                    f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
+                )
+
+        try:
+            # Load feature_extractor dict
+            with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
+                text = reader.read()
+            feature_extractor_dict = json.loads(text)
+
+        except json.JSONDecodeError:
+            raise EnvironmentError(
+                f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
+            )
+
+        if is_local:
+            logger.info(f"loading configuration file {resolved_feature_extractor_file}")
+        else:
+            logger.info(
+                f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
+            )
+
+        if not is_local:
+            if "auto_map" in feature_extractor_dict:
+                feature_extractor_dict["auto_map"] = add_model_info_to_auto_map(
+                    feature_extractor_dict["auto_map"], pretrained_model_name_or_path
+                )
+            if "custom_pipelines" in feature_extractor_dict:
+                feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
+                    feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path
+                )
+
+        return feature_extractor_dict, kwargs
+
+    @classmethod
+    def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
+        """
+        Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
+        parameters.
+
+        Args:
+            feature_extractor_dict (`Dict[str, Any]`):
+                Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
+                retrieved from a pretrained checkpoint by leveraging the
+                [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
+            kwargs (`Dict[str, Any]`):
+                Additional parameters from which to initialize the feature extractor object.
+
+        Returns:
+            [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
+            parameters.
+        """
+        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+        # Update feature_extractor with kwargs if needed
+        to_remove = []
+        for key, value in kwargs.items():
+            if key in feature_extractor_dict:
+                feature_extractor_dict[key] = value
+                to_remove.append(key)
+        for key in to_remove:
+            kwargs.pop(key, None)
+
+        feature_extractor = cls(**feature_extractor_dict)
+
+        logger.info(f"Feature extractor {feature_extractor}")
+        if return_unused_kwargs:
+            return feature_extractor, kwargs
+        else:
+            return feature_extractor
+
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Serializes this instance to a Python dictionary. Returns:
+            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+        """
+        output = copy.deepcopy(self.__dict__)
+        output["feature_extractor_type"] = self.__class__.__name__
+        if "mel_filters" in output:
+            del output["mel_filters"]
+        if "window" in output:
+            del output["window"]
+        return output
+
+    @classmethod
+    def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor:
+        """
+        Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
+        a JSON file of parameters.
+
+        Args:
+            json_file (`str` or `os.PathLike`):
+                Path to the JSON file containing the parameters.
+
+        Returns:
+            A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
+            object instantiated from that JSON file.
+        """
+        with open(json_file, "r", encoding="utf-8") as reader:
+            text = reader.read()
+        feature_extractor_dict = json.loads(text)
+        return cls(**feature_extractor_dict)
+
+    def to_json_string(self) -> str:
+        """
+        Serializes this instance to a JSON string.
+
+        Returns:
+            `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+        """
+        dictionary = self.to_dict()
+
+        for key, value in dictionary.items():
+            if isinstance(value, np.ndarray):
+                dictionary[key] = value.tolist()
+
+        # make sure private name "_processor_class" is correctly
+        # saved as "processor_class"
+        _processor_class = dictionary.pop("_processor_class", None)
+        if _processor_class is not None:
+            dictionary["processor_class"] = _processor_class
+
+        return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+        """
+        Save this instance to a JSON file.
+
+        Args:
+            json_file_path (`str` or `os.PathLike`):
+                Path to the JSON file in which this feature_extractor instance's parameters will be saved.
+        """
+        with open(json_file_path, "w", encoding="utf-8") as writer:
+            writer.write(self.to_json_string())
+
+    def __repr__(self):
+        return f"{self.__class__.__name__} {self.to_json_string()}"
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
+        """
+        Register this class with a given auto class. This should only be used for custom feature extractors as the ones
+        in the library are already mapped with `AutoFeatureExtractor`.
+
+        
+
+        This API is experimental and may have some slight breaking changes in the next releases.
+
+        
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
+                The auto class to register this new feature extractor with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+
+FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
+if FeatureExtractionMixin.push_to_hub.__doc__ is not None:
+    FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
+        object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
+    )
diff --git a/.venv/lib/python3.11/site-packages/transformers/file_utils.py b/.venv/lib/python3.11/site-packages/transformers/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c13d506a613795492f73b3b6c9519096f254f16b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/file_utils.py
@@ -0,0 +1,133 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+File utilities: utilities related to download and cache models
+
+This module should not be update anymore and is only left for backward compatibility.
+"""
+
+from huggingface_hub import get_full_repo_name  # for backward compatibility
+from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY  # for backward compatibility
+
+from . import __version__
+
+# Backward compatibility imports, to make sure all those objects can be found in file_utils
+from .utils import (
+    CLOUDFRONT_DISTRIB_PREFIX,
+    CONFIG_NAME,
+    DUMMY_INPUTS,
+    DUMMY_MASK,
+    ENV_VARS_TRUE_AND_AUTO_VALUES,
+    ENV_VARS_TRUE_VALUES,
+    FEATURE_EXTRACTOR_NAME,
+    FLAX_WEIGHTS_NAME,
+    HF_MODULES_CACHE,
+    HUGGINGFACE_CO_PREFIX,
+    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+    MODEL_CARD_NAME,
+    MULTIPLE_CHOICE_DUMMY_INPUTS,
+    PYTORCH_PRETRAINED_BERT_CACHE,
+    PYTORCH_TRANSFORMERS_CACHE,
+    S3_BUCKET_PREFIX,
+    SENTENCEPIECE_UNDERLINE,
+    SPIECE_UNDERLINE,
+    TF2_WEIGHTS_NAME,
+    TF_WEIGHTS_NAME,
+    TORCH_FX_REQUIRED_VERSION,
+    TRANSFORMERS_CACHE,
+    TRANSFORMERS_DYNAMIC_MODULE_NAME,
+    USE_JAX,
+    USE_TF,
+    USE_TORCH,
+    WEIGHTS_INDEX_NAME,
+    WEIGHTS_NAME,
+    ContextManagers,
+    DummyObject,
+    EntryNotFoundError,
+    ExplicitEnum,
+    ModelOutput,
+    PaddingStrategy,
+    PushToHubMixin,
+    RepositoryNotFoundError,
+    RevisionNotFoundError,
+    TensorType,
+    _LazyModule,
+    add_code_sample_docstrings,
+    add_end_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    cached_property,
+    copy_func,
+    default_cache_path,
+    define_sagemaker_information,
+    get_cached_models,
+    get_file_from_repo,
+    get_torch_version,
+    has_file,
+    http_user_agent,
+    is_apex_available,
+    is_bs4_available,
+    is_coloredlogs_available,
+    is_datasets_available,
+    is_detectron2_available,
+    is_faiss_available,
+    is_flax_available,
+    is_ftfy_available,
+    is_g2p_en_available,
+    is_in_notebook,
+    is_ipex_available,
+    is_librosa_available,
+    is_offline_mode,
+    is_onnx_available,
+    is_pandas_available,
+    is_phonemizer_available,
+    is_protobuf_available,
+    is_psutil_available,
+    is_py3nvml_available,
+    is_pyctcdecode_available,
+    is_pytesseract_available,
+    is_pytorch_quantization_available,
+    is_rjieba_available,
+    is_sagemaker_dp_enabled,
+    is_sagemaker_mp_enabled,
+    is_scipy_available,
+    is_sentencepiece_available,
+    is_seqio_available,
+    is_sklearn_available,
+    is_soundfile_available,
+    is_spacy_available,
+    is_speech_available,
+    is_tensor,
+    is_tensorflow_probability_available,
+    is_tf2onnx_available,
+    is_tf_available,
+    is_timm_available,
+    is_tokenizers_available,
+    is_torch_available,
+    is_torch_bf16_available,
+    is_torch_cuda_available,
+    is_torch_fx_available,
+    is_torch_fx_proxy,
+    is_torch_mps_available,
+    is_torch_tf32_available,
+    is_torch_xla_available,
+    is_torchaudio_available,
+    is_training_run_on_sagemaker,
+    is_vision_available,
+    replace_return_docstrings,
+    requires_backends,
+    to_numpy,
+    to_py_obj,
+    torch_only_method,
+)
diff --git a/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py b/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5453c1ac4de5b5d8a74a58246d5ab89a20dca7fe
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py
@@ -0,0 +1,5672 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections
+import copy
+import functools
+import gc
+import importlib.metadata
+import inspect
+import itertools
+import json
+import os
+import re
+import shutil
+import tempfile
+import warnings
+from contextlib import contextmanager
+from dataclasses import dataclass
+from functools import partial, wraps
+from threading import Thread
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
+from zipfile import is_zipfile
+
+import torch
+from huggingface_hub import split_torch_state_dict_into_shards
+from packaging import version
+from torch import Tensor, nn
+from torch.nn import CrossEntropyLoss, Identity
+from torch.utils.checkpoint import checkpoint
+
+from .activations import get_activation
+from .configuration_utils import PretrainedConfig
+from .dynamic_module_utils import custom_object_save
+from .generation import CompileConfig, GenerationConfig, GenerationMixin
+from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
+from .integrations.flash_attention import flash_attention_forward
+from .integrations.flex_attention import flex_attention_forward
+from .integrations.sdpa_attention import sdpa_attention_forward
+from .loss.loss_utils import LOSS_MAPPING
+from .pytorch_utils import (  # noqa: F401
+    Conv1D,
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    id_tensor_storage,
+    prune_conv1d_layer,
+    prune_layer,
+    prune_linear_layer,
+    translate_to_torch_parallel_style,
+)
+from .quantizers import AutoHfQuantizer, HfQuantizer
+from .quantizers.quantizers_utils import get_module_from_name
+from .safetensors_conversion import auto_conversion
+from .utils import (
+    ACCELERATE_MIN_VERSION,
+    ADAPTER_SAFE_WEIGHTS_NAME,
+    ADAPTER_WEIGHTS_NAME,
+    CONFIG_NAME,
+    DUMMY_INPUTS,
+    FLAX_WEIGHTS_NAME,
+    SAFE_WEIGHTS_INDEX_NAME,
+    SAFE_WEIGHTS_NAME,
+    TF2_WEIGHTS_NAME,
+    TF_WEIGHTS_NAME,
+    WEIGHTS_INDEX_NAME,
+    WEIGHTS_NAME,
+    ContextManagers,
+    ModelOutput,
+    PushToHubMixin,
+    cached_file,
+    copy_func,
+    download_url,
+    extract_commit_hash,
+    has_file,
+    is_accelerate_available,
+    is_bitsandbytes_available,
+    is_flash_attn_2_available,
+    is_offline_mode,
+    is_optimum_available,
+    is_peft_available,
+    is_remote_url,
+    is_safetensors_available,
+    is_torch_flex_attn_available,
+    is_torch_greater_or_equal,
+    is_torch_sdpa_available,
+    is_torch_xla_available,
+    logging,
+    replace_return_docstrings,
+    strtobool,
+)
+from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
+from .utils.import_utils import (
+    ENV_VARS_TRUE_VALUES,
+    is_sagemaker_mp_enabled,
+    is_torch_fx_proxy,
+    is_torchdynamo_compiling,
+)
+from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
+
+
+XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
+XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
+
+
+if is_accelerate_available():
+    from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
+    from accelerate.hooks import add_hook_to_module
+    from accelerate.utils import (
+        check_tied_parameters_on_same_device,
+        extract_model_from_parallel,
+        find_tied_parameters,
+        get_balanced_memory,
+        get_max_memory,
+        load_offloaded_weights,
+        offload_weight,
+        save_offload_index,
+        set_module_tensor_to_device,
+    )
+
+    accelerate_version = version.parse(importlib.metadata.version("accelerate"))
+    if accelerate_version >= version.parse("0.31"):
+        from accelerate.utils.modeling import get_state_dict_from_offload
+
+if is_safetensors_available():
+    from safetensors import safe_open
+    from safetensors.torch import load_file as safe_load_file
+    from safetensors.torch import save_file as safe_save_file
+
+logger = logging.get_logger(__name__)
+
+
+_init_weights = True
+_is_quantized = False
+_is_ds_init_called = False
+
+
+def is_fsdp_enabled():
+    return (
+        torch.distributed.is_available()
+        and torch.distributed.is_initialized()
+        and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
+        and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
+    )
+
+
+def is_local_dist_rank_0():
+    return (
+        torch.distributed.is_available()
+        and torch.distributed.is_initialized()
+        and int(os.environ.get("LOCAL_RANK", -1)) == 0
+    )
+
+
+if is_sagemaker_mp_enabled():
+    import smdistributed.modelparallel.torch as smp
+    from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+else:
+    IS_SAGEMAKER_MP_POST_1_10 = False
+
+if is_peft_available():
+    from .utils import find_adapter_config_file
+
+SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
+
+TORCH_INIT_FUNCTIONS = {
+    "uniform_": nn.init.uniform_,
+    "normal_": nn.init.normal_,
+    "trunc_normal_": nn.init.trunc_normal_,
+    "constant_": nn.init.constant_,
+    "xavier_uniform_": nn.init.xavier_uniform_,
+    "xavier_normal_": nn.init.xavier_normal_,
+    "kaiming_uniform_": nn.init.kaiming_uniform_,
+    "kaiming_normal_": nn.init.kaiming_normal_,
+    "uniform": nn.init.uniform,
+    "normal": nn.init.normal,
+    "xavier_uniform": nn.init.xavier_uniform,
+    "xavier_normal": nn.init.xavier_normal,
+    "kaiming_uniform": nn.init.kaiming_uniform,
+    "kaiming_normal": nn.init.kaiming_normal,
+}
+
+
+@contextmanager
+def no_init_weights(_enable=True):
+    """
+    Context manager to globally disable weight initialization to speed up loading large models.
+
+    TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
+    """
+    global _init_weights
+    old_init_weights = _init_weights
+
+    if _enable:
+        _init_weights = False
+
+        def _skip_init(*args, **kwargs):
+            pass
+
+        # # Save the original initialization functions
+        for name, init_func in TORCH_INIT_FUNCTIONS.items():
+            setattr(torch.nn.init, name, _skip_init)
+    try:
+        yield
+    finally:
+        _init_weights = old_init_weights
+        if _enable:
+            # # Restore the original initialization functions
+            for name, init_func in TORCH_INIT_FUNCTIONS.items():
+                setattr(torch.nn.init, name, init_func)
+
+
+@contextmanager
+def set_quantized_state():
+    global _is_quantized
+    _is_quantized = True
+    try:
+        yield
+    finally:
+        _is_quantized = False
+
+
+# Skip recursive calls to deepspeed.zero.Init to avoid pinning errors.
+# This issue occurs with ZeRO stage 3 when using NVMe offloading.
+# For more details, refer to issue #34429.
+@contextmanager
+def set_zero3_state():
+    global _is_ds_init_called
+    _is_ds_init_called = True
+    try:
+        yield
+    finally:
+        _is_ds_init_called = False
+
+
+def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
+    try:
+        return next(parameter.parameters()).device
+    except StopIteration:
+        # For nn.DataParallel compatibility in PyTorch 1.5
+
+        def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
+            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+            return tuples
+
+        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+        first_tuple = next(gen)
+        return first_tuple[1].device
+
+
+def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
+    """
+    Returns the first parameter dtype (can be non-floating) or asserts if none were found.
+    """
+    try:
+        return next(parameter.parameters()).dtype
+    except StopIteration:
+        # For nn.DataParallel compatibility in PyTorch > 1.5
+
+        def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
+            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+            return tuples
+
+        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+        first_tuple = next(gen)
+        return first_tuple[1].dtype
+
+
+def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
+    """
+    Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
+    """
+    last_dtype = None
+    for t in parameter.parameters():
+        last_dtype = t.dtype
+        if t.is_floating_point():
+            # Adding fix for https://github.com/pytorch/xla/issues/4152
+            # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
+            # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
+            # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
+            if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
+                return torch.bfloat16
+            if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
+                if t.dtype == torch.float:
+                    return torch.bfloat16
+                if t.dtype == torch.double:
+                    return torch.float32
+            return t.dtype
+
+    if last_dtype is not None:
+        # if no floating dtype was found return whatever the first dtype is
+        return last_dtype
+
+    # For nn.DataParallel compatibility in PyTorch > 1.5
+    def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
+        tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+        return tuples
+
+    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+    last_tuple = None
+    for tuple in gen:
+        last_tuple = tuple
+        if tuple[1].is_floating_point():
+            return tuple[1].dtype
+
+    if last_tuple is not None:
+        # fallback to the last dtype
+        return last_tuple[1].dtype
+
+    # fallback to buffer dtype
+    for t in parameter.buffers():
+        last_dtype = t.dtype
+        if t.is_floating_point():
+            return t.dtype
+    return last_dtype
+
+
+def get_state_dict_float_dtype(state_dict):
+    """
+    Returns the first found floating dtype in `state_dict` or asserts if none were found.
+    """
+    for t in state_dict.values():
+        if t.is_floating_point():
+            return t.dtype
+
+    raise ValueError("couldn't find any floating point dtypes in state_dict")
+
+
+def get_state_dict_dtype(state_dict):
+    """
+    Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
+    """
+    for t in state_dict.values():
+        if t.is_floating_point():
+            return t.dtype
+
+    # if no floating dtype was found return whatever the first dtype is
+    else:
+        return next(state_dict.values()).dtype
+
+
+def dtype_byte_size(dtype):
+    """
+    Returns the size (in bytes) occupied by one parameter of type `dtype`.
+
+    Example:
+
+    ```py
+    >>> dtype_byte_size(torch.float32)
+    4
+    ```
+    """
+    if dtype == torch.bool:
+        return 1 / 8
+    bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
+    if bit_search is None:
+        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
+    bit_size = int(bit_search.groups()[0])
+    return bit_size // 8
+
+
+def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
+    """
+    Checks if `model_to_load` supports param buffer assignment (such
+    as when loading in empty weights) by first checking
+    if the model explicitly disables it, then by ensuring that the state dict keys
+    are a subset of the model's parameters.
+
+    Note: We fully disable this if we are using `deepspeed`
+    """
+    if model_to_load.device.type == "meta":
+        return False
+
+    if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
+        return False
+
+    if is_deepspeed_zero3_enabled():
+        return False
+
+    # Some models explicitly do not support param buffer assignment
+    if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
+        logger.debug(
+            f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
+        )
+        return False
+
+    # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
+    first_key = next(iter(model_to_load.state_dict().keys()))
+    if start_prefix + first_key in state_dict:
+        return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
+
+    # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
+    return False
+
+
+def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
+    """
+    This is the same as
+    [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
+    but for a sharded checkpoint.
+
+    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+    loaded in the model.
+
+    Args:
+        model (`torch.nn.Module`): The model in which to load the checkpoint.
+        folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
+        strict (`bool`, *optional`, defaults to `True`):
+            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
+        prefer_safe (`bool`, *optional*, defaults to `False`)
+            If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
+            safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
+
+    Returns:
+        `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
+            - `missing_keys` is a list of str containing the missing keys
+            - `unexpected_keys` is a list of str containing the unexpected keys
+    """
+    # Load the index
+    index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
+    safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
+
+    index_present = os.path.isfile(index_file)
+    safe_index_present = os.path.isfile(safe_index_file)
+
+    if not index_present and not (safe_index_present and is_safetensors_available()):
+        filenames = (
+            (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,)
+        )
+        raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
+
+    load_safe = False
+    if safe_index_present:
+        if prefer_safe:
+            if is_safetensors_available():
+                load_safe = True  # load safe due to preference
+            else:
+                logger.warning(
+                    f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!"
+                )
+        elif not index_present:
+            load_safe = True  # load safe since we have no other choice
+
+    load_index = safe_index_file if load_safe else index_file
+
+    with open(load_index, "r", encoding="utf-8") as f:
+        index = json.load(f)
+
+    shard_files = list(set(index["weight_map"].values()))
+
+    # If strict=True, error before loading any of the state dicts.
+    loaded_keys = index["weight_map"].keys()
+    model_keys = model.state_dict().keys()
+    missing_keys = [key for key in model_keys if key not in loaded_keys]
+    unexpected_keys = [key for key in loaded_keys if key not in model_keys]
+    if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
+        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
+        if len(missing_keys) > 0:
+            str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
+            error_message += f"\nMissing key(s): {str_missing_keys}."
+        if len(unexpected_keys) > 0:
+            str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
+            error_message += f"\nMissing key(s): {str_unexpected_keys}."
+        raise RuntimeError(error_message)
+
+    weights_only_kwarg = {"weights_only": True}
+    loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
+
+    for shard_file in shard_files:
+        state_dict = loader(os.path.join(folder, shard_file))
+        model.load_state_dict(state_dict, strict=False)
+
+        # Make sure memory is freed before we load the next state dict.
+        del state_dict
+        gc.collect()
+
+    # Return the same thing as PyTorch load_state_dict function.
+    return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
+
+
+def load_state_dict(
+    checkpoint_file: Union[str, os.PathLike],
+    is_quantized: bool = False,
+    map_location: Optional[Union[str, torch.device]] = None,
+    weights_only: bool = True,
+):
+    """
+    Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
+    """
+    if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
+        # Check format of the archive
+        with safe_open(checkpoint_file, framework="pt") as f:
+            metadata = f.metadata()
+        if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
+            raise OSError(
+                f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
+                "you save your model with the `save_pretrained` method."
+            )
+        return safe_load_file(checkpoint_file)
+    try:
+        if map_location is None:
+            if (
+                (
+                    is_deepspeed_zero3_enabled()
+                    and torch.distributed.is_initialized()
+                    and torch.distributed.get_rank() > 0
+                )
+                or (is_fsdp_enabled() and not is_local_dist_rank_0())
+            ) and not is_quantized:
+                map_location = "meta"
+            else:
+                map_location = "cpu"
+        extra_args = {}
+        # mmap can only be used with files serialized with zipfile-based format.
+        if (
+            isinstance(checkpoint_file, str)
+            and map_location != "meta"
+            and version.parse(torch.__version__) >= version.parse("2.1.0")
+            and is_zipfile(checkpoint_file)
+        ):
+            extra_args = {"mmap": True}
+        weights_only_kwarg = {"weights_only": weights_only}
+        return torch.load(
+            checkpoint_file,
+            map_location=map_location,
+            **weights_only_kwarg,
+            **extra_args,
+        )
+    except Exception as e:
+        try:
+            with open(checkpoint_file) as f:
+                if f.read(7) == "version":
+                    raise OSError(
+                        "You seem to have cloned a repository without having git-lfs installed. Please install "
+                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+                        "you cloned."
+                    )
+                else:
+                    raise ValueError(
+                        f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+                        "model. Make sure you have saved the model properly."
+                    ) from e
+        except (UnicodeDecodeError, ValueError):
+            raise OSError(
+                f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
+                f"at '{checkpoint_file}'. "
+                "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+            )
+
+
+def set_initialized_submodules(model, state_dict_keys):
+    """
+    Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
+    dict.
+    """
+    not_initialized_submodules = {}
+    for module_name, module in model.named_modules():
+        loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
+        # When checking if the root module is loaded all state_dict_keys must be used.
+        if module_name == "":
+            loaded_keys = set(state_dict_keys)
+        if loaded_keys.issuperset(module.state_dict()):
+            module._is_hf_initialized = True
+        else:
+            not_initialized_submodules[module_name] = module
+    return not_initialized_submodules
+
+
+def _end_ptr(tensor: torch.Tensor) -> int:
+    # extract the end of the pointer if the tensor is a slice of a bigger tensor
+    if tensor.nelement():
+        stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
+    else:
+        stop = tensor.data_ptr()
+    return stop
+
+
+def _get_tied_weight_keys(module: nn.Module, prefix=""):
+    tied_weight_keys = []
+    if getattr(module, "_tied_weights_keys", None) is not None:
+        names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys]
+        tied_weight_keys.extend(names)
+    if getattr(module, "_dynamic_tied_weights_keys", None) is not None:
+        names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys]
+        tied_weight_keys.extend(names)
+    for name, submodule in module.named_children():
+        local_prefix = f"{prefix}.{name}" if prefix else name
+        tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
+    return tied_weight_keys
+
+
+def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]:
+    filtered_tensors = []
+    for shared in tensors:
+        if len(shared) < 2:
+            filtered_tensors.append(shared)
+            continue
+
+        areas = []
+        for name in shared:
+            tensor = state_dict[name]
+            areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
+        areas.sort()
+
+        _, last_stop, last_name = areas[0]
+        filtered_tensors.append({last_name})
+        for start, stop, name in areas[1:]:
+            if start >= last_stop:
+                filtered_tensors.append({name})
+            else:
+                filtered_tensors[-1].add(name)
+            last_stop = stop
+    disjoint_tensors = []
+    shared_tensors = []
+    for tensors in filtered_tensors:
+        if len(tensors) == 1:
+            disjoint_tensors.append(tensors.pop())
+        else:
+            shared_tensors.append(tensors)
+    return shared_tensors, disjoint_tensors
+
+
+def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
+    shared_tensors = []
+    identical = []
+    for shared in tensors:
+        if len(shared) < 2:
+            continue
+
+        areas = collections.defaultdict(set)
+        for name in shared:
+            tensor = state_dict[name]
+            area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
+            areas[area].add(name)
+        if len(areas) == 1:
+            identical.append(shared)
+        else:
+            shared_tensors.append(shared)
+    return shared_tensors, identical
+
+
+def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
+    # copy state_dict so _load_from_state_dict can modify it
+    metadata = getattr(state_dict, "_metadata", None)
+    state_dict = state_dict.copy()
+    if metadata is not None:
+        state_dict._metadata = metadata
+
+    error_msgs = []
+
+    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+    # so we need to apply the function recursively.
+    def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
+        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+        local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
+
+        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
+        # Parameters of module and children will start with prefix. We can exit early if there are none in this
+        # state_dict
+        if len([key for key in state_dict if key.startswith(prefix)]) > 0:
+            if is_deepspeed_zero3_enabled():
+                import deepspeed
+
+                # In sharded models, each shard has only part of the full state_dict, so only gather
+                # parameters that are in the current state_dict.
+                named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
+                params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
+                if len(params_to_gather) > 0:
+                    # because zero3 puts placeholders in model params, this context
+                    # manager gathers (unpartitions) the params of the current layer, then loads from
+                    # the state dict and then re-partitions them again
+                    with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
+                        if torch.distributed.get_rank() == 0:
+                            module._load_from_state_dict(*args)
+            else:
+                module._load_from_state_dict(*args)
+
+        for name, child in module._modules.items():
+            if child is not None:
+                load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
+
+    load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
+    # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
+    # it's safe to delete it.
+    del state_dict
+
+    return error_msgs
+
+
+def find_submodule_and_param_name(model, long_key, start_prefix):
+    """
+    A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
+    from the start of the key
+    """
+
+    if len(start_prefix) > 0 and long_key.startswith(start_prefix):
+        long_key = ".".join(long_key.split(".")[1:])
+
+    split_key = long_key.split(".")
+    submodule = model
+    while len(split_key) > 1:
+        if hasattr(submodule, split_key[0]):
+            submodule = getattr(submodule, split_key[0])
+            del split_key[0]
+        else:
+            submodule = None
+            break
+    if submodule == model:
+        submodule = None
+    return submodule, split_key[0]
+
+
+def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
+    """
+    Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
+
+    `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
+    `bert.pooler.dense.weight`
+
+    """
+
+    # dematerialize param storage for keys that are going to be replaced by state_dict, by
+    # putting those on the meta device
+    for k in loaded_state_dict_keys:
+        submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
+        if submodule is not None:
+            # selectively switch to the meta device only those params/buffers that will
+            # be next replaced from state_dict. This a complex way to do p.to_("meta")
+            # since we have no in-place to_ for tensors.
+            new_val = getattr(submodule, param_name)
+            if isinstance(new_val, torch.nn.Parameter):
+                # isinstance returns False for Params on meta device, so switch after the check
+                new_val = torch.nn.Parameter(new_val.to("meta"))
+            else:
+                new_val = new_val.to("meta")
+            setattr(submodule, param_name, new_val)
+
+
+def _load_state_dict_into_meta_model(
+    model,
+    state_dict,
+    start_prefix,
+    expected_keys,
+    device_map=None,
+    offload_folder=None,
+    offload_index=None,
+    state_dict_folder=None,
+    state_dict_index=None,
+    dtype=None,
+    hf_quantizer=None,
+    is_safetensors=False,
+    keep_in_fp32_modules=None,
+    unexpected_keys=None,  # passing `unexpected` for cleanup from quantization items
+    pretrained_model_name_or_path=None,  # for flagging the user when the model contains renamed keys
+):
+    """
+    This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
+    params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
+    params back to the normal device, but only for `loaded_state_dict_keys`.
+
+    `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
+    `bert.pooler.dense.weight`
+
+    """
+
+    # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
+    # - deepspeed zero 3 support
+    # - need to copy metadata if any - see _load_state_dict_into_model
+    # - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
+
+    error_msgs = []
+
+    is_quantized = hf_quantizer is not None
+
+    is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
+
+    for param_name, param in state_dict.items():
+        if param_name not in expected_keys:
+            continue
+
+        if param_name.startswith(start_prefix):
+            param_name = param_name[len(start_prefix) :]
+
+        module_name = param_name
+        set_module_kwargs = {}
+
+        # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
+        # in int/uint/bool and not cast them.
+        is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
+        if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
+            if (
+                keep_in_fp32_modules is not None
+                and any(
+                    module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
+                )
+                and dtype == torch.float16
+            ):
+                param = param.to(torch.float32)
+
+                # For backward compatibility with older versions of `accelerate`
+                # TODO: @sgugger replace this check with version check at the next `accelerate` release
+                if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters):
+                    set_module_kwargs["dtype"] = torch.float32
+            else:
+                param = param.to(dtype)
+
+        # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
+        # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
+        # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
+
+        old_param = model
+        splits = param_name.split(".")
+        for split in splits:
+            # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
+            old_param = getattr(old_param, split, None)
+            if old_param is None:
+                break
+
+        if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
+            old_param = None
+
+        if old_param is not None:
+            if dtype is None:
+                param = param.to(old_param.dtype)
+
+            if old_param.is_contiguous():
+                param = param.contiguous()
+
+        set_module_kwargs["value"] = param
+
+        if device_map is None:
+            param_device = "cpu"
+        else:
+            # find next higher level module that is defined in device_map:
+            # bert.lm_head.weight -> bert.lm_head -> bert -> ''
+            while len(module_name) > 0 and module_name not in device_map:
+                module_name = ".".join(module_name.split(".")[:-1])
+            if module_name == "" and "" not in device_map:
+                # TODO: group all errors and raise at the end.
+                raise ValueError(f"{param_name} doesn't have any device set.")
+            param_device = device_map[module_name]
+
+        if param_device == "disk":
+            if not is_safetensors:
+                offload_index = offload_weight(param, param_name, offload_folder, offload_index)
+        elif param_device == "cpu" and state_dict_index is not None:
+            state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
+        elif (
+            not is_quantized
+            or (not hf_quantizer.requires_parameters_quantization)
+            or (
+                not hf_quantizer.check_quantized_param(
+                    model, param, param_name, state_dict, param_device=param_device, device_map=device_map
+                )
+            )
+        ):
+            if is_fsdp_enabled():
+                param_device = "cpu" if is_local_dist_rank_0() else "meta"
+
+            # For backward compatibility with older versions of `accelerate` and for non-quantized params
+            set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
+        else:
+            hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
+            # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
+            # and then cast it to CPU to avoid excessive memory usage on each GPU
+            # in comparison to the sharded model across GPUs.
+            if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
+                module, tensor_name = get_module_from_name(model, param_name)
+                value = getattr(module, tensor_name)
+                param_to = "cpu"
+                if is_fsdp_enabled() and not is_local_dist_rank_0():
+                    param_to = "meta"
+                val_kwargs = {}
+                if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
+                    val_kwargs["requires_grad"] = False
+                value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
+                setattr(module, tensor_name, value)
+            # TODO: consider removing used param_parts from state_dict before return
+
+    return error_msgs, offload_index, state_dict_index
+
+
+def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
+    if variant is not None:
+        splits = weights_name.split(".")
+        splits = splits[:-1] + [variant] + splits[-1:]
+        weights_name = ".".join(splits)
+
+    return weights_name
+
+
+class ModuleUtilsMixin:
+    """
+    A few utilities for `torch.nn.Modules`, to be used as a mixin.
+    """
+
+    @staticmethod
+    def _hook_rss_memory_pre_forward(module, *args, **kwargs):
+        try:
+            import psutil
+        except ImportError:
+            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
+
+        process = psutil.Process(os.getpid())
+        mem = process.memory_info()
+        module.mem_rss_pre_forward = mem.rss
+        return None
+
+    @staticmethod
+    def _hook_rss_memory_post_forward(module, *args, **kwargs):
+        try:
+            import psutil
+        except ImportError:
+            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
+
+        process = psutil.Process(os.getpid())
+        mem = process.memory_info()
+        module.mem_rss_post_forward = mem.rss
+        mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
+        module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
+        return None
+
+    def add_memory_hooks(self):
+        """
+        Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
+
+        Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero
+        with `model.reset_memory_hooks_state()`.
+        """
+        for module in self.modules():
+            module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
+            module.register_forward_hook(self._hook_rss_memory_post_forward)
+        self.reset_memory_hooks_state()
+
+    def reset_memory_hooks_state(self):
+        """
+        Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]).
+        """
+        for module in self.modules():
+            module.mem_rss_diff = 0
+            module.mem_rss_post_forward = 0
+            module.mem_rss_pre_forward = 0
+
+    @property
+    def device(self) -> torch.device:
+        """
+        `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+        device).
+        """
+        return get_parameter_device(self)
+
+    @property
+    def dtype(self) -> torch.dtype:
+        """
+        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+        """
+        return get_parameter_dtype(self)
+
+    def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
+        """
+        Invert an attention mask (e.g., switches 0. and 1.).
+
+        Args:
+            encoder_attention_mask (`torch.Tensor`): An attention mask.
+
+        Returns:
+            `torch.Tensor`: The inverted attention mask.
+        """
+        if encoder_attention_mask.dim() == 3:
+            encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+        if encoder_attention_mask.dim() == 2:
+            encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+        # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+        # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+        # /transformer/transformer_layers.py#L270
+        # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+        # encoder_extended_attention_mask.transpose(-1, -2))
+        encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
+
+        return encoder_extended_attention_mask
+
+    @staticmethod
+    def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
+        if device is not None:
+            warnings.warn(
+                "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
+            )
+        else:
+            device = attention_mask.device
+        batch_size, seq_length = input_shape
+        seq_ids = torch.arange(seq_length, device=device)
+        causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+        # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+        # causal and attention masks must have same type with pytorch version < 1.3
+        causal_mask = causal_mask.to(attention_mask.dtype)
+
+        if causal_mask.shape[1] < attention_mask.shape[1]:
+            prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+            causal_mask = torch.cat(
+                [
+                    torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+                    causal_mask,
+                ],
+                axis=-1,
+            )
+
+        extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+        return extended_attention_mask
+
+    def get_extended_attention_mask(
+        self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
+    ) -> Tensor:
+        """
+        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (`torch.Tensor`):
+                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+            input_shape (`Tuple[int]`):
+                The shape of the input to the model.
+
+        Returns:
+            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
+        """
+        if dtype is None:
+            dtype = self.dtype
+
+        if not (attention_mask.dim() == 2 and self.config.is_decoder):
+            # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
+            if device is not None:
+                warnings.warn(
+                    "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
+                )
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if attention_mask.dim() == 3:
+            extended_attention_mask = attention_mask[:, None, :, :]
+        elif attention_mask.dim() == 2:
+            # Provided a padding mask of dimensions [batch_size, seq_length]
+            # - if the model is a decoder, apply a causal mask in addition to the padding mask
+            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            if self.config.is_decoder:
+                extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
+                    input_shape, attention_mask, device
+                )
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+        else:
+            raise ValueError(
+                f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and the dtype's smallest value for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
+        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
+        return extended_attention_mask
+
+    def get_head_mask(
+        self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
+    ) -> Tensor:
+        """
+        Prepare the head mask if needed.
+
+        Args:
+            head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
+            num_hidden_layers (`int`):
+                The number of hidden layers in the model.
+            is_attention_chunked (`bool`, *optional*, defaults to `False`):
+                Whether or not the attentions scores are computed by chunks or not.
+
+        Returns:
+            `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
+            `[None]` for each layer.
+        """
+        if head_mask is not None:
+            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+            if is_attention_chunked is True:
+                head_mask = head_mask.unsqueeze(-1)
+        else:
+            head_mask = [None] * num_hidden_layers
+
+        return head_mask
+
+    def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+        """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+        if head_mask.dim() == 1:
+            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+            head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
+        elif head_mask.dim() == 2:
+            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
+        assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
+        head_mask = head_mask.to(dtype=self.dtype)  # switch to float if need + fp16 compatibility
+        return head_mask
+
+    def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+        """
+        Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+        Args:
+            only_trainable (`bool`, *optional*, defaults to `False`):
+                Whether or not to return only the number of trainable parameters
+
+            exclude_embeddings (`bool`, *optional*, defaults to `False`):
+                Whether or not to return only the number of non-embeddings parameters
+
+        Returns:
+            `int`: The number of parameters.
+        """
+
+        if exclude_embeddings:
+            embedding_param_names = [
+                f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
+            ]
+            total_parameters = [
+                parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+            ]
+        else:
+            total_parameters = list(self.parameters())
+
+        total_numel = []
+        is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
+
+        if is_loaded_in_4bit:
+            if is_bitsandbytes_available():
+                import bitsandbytes as bnb
+            else:
+                raise ValueError(
+                    "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
+                    " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
+                )
+
+        for param in total_parameters:
+            if param.requires_grad or not only_trainable:
+                # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
+                # used for the 4bit quantization (uint8 tensors are stored)
+                if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
+                    if hasattr(param, "element_size"):
+                        num_bytes = param.element_size()
+                    elif hasattr(param, "quant_storage"):
+                        num_bytes = param.quant_storage.itemsize
+                    else:
+                        num_bytes = 1
+                    total_numel.append(param.numel() * 2 * num_bytes)
+                else:
+                    total_numel.append(param.numel())
+
+        return sum(total_numel)
+
+    def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
+        """
+        Helper function to estimate the total number of tokens from the model inputs.
+
+        Args:
+            inputs (`dict`): The model inputs.
+
+        Returns:
+            `int`: The total number of tokens.
+        """
+        if not hasattr(self, "warnings_issued"):
+            self.warnings_issued = {}
+        if self.main_input_name in input_dict:
+            return input_dict[self.main_input_name].numel()
+        elif "estimate_tokens" not in self.warnings_issued:
+            logger.warning(
+                "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
+            )
+            self.warnings_issued["estimate_tokens"] = True
+        return 0
+
+    def floating_point_ops(
+        self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
+    ) -> int:
+        """
+        Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
+        batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
+        tokens (valid if `12 * d_model << sequence_length`) as laid out in [this
+        paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter
+        re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.
+
+        Args:
+            batch_size (`int`):
+                The batch size for the forward pass.
+
+            sequence_length (`int`):
+                The number of tokens in each line of the batch.
+
+            exclude_embeddings (`bool`, *optional*, defaults to `True`):
+                Whether or not to count embedding and softmax operations.
+
+        Returns:
+            `int`: The number of floating-point operations.
+        """
+
+        return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
+
+
+# TODO (joao): remove `GenerationMixin` inheritance in v4.50
+class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
+    r"""
+    Base class for all models.
+
+    [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
+    downloading and saving models as well as a few methods common to all models to:
+
+        - resize the input embeddings,
+        - prune heads in the self-attention heads.
+
+    Class attributes (overridden by derived classes):
+
+        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+          for this model architecture.
+        - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
+          taking as arguments:
+
+            - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
+            - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
+            - **path** (`str`) -- A path to the TensorFlow checkpoint.
+
+        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
+          classes of the same architecture adding modules on top of the base model.
+        - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
+        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
+          models, `pixel_values` for vision models and `input_values` for speech models).
+    """
+
+    config_class = None
+    base_model_prefix = ""
+    main_input_name = "input_ids"
+    model_tags = None
+
+    _auto_class = None
+    _no_split_modules = None
+    _skip_keys_device_placement = None
+    _keep_in_fp32_modules = None
+
+    # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
+    # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
+    _keys_to_ignore_on_load_missing = None
+    # a list of `re` patterns of `state_dict` keys that should be removed from the list of
+    # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
+    # warnings.
+    _keys_to_ignore_on_load_unexpected = None
+    # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
+    # trained, but which are either deterministic or tied variables)
+    _keys_to_ignore_on_save = None
+    # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
+    _tied_weights_keys = None
+
+    is_parallelizable = False
+    supports_gradient_checkpointing = False
+    _is_stateful = False
+
+    # Flash Attention 2 support
+    _supports_flash_attn_2 = False
+
+    # SDPA support
+    _supports_sdpa = False
+
+    # Flex Attention support
+    _supports_flex_attn = False
+
+    # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
+    _supports_cache_class = False
+    _supports_static_cache = False
+
+    # Has support for a `QuantoQuantizedCache` instance as `past_key_values`
+    _supports_quantized_cache = False
+
+    # A tensor parallel plan to be applied to the model when TP is enabled. For
+    # top-level models, this attribute is currently defined in respective model
+    # code. For base models, this attribute comes from
+    # `config.base_model_tp_plan` during `post_init`.
+    _tp_plan = None
+
+    @property
+    def dummy_inputs(self) -> Dict[str, torch.Tensor]:
+        """
+        `Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
+        """
+        return {"input_ids": torch.tensor(DUMMY_INPUTS)}
+
+    @property
+    def framework(self) -> str:
+        """
+        :str: Identifies that this is a PyTorch model.
+        """
+        return "pt"
+
+    def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
+        super().__init__()
+        if not isinstance(config, PretrainedConfig):
+            raise ValueError(
+                f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
+                "`PretrainedConfig`. To create a model from a pretrained model use "
+                f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
+            )
+        # Save config and origin of the pretrained weights if given in model
+        if not getattr(config, "_attn_implementation_autoset", False):
+            config = self._autoset_attn_implementation(
+                config, torch_dtype=torch.get_default_dtype(), check_device_map=False
+            )
+        self.config = config
+
+        # for initialization of the loss
+        loss_type = self.__class__.__name__
+        if loss_type not in LOSS_MAPPING:
+            loss_groups = f"({'|'.join(LOSS_MAPPING)})"
+            loss_type = re.findall(loss_groups, self.__class__.__name__)
+            if len(loss_type) > 0:
+                loss_type = loss_type[0]
+            else:
+                loss_type = None
+        self.loss_type = loss_type
+
+        self.name_or_path = config.name_or_path
+        self.warnings_issued = {}
+        self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+        # Overwrite the class attribute to make it an instance attribute, so models like
+        # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
+        # when a different component (e.g. language_model) is used.
+        self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
+
+    def post_init(self):
+        """
+        A method executed at the end of each Transformer model initialization, to execute code that needs the model's
+        modules properly initialized (such as weight initialization).
+        """
+        self.init_weights()
+        self._backward_compatibility_gradient_checkpointing()
+        # If current model is a base model, attach `base_model_tp_plan` from config
+        if self.base_model is self:
+            self._tp_plan = self.config.base_model_tp_plan
+
+    def dequantize(self):
+        """
+        Potentially dequantize the model in case it has been quantized by a quantization method that support
+        dequantization.
+        """
+        hf_quantizer = getattr(self, "hf_quantizer", None)
+
+        if hf_quantizer is None:
+            raise ValueError("You need to first quantize your model in order to dequantize it")
+
+        return hf_quantizer.dequantize(self)
+
+    def _backward_compatibility_gradient_checkpointing(self):
+        if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
+            self.gradient_checkpointing_enable()
+            # Remove the attribute now that is has been consumed, so it's no saved in the config.
+            delattr(self.config, "gradient_checkpointing")
+
+    def add_model_tags(self, tags: Union[List[str], str]) -> None:
+        r"""
+        Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
+        not overwrite existing tags in the model.
+
+        Args:
+            tags (`Union[List[str], str]`):
+                The desired tags to inject in the model
+
+        Examples:
+
+        ```python
+        from transformers import AutoModel
+
+        model = AutoModel.from_pretrained("google-bert/bert-base-cased")
+
+        model.add_model_tags(["custom", "custom-bert"])
+
+        # Push the model to your namespace with the name "my-custom-bert".
+        model.push_to_hub("my-custom-bert")
+        ```
+        """
+        if isinstance(tags, str):
+            tags = [tags]
+
+        if self.model_tags is None:
+            self.model_tags = []
+
+        for tag in tags:
+            if tag not in self.model_tags:
+                self.model_tags.append(tag)
+
+    @classmethod
+    def _from_config(cls, config, **kwargs):
+        """
+        All context managers that the model should be initialized under go here.
+
+        Args:
+            torch_dtype (`torch.dtype`, *optional*):
+                Override the default `torch.dtype` and load the model under this dtype.
+        """
+        # when we init a model from within another model (e.g. VLMs) and dispatch on FA2
+        # a warning is raised that dtype should be fp16. Since we never pass dtype from within
+        # modeling code, we can try to infer it here same way as done in `from_pretrained`
+        torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
+        use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
+
+        # override default dtype if needed
+        dtype_orig = None
+        if torch_dtype is not None:
+            dtype_orig = cls._set_default_torch_dtype(torch_dtype)
+
+        config = copy.deepcopy(config)  # We do not want to modify the config inplace in _from_config.
+
+        if config._attn_implementation_internal is not None:
+            # In this case, the config has been created with the attn_implementation set by the user, which we
+            # should respect.
+            attn_implementation = config._attn_implementation_internal
+        else:
+            attn_implementation = None
+
+        config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
+        if not getattr(config, "_attn_implementation_autoset", False):
+            config = cls._autoset_attn_implementation(
+                config,
+                use_flash_attention_2=use_flash_attention_2,
+                check_device_map=False,
+                torch_dtype=torch_dtype,
+            )
+
+        if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
+            import deepspeed
+
+            logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
+            # this immediately partitions the model across all gpus, to avoid the overhead in time
+            # and memory copying it on CPU or each GPU first
+            init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
+            with ContextManagers(init_contexts):
+                model = cls(config, **kwargs)
+
+        else:
+            model = cls(config, **kwargs)
+
+        # restore default dtype if it was modified
+        if dtype_orig is not None:
+            torch.set_default_dtype(dtype_orig)
+
+        return model
+
+    @classmethod
+    def _autoset_attn_implementation(
+        cls,
+        config,
+        use_flash_attention_2: bool = False,
+        torch_dtype: Optional[torch.dtype] = None,
+        device_map: Optional[Union[str, Dict[str, int]]] = None,
+        check_device_map: bool = True,
+    ):
+        """
+        Automatically checks and dispatches to a default attention implementation. In order of priority:
+            1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
+            2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
+            3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
+            4. The default model's implementation otherwise (`LlamaAttention` for example) .
+        """
+        # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
+        # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
+        # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
+        requested_attn_implementation = None
+        if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
+            if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
+                raise ValueError(
+                    f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
+                    ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
+                )
+
+            if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
+                "eager"
+            ] + list(ALL_ATTENTION_FUNCTIONS.keys()):
+                message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
+                if cls._supports_flash_attn_2:
+                    message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
+                if cls._supports_sdpa:
+                    message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
+                if cls._supports_flex_attn:
+                    message += (
+                        ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
+                    )
+                raise ValueError(message + ".")
+
+            # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
+            requested_attn_implementation = config._attn_implementation_internal
+
+        # Composite models consisting of several PretrainedModels have to specify attention impl as a dict
+        # where keys are sub-config names. But most people will specify one `str` which means that should dispatch it
+        # for all sub-models.
+        # Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
+        # Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
+        # If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
+        for key in config.sub_configs.keys():
+            sub_config = getattr(config, key)
+            curr_attn_implementation = (
+                requested_attn_implementation
+                if not isinstance(requested_attn_implementation, dict)
+                else requested_attn_implementation.get(key, None)
+            )
+            sub_config._attn_implementation_internal = curr_attn_implementation
+
+        if use_flash_attention_2:
+            logger.warning_once(
+                'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
+            )
+            config._attn_implementation = "flash_attention_2"
+
+        if config._attn_implementation == "flash_attention_2":
+            cls._check_and_enable_flash_attn_2(
+                config,
+                torch_dtype=torch_dtype,
+                device_map=device_map,
+                hard_check_only=False,
+                check_device_map=check_device_map,
+            )
+        elif requested_attn_implementation == "flex_attention":
+            config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
+        elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
+            # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
+            config = cls._check_and_enable_sdpa(
+                config,
+                hard_check_only=False if requested_attn_implementation is None else True,
+            )
+
+            if (
+                torch.version.hip is not None
+                and config._attn_implementation == "sdpa"
+                and torch.cuda.device_count() > 1
+            ):
+                logger.warning_once(
+                    "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
+                )
+                torch.backends.cuda.enable_flash_sdp(False)
+        elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
+            config._attn_implementation = requested_attn_implementation
+        elif isinstance(requested_attn_implementation, dict):
+            config._attn_implementation = None
+        else:
+            config._attn_implementation = "eager"
+
+        config._attn_implementation_autoset = True
+        return config
+
+    @classmethod
+    def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
+        """
+        Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
+        under specific dtype.
+
+        Args:
+            dtype (`torch.dtype`):
+                a floating dtype to set to.
+
+        Returns:
+            `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
+            modified. If it wasn't, returns `None`.
+
+        Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
+        `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
+        """
+        if not dtype.is_floating_point:
+            raise ValueError(
+                f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
+            )
+
+        logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
+        dtype_orig = torch.get_default_dtype()
+        torch.set_default_dtype(dtype)
+        return dtype_orig
+
+    @property
+    def base_model(self) -> nn.Module:
+        """
+        `torch.nn.Module`: The main body of the model.
+        """
+        return getattr(self, self.base_model_prefix, self)
+
+    @classmethod
+    def can_generate(cls) -> bool:
+        """
+        Returns whether this model can generate sequences with `.generate()`.
+
+        Returns:
+            `bool`: Whether this model can generate sequences with `.generate()`.
+        """
+        # Directly inherits `GenerationMixin` -> can generate
+        if "GenerationMixin" in str(cls.__bases__):
+            return True
+        # Model class overwrites `generate` (e.g. time series models) -> can generate
+        if str(cls.__name__) in str(cls.generate):
+            return True
+        # The class inherits from a class that can generate (recursive check) -> can generate
+        for base in cls.__bases__:
+            if not hasattr(base, "can_generate"):
+                continue
+            if "PreTrainedModel" not in str(base) and base.can_generate():
+                return True
+        # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
+        # was how we detected whether a model could generate.
+        if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
+            logger.warning_once(
+                f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
+                "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
+                "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
+                "to call `generate` and other related functions."
+                "\n  - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
+                "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes"
+                "\n  - If you are the owner of the model architecture code, please modify your model class such that "
+                "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)."
+                "\n  - If you are not the owner of the model architecture class, please contact the model code owner "
+                "to update it."
+            )
+            return True
+        # Otherwise, can't generate
+        return False
+
+    @classmethod
+    def _check_and_enable_flash_attn_2(
+        cls,
+        config,
+        torch_dtype: Optional[torch.dtype] = None,
+        device_map: Optional[Union[str, Dict[str, int]]] = None,
+        check_device_map: bool = True,
+        hard_check_only: bool = False,
+    ) -> PretrainedConfig:
+        """
+        Checks the availability of Flash Attention 2 and compatibility with the current model.
+
+        If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
+        """
+        if not cls._supports_flash_attn_2:
+            raise ValueError(
+                f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
+                f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
+                " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+            )
+
+        if not is_flash_attn_2_available():
+            preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
+            install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
+
+            if importlib.util.find_spec("flash_attn") is None:
+                raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
+
+            flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
+            if torch.version.cuda:
+                if flash_attention_version < version.parse("2.1.0"):
+                    raise ImportError(
+                        f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
+                    )
+                elif not torch.cuda.is_available():
+                    raise ValueError(
+                        f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
+                    )
+                else:
+                    raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
+            elif torch.version.hip:
+                if flash_attention_version < version.parse("2.0.4"):
+                    raise ImportError(
+                        f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
+                    )
+                else:
+                    raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
+
+        _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+
+        if _is_bettertransformer:
+            raise ValueError(
+                "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
+            )
+
+        if torch_dtype is None:
+            logger.warning_once(
+                "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
+            )
+        elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
+            logger.warning_once(
+                "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
+                f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
+                ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
+            )
+
+        # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
+        # or the model may be initialized under the context manager `with torch.device("cuda"):`.
+        if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
+            if torch.cuda.is_available():
+                logger.warning_once(
+                    "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
+                    " after initializing it on CPU with `model.to('cuda')`."
+                )
+            else:
+                raise ValueError(
+                    "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
+                    "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
+                    "or initialising the model on CPU and then moving it to GPU."
+                )
+        elif (
+            check_device_map
+            and device_map is not None
+            and isinstance(device_map, dict)
+            and ("cpu" in device_map.values() or "disk" in device_map.values())
+        ):
+            raise ValueError(
+                "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
+                "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
+            )
+        if not hard_check_only:
+            config._attn_implementation = "flash_attention_2"
+        return config
+
+    @classmethod
+    def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
+        """
+        Checks the availability of SDPA for a given model.
+
+        If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
+        """
+        if hard_check_only:
+            if not cls._supports_sdpa:
+                raise ValueError(
+                    f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
+                    " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
+                    ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
+                )
+            if not is_torch_sdpa_available():
+                raise ImportError(
+                    "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
+                )
+
+        if not is_torch_sdpa_available() or not cls._supports_sdpa:
+            return config
+
+        _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+        if _is_bettertransformer:
+            return config
+
+        if not hard_check_only:
+            config._attn_implementation = "sdpa"
+        return config
+
+    @classmethod
+    def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
+        """
+        Checks the availability of Flex Attention for a given model.
+
+        If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
+        """
+        if hard_check_only:
+            if not cls._supports_flex_attn:
+                raise ValueError(
+                    f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
+                    " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
+                    " If you believe this error is a bug, please open an issue in Transformers GitHub repository"
+                    ' and load your model with the argument `attn_implementation="eager"` meanwhile.'
+                    ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
+                )
+            if not is_torch_flex_attn_available():
+                raise ImportError(
+                    "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
+                )
+
+        if not is_torch_flex_attn_available() or not cls._supports_flex_attn:
+            return config
+
+        if not hard_check_only:
+            config._attn_implementation = "flex_attention"
+
+        return config
+
+    def enable_input_require_grads(self):
+        """
+        Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
+        the model weights fixed.
+        """
+
+        def make_inputs_require_grads(module, input, output):
+            output.requires_grad_(True)
+
+        self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+
+    def disable_input_require_grads(self):
+        """
+        Removes the `_require_grads_hook`.
+        """
+        self._require_grads_hook.remove()
+
+    def get_input_embeddings(self) -> nn.Module:
+        """
+        Returns the model's input embeddings.
+
+        Returns:
+            `nn.Module`: A torch module mapping vocabulary to hidden states.
+        """
+        base_model = getattr(self, self.base_model_prefix, self)
+        if base_model is not self:
+            return base_model.get_input_embeddings()
+        else:
+            raise NotImplementedError
+
+    def set_input_embeddings(self, value: nn.Module):
+        """
+        Set model's input embeddings.
+
+        Args:
+            value (`nn.Module`): A module mapping vocabulary to hidden states.
+        """
+        base_model = getattr(self, self.base_model_prefix, self)
+        if base_model is not self:
+            base_model.set_input_embeddings(value)
+        else:
+            raise NotImplementedError
+
+    def get_output_embeddings(self) -> nn.Module:
+        """
+        Returns the model's output embeddings.
+
+        Returns:
+            `nn.Module`: A torch module mapping hidden states to vocabulary.
+        """
+        return None  # Overwrite for models with output embeddings
+
+    def _init_weights(self, module):
+        """
+        Initialize the weights. This method should be overridden by derived class and is
+        the only initialization method that will be called when loading a checkpoint
+        using `from_pretrained`. Any attempt to initialize outside of this function
+        will be useless as the torch.nn.init function are all replaced with skip.
+        """
+        pass
+
+    def _initialize_weights(self, module):
+        """
+        Initialize the weights if they are not already initialized.
+        """
+        if getattr(module, "_is_hf_initialized", False):
+            return
+        self._init_weights(module)
+        module._is_hf_initialized = True
+
+    def tie_weights(self):
+        """
+        Tie the weights between the input embeddings and the output embeddings.
+
+        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
+        weights instead.
+        """
+        if getattr(self.config, "tie_word_embeddings", True):
+            output_embeddings = self.get_output_embeddings()
+            if output_embeddings is not None:
+                self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
+
+        if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
+            if hasattr(self, self.base_model_prefix):
+                self = getattr(self, self.base_model_prefix)
+            tied_weights = self._tie_encoder_decoder_weights(
+                self.encoder, self.decoder, self.base_model_prefix, "encoder"
+            )
+            # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+            # attributed not an instance member, therefore modifying it will modify the entire class
+            # Leading to issues on subsequent calls by different tests or subsequent calls.
+            self._dynamic_tied_weights_keys = tied_weights
+
+        for module in self.modules():
+            if hasattr(module, "_tie_weights"):
+                module._tie_weights()
+
+    @staticmethod
+    def _tie_encoder_decoder_weights(
+        encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
+    ):
+        uninitialized_encoder_weights: List[str] = []
+        tied_weights: List[str] = []
+        if decoder.__class__ != encoder.__class__:
+            logger.info(
+                f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
+                " weights are correctly initialized."
+            )
+
+        def tie_encoder_to_decoder_recursively(
+            decoder_pointer: nn.Module,
+            encoder_pointer: nn.Module,
+            module_name: str,
+            base_encoder_name: str,
+            uninitialized_encoder_weights: List[str],
+            depth=0,
+            total_decoder_name="",
+            total_encoder_name="",
+        ):
+            assert isinstance(decoder_pointer, nn.Module) and isinstance(
+                encoder_pointer, nn.Module
+            ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
+            if hasattr(decoder_pointer, "weight"):
+                assert hasattr(encoder_pointer, "weight")
+                encoder_pointer.weight = decoder_pointer.weight
+                tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
+                if hasattr(decoder_pointer, "bias"):
+                    assert hasattr(encoder_pointer, "bias")
+                    tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
+                    encoder_pointer.bias = decoder_pointer.bias
+                return
+
+            encoder_modules = encoder_pointer._modules
+            decoder_modules = decoder_pointer._modules
+            if len(decoder_modules) > 0:
+                assert (
+                    len(encoder_modules) > 0
+                ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
+
+                all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
+                encoder_layer_pos = 0
+                for name, module in decoder_modules.items():
+                    if name.isdigit():
+                        encoder_name = str(int(name) + encoder_layer_pos)
+                        decoder_name = name
+                        if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
+                            encoder_modules
+                        ) != len(decoder_modules):
+                            # this can happen if the name corresponds to the position in a list module list of layers
+                            # in this case the decoder has added a cross-attention that the encoder does not have
+                            # thus skip this step and subtract one layer pos from encoder
+                            encoder_layer_pos -= 1
+                            continue
+                    elif name not in encoder_modules:
+                        continue
+                    elif depth > 500:
+                        raise ValueError(
+                            "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is"
+                            " a circular dependency between two or more `nn.Modules` of your model."
+                        )
+                    else:
+                        decoder_name = encoder_name = name
+                    tie_encoder_to_decoder_recursively(
+                        decoder_modules[decoder_name],
+                        encoder_modules[encoder_name],
+                        module_name + "/" + name,
+                        base_encoder_name,
+                        uninitialized_encoder_weights,
+                        depth=depth + 1,
+                        total_encoder_name=f"{total_encoder_name}.{encoder_name}",
+                        total_decoder_name=f"{total_decoder_name}.{decoder_name}",
+                    )
+                    all_encoder_weights.remove(module_name + "/" + encoder_name)
+
+                uninitialized_encoder_weights += list(all_encoder_weights)
+
+        # tie weights recursively
+        tie_encoder_to_decoder_recursively(
+            decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights
+        )
+
+        if len(uninitialized_encoder_weights) > 0:
+            logger.warning(
+                f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
+            )
+        return tied_weights
+
+    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
+        """Tie or clone module weights depending of whether we are using TorchScript or not"""
+        if self.config.torchscript:
+            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
+        else:
+            output_embeddings.weight = input_embeddings.weight
+
+        if getattr(output_embeddings, "bias", None) is not None:
+            output_embeddings.bias.data = nn.functional.pad(
+                output_embeddings.bias.data,
+                (
+                    0,
+                    output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
+                ),
+                "constant",
+                0,
+            )
+        if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
+            output_embeddings.out_features = input_embeddings.num_embeddings
+
+    def _get_no_split_modules(self, device_map: str):
+        """
+        Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
+        get the underlying `_no_split_modules`.
+
+        Args:
+            device_map (`str`):
+                The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
+
+        Returns:
+            `List[str]`: List of modules that should not be split
+        """
+        _no_split_modules = set()
+        modules_to_check = [self]
+        while len(modules_to_check) > 0:
+            module = modules_to_check.pop(-1)
+            # if the module does not appear in _no_split_modules, we also check the children
+            if module.__class__.__name__ not in _no_split_modules:
+                if isinstance(module, PreTrainedModel):
+                    if module._no_split_modules is None:
+                        raise ValueError(
+                            f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
+                            "class needs to implement the `_no_split_modules` attribute."
+                        )
+                    else:
+                        _no_split_modules = _no_split_modules | set(module._no_split_modules)
+                modules_to_check += list(module.children())
+        return list(_no_split_modules)
+
+    def resize_token_embeddings(
+        self,
+        new_num_tokens: Optional[int] = None,
+        pad_to_multiple_of: Optional[int] = None,
+        mean_resizing: bool = True,
+    ) -> nn.Embedding:
+        """
+        Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
+
+        Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
+
+        Arguments:
+            new_num_tokens (`int`, *optional*):
+                The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
+                vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
+                returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
+                `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
+
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
+                details about this, or help on choosing the correct value for resizing, refer to this guide:
+                https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
+            mean_resizing (`bool`):
+                Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
+                covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
+
+                Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
+                where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
+                old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
+                Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+
+        Return:
+            `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
+        """
+        model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
+        if new_num_tokens is None and pad_to_multiple_of is None:
+            return model_embeds
+
+        # Since we are basically resuing the same old embeddings with new weight values, gathering is required
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
+                vocab_size = model_embeds.weight.shape[0]
+        else:
+            vocab_size = model_embeds.weight.shape[0]
+
+        # Update base model and current model config.
+        self.config.get_text_config().vocab_size = vocab_size
+        self.vocab_size = vocab_size
+
+        # Tie weights again if needed
+        self.tie_weights()
+
+        return model_embeds
+
+    def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
+        old_embeddings = self.get_input_embeddings()
+        new_embeddings = self._get_resized_embeddings(
+            old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing
+        )
+        if hasattr(old_embeddings, "_hf_hook"):
+            hook = old_embeddings._hf_hook
+            add_hook_to_module(new_embeddings, hook)
+        old_embeddings_requires_grad = old_embeddings.weight.requires_grad
+        new_embeddings.requires_grad_(old_embeddings_requires_grad)
+        self.set_input_embeddings(new_embeddings)
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+
+        # Update new_num_tokens with the actual size of new_embeddings
+        if pad_to_multiple_of is not None:
+            if is_deepspeed_zero3_enabled() and not is_quantized:
+                import deepspeed
+
+                with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
+                    new_num_tokens = new_embeddings.weight.shape[0]
+            else:
+                new_num_tokens = new_embeddings.weight.shape[0]
+
+        # if word embeddings are not tied, make sure that lm head is resized as well
+        if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
+            old_lm_head = self.get_output_embeddings()
+            if isinstance(old_lm_head, torch.nn.Embedding):
+                new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
+            else:
+                new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
+            if hasattr(old_lm_head, "_hf_hook"):
+                hook = old_lm_head._hf_hook
+                add_hook_to_module(new_lm_head, hook)
+            old_lm_head_requires_grad = old_lm_head.weight.requires_grad
+            new_lm_head.requires_grad_(old_lm_head_requires_grad)
+            self.set_output_embeddings(new_lm_head)
+
+        return self.get_input_embeddings()
+
+    def _get_resized_embeddings(
+        self,
+        old_embeddings: nn.Embedding,
+        new_num_tokens: Optional[int] = None,
+        pad_to_multiple_of: Optional[int] = None,
+        mean_resizing: bool = True,
+    ) -> nn.Embedding:
+        """
+        Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
+        initialized vectors at the end. Reducing the size will remove vectors from the end
+
+        Args:
+            old_embeddings (`torch.nn.Embedding`):
+                Old embeddings to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the embedding matrix.
+
+                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+                vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
+                `torch.nn.Embedding` module of the model without doing anything.
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
+                `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
+
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
+                details about this, or help on choosing the correct value for resizing, refer to this guide:
+                https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
+            mean_resizing (`bool`):
+                Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
+                covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
+
+                Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
+                where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
+                old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
+                Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+
+
+        Return:
+            `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
+            `new_num_tokens` is `None`
+        """
+
+        if pad_to_multiple_of is not None:
+            if not isinstance(pad_to_multiple_of, int):
+                raise ValueError(
+                    f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
+                )
+            if new_num_tokens is None:
+                new_num_tokens = old_embeddings.weight.shape[0]
+            new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
+        else:
+            logger.info(
+                "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
+                f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
+                " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
+                " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
+            )
+
+        if new_num_tokens is None:
+            return old_embeddings
+
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
+                old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
+        else:
+            old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
+
+        if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
+            return old_embeddings
+
+        if not isinstance(old_embeddings, nn.Embedding):
+            raise TypeError(
+                f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
+                " should either use a different resize function or make sure that `old_embeddings` are an instance of"
+                f" {nn.Embedding}."
+            )
+
+        # Build new embeddings
+
+        # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
+        # because the shape of the new embedding layer is used across various modeling files
+        # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
+        # to errors when training.
+        new_embeddings = nn.Embedding(
+            new_num_tokens,
+            old_embedding_dim,
+            device=old_embeddings.weight.device,
+            dtype=old_embeddings.weight.dtype,
+        )
+
+        if new_num_tokens > old_num_tokens and not mean_resizing:
+            # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
+            self._init_weights(new_embeddings)
+
+        elif new_num_tokens > old_num_tokens and mean_resizing:
+            # initialize new embeddings  (in particular added tokens). The new embeddings will be initialized
+            # from a multivariate normal distribution that has old embeddings' mean and covariance.
+            # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+            logger.warning_once(
+                "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
+                "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
+                "To disable this, use `mean_resizing=False`"
+            )
+
+            added_num_tokens = new_num_tokens - old_num_tokens
+            if is_deepspeed_zero3_enabled() and not is_quantized:
+                import deepspeed
+
+                with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
+                    self._init_added_embeddings_weights_with_mean(
+                        old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
+                    )
+            else:
+                self._init_added_embeddings_weights_with_mean(
+                    old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
+                )
+
+        # Copy token embeddings from the previous weights
+
+        # numbers of tokens to copy
+        n = min(old_num_tokens, new_num_tokens)
+
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            params = [old_embeddings.weight, new_embeddings.weight]
+            with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
+                new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
+        else:
+            new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
+
+        # Replace weights in old_embeddings and return to maintain the same embedding type.
+        # This ensures correct functionality when a Custom Embedding class is passed as input.
+        # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            params = [old_embeddings.weight, new_embeddings.weight]
+            with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
+                old_embeddings.weight = new_embeddings.weight
+                old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
+
+                # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
+                # will be set to `None` in the resized embeddings.
+                if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
+                    old_embeddings.padding_idx = None
+        else:
+            old_embeddings.weight.data = new_embeddings.weight.data
+            old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
+            if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
+                old_embeddings.padding_idx = None
+
+        return old_embeddings
+
+    def _get_resized_lm_head(
+        self,
+        old_lm_head: nn.Linear,
+        new_num_tokens: Optional[int] = None,
+        transposed: Optional[bool] = False,
+        mean_resizing: bool = True,
+    ) -> nn.Linear:
+        """
+        Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
+        vectors at the end. Reducing the size will remove vectors from the end
+
+        Args:
+            old_lm_head (`torch.nn.Linear`):
+                Old lm head liner layer to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the linear matrix.
+
+                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+                vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
+                `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
+                to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
+                vocab_size` else `vocab_size, lm_head_dim`.
+            mean_resizing (`bool`):
+                Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
+                covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
+
+                Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
+                where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
+                old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
+                Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+
+        Return:
+            `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
+            `None`
+        """
+        if new_num_tokens is None:
+            return old_lm_head
+
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
+                old_num_tokens, old_lm_head_dim = (
+                    old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
+                )
+        else:
+            old_num_tokens, old_lm_head_dim = (
+                old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
+            )
+
+        if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
+            return old_lm_head
+
+        if not isinstance(old_lm_head, nn.Linear):
+            raise TypeError(
+                f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
+                " should either use a different resize function or make sure that `old_lm_head` are an instance of"
+                f" {nn.Linear}."
+            )
+
+        # Build new lm head
+        new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
+        has_new_lm_head_bias = old_lm_head.bias is not None
+
+        # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
+        # because the shape of the new embedding layer is used across various modeling files
+        # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
+        # to errors when training.
+        new_lm_head = nn.Linear(
+            *new_lm_head_shape,
+            bias=has_new_lm_head_bias,
+            device=old_lm_head.weight.device,
+            dtype=old_lm_head.weight.dtype,
+        )
+
+        if new_num_tokens > old_num_tokens and not mean_resizing:
+            # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
+            self._init_weights(new_lm_head)
+
+        elif new_num_tokens > old_num_tokens and mean_resizing:
+            # initialize new lm_head weights (in particular added tokens). The new lm_head weights
+            # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance.
+            # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+            logger.warning_once(
+                "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
+                "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
+                "To disable this, use `mean_resizing=False`"
+            )
+
+            added_num_tokens = new_num_tokens - old_num_tokens
+            if is_deepspeed_zero3_enabled() and not is_quantized:
+                import deepspeed
+
+                params = [old_lm_head.weight]
+                if has_new_lm_head_bias:
+                    params += [old_lm_head.bias]
+                with deepspeed.zero.GatheredParameters(params, modifier_rank=None):
+                    self._init_added_lm_head_weights_with_mean(
+                        old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
+                    )
+                    if has_new_lm_head_bias:
+                        self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
+
+            else:
+                self._init_added_lm_head_weights_with_mean(
+                    old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
+                )
+                if has_new_lm_head_bias:
+                    self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
+
+        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
+            with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
+                self._copy_lm_head_original_to_resized(
+                    new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+                )
+        else:
+            self._copy_lm_head_original_to_resized(
+                new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+            )
+
+        return new_lm_head
+
+    def _init_added_embeddings_weights_with_mean(
+        self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
+    ):
+        old_embeddings_weight = old_embeddings.weight.data.to(torch.float32)
+        mean_embeddings = torch.mean(old_embeddings_weight, axis=0)
+        old_centered_embeddings = old_embeddings_weight - mean_embeddings
+        covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
+
+        # Check if the covariance is positive definite.
+        eigenvalues = torch.linalg.eigvals(covariance)
+        is_covariance_psd = bool(
+            (covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all()
+        )
+        if is_covariance_psd:
+            # If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
+            distribution = torch.distributions.multivariate_normal.MultivariateNormal(
+                mean_embeddings, covariance_matrix=1e-9 * covariance
+            )
+            new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
+                sample_shape=(added_num_tokens,)
+            ).to(old_embeddings.weight.dtype)
+        else:
+            # Otherwise, just initialize with the mean. because distribtion will not be created.
+            new_embeddings.weight.data[-1 * added_num_tokens :, :] = (
+                mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype)
+            )
+
+    def _init_added_lm_head_weights_with_mean(
+        self,
+        old_lm_head,
+        new_lm_head,
+        old_lm_head_dim,
+        old_num_tokens,
+        added_num_tokens,
+        transposed=False,
+    ):
+        if transposed:
+            # Transpose to the desired shape for the function.
+            new_lm_head.weight.data = new_lm_head.weight.data.T
+            old_lm_head.weight.data = old_lm_head.weight.data.T
+
+        # The same initilization logic as Embeddings.
+        self._init_added_embeddings_weights_with_mean(
+            old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens
+        )
+
+        if transposed:
+            # Transpose again to the correct shape.
+            new_lm_head.weight.data = new_lm_head.weight.data.T
+            old_lm_head.weight.data = old_lm_head.weight.data.T
+
+    def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens):
+        bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32)
+        bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32)
+        new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std)
+
+    def _copy_lm_head_original_to_resized(
+        self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+    ):
+        # Copy old lm head weights to new lm head
+        if not transposed:
+            new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
+        else:
+            new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
+
+        # Copy bias weights to new lm head
+        if has_new_lm_head_bias:
+            new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
+
+    def resize_position_embeddings(self, new_num_position_embeddings: int):
+        raise NotImplementedError(
+            f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
+            f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
+        )
+
+    def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
+        raise NotImplementedError(
+            f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
+            f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
+        )
+
+    def init_weights(self):
+        """
+        If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
+        initialization logic in `_init_weights`.
+        """
+        # Prune heads if needed
+        if self.config.pruned_heads:
+            self.prune_heads(self.config.pruned_heads)
+
+        if _init_weights:
+            # Initialize weights
+            self.apply(self._initialize_weights)
+
+            # Tie weights should be skipped when not initializing all weights
+            # since from_pretrained(...) calls tie weights anyways
+            self.tie_weights()
+
+    def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
+        """
+        Prunes heads of the base model.
+
+        Arguments:
+            heads_to_prune (`Dict[int, List[int]]`):
+                Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
+                to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
+                layer 1 and heads 2 and 3 on layer 2.
+        """
+        # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
+        for layer, heads in heads_to_prune.items():
+            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
+            self.config.pruned_heads[layer] = list(union_heads)  # Unfortunately we have to store it as list for JSON
+
+        self.base_model._prune_heads(heads_to_prune)
+
+    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
+        """
+        Activates gradient checkpointing for the current model.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+
+        We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
+        the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+
+        Args:
+            gradient_checkpointing_kwargs (dict, *optional*):
+                Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
+        """
+        if not self.supports_gradient_checkpointing:
+            raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+
+        if gradient_checkpointing_kwargs is None:
+            gradient_checkpointing_kwargs = {"use_reentrant": True}
+
+        gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
+
+        # For old GC format (transformers < 4.35.0) for models that live on the Hub
+        # we will fall back to the overwritten `_set_gradient_checkpointing` method
+        _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
+
+        if not _is_using_old_format:
+            self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
+        else:
+            self.apply(partial(self._set_gradient_checkpointing, value=True))
+            logger.warning(
+                "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
+                "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
+            )
+
+        if getattr(self, "_hf_peft_config_loaded", False):
+            # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
+            # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
+            # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
+            # the gradients to make sure the gradient flows.
+            self.enable_input_require_grads()
+
+    def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
+        is_gradient_checkpointing_set = False
+
+        # Apply it on the top-level module in case the top-level modules supports it
+        # for example, LongT5Stack inherits from `PreTrainedModel`.
+        if hasattr(self, "gradient_checkpointing"):
+            self._gradient_checkpointing_func = gradient_checkpointing_func
+            self.gradient_checkpointing = enable
+            is_gradient_checkpointing_set = True
+
+        for module in self.modules():
+            if hasattr(module, "gradient_checkpointing"):
+                module._gradient_checkpointing_func = gradient_checkpointing_func
+                module.gradient_checkpointing = enable
+                is_gradient_checkpointing_set = True
+
+        if not is_gradient_checkpointing_set:
+            raise ValueError(
+                f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
+                " `gradient_checkpointing` to modules of the model that uses checkpointing."
+            )
+
+    def gradient_checkpointing_disable(self):
+        """
+        Deactivates gradient checkpointing for the current model.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+        """
+        if self.supports_gradient_checkpointing:
+            # For old GC format (transformers < 4.35.0) for models that live on the Hub
+            # we will fall back to the overwritten `_set_gradient_checkpointing` methid
+            _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
+            if not _is_using_old_format:
+                self._set_gradient_checkpointing(enable=False)
+            else:
+                logger.warning(
+                    "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
+                    "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
+                )
+                self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+        if getattr(self, "_hf_peft_config_loaded", False):
+            self.disable_input_require_grads()
+
+    @property
+    def is_gradient_checkpointing(self) -> bool:
+        """
+        Whether gradient checkpointing is activated for this model or not.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+        """
+        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+    def save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike],
+        is_main_process: bool = True,
+        state_dict: Optional[dict] = None,
+        save_function: Callable = torch.save,
+        push_to_hub: bool = False,
+        max_shard_size: Union[int, str] = "5GB",
+        safe_serialization: bool = True,
+        variant: Optional[str] = None,
+        token: Optional[Union[str, bool]] = None,
+        save_peft_format: bool = True,
+        **kwargs,
+    ):
+        """
+        Save a model and its configuration file to a directory, so that it can be re-loaded using the
+        [`~PreTrainedModel.from_pretrained`] class method.
+
+        Arguments:
+            save_directory (`str` or `os.PathLike`):
+                Directory to which to save. Will be created if it doesn't exist.
+            is_main_process (`bool`, *optional*, defaults to `True`):
+                Whether the process calling this is the main process or not. Useful when in distributed training like
+                TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+                the main process to avoid race conditions.
+            state_dict (nested dictionary of `torch.Tensor`):
+                The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
+                save parts of the model or if special precautions need to be taken when recovering the state dictionary
+                of a model (like when using model parallelism).
+            save_function (`Callable`):
+                The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+                need to replace `torch.save` by another method.
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
+                The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+                lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+                We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
+                without CPU OOM issues.
+
+                
+
+                If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+                which will be bigger than `max_shard_size`.
+
+                
+
+            safe_serialization (`bool`, *optional*, defaults to `True`):
+                Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+            variant (`str`, *optional*):
+                If specified, weights are saved in the format pytorch_model..bin.
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+            save_peft_format (`bool`, *optional*, defaults to `True`):
+                For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
+                keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can
+                disable this behaviours by setting `save_peft_format` to `False`.
+            kwargs (`Dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None:
+            kwargs["token"] = token
+
+        _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
+
+        hf_quantizer = getattr(self, "hf_quantizer", None)
+        quantization_serializable = (
+            hf_quantizer is not None
+            and isinstance(hf_quantizer, HfQuantizer)
+            and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
+        )
+
+        if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
+            raise ValueError(
+                f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
+                " the logger on the traceback to understand the reason why the quantized model is not serializable."
+            )
+
+        if "save_config" in kwargs:
+            warnings.warn(
+                "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
+            )
+            is_main_process = kwargs.pop("save_config")
+        if safe_serialization and not is_safetensors_available():
+            raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
+
+        if os.path.isfile(save_directory):
+            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+            return
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+        # Only save the model itself if we are using distributed training
+        model_to_save = unwrap_model(self)
+
+        # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
+        # we currently don't use this setting automatically, but may start to use with v5
+        dtype = get_parameter_dtype(model_to_save)
+        model_to_save.config.torch_dtype = str(dtype).split(".")[1]
+
+        # Attach architecture to the config
+        model_to_save.config.architectures = [model_to_save.__class__.__name__]
+
+        # Unset attn implementation so it can be set to another one when loading back
+        model_to_save.config._attn_implementation_autoset = False
+
+        # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            custom_object_save(self, save_directory, config=self.config)
+
+        # Save the config
+        if is_main_process:
+            if not _hf_peft_config_loaded:
+                # If the model config has set attributes that should be in the generation config, move them there.
+                misplaced_generation_parameters = model_to_save.config._get_non_default_generation_parameters()
+                if self.can_generate() and len(misplaced_generation_parameters) > 0:
+                    warnings.warn(
+                        "Moving the following attributes in the config to the generation config: "
+                        f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
+                        "generation parameters in the model config, as opposed to in the generation config.",
+                        UserWarning,
+                    )
+                    for param_name, param_value in misplaced_generation_parameters.items():
+                        setattr(model_to_save.generation_config, param_name, param_value)
+                        setattr(model_to_save.config, param_name, None)
+
+                model_to_save.config.save_pretrained(save_directory)
+            if self.can_generate():
+                model_to_save.generation_config.save_pretrained(save_directory)
+
+            if _hf_peft_config_loaded:
+                logger.info(
+                    "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
+                )
+                state_dict = model_to_save.get_adapter_state_dict()
+
+                if save_peft_format:
+                    logger.info(
+                        "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`."
+                    )
+                    peft_state_dict = {}
+                    for key, value in state_dict.items():
+                        peft_state_dict[f"base_model.model.{key}"] = value
+                    state_dict = peft_state_dict
+
+                active_adapter = self.active_adapters()
+
+                if len(active_adapter) > 1:
+                    raise ValueError(
+                        "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
+                        "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
+                    )
+                active_adapter = active_adapter[0]
+
+                current_peft_config = self.peft_config[active_adapter]
+                current_peft_config.save_pretrained(save_directory)
+
+        # for offloaded modules
+        module_map = {}
+
+        # Save the model
+        if state_dict is None:
+            # if any model parameters are offloaded, make module map
+            if (
+                hasattr(self, "hf_device_map")
+                and len(set(self.hf_device_map.values())) > 1
+                and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
+            ):
+                warnings.warn(
+                    "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
+                )
+                for name, module in model_to_save.named_modules():
+                    if name == "":
+                        continue
+                    module_state_dict = module.state_dict()
+
+                    for key in module_state_dict:
+                        module_map[name + f".{key}"] = module
+            state_dict = model_to_save.state_dict()
+
+        # Translate state_dict from smp to hf if saving with smp >= 1.10
+        if IS_SAGEMAKER_MP_POST_1_10:
+            for smp_to_hf, _ in smp.state.module_manager.translate_functions:
+                state_dict = smp_to_hf(state_dict)
+
+        # Handle the case where some state_dict keys shouldn't be saved
+        if self._keys_to_ignore_on_save is not None:
+            for ignore_key in self._keys_to_ignore_on_save:
+                if ignore_key in state_dict.keys():
+                    del state_dict[ignore_key]
+
+        # Rename state_dict keys before saving to file. Do nothing unless overriden in a particular model.
+        # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
+        state_dict = self._fix_state_dict_keys_on_save(state_dict)
+
+        if safe_serialization:
+            # Safetensors does not allow tensor aliasing.
+            # We're going to remove aliases before saving
+            ptrs = collections.defaultdict(list)
+            for name, tensor in state_dict.items():
+                # Sometimes in the state_dict we have non-tensor objects.
+                # e.g. in bitsandbytes we have some `str` objects in the state_dict
+                if isinstance(tensor, torch.Tensor):
+                    ptrs[id_tensor_storage(tensor)].append(name)
+                else:
+                    # In the non-tensor case, fall back to the pointer of the object itself
+                    ptrs[id(tensor)].append(name)
+
+            # These are all the pointers of shared tensors
+            if hasattr(self, "hf_device_map"):
+                # if the model has offloaded parameters, we must check using find_tied_parameters()
+                tied_params = find_tied_parameters(self)
+                if tied_params:
+                    tied_names = tied_params[0]
+                    shared_ptrs = {
+                        ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names)
+                    }
+                else:
+                    shared_ptrs = {}
+            else:
+                shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
+
+            # Recursively descend to find tied weight keys
+            _tied_weights_keys = _get_tied_weight_keys(self)
+            error_names = []
+            to_delete_names = set()
+            for names in shared_ptrs.values():
+                # Removing the keys which are declared as known duplicates on
+                # load. This allows to make sure the name which is kept is consistent.
+                if _tied_weights_keys is not None:
+                    found = 0
+                    for name in sorted(names):
+                        matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
+                        if matches_pattern and name in state_dict:
+                            found += 1
+                            if found < len(names):
+                                to_delete_names.add(name)
+            # We are entering a place where the weights and the transformers configuration do NOT match.
+            shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
+            # Those are actually tensor sharing but disjoint from each other, we can safely clone them
+            # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
+            for name in disjoint_names:
+                state_dict[name] = state_dict[name].clone()
+
+            # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
+            # If the link between tensors was done at runtime then `from_pretrained` will not get
+            # the key back leading to random tensor. A proper warning will be shown
+            # during reload (if applicable), but since the file is not necessarily compatible with
+            # the config, better show a proper warning.
+            shared_names, identical_names = _find_identical(shared_names, state_dict)
+            # delete tensors that have identical storage
+            for inames in identical_names:
+                known = inames.intersection(to_delete_names)
+                for name in known:
+                    del state_dict[name]
+                unknown = inames.difference(to_delete_names)
+                if len(unknown) > 1:
+                    error_names.append(unknown)
+
+            if shared_names:
+                error_names.append(set(shared_names))
+
+            if len(error_names) > 0:
+                raise RuntimeError(
+                    f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
+                )
+
+        # Shard the model if it is too big.
+        if not _hf_peft_config_loaded:
+            weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+            weights_name = _add_variant(weights_name, variant)
+        else:
+            weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
+
+        filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
+        state_dict_split = split_torch_state_dict_into_shards(
+            state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
+        )
+        # Save index if sharded
+        index = None
+        if state_dict_split.is_sharded:
+            index = {
+                "metadata": state_dict_split.metadata,
+                "weight_map": state_dict_split.tensor_to_filename,
+            }
+
+        # Clean the folder from a previous save
+        for filename in os.listdir(save_directory):
+            full_filename = os.path.join(save_directory, filename)
+            # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+            # in distributed settings to avoid race conditions.
+            weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+
+            # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
+            filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
+            reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
+
+            if (
+                filename.startswith(weights_no_suffix)
+                and os.path.isfile(full_filename)
+                and filename not in state_dict_split.filename_to_tensors.keys()
+                and is_main_process
+                and reg.fullmatch(filename_no_suffix) is not None
+            ):
+                os.remove(full_filename)
+        # Save the model
+        filename_to_tensors = state_dict_split.filename_to_tensors.items()
+        if module_map:
+            filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
+        for shard_file, tensors in filename_to_tensors:
+            shard = {}
+            for tensor in tensors:
+                shard[tensor] = state_dict[tensor].contiguous()
+                # delete reference, see https://github.com/huggingface/transformers/pull/34890
+                del state_dict[tensor]
+
+            # remake shard with onloaded parameters if necessary
+            if module_map:
+                if accelerate_version < version.parse("0.31"):
+                    raise ImportError(
+                        f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
+                        f"Please upgrade accelerate with `pip install -U accelerate`"
+                    )
+                # init state_dict for this shard
+                shard_state_dict = {name: "" for name in shard}
+                for module_name in shard:
+                    module = module_map[module_name]
+                    # update state dict with onloaded parameters
+                    shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
+
+                # assign shard to be the completed state dict
+                shard = shard_state_dict
+                del shard_state_dict
+                gc.collect()
+
+            if safe_serialization:
+                # At some point we will need to deal better with save_function (used for TPU and other distributed
+                # joyfulness), but for now this enough.
+                safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
+            else:
+                save_function(shard, os.path.join(save_directory, shard_file))
+
+        del state_dict
+
+        if index is None:
+            path_to_weights = os.path.join(save_directory, weights_name)
+            logger.info(f"Model weights saved in {path_to_weights}")
+        else:
+            save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+            save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
+            # Save the index as well
+            with open(save_index_file, "w", encoding="utf-8") as f:
+                content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+                f.write(content)
+            logger.info(
+                f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+                f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
+                f"index located at {save_index_file}."
+            )
+
+        if push_to_hub:
+            # Eventually create an empty model card
+            model_card = create_and_tag_model_card(
+                repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
+            )
+
+            # Update model card if needed:
+            model_card.save(os.path.join(save_directory, "README.md"))
+
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=token,
+            )
+
+    @wraps(PushToHubMixin.push_to_hub)
+    def push_to_hub(self, *args, **kwargs):
+        tags = self.model_tags if self.model_tags is not None else []
+
+        tags_kwargs = kwargs.get("tags", [])
+        if isinstance(tags_kwargs, str):
+            tags_kwargs = [tags_kwargs]
+
+        for tag in tags_kwargs:
+            if tag not in tags:
+                tags.append(tag)
+
+        if tags:
+            kwargs["tags"] = tags
+        return super().push_to_hub(*args, **kwargs)
+
+    def get_memory_footprint(self, return_buffers=True):
+        r"""
+        Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
+        Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
+        PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
+
+        Arguments:
+            return_buffers (`bool`, *optional*, defaults to `True`):
+                Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
+                are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
+                norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
+        """
+        mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
+        if return_buffers:
+            mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
+            mem = mem + mem_bufs
+        return mem
+
+    @wraps(torch.nn.Module.cuda)
+    def cuda(self, *args, **kwargs):
+        if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
+            raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
+        # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+        if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+            if getattr(self, "is_loaded_in_8bit", False):
+                raise ValueError(
+                    "Calling `cuda()` is not supported for `8-bit` quantized models. "
+                    " Please use the model as it is, since the model has already been set to the correct devices."
+                )
+            elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
+                raise ValueError(
+                    "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+                    f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+                )
+        else:
+            return super().cuda(*args, **kwargs)
+
+    @wraps(torch.nn.Module.to)
+    def to(self, *args, **kwargs):
+        # For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
+        # the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
+        dtype_present_in_args = "dtype" in kwargs
+
+        if not dtype_present_in_args:
+            for arg in args:
+                if isinstance(arg, torch.dtype):
+                    dtype_present_in_args = True
+                    break
+
+        if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
+            raise ValueError("`.to` is not supported for HQQ-quantized models.")
+        # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+        if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+            if dtype_present_in_args:
+                raise ValueError(
+                    "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
+                    " desired `dtype` by passing the correct `torch_dtype` argument."
+                )
+
+            if getattr(self, "is_loaded_in_8bit", False):
+                raise ValueError(
+                    "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
+                    " model has already been set to the correct devices and casted to the correct `dtype`."
+                )
+            elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
+                raise ValueError(
+                    "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+                    f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+                )
+        elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
+            if dtype_present_in_args:
+                raise ValueError(
+                    "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
+                    " `dtype` by passing the correct `torch_dtype` argument."
+                )
+        return super().to(*args, **kwargs)
+
+    def half(self, *args):
+        # Checks if the model is quantized
+        if getattr(self, "is_quantized", False):
+            raise ValueError(
+                "`.half()` is not supported for quantized model. Please use the model as it is, since the"
+                " model has already been casted to the correct `dtype`."
+            )
+        else:
+            return super().half(*args)
+
+    def float(self, *args):
+        # Checks if the model is quantized
+        if getattr(self, "is_quantized", False):
+            raise ValueError(
+                "`.float()` is not supported for quantized model. Please use the model as it is, since the"
+                " model has already been casted to the correct `dtype`."
+            )
+        else:
+            return super().float(*args)
+
+    @classmethod
+    def from_pretrained(
+        cls: Type[SpecificPreTrainedModelType],
+        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+        *model_args,
+        config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        ignore_mismatched_sizes: bool = False,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        use_safetensors: Optional[bool] = None,
+        weights_only: bool = True,
+        **kwargs,
+    ) -> SpecificPreTrainedModelType:
+        r"""
+        Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+        the model, you should first set it back in training mode with `model.train()`.
+
+        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+        task.
+
+        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+        weights are discarded.
+
+        If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
+        in using the `meta` device and brought into memory once an input is passed through that layer regardless of
+        `low_cpu_mem_usage`.
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+                Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
+                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+                    - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
+                      `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
+                      `True`.
+                    - `None` if you are both providing the configuration and state dictionary (resp. with keyword
+                      arguments `config` and `state_dict`).
+            model_args (sequence of positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+            config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
+                Can be either:
+
+                    - an instance of a class derived from [`PretrainedConfig`],
+                    - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
+
+                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+                be automatically loaded when:
+
+                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+                      model).
+                    - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+                      save directory.
+                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+                      configuration JSON file named *config.json* is found in the directory.
+            state_dict (`Dict[str, torch.Tensor]`, *optional*):
+                A state dictionary to use instead of a state dictionary loaded from saved weights file.
+
+                This option can be used if you want to create a model from a pretrained configuration but load your own
+                weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
+                [`~PreTrainedModel.from_pretrained`] is not a simpler option.
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory in which a downloaded pretrained model configuration should be cached if the
+                standard cache should not be used.
+            from_tf (`bool`, *optional*, defaults to `False`):
+                Load the model weights from a TensorFlow checkpoint save file (see docstring of
+                `pretrained_model_name_or_path` argument).
+            from_flax (`bool`, *optional*, defaults to `False`):
+                Load the model weights from a Flax checkpoint save file (see docstring of
+                `pretrained_model_name_or_path` argument).
+            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+                checkpoint with 3 labels).
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download:
+                Deprecated and ignored. All downloads are now resumed by default when possible.
+                Will be removed in v5 of Transformers.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            output_loading_info(`bool`, *optional*, defaults to `False`):
+                Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+            local_files_only(`bool`, *optional*, defaults to `False`):
+                Whether or not to only look at local files (i.e., do not try to download the model).
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+
+                
+
+                To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+                
+
+            mirror (`str`, *optional*):
+                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+                Please refer to the mirror site for more information.
+            _fast_init(`bool`, *optional*, defaults to `True`):
+                Whether or not to disable fast initialization.
+
+                
+
+                One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
+                4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
+                [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
+
+                
+            attn_implementation (`str`, *optional*):
+                The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
+
+            > Parameters for big model inference
+
+            low_cpu_mem_usage(`bool`, *optional*):
+                Tries not to use more than 1x model size in CPU memory (including peak memory) while loading the model.
+                Generally should be combined with a `device_map` (such as `"auto"`) for best results.
+                This is an experimental feature and a subject to change at any moment.
+                
+                    If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
+                    `device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
+                    this should still be enabled if you are passing in a `device_map`.
+                
+            torch_dtype (`str` or `torch.dtype`, *optional*):
+                Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
+                are:
+
+                1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
+                  `dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified
+                  - the model will get loaded in `torch.float` (fp32).
+
+                2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be
+                  attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
+                  the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
+                  using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
+                  the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
+
+                3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
+
+                
+
+                For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
+                reach out to the authors and ask them to add this information to the model's card and to insert the
+                `torch_dtype` entry in `config.json` on the hub.
+
+                
+
+            device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
+                A map that specifies where each submodule should go. It doesn't need to be refined to each
+                parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+                same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
+                like `1`) on which the model will be allocated, the device map will map the entire model to this
+                device. Passing `device_map = 0` means put the whole model on GPU 0.
+
+                To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+                more information about each option see [designing a device
+                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+            max_memory (`Dict`, *optional*):
+                A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+                GPU and the available CPU RAM if unset.
+            offload_folder (`str` or `os.PathLike`, *optional*):
+                If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
+            offload_state_dict (`bool`, *optional*):
+                If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
+                RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
+                `True` when there is some disk offload.
+            offload_buffers (`bool`, *optional*):
+                Whether or not to offload the buffers with the model parameters.
+            quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
+                A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
+                bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and
+                `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes
+                quantizations and not preferred. consider inserting all such arguments into quantization_config
+                instead.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+                specify the folder name here.
+            variant (`str`, *optional*):
+                If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+                ignored when using `from_tf` or `from_flax`.
+            use_safetensors (`bool`, *optional*, defaults to `None`):
+                Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
+                is not installed, it will be set to `False`.
+
+            weights_only (`bool`, *optional*, defaults to `True`):
+                Indicates whether unpickler should be restricted to loading only tensors, primitive types,
+                dictionaries and any types added via torch.serialization.add_safe_globals().
+                When set to False, we can load wrapper tensor subclass weights.
+
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+                automatically loaded:
+
+                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
+                      already been done)
+                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+                      corresponds to a configuration attribute will be used to override said attribute with the
+                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+                      will be passed to the underlying model's `__init__` function.
+
+        
+
+        Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+        use this method in a firewalled environment.
+
+        
+
+        Examples:
+
+        ```python
+        >>> from transformers import BertConfig, BertModel
+
+        >>> # Download model and configuration from huggingface.co and cache.
+        >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
+        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+        >>> model = BertModel.from_pretrained("./test/saved_model/")
+        >>> # Update configuration during loading.
+        >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
+        >>> assert model.config.output_attentions == True
+        >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
+        >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
+        >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
+        >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
+        >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
+        ```
+
+        * `low_cpu_mem_usage` algorithm:
+
+        This is an experimental function that loads the model using ~1x model size CPU memory
+
+        Here is how it works:
+
+        1. save which state_dict keys we have
+        2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
+        3. after the model has been instantiated switch to the meta device all params/buffers that
+        are going to be replaced from the loaded state_dict
+        4. load state_dict 2nd time
+        5. replace the params/buffers from the state_dict
+
+        Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
+
+        """
+        state_dict = kwargs.pop("state_dict", None)
+        from_tf = kwargs.pop("from_tf", False)
+        from_flax = kwargs.pop("from_flax", False)
+        resume_download = kwargs.pop("resume_download", None)
+        proxies = kwargs.pop("proxies", None)
+        output_loading_info = kwargs.pop("output_loading_info", False)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        trust_remote_code = kwargs.pop("trust_remote_code", None)
+        _ = kwargs.pop("mirror", None)
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+        _fast_init = kwargs.pop("_fast_init", True)
+        torch_dtype = kwargs.pop("torch_dtype", None)
+        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
+        device_map = kwargs.pop("device_map", None)
+        max_memory = kwargs.pop("max_memory", None)
+        offload_folder = kwargs.pop("offload_folder", None)
+        offload_state_dict = kwargs.pop("offload_state_dict", False)
+        offload_buffers = kwargs.pop("offload_buffers", False)
+        load_in_8bit = kwargs.pop("load_in_8bit", False)
+        load_in_4bit = kwargs.pop("load_in_4bit", False)
+        quantization_config = kwargs.pop("quantization_config", None)
+        subfolder = kwargs.pop("subfolder", "")
+        commit_hash = kwargs.pop("_commit_hash", None)
+        variant = kwargs.pop("variant", None)
+        adapter_kwargs = kwargs.pop("adapter_kwargs", {})
+        adapter_name = kwargs.pop("adapter_name", "default")
+        use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
+        generation_config = kwargs.pop("generation_config", None)
+
+        gguf_file = kwargs.pop("gguf_file", None)
+        # Cache path to the GGUF file
+        gguf_path = None
+
+        tp_plan = kwargs.pop("tp_plan", None)
+        if tp_plan is not None and tp_plan != "auto":
+            # TODO: we can relax this check when we support taking tp_plan from a json file, for example.
+            raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
+
+        if is_fsdp_enabled():
+            low_cpu_mem_usage = True
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
+            adapter_kwargs["token"] = token
+
+        if use_safetensors is None and not is_safetensors_available():
+            use_safetensors = False
+        if trust_remote_code is True:
+            logger.warning(
+                "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
+                " ignored."
+            )
+
+        if gguf_file is not None and not is_accelerate_available():
+            raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.")
+
+        if commit_hash is None:
+            if not isinstance(config, PretrainedConfig):
+                # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
+                resolved_config_file = cached_file(
+                    pretrained_model_name_or_path,
+                    CONFIG_NAME,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _raise_exceptions_for_gated_repo=False,
+                    _raise_exceptions_for_missing_entries=False,
+                    _raise_exceptions_for_connection_errors=False,
+                )
+                commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+            else:
+                commit_hash = getattr(config, "_commit_hash", None)
+
+        if is_peft_available():
+            _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
+
+            if _adapter_model_path is None:
+                _adapter_model_path = find_adapter_config_file(
+                    pretrained_model_name_or_path,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    _commit_hash=commit_hash,
+                    **adapter_kwargs,
+                )
+            if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
+                with open(_adapter_model_path, "r", encoding="utf-8") as f:
+                    _adapter_model_path = pretrained_model_name_or_path
+                    pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
+        else:
+            _adapter_model_path = None
+
+        # change device_map into a map if we passed an int, a str or a torch.device
+        if isinstance(device_map, torch.device):
+            device_map = {"": device_map}
+        elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
+            try:
+                device_map = {"": torch.device(device_map)}
+            except RuntimeError:
+                raise ValueError(
+                    "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
+                    f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
+                )
+        elif isinstance(device_map, int):
+            if device_map < 0:
+                raise ValueError(
+                    "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
+                )
+            else:
+                device_map = {"": device_map}
+
+        if device_map is not None:
+            if low_cpu_mem_usage is None:
+                low_cpu_mem_usage = True
+            elif not low_cpu_mem_usage:
+                raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
+
+        if low_cpu_mem_usage:
+            if is_deepspeed_zero3_enabled():
+                raise ValueError(
+                    "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
+                )
+            elif not is_accelerate_available():
+                raise ImportError(
+                    f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
+                )
+
+        # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
+        if load_in_4bit or load_in_8bit:
+            if quantization_config is not None:
+                raise ValueError(
+                    "You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing "
+                    "`quantization_config` argument at the same time."
+                )
+
+            # preparing BitsAndBytesConfig from kwargs
+            config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
+            config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit}
+            quantization_config, kwargs = BitsAndBytesConfig.from_dict(
+                config_dict=config_dict, return_unused_kwargs=True, **kwargs
+            )
+            logger.warning(
+                "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. "
+                "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead."
+            )
+
+        from_pt = not (from_tf | from_flax)
+
+        user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        # Load config if we don't provide a configuration
+        if not isinstance(config, PretrainedConfig):
+            config_path = config if config is not None else pretrained_model_name_or_path
+            config, model_kwargs = cls.config_class.from_pretrained(
+                config_path,
+                cache_dir=cache_dir,
+                return_unused_kwargs=True,
+                force_download=force_download,
+                resume_download=resume_download,
+                proxies=proxies,
+                local_files_only=local_files_only,
+                token=token,
+                revision=revision,
+                subfolder=subfolder,
+                _from_auto=from_auto_class,
+                _from_pipeline=from_pipeline,
+                **kwargs,
+            )
+        else:
+            # In case one passes a config to `from_pretrained` + "attn_implementation"
+            # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
+            # Please see: https://github.com/huggingface/transformers/issues/28038
+
+            # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
+            # we pop attn_implementation from the kwargs but this handles the case where users
+            # passes manually the config to `from_pretrained`.
+            config = copy.deepcopy(config)
+
+            kwarg_attn_imp = kwargs.pop("attn_implementation", None)
+            if kwarg_attn_imp is not None:
+                config._attn_implementation = kwarg_attn_imp
+
+            model_kwargs = kwargs
+
+        pre_quantized = getattr(config, "quantization_config", None) is not None
+        if pre_quantized or quantization_config is not None:
+            if pre_quantized:
+                config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
+                    config.quantization_config, quantization_config
+                )
+            else:
+                config.quantization_config = quantization_config
+
+            hf_quantizer = AutoHfQuantizer.from_config(
+                config.quantization_config,
+                pre_quantized=pre_quantized,
+            )
+
+        else:
+            hf_quantizer = None
+
+        if hf_quantizer is not None:
+            hf_quantizer.validate_environment(
+                torch_dtype=torch_dtype,
+                from_tf=from_tf,
+                from_flax=from_flax,
+                device_map=device_map,
+                weights_only=weights_only,
+            )
+            torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+            device_map = hf_quantizer.update_device_map(device_map)
+
+            # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
+            user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
+
+            # Force-set to `True` for more mem efficiency
+            if low_cpu_mem_usage is None:
+                low_cpu_mem_usage = True
+                logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.")
+        is_quantized = hf_quantizer is not None
+
+        # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+        # index of the files.
+        is_sharded = False
+        sharded_metadata = None
+        # Load model
+        loading_info = None
+
+        # Keep in fp32 modules
+        keep_in_fp32_modules = None
+        use_keep_in_fp32_modules = False
+
+        if gguf_file is not None and hf_quantizer is not None:
+            raise ValueError(
+                "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
+            )
+
+        if pretrained_model_name_or_path is not None and gguf_file is None:
+            pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+            is_local = os.path.isdir(pretrained_model_name_or_path)
+            if is_local:
+                if from_tf and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
+                ):
+                    # Load from a TF 1.0 checkpoint in priority if from_tf
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
+                elif from_tf and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
+                ):
+                    # Load from a TF 2.0 checkpoint in priority if from_tf
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
+                elif from_flax and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+                ):
+                    # Load from a Flax checkpoint in priority if from_flax
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+                elif use_safetensors is not False and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
+                ):
+                    # Load from a safetensors checkpoint
+                    archive_file = os.path.join(
+                        pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
+                    )
+                elif use_safetensors is not False and os.path.isfile(
+                    os.path.join(
+                        pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
+                    )
+                ):
+                    # Load from a sharded safetensors checkpoint
+                    archive_file = os.path.join(
+                        pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
+                    )
+                    is_sharded = True
+                elif not use_safetensors and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
+                ):
+                    # Load from a PyTorch checkpoint
+                    archive_file = os.path.join(
+                        pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
+                    )
+                elif not use_safetensors and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
+                ):
+                    # Load from a sharded PyTorch checkpoint
+                    archive_file = os.path.join(
+                        pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
+                    )
+                    is_sharded = True
+                # At this stage we don't have a weight file so we will raise an error.
+                elif not use_safetensors and (
+                    os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
+                    or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
+                ):
+                    raise EnvironmentError(
+                        f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
+                        f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
+                        " `from_tf=True` to load this model from those weights."
+                    )
+                elif not use_safetensors and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+                ):
+                    raise EnvironmentError(
+                        f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
+                        f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
+                        " to load this model from those weights."
+                    )
+                elif use_safetensors:
+                    raise EnvironmentError(
+                        f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory"
+                        f" {pretrained_model_name_or_path}."
+                    )
+                else:
+                    raise EnvironmentError(
+                        f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
+                        f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
+                        f" {pretrained_model_name_or_path}."
+                    )
+            elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+                archive_file = pretrained_model_name_or_path
+                is_local = True
+            elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
+                if not from_tf:
+                    raise ValueError(
+                        f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
+                        "from_tf to True to load from this checkpoint."
+                    )
+                archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
+                is_local = True
+            elif is_remote_url(pretrained_model_name_or_path):
+                filename = pretrained_model_name_or_path
+                resolved_archive_file = download_url(pretrained_model_name_or_path)
+            else:
+                # set correct filename
+                if from_tf:
+                    filename = TF2_WEIGHTS_NAME
+                elif from_flax:
+                    filename = FLAX_WEIGHTS_NAME
+                elif use_safetensors is not False:
+                    filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
+                else:
+                    filename = _add_variant(WEIGHTS_NAME, variant)
+
+                try:
+                    # Load from URL or cache if already cached
+                    cached_file_kwargs = {
+                        "cache_dir": cache_dir,
+                        "force_download": force_download,
+                        "proxies": proxies,
+                        "resume_download": resume_download,
+                        "local_files_only": local_files_only,
+                        "token": token,
+                        "user_agent": user_agent,
+                        "revision": revision,
+                        "subfolder": subfolder,
+                        "_raise_exceptions_for_gated_repo": False,
+                        "_raise_exceptions_for_missing_entries": False,
+                        "_commit_hash": commit_hash,
+                    }
+                    resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+
+                    # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
+                    # result when internet is up, the repo and revision exist, but the file does not.
+                    if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
+                        # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path,
+                            _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
+                            **cached_file_kwargs,
+                        )
+                        if resolved_archive_file is not None:
+                            is_sharded = True
+                        elif use_safetensors:
+                            if revision == "main":
+                                resolved_archive_file, revision, is_sharded = auto_conversion(
+                                    pretrained_model_name_or_path, **cached_file_kwargs
+                                )
+                            cached_file_kwargs["revision"] = revision
+                            if resolved_archive_file is None:
+                                raise EnvironmentError(
+                                    f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                    f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
+                                    "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
+                                    "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
+                                )
+                        else:
+                            # This repo has no safetensors file of any kind, we switch to PyTorch.
+                            filename = _add_variant(WEIGHTS_NAME, variant)
+                            resolved_archive_file = cached_file(
+                                pretrained_model_name_or_path, filename, **cached_file_kwargs
+                            )
+                    if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
+                        # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path,
+                            _add_variant(WEIGHTS_INDEX_NAME, variant),
+                            **cached_file_kwargs,
+                        )
+                        if resolved_archive_file is not None:
+                            is_sharded = True
+                    if not local_files_only and not is_offline_mode():
+                        if resolved_archive_file is not None:
+                            if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
+                                # If the PyTorch file was found, check if there is a safetensors file on the repository
+                                # If there is no safetensors file on the repositories, start an auto conversion
+                                safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
+                                has_file_kwargs = {
+                                    "revision": revision,
+                                    "proxies": proxies,
+                                    "token": token,
+                                    "cache_dir": cache_dir,
+                                    "local_files_only": local_files_only,
+                                }
+                                cached_file_kwargs = {
+                                    "cache_dir": cache_dir,
+                                    "force_download": force_download,
+                                    "resume_download": resume_download,
+                                    "local_files_only": local_files_only,
+                                    "user_agent": user_agent,
+                                    "subfolder": subfolder,
+                                    "_raise_exceptions_for_gated_repo": False,
+                                    "_raise_exceptions_for_missing_entries": False,
+                                    "_commit_hash": commit_hash,
+                                    **has_file_kwargs,
+                                }
+                                if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
+                                    Thread(
+                                        target=auto_conversion,
+                                        args=(pretrained_model_name_or_path,),
+                                        kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
+                                        name="Thread-auto_conversion",
+                                    ).start()
+                        else:
+                            # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
+                            # We try those to give a helpful error message.
+                            has_file_kwargs = {
+                                "revision": revision,
+                                "proxies": proxies,
+                                "token": token,
+                                "cache_dir": cache_dir,
+                                "local_files_only": local_files_only,
+                            }
+                            if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
+                                raise EnvironmentError(
+                                    f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                    f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
+                                    " Use `from_tf=True` to load this model from those weights."
+                                )
+                            elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
+                                raise EnvironmentError(
+                                    f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                    f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
+                                    " `from_flax=True` to load this model from those weights."
+                                )
+                            elif variant is not None and has_file(
+                                pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
+                            ):
+                                raise EnvironmentError(
+                                    f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                    f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
+                                    f" {variant}. Use `variant=None` to load this model from those weights."
+                                )
+                            else:
+                                raise EnvironmentError(
+                                    f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                    f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
+                                    f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
+                                )
+
+                except EnvironmentError:
+                    # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+                    # to the original exception.
+                    raise
+                except Exception as e:
+                    # For any other exception, we throw a generic error.
+                    raise EnvironmentError(
+                        f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+                        " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+                        f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+                        f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
+                        f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
+                    ) from e
+
+            if is_local:
+                logger.info(f"loading weights file {archive_file}")
+                resolved_archive_file = archive_file
+            else:
+                logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
+        elif gguf_file:
+            from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
+
+            # Case 1: the GGUF file is present locally
+            if os.path.isfile(gguf_file):
+                gguf_path = gguf_file
+            # Case 2: The GGUF path is a location on the Hub
+            # Load from URL or cache if already cached
+            else:
+                cached_file_kwargs = {
+                    "cache_dir": cache_dir,
+                    "force_download": force_download,
+                    "proxies": proxies,
+                    "resume_download": resume_download,
+                    "local_files_only": local_files_only,
+                    "token": token,
+                    "user_agent": user_agent,
+                    "revision": revision,
+                    "subfolder": subfolder,
+                    "_raise_exceptions_for_gated_repo": False,
+                    "_raise_exceptions_for_missing_entries": False,
+                    "_commit_hash": commit_hash,
+                }
+
+                gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs)
+
+            # we need a dummy model to help rename state_dict
+            with torch.device("meta"):
+                dummy_model = cls(config)
+            state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True, model_to_load=dummy_model)["tensors"]
+
+            resolved_archive_file = None
+            is_sharded = False
+        else:
+            resolved_archive_file = None
+
+        # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
+        if is_sharded:
+            # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
+            resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
+                pretrained_model_name_or_path,
+                resolved_archive_file,
+                cache_dir=cache_dir,
+                force_download=force_download,
+                proxies=proxies,
+                resume_download=resume_download,
+                local_files_only=local_files_only,
+                token=token,
+                user_agent=user_agent,
+                revision=revision,
+                subfolder=subfolder,
+                _commit_hash=commit_hash,
+            )
+
+        if (
+            is_safetensors_available()
+            and isinstance(resolved_archive_file, str)
+            and resolved_archive_file.endswith(".safetensors")
+        ):
+            with safe_open(resolved_archive_file, framework="pt") as f:
+                metadata = f.metadata()
+
+            if metadata is None:
+                # Assume it's a pytorch checkpoint (introduced for timm checkpoints)
+                pass
+            elif metadata.get("format") == "pt":
+                pass
+            elif metadata.get("format") == "tf":
+                from_tf = True
+                logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
+            elif metadata.get("format") == "flax":
+                from_flax = True
+                logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
+            elif metadata.get("format") == "mlx":
+                # This is a mlx file, we assume weights are compatible with pt
+                pass
+            else:
+                raise ValueError(
+                    f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
+                )
+
+        from_pt = not (from_tf | from_flax)
+
+        # load pt weights early so that we know which dtype to init the model under
+
+        if from_pt:
+            if not is_sharded and state_dict is None:
+                # Time to load the checkpoint
+                state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
+
+            # set dtype to instantiate the model under:
+            # 1. If torch_dtype is not None, we use that dtype
+            # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
+            #    weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
+            # we also may have config.torch_dtype available, but we won't rely on it till v5
+            dtype_orig = None
+
+            if torch_dtype is not None:
+                if isinstance(torch_dtype, str):
+                    if torch_dtype == "auto":
+                        if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
+                            torch_dtype = config.torch_dtype
+                            logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object")
+                        else:
+                            if is_sharded and "dtype" in sharded_metadata:
+                                torch_dtype = sharded_metadata["dtype"]
+                            elif not is_sharded:
+                                torch_dtype = get_state_dict_dtype(state_dict)
+                            else:
+                                one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
+                                torch_dtype = get_state_dict_dtype(one_state_dict)
+                                del one_state_dict  # free CPU memory
+                            logger.info(
+                                "Since the `torch_dtype` attribute can't be found in model's config object, "
+                                "will use torch_dtype={torch_dtype} as derived from model's weights"
+                            )
+                    elif hasattr(torch, torch_dtype):
+                        torch_dtype = getattr(torch, torch_dtype)
+                        for sub_config_key in config.sub_configs.keys():
+                            sub_config = getattr(config, sub_config_key)
+                            sub_config.torch_dtype = torch_dtype
+                elif isinstance(torch_dtype, torch.dtype):
+                    for sub_config_key in config.sub_configs.keys():
+                        sub_config = getattr(config, sub_config_key)
+                        sub_config.torch_dtype = torch_dtype
+                elif isinstance(torch_dtype, dict):
+                    for key, curr_dtype in torch_dtype.items():
+                        if hasattr(config, key):
+                            value = getattr(config, key)
+                            value.torch_dtype = curr_dtype
+                    # main torch dtype for modules that aren't part of any sub-config
+                    torch_dtype = torch_dtype.get("")
+                    config.torch_dtype = torch_dtype
+                    if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
+                        torch_dtype = getattr(torch, torch_dtype)
+                    elif torch_dtype is None:
+                        torch_dtype = torch.float32
+                else:
+                    raise ValueError(
+                        f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
+                        f"for each sub-config in composite configs, but received {torch_dtype}"
+                    )
+
+                dtype_orig = cls._set_default_torch_dtype(torch_dtype)
+
+            # Check if `_keep_in_fp32_modules` is not None
+            use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
+                (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
+            )
+
+            if is_sharded:
+                loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
+            else:
+                loaded_state_dict_keys = list(state_dict.keys())
+            if (
+                gguf_path is None
+                and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()))
+                and pretrained_model_name_or_path is not None
+            ):
+                # In case some weights need to be kept in float32 and accelerate is not installed,
+                # we later on want to take the path where state_dict is not None, that is the one
+                # that do not require accelerate.
+                state_dict = None
+
+        config.name_or_path = pretrained_model_name_or_path
+
+        # Instantiate model.
+        init_contexts = [no_init_weights(_enable=_fast_init)]
+        tp_device = None
+
+        if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
+            import deepspeed
+
+            logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
+            init_contexts = [
+                deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
+                set_zero3_state(),
+            ] + init_contexts
+        elif low_cpu_mem_usage:
+            if not is_accelerate_available():
+                raise ImportError(
+                    f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
+                )
+            init_contexts.append(init_empty_weights())
+        elif tp_plan is not None:
+            if not torch.distributed.is_initialized():
+                raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
+
+            # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
+            device_type = torch._C._get_accelerator().type
+            device_module = torch.get_device_module(device_type)
+            # Get device with index assuming equal number of devices per host
+            tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
+            init_contexts.append(tp_device)
+
+        if is_deepspeed_zero3_enabled() and is_quantized:
+            init_contexts.append(set_quantized_state())
+
+        config = copy.deepcopy(config)  # We do not want to modify the config inplace in from_pretrained.
+        if not getattr(config, "_attn_implementation_autoset", False):
+            config = cls._autoset_attn_implementation(
+                config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
+            )
+
+        with ContextManagers(init_contexts):
+            # Let's make sure we don't run the init function of buffer modules
+            model = cls(config, *model_args, **model_kwargs)
+
+        # make sure we use the model's config since the __init__ call might have copied it
+        config = model.config
+
+        # Check first if we are `from_pt`
+        if use_keep_in_fp32_modules:
+            if is_accelerate_available() and not is_deepspeed_zero3_enabled():
+                low_cpu_mem_usage = True
+            keep_in_fp32_modules = model._keep_in_fp32_modules
+        else:
+            keep_in_fp32_modules = []
+
+        if hf_quantizer is not None:
+            hf_quantizer.preprocess_model(
+                model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
+            )
+
+            # We store the original dtype for quantized models as we cannot easily retrieve it
+            # once the weights have been quantized
+            # Note that once you have loaded a quantized model, you can't change its dtype so this will
+            # remain a single source of truth
+            config._pre_quantization_dtype = torch_dtype
+
+        if isinstance(device_map, str):
+            special_dtypes = {}
+
+            if hf_quantizer is not None:
+                special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
+
+            special_dtypes.update(
+                {
+                    name: torch.float32
+                    for name, _ in model.named_parameters()
+                    if any(m in name for m in keep_in_fp32_modules)
+                }
+            )
+
+            target_dtype = torch_dtype
+
+            if hf_quantizer is not None:
+                target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
+
+            no_split_modules = model._get_no_split_modules(device_map)
+            if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
+                raise ValueError(
+                    "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
+                    "'sequential'."
+                )
+
+            device_map_kwargs = {"no_split_module_classes": no_split_modules}
+            if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
+                device_map_kwargs["special_dtypes"] = special_dtypes
+            elif len(special_dtypes) > 0:
+                logger.warning(
+                    "This model has some weights that should be kept in higher precision, you need to upgrade "
+                    "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
+                )
+            if device_map != "sequential":
+                max_memory = get_balanced_memory(
+                    model,
+                    dtype=target_dtype,
+                    low_zero=(device_map == "balanced_low_0"),
+                    max_memory=max_memory,
+                    **device_map_kwargs,
+                )
+            else:
+                max_memory = get_max_memory(max_memory)
+            if hf_quantizer is not None:
+                max_memory = hf_quantizer.adjust_max_memory(max_memory)
+            device_map_kwargs["max_memory"] = max_memory
+
+            # Make sure tied weights are tied before creating the device map.
+            model.tie_weights()
+            device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
+
+            if hf_quantizer is not None:
+                hf_quantizer.validate_environment(device_map=device_map)
+
+        elif device_map is not None:
+            model.tie_weights()
+            tied_params = find_tied_parameters(model)
+            # check if we don't have tied param in different devices
+            check_tied_parameters_on_same_device(tied_params, device_map)
+
+        if from_tf:
+            if resolved_archive_file.endswith(".index"):
+                # Load from a TensorFlow 1.X checkpoint - provided by original authors
+                model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])  # Remove the '.index'
+            else:
+                # Load from our TensorFlow 2.0 checkpoints
+                try:
+                    from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model
+
+                    model, loading_info = load_tf2_checkpoint_in_pytorch_model(
+                        model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True
+                    )
+                except ImportError:
+                    logger.error(
+                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed."
+                        " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation"
+                        " instructions."
+                    )
+                    raise
+        elif from_flax:
+            try:
+                from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
+
+                model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
+            except ImportError:
+                logger.error(
+                    "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see"
+                    " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for"
+                    " installation instructions."
+                )
+                raise
+        elif from_pt:
+            # restore default dtype
+            if dtype_orig is not None:
+                torch.set_default_dtype(dtype_orig)
+
+            load_contexts = []
+            # Make sure we load onto targeted device
+            if tp_device is not None:
+                load_contexts.append(tp_device)
+
+            with ContextManagers(load_contexts):
+                (
+                    model,
+                    missing_keys,
+                    unexpected_keys,
+                    mismatched_keys,
+                    offload_index,
+                    error_msgs,
+                ) = cls._load_pretrained_model(
+                    model,
+                    state_dict,
+                    loaded_state_dict_keys,  # XXX: rename?
+                    resolved_archive_file,
+                    pretrained_model_name_or_path,
+                    ignore_mismatched_sizes=ignore_mismatched_sizes,
+                    sharded_metadata=sharded_metadata,
+                    _fast_init=_fast_init,
+                    low_cpu_mem_usage=low_cpu_mem_usage,
+                    device_map=device_map,
+                    offload_folder=offload_folder,
+                    offload_state_dict=offload_state_dict,
+                    dtype=torch_dtype,
+                    hf_quantizer=hf_quantizer,
+                    keep_in_fp32_modules=keep_in_fp32_modules,
+                    gguf_path=gguf_path,
+                    weights_only=weights_only,
+                )
+
+        # make sure token embedding weights are still tied if needed
+        model.tie_weights()
+
+        # Set model in evaluation mode to deactivate DropOut modules by default
+        model.eval()
+
+        # If it is a model with generation capabilities, attempt to load the generation config
+        if model.can_generate() and generation_config is not None:
+            logger.info("The user-defined `generation_config` will be used to override the default generation config.")
+            model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
+        elif model.can_generate() and pretrained_model_name_or_path is not None:
+            try:
+                model.generation_config = GenerationConfig.from_pretrained(
+                    pretrained_model_name_or_path,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _from_auto=from_auto_class,
+                    _from_pipeline=from_pipeline,
+                    **kwargs,
+                )
+            except OSError:
+                logger.info(
+                    "Generation config file not found, using a generation config created from the model config."
+                )
+                pass
+
+        # Dispatch model with hooks on all devices if necessary
+        if device_map is not None:
+            device_map_kwargs = {
+                "device_map": device_map,
+                "offload_dir": offload_folder,
+                "offload_index": offload_index,
+                "offload_buffers": offload_buffers,
+            }
+            if "skip_keys" in inspect.signature(dispatch_model).parameters:
+                device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
+            # For HQQ method we force-set the hooks for single GPU envs
+            if (
+                "force_hooks" in inspect.signature(dispatch_model).parameters
+                and hf_quantizer is not None
+                and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
+            ):
+                device_map_kwargs["force_hooks"] = True
+            if (
+                hf_quantizer is not None
+                and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
+                and isinstance(device_map, dict)
+                and ("cpu" in device_map.values() or "disk" in device_map.values())
+            ):
+                device_map_kwargs["offload_buffers"] = True
+
+            if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
+                dispatch_model(model, **device_map_kwargs)
+
+        if hf_quantizer is not None:
+            hf_quantizer.postprocess_model(model, config=config)
+            model.hf_quantizer = hf_quantizer
+
+        if _adapter_model_path is not None:
+            model.load_adapter(
+                _adapter_model_path,
+                adapter_name=adapter_name,
+                token=token,
+                adapter_kwargs=adapter_kwargs,
+            )
+
+        if output_loading_info:
+            if loading_info is None:
+                loading_info = {
+                    "missing_keys": missing_keys,
+                    "unexpected_keys": unexpected_keys,
+                    "mismatched_keys": mismatched_keys,
+                    "error_msgs": error_msgs,
+                }
+            return model, loading_info
+
+        if tp_plan is not None:
+            assert tp_device is not None, "tp_device not set!"
+            if not model.supports_tp_plan:
+                raise NotImplementedError("This model does not have a tensor parallel plan.")
+            # Assuming sharding the model onto the world
+            world_size = torch.distributed.get_world_size()
+            device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
+            # Apply Tensor Parallelism
+            model.tensor_parallel(device_mesh)
+
+        return model
+
+    @staticmethod
+    def _fix_state_dict_key_on_load(key):
+        """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
+
+        if "beta" in key:
+            return key.replace("beta", "bias")
+        if "gamma" in key:
+            return key.replace("gamma", "weight")
+
+        # to avoid logging parametrized weight norm renaming
+        if hasattr(nn.utils.parametrizations, "weight_norm"):
+            if "weight_g" in key:
+                return key.replace("weight_g", "parametrizations.weight.original0")
+            if "weight_v" in key:
+                return key.replace("weight_v", "parametrizations.weight.original1")
+        else:
+            if "parametrizations.weight.original0" in key:
+                return key.replace("parametrizations.weight.original0", "weight_g")
+            if "parametrizations.weight.original1" in key:
+                return key.replace("parametrizations.weight.original1", "weight_v")
+        return key
+
+    @classmethod
+    def _fix_state_dict_keys_on_load(cls, state_dict):
+        """Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
+        Logs if any parameters have been renamed.
+        """
+
+        renamed_keys = {}
+        state_dict_keys = list(state_dict.keys())
+        for key in state_dict_keys:
+            new_key = cls._fix_state_dict_key_on_load(key)
+            if new_key != key:
+                state_dict[new_key] = state_dict.pop(key)
+
+            # add it once for logging
+            if "gamma" in key and "gamma" not in renamed_keys:
+                renamed_keys["gamma"] = (key, new_key)
+            if "beta" in key and "beta" not in renamed_keys:
+                renamed_keys["beta"] = (key, new_key)
+
+        if renamed_keys:
+            warning_msg = f"A pretrained model of type `{cls.__name__}` "
+            warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
+            for old_key, new_key in renamed_keys.values():
+                warning_msg += f"* `{old_key}` -> `{new_key}`\n"
+            warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
+            logger.info_once(warning_msg)
+
+        return state_dict
+
+    @staticmethod
+    def _fix_state_dict_key_on_save(key):
+        """
+        Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
+        Do nothing by default, but can be overriden in particular models.
+        """
+        return key
+
+    def _fix_state_dict_keys_on_save(self, state_dict):
+        """
+        Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
+        Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
+        """
+        return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
+
+    @classmethod
+    def _load_pretrained_model(
+        cls,
+        model,
+        state_dict,
+        loaded_keys,
+        resolved_archive_file,
+        pretrained_model_name_or_path,
+        ignore_mismatched_sizes=False,
+        sharded_metadata=None,
+        _fast_init=True,
+        low_cpu_mem_usage=False,
+        device_map=None,
+        offload_folder=None,
+        offload_state_dict=None,
+        dtype=None,
+        hf_quantizer=None,
+        keep_in_fp32_modules=None,
+        gguf_path=None,
+        weights_only=True,
+    ):
+        is_safetensors = False
+        is_quantized = hf_quantizer is not None
+        state_dict_folder = None
+        state_dict_index = None
+
+        if device_map is not None and "disk" in device_map.values():
+            archive_file = (
+                resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file
+            )
+            is_safetensors = archive_file.endswith(".safetensors")
+            if offload_folder is None and not is_safetensors:
+                raise ValueError(
+                    "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
+                    " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
+                    " offers the weights in this format."
+                )
+            if offload_folder is not None:
+                os.makedirs(offload_folder, exist_ok=True)
+            if offload_state_dict is None:
+                offload_state_dict = True
+
+        is_sharded_safetensors = is_safetensors and sharded_metadata is not None
+
+        # tie the model weights before retrieving the state_dict
+        model.tie_weights()
+
+        # Retrieve missing & unexpected_keys
+        model_state_dict = model.state_dict()
+        expected_keys = list(model_state_dict.keys())
+        prefix = model.base_model_prefix
+
+        if hf_quantizer is not None:
+            expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
+
+        original_loaded_keys = loaded_keys
+        loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
+
+        if len(prefix) > 0:
+            has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
+            expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
+        else:
+            has_prefix_module = False
+            expects_prefix_module = False
+
+        # key re-naming operations are never done on the keys
+        # that are loaded, but always on the keys of the newly initialized model
+        remove_prefix_from_model = not has_prefix_module and expects_prefix_module
+        add_prefix_to_model = has_prefix_module and not expects_prefix_module
+
+        if remove_prefix_from_model:
+            _prefix = f"{prefix}."
+            expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
+            expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
+        elif add_prefix_to_model:
+            expected_keys = [".".join([prefix, s]) for s in expected_keys]
+
+        missing_keys = sorted(set(expected_keys) - set(loaded_keys))
+        unexpected_keys = set(loaded_keys) - set(expected_keys)
+
+        # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
+        # buffers
+        model_buffers = {n for n, _ in model.named_buffers()}
+        if remove_prefix_from_model:
+            model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
+        elif add_prefix_to_model:
+            model_buffers = {".".join([prefix, key]) for key in model_buffers}
+        unexpected_keys = sorted(unexpected_keys - model_buffers)
+
+        # Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858)
+        has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
+        if has_inv_freq_buffers:
+            unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k}
+
+        model.tie_weights()
+        if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
+            ptrs = collections.defaultdict(list)
+            for name, tensor in model.state_dict().items():
+                id_tensor = id_tensor_storage(tensor)
+                ptrs[id_tensor].append(name)
+
+            # These are all the pointers of shared tensors.
+            tied_params = [names for _, names in ptrs.items() if len(names) > 1]
+        else:
+            # id function doesn't work for meta tensor so we need this function
+            tied_params = find_tied_parameters(model)
+
+        for group in tied_params:
+            if remove_prefix_from_model:
+                group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group]
+            elif add_prefix_to_model:
+                group = [".".join([prefix, key]) for key in group]
+            missing_in_group = [k for k in missing_keys if k in group]
+            if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
+                missing_keys = [k for k in missing_keys if k not in missing_in_group]
+
+        # Some models may have keys that are not in the state by design, removing them before needlessly warning
+        # the user.
+        if cls._keys_to_ignore_on_load_missing is not None:
+            for pat in cls._keys_to_ignore_on_load_missing:
+                missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
+
+        if cls._keys_to_ignore_on_load_unexpected is not None:
+            for pat in cls._keys_to_ignore_on_load_unexpected:
+                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+        if hf_quantizer is not None:
+            missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
+
+        # retrieve weights on meta device and put them back on CPU.
+        # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
+        if low_cpu_mem_usage:
+            for key in missing_keys:
+                if key in list(model_state_dict.keys()):
+                    key = key
+                elif f"{prefix}.{key}" in list(model_state_dict.keys()):
+                    key = f"{prefix}.{key}"
+                elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
+                    key = ".".join(key.split(".")[1:])
+                param = model_state_dict[key]
+
+                # upcast in fp32 if any
+                target_dtype = dtype
+                if (
+                    keep_in_fp32_modules is not None
+                    and dtype == torch.float16
+                    and any(
+                        module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
+                    )
+                ):
+                    target_dtype = torch.float32
+
+                if param.device == torch.device("meta"):
+                    value = torch.empty(*param.size(), dtype=target_dtype)
+                    if (
+                        not is_quantized
+                        or (getattr(hf_quantizer, "requires_parameters_quantization", False))
+                        or not hf_quantizer.check_quantized_param(
+                            model, param_value=value, param_name=key, state_dict={}
+                        )
+                    ):
+                        set_module_tensor_to_device(model, key, "cpu", value)
+                    else:
+                        hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys)
+
+        # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
+        if _fast_init:
+            if not ignore_mismatched_sizes:
+                if remove_prefix_from_model:
+                    _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
+                elif add_prefix_to_model:
+                    _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
+                else:
+                    _loaded_keys = loaded_keys
+                not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
+                # If we're about to tie the output embeds to the input embeds we don't need to init them
+                if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
+                    output_embeddings = model.get_output_embeddings()
+                    if output_embeddings is not None:
+                        # Still need to initialize if there is a bias term since biases are not tied.
+                        if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None:
+                            output_embeddings._is_hf_initialized = True
+            else:
+                not_initialized_submodules = dict(model.named_modules())
+            # This will only initialize submodules that are not marked as initialized by the line above.
+            if is_deepspeed_zero3_enabled() and not is_quantized:
+                import deepspeed
+
+                not_initialized_parameters = list(
+                    set(
+                        itertools.chain.from_iterable(
+                            submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
+                        )
+                    )
+                )
+                with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
+                    model.apply(model._initialize_weights)
+            else:
+                model.apply(model._initialize_weights)
+
+        # Set some modules to fp32 if any
+        if keep_in_fp32_modules is not None:
+            for name, param in model.named_parameters():
+                if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
+                    # param = param.to(torch.float32) does not work here as only in the local scope.
+                    param.data = param.data.to(torch.float32)
+
+        # Make sure we are able to load base models as well as derived models (with heads)
+        start_prefix = ""
+        model_to_load = model
+        if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
+            start_prefix = cls.base_model_prefix + "."
+        if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
+            model_to_load = getattr(model, cls.base_model_prefix)
+            base_model_expected_keys = list(model_to_load.state_dict().keys())
+            if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
+                raise ValueError(
+                    "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
+                    "properly saved?"
+                )
+            if device_map is not None:
+                device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
+
+        def _find_mismatched_keys(
+            state_dict,
+            model_state_dict,
+            loaded_keys,
+            original_loaded_keys,
+            add_prefix_to_model,
+            remove_prefix_from_model,
+            ignore_mismatched_sizes,
+        ):
+            mismatched_keys = []
+            if ignore_mismatched_sizes:
+                for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
+                    # If the checkpoint is sharded, we may not have the key here.
+                    if checkpoint_key not in state_dict:
+                        continue
+                    if remove_prefix_from_model:
+                        # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
+                        model_key = f"{prefix}.{model_key}"
+                    elif add_prefix_to_model:
+                        # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
+                        model_key = ".".join(model_key.split(".")[1:])
+
+                    if (
+                        model_key in model_state_dict
+                        and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+                    ):
+                        if (
+                            state_dict[checkpoint_key].shape[-1] == 1
+                            and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
+                        ):
+                            # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
+                            # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
+                            pass
+                        else:
+                            mismatched_keys.append(
+                                (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+                            )
+                            del state_dict[checkpoint_key]
+            return mismatched_keys
+
+        if resolved_archive_file is not None:
+            folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
+        else:
+            folder = None
+        if device_map is not None and is_safetensors:
+            param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
+            str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
+            if sharded_metadata is None:
+                archive_file = (
+                    resolved_archive_file[0]
+                    if isinstance(resolved_archive_file, (list, tuple))
+                    else resolved_archive_file
+                )
+                weight_map = {p: archive_file for p in original_loaded_keys}
+            else:
+                weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
+            offload_index = {
+                p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
+                for p, f in weight_map.items()
+                if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
+            }
+        else:
+            offload_index = None
+
+        if state_dict is not None:
+            # Whole checkpoint
+            mismatched_keys = _find_mismatched_keys(
+                state_dict,
+                model_state_dict,
+                loaded_keys,
+                original_loaded_keys,
+                add_prefix_to_model,
+                remove_prefix_from_model,
+                ignore_mismatched_sizes,
+            )
+
+            # For GGUF models `state_dict` is never set to None as the state dict is always small
+            if gguf_path or low_cpu_mem_usage:
+                fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
+                error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
+                    model_to_load,
+                    fixed_state_dict,
+                    start_prefix,
+                    expected_keys,
+                    device_map=device_map,
+                    offload_folder=offload_folder,
+                    offload_index=offload_index,
+                    state_dict_folder=state_dict_folder,
+                    state_dict_index=state_dict_index,
+                    dtype=dtype,
+                    hf_quantizer=hf_quantizer,
+                    is_safetensors=is_safetensors,
+                    keep_in_fp32_modules=keep_in_fp32_modules,
+                    unexpected_keys=unexpected_keys,
+                )
+            else:
+                # Sharded checkpoint or whole but low_cpu_mem_usage==True
+                assign_to_params_buffers = check_support_param_buffer_assignment(
+                    model_to_load, state_dict, start_prefix
+                )
+                fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
+                error_msgs = _load_state_dict_into_model(
+                    model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
+                )
+
+        else:
+            # This should always be a list but, just to be sure.
+            if not isinstance(resolved_archive_file, list):
+                resolved_archive_file = [resolved_archive_file]
+
+            error_msgs = []
+            mismatched_keys = []
+            if not is_safetensors:
+                offload_index = {} if device_map is not None and "disk" in device_map.values() else None
+            if offload_state_dict:
+                state_dict_folder = tempfile.mkdtemp()
+                state_dict_index = {}
+            else:
+                state_dict_folder = None
+                state_dict_index = None
+
+            if is_sharded_safetensors:
+                disk_only_shard_files = get_disk_only_shard_files(
+                    device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
+                )
+                disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
+            else:
+                disk_only_shard_files = []
+
+            if len(resolved_archive_file) > 1:
+                resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
+            assign_to_params_buffers = None
+            for shard_file in resolved_archive_file:
+                # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
+                if shard_file in disk_only_shard_files:
+                    continue
+                map_location = None
+                if (
+                    device_map is not None
+                    and hf_quantizer is not None
+                    and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
+                    and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
+                ):
+                    map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
+                state_dict = load_state_dict(
+                    shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
+                )
+
+                # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
+                # matching the weights in the model.
+                mismatched_keys += _find_mismatched_keys(
+                    state_dict,
+                    model_state_dict,
+                    loaded_keys,
+                    original_loaded_keys,
+                    add_prefix_to_model,
+                    remove_prefix_from_model,
+                    ignore_mismatched_sizes,
+                )
+                if low_cpu_mem_usage:
+                    if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
+                        for key, param in model_to_load.state_dict().items():
+                            if param.device == torch.device("meta"):
+                                set_module_tensor_to_device(
+                                    model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
+                                )
+                    else:
+                        fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
+                        new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
+                            model_to_load,
+                            fixed_state_dict,
+                            start_prefix,
+                            expected_keys,
+                            device_map=device_map,
+                            offload_folder=offload_folder,
+                            offload_index=offload_index,
+                            state_dict_folder=state_dict_folder,
+                            state_dict_index=state_dict_index,
+                            dtype=dtype,
+                            hf_quantizer=hf_quantizer,
+                            is_safetensors=is_safetensors,
+                            keep_in_fp32_modules=keep_in_fp32_modules,
+                            unexpected_keys=unexpected_keys,
+                        )
+                        error_msgs += new_error_msgs
+                else:
+                    # Sharded checkpoint or whole but low_cpu_mem_usage==True
+                    if assign_to_params_buffers is None:
+                        assign_to_params_buffers = check_support_param_buffer_assignment(
+                            model_to_load, state_dict, start_prefix
+                        )
+                    fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
+                    error_msgs += _load_state_dict_into_model(
+                        model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
+                    )
+
+                # force memory release
+                del state_dict
+                gc.collect()
+
+            if offload_index is not None and len(offload_index) > 0:
+                if model != model_to_load:
+                    # We need to add the prefix of the base model
+                    prefix = cls.base_model_prefix
+                    if not is_safetensors:
+                        for weight_name in offload_index:
+                            shutil.move(
+                                os.path.join(offload_folder, f"{weight_name}.dat"),
+                                os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
+                            )
+                    offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
+                if not is_safetensors:
+                    save_offload_index(offload_index, offload_folder)
+                    offload_index = None
+
+            if offload_state_dict:
+                # Load back temporarily offloaded state dict
+                load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
+                shutil.rmtree(state_dict_folder)
+
+        if len(error_msgs) > 0:
+            error_msg = "\n\t".join(error_msgs)
+            if "size mismatch" in error_msg:
+                error_msg += (
+                    "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+                )
+            raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+        if len(unexpected_keys) > 0:
+            archs = [] if model.config.architectures is None else model.config.architectures
+            warner = logger.warning if model.__class__.__name__ in archs else logger.info
+            warner(
+                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+                " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+                " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+                f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+                " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+            )
+        else:
+            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+        if len(missing_keys) > 0:
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+            )
+        elif len(mismatched_keys) == 0:
+            logger.info(
+                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+                f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+                " training."
+            )
+        if len(mismatched_keys) > 0:
+            mismatched_warning = "\n".join(
+                [
+                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+                    for key, shape1, shape2 in mismatched_keys
+                ]
+            )
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+                f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+                " to use it for predictions and inference."
+            )
+
+        return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
+
+    def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
+        module_keys = {".".join(key.split(".")[:-1]) for key in names}
+
+        # torch.nn.ParameterList is a special case where two parameter keywords
+        # are appended to the module name, *e.g.* bert.special_embeddings.0
+        module_keys = module_keys.union(
+            {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
+        )
+
+        retrieved_modules = []
+        # retrieve all modules that has at least one missing weight name
+        for name, module in self.named_modules():
+            if remove_prefix:
+                _prefix = f"{self.base_model_prefix}."
+                name = name[len(_prefix) :] if name.startswith(_prefix) else name
+            elif add_prefix:
+                name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
+
+            if name in module_keys:
+                retrieved_modules.append(module)
+
+        return retrieved_modules
+
+    @staticmethod
+    def _load_pretrained_model_low_mem(
+        model,
+        loaded_state_dict_keys,
+        resolved_archive_file,
+        start_prefix="",
+        hf_quantizer=None,
+        pretrained_model_name_or_path=None,
+        weights_only=True,
+    ):
+        """
+        This is an experimental function that loads the model using ~1.x model size CPU memory
+
+        Before you call it do:
+
+        1. save which state_dict keys are available
+        2. drop state_dict before model is created, since the latter takes 1x model size memory
+
+        Here then we continue:
+
+        3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
+        4. load state_dict 2nd time
+        5. replace the params/buffers from the state_dict
+
+        Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To
+        handle bitsandbytes, needs non-empty hf_quantizer argument.
+        """
+
+        _move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
+        state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
+        expected_keys = loaded_state_dict_keys  # plug for missing expected_keys. TODO: replace with proper keys
+        fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict)
+        error_msgs = _load_state_dict_into_meta_model(
+            model,
+            fixed_state_dict,
+            start_prefix,
+            expected_keys=expected_keys,
+            hf_quantizer=hf_quantizer,
+        )
+        return error_msgs
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="AutoModel"):
+        """
+        Register this class with a given auto class. This should only be used for custom models as the ones in the
+        library are already mapped with an auto class.
+
+        
+
+        This API is experimental and may have some slight breaking changes in the next releases.
+
+        
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
+                The auto class to register this new model with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+    def to_bettertransformer(self) -> "PreTrainedModel":
+        """
+        Converts the model to use [PyTorch's native attention
+        implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to
+        Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a
+        subset of all Transformers models are supported.
+
+        PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested
+        tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog
+        post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).
+
+        Returns:
+            [`PreTrainedModel`]: The model converted to BetterTransformer.
+        """
+        if not is_optimum_available():
+            raise ImportError("The package `optimum` is required to use Better Transformer.")
+
+        from optimum.version import __version__ as optimum_version
+
+        if version.parse(optimum_version) < version.parse("1.7.0"):
+            raise ImportError(
+                f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
+            )
+
+        from optimum.bettertransformer import BetterTransformer
+
+        return BetterTransformer.transform(self)
+
+    def reverse_bettertransformer(self):
+        """
+        Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is
+        used, for example in order to save the model.
+
+        Returns:
+            [`PreTrainedModel`]: The model converted back to the original modeling.
+        """
+        if not is_optimum_available():
+            raise ImportError("The package `optimum` is required to use Better Transformer.")
+
+        from optimum.version import __version__ as optimum_version
+
+        if version.parse(optimum_version) < version.parse("1.7.0"):
+            raise ImportError(
+                f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
+            )
+
+        from optimum.bettertransformer import BetterTransformer
+
+        return BetterTransformer.reverse(self)
+
+    def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
+        """
+        Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.
+        """
+
+        # Skip the check during tracing.
+        if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
+            return
+
+        if (attention_mask is not None) or (self.config.pad_token_id is None):
+            return
+
+        # Check only the first and last input IDs to reduce overhead.
+        if self.config.pad_token_id in input_ids[:, [-1, 0]]:
+            warn_string = (
+                "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
+                "https://huggingface.co/docs/transformers/troubleshooting"
+                "#incorrect-output-when-padding-tokens-arent-masked."
+            )
+
+            # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
+            # attention_mask or not. In this case, we should still show a warning because this is a rare case.
+            if (
+                (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
+                or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
+                or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
+            ):
+                warn_string += (
+                    f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
+                    f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
+                    f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
+                )
+
+            logger.warning_once(warn_string)
+
+    @property
+    def supports_tp_plan(self):
+        """
+        Returns whether the model has a tensor parallelism plan.
+        """
+        if self._tp_plan is not None:
+            return True
+        # Check if base model has a TP plan
+        if getattr(self.base_model, "_tp_plan", None) is not None:
+            return True
+        return False
+
+    def tensor_parallel(self, device_mesh):
+        """
+        Tensor parallelize the model across the given device mesh.
+
+        Args:
+            device_mesh (`torch.distributed.DeviceMesh`):
+                The device mesh to use for tensor parallelism.
+        """
+        if not is_torch_greater_or_equal("2.5"):
+            raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
+
+        # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
+        # No op if `_tp_plan` attribute does not exist under the module.
+        # This is a helper function to be used with `model.apply` to recursively
+        # parallelize a model.
+        def tplize(mod: torch.nn.Module) -> None:
+            tp_plan = getattr(mod, "_tp_plan", None)
+            if tp_plan is None:
+                return
+            logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
+            # In model configs, we use a neutral type (string) to specify
+            # parallel styles, here we translate them into torch TP types.
+            # Using tree_map because `tp_plan` is a dict.
+            tp_plan = torch.utils._pytree.tree_map(
+                translate_to_torch_parallel_style,
+                tp_plan,
+            )
+            # Apply TP to current module.
+            torch.distributed.tensor.parallel.parallelize_module(
+                mod,
+                device_mesh=device_mesh,
+                parallelize_plan=tp_plan,
+            )
+
+        # `apply` is a native method of `nn.Module` that recursively applies a
+        # function to every submodule.
+        self.apply(tplize)
+
+    @property
+    def loss_function(self):
+        if hasattr(self, "_loss_function"):
+            return self._loss_function
+
+        loss_type = getattr(self, "loss_type", None)
+
+        if loss_type is None or loss_type not in LOSS_MAPPING:
+            logger.warning_once(
+                f"`loss_type={loss_type}` was set in the config but it is unrecognised."
+                f"Using the default loss: `ForCausalLMLoss`."
+            )
+            loss_type = "ForCausalLM"
+        return LOSS_MAPPING[loss_type]
+
+    @loss_function.setter
+    def loss_function(self, value):
+        self._loss_function = value
+
+    def get_compiled_call(self, compile_config: CompileConfig):
+        """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
+        non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
+        want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
+        (where we want the speed-ups of compiled version with static shapes)."""
+        # Only reset it if not present or different from previous config
+        default_config = getattr(self.generation_config, "compile_config", CompileConfig())
+        if (
+            not hasattr(self, "_compiled_call")
+            or getattr(self, "_last_compile_config", default_config) != compile_config
+        ):
+            self._last_compile_config = compile_config
+            self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
+        return self._compiled_call
+
+
+PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
+if PreTrainedModel.push_to_hub.__doc__ is not None:
+    PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
+        object="model", object_class="AutoModel", object_files="model file"
+    )
+
+
+class PoolerStartLogits(nn.Module):
+    """
+    Compute SQuAD start logits from sequence hidden states.
+
+    Args:
+        config ([`PretrainedConfig`]):
+            The config used by the model, will be used to grab the `hidden_size` of the model.
+    """
+
+    def __init__(self, config: PretrainedConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, 1)
+
+    def forward(
+        self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
+    ) -> torch.FloatTensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
+                The final hidden states of the model.
+            p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
+                Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
+                should be masked.
+
+        Returns:
+            `torch.FloatTensor`: The start logits for SQuAD.
+        """
+        x = self.dense(hidden_states).squeeze(-1)
+
+        if p_mask is not None:
+            if get_parameter_dtype(self) == torch.float16:
+                x = x * (1 - p_mask) - 65500 * p_mask
+            else:
+                x = x * (1 - p_mask) - 1e30 * p_mask
+
+        return x
+
+
+class PoolerEndLogits(nn.Module):
+    """
+    Compute SQuAD end logits from sequence hidden states.
+
+    Args:
+        config ([`PretrainedConfig`]):
+            The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
+            to use.
+    """
+
+    def __init__(self, config: PretrainedConfig):
+        super().__init__()
+        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
+        self.activation = nn.Tanh()
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dense_1 = nn.Linear(config.hidden_size, 1)
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        start_states: Optional[torch.FloatTensor] = None,
+        start_positions: Optional[torch.LongTensor] = None,
+        p_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
+                The final hidden states of the model.
+            start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
+                The hidden states of the first tokens for the labeled span.
+            start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+                The position of the first token for the labeled span.
+            p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
+                Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
+                should be masked.
+
+        
+
+        One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
+        `start_states`.
+
+        
+
+        Returns:
+            `torch.FloatTensor`: The end logits for SQuAD.
+        """
+        assert (
+            start_states is not None or start_positions is not None
+        ), "One of start_states, start_positions should be not None"
+        if start_positions is not None:
+            slen, hsz = hidden_states.shape[-2:]
+            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
+            start_states = hidden_states.gather(-2, start_positions)  # shape (bsz, 1, hsz)
+            start_states = start_states.expand(-1, slen, -1)  # shape (bsz, slen, hsz)
+
+        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
+        x = self.activation(x)
+        x = self.LayerNorm(x)
+        x = self.dense_1(x).squeeze(-1)
+
+        if p_mask is not None:
+            if get_parameter_dtype(self) == torch.float16:
+                x = x * (1 - p_mask) - 65500 * p_mask
+            else:
+                x = x * (1 - p_mask) - 1e30 * p_mask
+
+        return x
+
+
+class PoolerAnswerClass(nn.Module):
+    """
+    Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
+
+    Args:
+        config ([`PretrainedConfig`]):
+            The config used by the model, will be used to grab the `hidden_size` of the model.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
+        self.activation = nn.Tanh()
+        self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        start_states: Optional[torch.FloatTensor] = None,
+        start_positions: Optional[torch.LongTensor] = None,
+        cls_index: Optional[torch.LongTensor] = None,
+    ) -> torch.FloatTensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
+                The final hidden states of the model.
+            start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
+                The hidden states of the first tokens for the labeled span.
+            start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+                The position of the first token for the labeled span.
+            cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+                Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
+
+        
+
+        One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
+        `start_states`.
+
+        
+
+        Returns:
+            `torch.FloatTensor`: The SQuAD 2.0 answer class.
+        """
+        # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
+        hsz = hidden_states.shape[-1]
+        assert (
+            start_states is not None or start_positions is not None
+        ), "One of start_states, start_positions should be not None"
+        if start_positions is not None:
+            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
+            start_states = hidden_states.gather(-2, start_positions).squeeze(-2)  # shape (bsz, hsz)
+
+        if cls_index is not None:
+            cls_index = cls_index[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
+            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, hsz)
+        else:
+            cls_token_state = hidden_states[:, -1, :]  # shape (bsz, hsz)
+
+        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
+        x = self.activation(x)
+        x = self.dense_1(x).squeeze(-1)
+
+        return x
+
+
+@dataclass
+class SquadHeadOutput(ModelOutput):
+    """
+    Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`].
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
+            Classification loss as the sum of start token, end token (and is_impossible if provided) classification
+            losses.
+        start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
+            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
+        start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
+            Indices for the top config.start_n_top start token possibilities (beam-search).
+        end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
+            Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
+            (beam-search).
+        end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
+            Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
+        cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
+            Log probabilities for the `is_impossible` label of the answers.
+
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    start_top_log_probs: Optional[torch.FloatTensor] = None
+    start_top_index: Optional[torch.LongTensor] = None
+    end_top_log_probs: Optional[torch.FloatTensor] = None
+    end_top_index: Optional[torch.LongTensor] = None
+    cls_logits: Optional[torch.FloatTensor] = None
+
+
+class SQuADHead(nn.Module):
+    r"""
+    A SQuAD head inspired by XLNet.
+
+    Args:
+        config ([`PretrainedConfig`]):
+            The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
+            to use.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.start_n_top = config.start_n_top
+        self.end_n_top = config.end_n_top
+
+        self.start_logits = PoolerStartLogits(config)
+        self.end_logits = PoolerEndLogits(config)
+        self.answer_class = PoolerAnswerClass(config)
+
+    @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        start_positions: Optional[torch.LongTensor] = None,
+        end_positions: Optional[torch.LongTensor] = None,
+        cls_index: Optional[torch.LongTensor] = None,
+        is_impossible: Optional[torch.LongTensor] = None,
+        p_mask: Optional[torch.FloatTensor] = None,
+        return_dict: bool = False,
+    ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
+                Final hidden states of the model on the sequence tokens.
+            start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+                Positions of the first token for the labeled span.
+            end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+                Positions of the last token for the labeled span.
+            cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+                Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
+            is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+                Whether the question has a possible answer in the paragraph or not.
+            p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
+                Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
+                should be masked.
+            return_dict (`bool`, *optional*, defaults to `False`):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+        Returns:
+        """
+        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
+
+        if start_positions is not None and end_positions is not None:
+            # If we are on multi-GPU, let's remove the dimension added by batch splitting
+            for x in (start_positions, end_positions, cls_index, is_impossible):
+                if x is not None and x.dim() > 1:
+                    x.squeeze_(-1)
+
+            # during training, compute the end logits based on the ground truth of the start position
+            end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
+
+            loss_fct = CrossEntropyLoss()
+            start_loss = loss_fct(start_logits, start_positions)
+            end_loss = loss_fct(end_logits, end_positions)
+            total_loss = (start_loss + end_loss) / 2
+
+            if cls_index is not None and is_impossible is not None:
+                # Predict answerability from the representation of CLS and START
+                cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
+                loss_fct_cls = nn.BCEWithLogitsLoss()
+                cls_loss = loss_fct_cls(cls_logits, is_impossible)
+
+                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
+                total_loss += cls_loss * 0.5
+
+            return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
+
+        else:
+            # during inference, compute the end logits based on beam search
+            bsz, slen, hsz = hidden_states.size()
+            start_log_probs = nn.functional.softmax(start_logits, dim=-1)  # shape (bsz, slen)
+
+            start_top_log_probs, start_top_index = torch.topk(
+                start_log_probs, self.start_n_top, dim=-1
+            )  # shape (bsz, start_n_top)
+            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz)  # shape (bsz, start_n_top, hsz)
+            start_states = torch.gather(hidden_states, -2, start_top_index_exp)  # shape (bsz, start_n_top, hsz)
+            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1)  # shape (bsz, slen, start_n_top, hsz)
+
+            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
+                start_states
+            )  # shape (bsz, slen, start_n_top, hsz)
+            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
+            end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
+            end_log_probs = nn.functional.softmax(end_logits, dim=1)  # shape (bsz, slen, start_n_top)
+
+            end_top_log_probs, end_top_index = torch.topk(
+                end_log_probs, self.end_n_top, dim=1
+            )  # shape (bsz, end_n_top, start_n_top)
+            end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
+            end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
+
+            start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
+            cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
+
+            if not return_dict:
+                return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
+            else:
+                return SquadHeadOutput(
+                    start_top_log_probs=start_top_log_probs,
+                    start_top_index=start_top_index,
+                    end_top_log_probs=end_top_log_probs,
+                    end_top_index=end_top_index,
+                    cls_logits=cls_logits,
+                )
+
+
+class SequenceSummary(nn.Module):
+    r"""
+    Compute a single vector summary of a sequence hidden states.
+
+    Args:
+        config ([`PretrainedConfig`]):
+            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+            config class of your model for the default values it uses):
+
+            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+                - `"last"` -- Take the last token hidden state (like XLNet)
+                - `"first"` -- Take the first token hidden state (like Bert)
+                - `"mean"` -- Take the mean of all tokens hidden states
+                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+                - `"attn"` -- Not implemented now, use multi-head attention
+
+            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+              (otherwise to `config.hidden_size`).
+            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+              another string or `None` will add no activation.
+            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+    """
+
+    def __init__(self, config: PretrainedConfig):
+        super().__init__()
+
+        self.summary_type = getattr(config, "summary_type", "last")
+        if self.summary_type == "attn":
+            # We should use a standard multi-head attention module with absolute positional embedding for that.
+            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+            raise NotImplementedError
+
+        self.summary = Identity()
+        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+                num_classes = config.num_labels
+            else:
+                num_classes = config.hidden_size
+            self.summary = nn.Linear(config.hidden_size, num_classes)
+
+        activation_string = getattr(config, "summary_activation", None)
+        self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
+
+        self.first_dropout = Identity()
+        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+            self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+        self.last_dropout = Identity()
+        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+            self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+    def forward(
+        self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
+    ) -> torch.FloatTensor:
+        """
+        Compute a single vector summary of a sequence hidden states.
+
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
+                The hidden states of the last layer.
+            cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+        Returns:
+            `torch.FloatTensor`: The summary of the sequence hidden states.
+        """
+        if self.summary_type == "last":
+            output = hidden_states[:, -1]
+        elif self.summary_type == "first":
+            output = hidden_states[:, 0]
+        elif self.summary_type == "mean":
+            output = hidden_states.mean(dim=1)
+        elif self.summary_type == "cls_index":
+            if cls_index is None:
+                cls_index = torch.full_like(
+                    hidden_states[..., :1, :],
+                    hidden_states.shape[-2] - 1,
+                    dtype=torch.long,
+                )
+            else:
+                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
+            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
+        elif self.summary_type == "attn":
+            raise NotImplementedError
+
+        output = self.first_dropout(output)
+        output = self.summary(output)
+        output = self.activation(output)
+        output = self.last_dropout(output)
+
+        return output
+
+
+def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
+    """
+    Recursively unwraps a model from potential containers (as used in distributed training).
+
+    Args:
+        model (`torch.nn.Module`): The model to unwrap.
+        recursive (`bool`, *optional*, defaults to `False`):
+            Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
+            recursively, not just the top-level distributed containers.
+    """
+    # Use accelerate implementation if available (should always be the case when using torch)
+    # This is for pytorch, as we also have to handle things like dynamo
+    if is_accelerate_available():
+        kwargs = {}
+        if recursive:
+            if not is_accelerate_available("0.29.0"):
+                raise RuntimeError(
+                    "Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
+                )
+            else:
+                kwargs["recursive"] = recursive
+        return extract_model_from_parallel(model, **kwargs)
+    else:
+        # since there could be multiple levels of wrapping, unwrap recursively
+        if hasattr(model, "module"):
+            return unwrap_model(model.module)
+        else:
+            return model
+
+
+def expand_device_map(device_map, param_names, start_prefix):
+    """
+    Expand a device map to return the correspondance parameter name to device.
+    """
+    new_device_map = {}
+    param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)]
+    for module, device in device_map.items():
+        new_device_map.update(
+            {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
+        )
+    return new_device_map
+
+
+def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
+    """
+    Returns the list of shard files containing only weights offloaded to disk.
+    """
+
+    weight_map = {
+        p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix)
+    }
+    files_content = collections.defaultdict(list)
+    for weight_name, filename in weight_map.items():
+        while len(weight_name) > 0 and weight_name not in device_map:
+            weight_name = ".".join(weight_name.split(".")[:-1])
+        files_content[filename].append(device_map[weight_name])
+
+    return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
+
+
+ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {}
+
+ALL_ATTENTION_FUNCTIONS.update(
+    {
+        "flash_attention_2": flash_attention_forward,
+        "flex_attention": flex_attention_forward,
+        "sdpa": sdpa_attention_forward,
+    }
+)
diff --git a/.venv/lib/python3.11/site-packages/transformers/trainer.py b/.venv/lib/python3.11/site-packages/transformers/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7108b12b9048eb8ae937cb04df3916d7ed02a02
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/trainer.py
@@ -0,0 +1,5170 @@
+# coding=utf-8
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
+"""
+
+import contextlib
+import copy
+import functools
+import glob
+import importlib.metadata
+import inspect
+import json
+import math
+import os
+import random
+import re
+import shutil
+import sys
+import tempfile
+import time
+import warnings
+from collections.abc import Mapping
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
+
+
+# Integrations must be imported before ML frameworks:
+# isort: off
+from .integrations import (
+    get_reporting_integration_callbacks,
+    hp_params,
+)
+
+# isort: on
+
+import huggingface_hub.utils as hf_hub_utils
+import numpy as np
+import torch
+import torch.distributed as dist
+from huggingface_hub import ModelCard, create_repo, upload_folder
+from packaging import version
+from torch import nn
+from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
+
+from . import __version__
+from .configuration_utils import PretrainedConfig
+from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
+from .debug_utils import DebugOption, DebugUnderflowOverflow
+from .feature_extraction_sequence_utils import SequenceFeatureExtractor
+from .feature_extraction_utils import FeatureExtractionMixin
+from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
+from .image_processing_utils import BaseImageProcessor
+from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
+from .integrations.tpu import tpu_spmd_dataloader
+from .modelcard import TrainingSummary
+from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
+from .models.auto.modeling_auto import (
+    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+    MODEL_MAPPING_NAMES,
+)
+from .optimization import Adafactor, get_scheduler
+from .processing_utils import ProcessorMixin
+from .pytorch_utils import (
+    ALL_LAYERNORM_LAYERS,
+    is_torch_greater_or_equal_than_2_3,
+)
+from .tokenization_utils_base import PreTrainedTokenizerBase
+from .trainer_callback import (
+    CallbackHandler,
+    DefaultFlowCallback,
+    ExportableState,
+    PrinterCallback,
+    ProgressCallback,
+    TrainerCallback,
+    TrainerControl,
+    TrainerState,
+)
+from .trainer_pt_utils import (
+    DistributedTensorGatherer,
+    EvalLoopContainer,
+    IterableDatasetShard,
+    LabelSmoother,
+    LayerWiseDummyOptimizer,
+    LengthGroupedSampler,
+    SequentialDistributedSampler,
+    distributed_broadcast_scalars,
+    distributed_concat,
+    find_batch_size,
+    get_model_param_count,
+    get_module_class_from_name,
+    get_parameter_names,
+    nested_concat,
+    nested_detach,
+    nested_numpify,
+    nested_xla_mesh_reduce,
+    reissue_pt_warnings,
+    remove_dummy_checkpoint,
+)
+from .trainer_utils import (
+    PREFIX_CHECKPOINT_DIR,
+    BestRun,
+    EvalLoopOutput,
+    EvalPrediction,
+    HPSearchBackend,
+    HubStrategy,
+    PredictionOutput,
+    RemoveColumnsCollator,
+    SaveStrategy,
+    TrainerMemoryTracker,
+    TrainOutput,
+    check_target_module_exists,
+    default_compute_objective,
+    denumpify_detensorize,
+    enable_full_determinism,
+    find_executable_batch_size,
+    get_last_checkpoint,
+    has_length,
+    neftune_post_forward_hook,
+    number_of_arguments,
+    seed_worker,
+    set_seed,
+    speed_metrics,
+)
+from .training_args import OptimizerNames, ParallelMode, TrainingArguments
+from .utils import (
+    ADAPTER_CONFIG_NAME,
+    ADAPTER_SAFE_WEIGHTS_NAME,
+    ADAPTER_WEIGHTS_NAME,
+    CONFIG_NAME,
+    SAFE_WEIGHTS_INDEX_NAME,
+    SAFE_WEIGHTS_NAME,
+    WEIGHTS_INDEX_NAME,
+    WEIGHTS_NAME,
+    XLA_FSDPV2_MIN_VERSION,
+    PushInProgress,
+    PushToHubMixin,
+    can_return_loss,
+    find_labels,
+    is_accelerate_available,
+    is_apex_available,
+    is_bitsandbytes_available,
+    is_datasets_available,
+    is_galore_torch_available,
+    is_grokadamw_available,
+    is_in_notebook,
+    is_ipex_available,
+    is_liger_kernel_available,
+    is_lomo_available,
+    is_peft_available,
+    is_safetensors_available,
+    is_sagemaker_dp_enabled,
+    is_sagemaker_mp_enabled,
+    is_schedulefree_available,
+    is_torch_compile_available,
+    is_torch_mlu_available,
+    is_torch_mps_available,
+    is_torch_musa_available,
+    is_torch_neuroncore_available,
+    is_torch_npu_available,
+    is_torch_xla_available,
+    is_torch_xpu_available,
+    is_torchao_available,
+    logging,
+    strtobool,
+)
+from .utils.deprecation import deprecate_kwarg
+from .utils.quantization_config import QuantizationMethod
+
+
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
+
+if is_in_notebook():
+    from .utils.notebook import NotebookProgressCallback
+
+    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
+
+if is_apex_available():
+    from apex import amp
+
+if is_datasets_available():
+    import datasets
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+    import torch_xla.debug.metrics as met
+    from torch_xla import __version__ as XLA_VERSION
+
+    IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
+    if IS_XLA_FSDPV2_POST_2_2:
+        import torch_xla.distributed.spmd as xs
+        import torch_xla.runtime as xr
+else:
+    IS_XLA_FSDPV2_POST_2_2 = False
+
+
+if is_sagemaker_mp_enabled():
+    import smdistributed.modelparallel.torch as smp
+    from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+
+    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+    IS_SAGEMAKER_MP_POST_1_10 = False
+
+
+if is_safetensors_available():
+    import safetensors.torch
+
+if is_peft_available():
+    from peft import PeftModel
+
+
+if is_accelerate_available():
+    from accelerate import Accelerator, skip_first_batches
+    from accelerate import __version__ as accelerate_version
+    from accelerate.state import AcceleratorState
+    from accelerate.utils import (
+        DistributedDataParallelKwargs,
+        DistributedType,
+        load_fsdp_model,
+        load_fsdp_optimizer,
+        save_fsdp_model,
+        save_fsdp_optimizer,
+    )
+
+    DATA_SAMPLERS = [RandomSampler]
+    if version.parse(accelerate_version) > version.parse("0.23.0"):
+        from accelerate.data_loader import SeedableRandomSampler
+
+        DATA_SAMPLERS += [SeedableRandomSampler]
+
+    if is_deepspeed_available():
+        from accelerate.utils import DeepSpeedSchedulerWrapper
+
+if is_accelerate_available("0.28.0"):
+    from accelerate.utils import DataLoaderConfiguration
+
+
+def _is_peft_model(model):
+    if is_peft_available():
+        classes_to_check = (PeftModel,) if is_peft_available() else ()
+        # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
+        if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
+            from peft import PeftMixedModel
+
+            classes_to_check = (*classes_to_check, PeftMixedModel)
+        return isinstance(model, classes_to_check)
+    return False
+
+
+def _get_fsdp_ckpt_kwargs():
+    # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
+    if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
+        return {"adapter_only": True}
+    else:
+        return {}
+
+
+def safe_globals():
+    # Starting from version 2.4 PyTorch introduces a check for the objects loaded
+    # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes
+    # a default and requires allowlisting of objects being loaded.
+    # See: https://github.com/pytorch/pytorch/pull/137602
+    # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
+    # See: https://github.com/huggingface/accelerate/pull/3036
+    if version.parse(torch.__version__).release < version.parse("2.6").release:
+        return contextlib.nullcontext()
+
+    np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
+    allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
+    # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
+    # all versions of numpy
+    allowlist += [type(np.dtype(np.uint32))]
+
+    return torch.serialization.safe_globals(allowlist)
+
+
+if TYPE_CHECKING:
+    import optuna
+
+    if is_datasets_available():
+        import datasets
+
+logger = logging.get_logger(__name__)
+
+
+# Name of the files used for checkpointing
+TRAINING_ARGS_NAME = "training_args.bin"
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.pt"
+OPTIMIZER_NAME_BIN = "optimizer.bin"
+SCHEDULER_NAME = "scheduler.pt"
+SCALER_NAME = "scaler.pt"
+FSDP_MODEL_NAME = "pytorch_model_fsdp"
+
+
+class Trainer:
+    """
+    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
+
+    Args:
+        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
+            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
+
+            
+
+            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
+            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
+            models.
+
+            
+
+        args ([`TrainingArguments`], *optional*):
+            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
+            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
+        data_collator (`DataCollator`, *optional*):
+            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
+            default to [`default_data_collator`] if no `processing_class` is provided, an instance of
+            [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer.
+        train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*):
+            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
+            `model.forward()` method are automatically removed.
+
+            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
+            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
+            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
+            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
+            sets the seed of the RNGs used.
+        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*):
+             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
+             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
+             dataset prepending the dictionary key to the metric name.
+        processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
+            Processing class used to process the data. If provided, will be used to automatically process the inputs
+            for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
+            reuse the fine-tuned model.
+            This supercedes the `tokenizer` argument, which is now deprecated.
+        model_init (`Callable[[], PreTrainedModel]`, *optional*):
+            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
+            from a new instance of the model as given by this function.
+
+            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
+            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
+            inner layers, dropout probabilities etc).
+        compute_loss_func (`Callable`, *optional*):
+            A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
+            batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`].
+        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
+            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
+            a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
+            `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
+            after the last eval batch to signal that the function needs to calculate and return the global summary
+            statistics rather than accumulating the batch-level statistics
+        callbacks (List of [`TrainerCallback`], *optional*):
+            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+            detailed in [here](callback).
+
+            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
+        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
+            A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
+            model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+        optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
+            A tuple containing the optimizer class and keyword arguments to use.
+            Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
+
+            Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
+        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
+            A function that preprocess the logits right before caching them at each evaluation step. Must take two
+            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+            by this function will be reflected in the predictions received by `compute_metrics`.
+
+            Note that the labels (second parameter) will be `None` if the dataset does not have them.
+
+    Important attributes:
+
+        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
+          subclass.
+        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
+          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
+          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
+          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
+        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
+          data parallelism, this means some of the model layers are split on different GPUs).
+        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
+          to `False` if model parallel or deepspeed is used, or if the default
+          `TrainingArguments.place_model_on_device` is overridden to return `False` .
+        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
+          in `train`)
+
+    """
+
+    # Those are used as methods of the Trainer in examples.
+    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
+
+    @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
+    def __init__(
+        self,
+        model: Union[PreTrainedModel, nn.Module] = None,
+        args: TrainingArguments = None,
+        data_collator: Optional[DataCollator] = None,
+        train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
+        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
+        processing_class: Optional[
+            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
+        ] = None,
+        model_init: Optional[Callable[[], PreTrainedModel]] = None,
+        compute_loss_func: Optional[Callable] = None,
+        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
+        callbacks: Optional[List[TrainerCallback]] = None,
+        optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+        optimizer_cls_and_kwargs: Optional[Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] = None,
+        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+    ):
+        if args is None:
+            output_dir = "tmp_trainer"
+            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
+            args = TrainingArguments(output_dir=output_dir)
+        if args.batch_eval_metrics and compute_metrics is not None:
+            if "compute_result" not in inspect.signature(compute_metrics).parameters.keys():
+                raise ValueError(
+                    "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
+                    " boolean argument which will be triggered after the last batch of the eval set to signal that the"
+                    " summary statistics should be returned by the function."
+                )
+        if args.eval_strategy is not None and args.eval_strategy != "no" and eval_dataset is None:
+            raise ValueError(
+                f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
+            )
+        if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
+            if args.metric_for_best_model is None:
+                raise ValueError(
+                    "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
+                )
+
+        self.args = args
+        self.compute_loss_func = compute_loss_func
+        # Seed must be set before instantiating the model when using model
+        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+
+        self.hp_name = None
+        self.deepspeed = None
+        self.is_in_train = False
+
+        self.create_accelerator_and_postprocess()
+
+        # memory metrics - must set up as early as possible
+        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+        self._memory_tracker.start()
+
+        # set the correct log level depending on the node
+        log_level = args.get_process_log_level()
+        logging.set_verbosity(log_level)
+
+        # force device and distributed setup init explicitly
+        args._setup_devices
+
+        if model is None:
+            if model_init is not None:
+                self.model_init = model_init
+                model = self.call_model_init()
+            else:
+                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
+        else:
+            if model_init is not None:
+                warnings.warn(
+                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+                    " release.",
+                    FutureWarning,
+                )
+            self.model_init = model_init
+
+        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
+            raise ValueError(
+                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
+                "computes hidden states and does not accept any labels. You should choose a model with a head "
+                "suitable for your task like any of the `AutoModelForXxx` listed at "
+                "https://huggingface.co/docs/transformers/model_doc/auto"
+            )
+
+        if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
+            self.is_model_parallel = True
+        else:
+            self.is_model_parallel = False
+
+        if getattr(model, "hf_device_map", None) is not None:
+            devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
+            if len(devices) > 1:
+                self.is_model_parallel = True
+            elif len(devices) == 1:
+                self.is_model_parallel = self.args.device != torch.device(devices[0])
+            else:
+                self.is_model_parallel = False
+
+            # warn users
+            if self.is_model_parallel:
+                logger.info(
+                    "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
+                    " to `True` to avoid any unexpected behavior such as device placement mismatching."
+                )
+
+        if self.args.use_liger_kernel:
+            if is_liger_kernel_available():
+                from liger_kernel.transformers import _apply_liger_kernel_to_instance
+
+                if isinstance(model, PreTrainedModel):
+                    # Patch the model with liger kernels. Use the default kernel configurations.
+                    _apply_liger_kernel_to_instance(model=model)
+                else:
+                    logger.warning(
+                        "The model is not an instance of PreTrainedModel. No liger kernels will be applied."
+                    )
+            else:
+                raise ImportError(
+                    "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
+                    "Please install it with `pip install liger-kernel`"
+                )
+
+        _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
+            model, "_hf_peft_config_loaded", False
+        )
+        _quantization_method_supports_training = (
+            getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
+        )
+
+        _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
+            model.hf_quantizer, "is_qat_trainable", False
+        )
+
+        # Filter out quantized + compiled models
+        if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
+            raise ValueError(
+                "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
+            )
+
+        # At this stage the model is already loaded
+        if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable:
+            raise ValueError(
+                "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
+                " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
+                " for more details"
+            )
+        elif _is_quantized_and_base_model and not _quantization_method_supports_training:
+            raise ValueError(
+                f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
+                " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
+                f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
+            )
+
+        self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
+        if len(args.fsdp) > 0:
+            if self.is_deepspeed_enabled:
+                raise ValueError(
+                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
+                )
+            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
+                raise ValueError("Using fsdp only works in distributed training.")
+
+        # one place to sort out whether to place the model on device or not
+        # postpone switching model to cuda when:
+        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
+        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
+        #    and we only use deepspeed for training at the moment
+        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
+        # 4. FSDP - same as MP
+        self.place_model_on_device = args.place_model_on_device
+        if (
+            self.is_model_parallel
+            or self.is_deepspeed_enabled
+            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
+            or self.is_fsdp_xla_enabled
+            or self.is_fsdp_enabled
+        ):
+            self.place_model_on_device = False
+
+        default_collator = (
+            DataCollatorWithPadding(processing_class)
+            if processing_class is not None
+            and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
+            else default_data_collator
+        )
+        self.data_collator = data_collator if data_collator is not None else default_collator
+        self.train_dataset = train_dataset
+        self.eval_dataset = eval_dataset
+        self.processing_class = processing_class
+
+        # Bnb Quantized models doesn't support `.to` operation.
+        if (
+            self.place_model_on_device
+            and not getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+        ):
+            self._move_model_to_device(model, args.device)
+
+        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
+        if self.is_model_parallel:
+            self.args._n_gpu = 1
+
+        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
+        self.model_wrapped = model
+        self.model = model
+
+        # Just in case the model was wrapped outside of the `Trainer`
+        unwrapped_model = self.accelerator.unwrap_model(model)
+        model_forward = (
+            unwrapped_model.forward
+            if not _is_peft_model(unwrapped_model)
+            else unwrapped_model.get_base_model().forward
+        )
+        forward_params = inspect.signature(model_forward).parameters
+
+        # Check if the model has explicit setup for loss kwargs,
+        # if not, check if `**kwargs` are in model.forward
+        if hasattr(model, "accepts_loss_kwargs"):
+            self.model_accepts_loss_kwargs = model.accepts_loss_kwargs
+        else:
+            self.model_accepts_loss_kwargs = any(
+                k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
+            )
+
+        self.neftune_noise_alpha = args.neftune_noise_alpha
+
+        self.compute_metrics = compute_metrics
+        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
+        self.optimizer, self.lr_scheduler = optimizers
+        self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
+        if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
+            raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
+        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
+            raise RuntimeError(
+                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
+                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+            )
+        if is_torch_xla_available() and self.optimizer is not None:
+            for param in self.model.parameters():
+                model_device = param.device
+                break
+            for param_group in self.optimizer.param_groups:
+                if len(param_group["params"]) > 0:
+                    optimizer_device = param_group["params"][0].device
+                    break
+            if model_device != optimizer_device:
+                raise ValueError(
+                    "The model and the optimizer parameters are not on the same device, which probably means you"
+                    " created an optimizer around your model **before** putting on the device and passing it to the"
+                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
+                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
+                )
+        if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
+            self.optimizer is not None or self.lr_scheduler is not None
+        ):
+            raise RuntimeError(
+                "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
+                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+            )
+        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+        self.callback_handler = CallbackHandler(
+            callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
+        )
+        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
+
+        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
+        self._loggers_initialized = False
+
+        # Create distant repo and output directory if needed
+        self.hub_model_id = None
+        if self.args.push_to_hub:
+            self.init_hf_repo()
+        if self.args.should_save:
+            os.makedirs(self.args.output_dir, exist_ok=True)
+
+        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
+            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
+
+        if args.max_steps > 0 and args.num_train_epochs > 0:
+            logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
+            raise ValueError(
+                "The train_dataset does not implement __len__, max_steps has to be specified. "
+                "The number of steps needs to be known in advance for the learning rate scheduler."
+            )
+
+        if (
+            train_dataset is not None
+            and isinstance(train_dataset, torch.utils.data.IterableDataset)
+            and args.group_by_length
+        ):
+            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
+
+        self._signature_columns = None
+
+        # Mixed precision setup
+        self.use_apex = False
+        self.use_cpu_amp = False
+
+        # Mixed precision setup for SageMaker Model Parallel
+        if is_sagemaker_mp_enabled():
+            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
+            if args.bf16:
+                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
+
+            if IS_SAGEMAKER_MP_POST_1_10:
+                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
+                if args.fp16 != smp.state.cfg.fp16:
+                    logger.warning(
+                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+                        f"but FP16 provided in trainer argument is {args.fp16}, "
+                        f"setting to {smp.state.cfg.fp16}"
+                    )
+                    args.fp16 = smp.state.cfg.fp16
+            else:
+                # smp < 1.10 does not support fp16 in trainer.
+                if hasattr(smp.state.cfg, "fp16"):
+                    logger.warning(
+                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
+                    )
+        if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
+            if args.device == torch.device("cpu"):
+                if args.fp16:
+                    if not is_torch_greater_or_equal_than_2_3:
+                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
+                else:
+                    args.half_precision_backend = "cpu_amp"
+            logger.info(f"Using {args.half_precision_backend} half precision backend")
+
+        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
+            # deepspeed and SageMaker Model Parallel manage their own half precision
+            if args.half_precision_backend == "cpu_amp":
+                self.use_cpu_amp = True
+                self.amp_dtype = torch.bfloat16
+            elif args.half_precision_backend == "apex":
+                if not is_apex_available():
+                    raise ImportError(
+                        "Using FP16 with APEX but APEX is not installed, please refer to"
+                        " https://www.github.com/nvidia/apex."
+                    )
+                self.use_apex = True
+
+        # Label smoothing
+        if self.args.label_smoothing_factor != 0:
+            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
+        else:
+            self.label_smoother = None
+
+        self.control = TrainerControl()
+
+        self.state = TrainerState(
+            is_local_process_zero=self.is_local_process_zero(),
+            is_world_process_zero=self.is_world_process_zero(),
+            stateful_callbacks=[
+                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+            ],
+        )
+        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
+        # returned to 0 every time flos need to be logged
+        self.current_flos = 0
+        self.hp_search_backend = None
+        default_label_names = find_labels(self.model.__class__)
+        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
+        self.can_return_loss = can_return_loss(self.model.__class__)
+        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+
+        # Internal variables to help with automatic batch size reduction
+        self._train_batch_size = args.train_batch_size
+        self._created_lr_scheduler = False
+
+        # very last
+        self._memory_tracker.stop_and_update_metrics()
+
+        # torch.compile
+        if args.torch_compile and not is_torch_compile_available():
+            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
+
+        self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
+        if self.is_fsdp_xla_v2_enabled:
+            if not IS_XLA_FSDPV2_POST_2_2:
+                raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
+            # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
+            # Tensor axis is just a placeholder where it will not be used in FSDPv2.
+            num_devices = xr.global_runtime_device_count()
+            xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
+        self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
+
+    @property
+    def tokenizer(self) -> Optional[PreTrainedTokenizerBase]:
+        logger.warning("Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.")
+        return self.processing_class
+
+    @tokenizer.setter
+    def tokenizer(self, processing_class) -> None:
+        logger.warning(
+            "Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead."
+        )
+        self.processing_class = processing_class
+
+    def _activate_neftune(self, model):
+        r"""
+        Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
+        https://arxiv.org/abs/2310.05914
+        """
+        unwrapped_model = self.accelerator.unwrap_model(model)
+
+        if _is_peft_model(unwrapped_model):
+            embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+        else:
+            embeddings = unwrapped_model.get_input_embeddings()
+
+        del unwrapped_model
+
+        embeddings.neftune_noise_alpha = self.neftune_noise_alpha
+        hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
+        self.neftune_hook_handle = hook_handle
+        return model
+
+    def _deactivate_neftune(self, model):
+        """
+        Deactivates the neftune method. Make sure to call `_activate_neftune` first.
+        """
+        if not hasattr(self, "neftune_hook_handle"):
+            raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
+
+        unwrapped_model = self.accelerator.unwrap_model(model)
+
+        if _is_peft_model(unwrapped_model):
+            embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+        else:
+            embeddings = unwrapped_model.get_input_embeddings()
+
+        self.neftune_hook_handle.remove()
+        del embeddings.neftune_noise_alpha, unwrapped_model
+
+    def add_callback(self, callback):
+        """
+        Add a callback to the current list of [`~transformers.TrainerCallback`].
+
+        Args:
+           callback (`type` or [`~transformers.TrainerCallback]`):
+               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+               first case, will instantiate a member of that class.
+        """
+        self.callback_handler.add_callback(callback)
+
+    def pop_callback(self, callback):
+        """
+        Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
+
+        If the callback is not found, returns `None` (and no error is raised).
+
+        Args:
+           callback (`type` or [`~transformers.TrainerCallback]`):
+               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+               first case, will pop the first member of that class found in the list of callbacks.
+
+        Returns:
+            [`~transformers.TrainerCallback`]: The callback removed, if found.
+        """
+        return self.callback_handler.pop_callback(callback)
+
+    def remove_callback(self, callback):
+        """
+        Remove a callback from the current list of [`~transformers.TrainerCallback`].
+
+        Args:
+           callback (`type` or [`~transformers.TrainerCallback]`):
+               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+               first case, will remove the first member of that class found in the list of callbacks.
+        """
+        self.callback_handler.remove_callback(callback)
+
+    def _move_model_to_device(self, model, device):
+        model = model.to(device)
+        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
+        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
+            model.tie_weights()
+
+    def _set_signature_columns_if_needed(self):
+        if self._signature_columns is None:
+            # Inspect model forward signature to keep only the arguments it accepts.
+            model_to_inspect = self.model
+            if _is_peft_model(self.model):
+                if hasattr(self.model, "get_base_model"):
+                    model_to_inspect = self.model.get_base_model()
+                else:
+                    # PeftMixedModel do not provide a `get_base_model` method
+                    model_to_inspect = self.model.base_model.model
+            signature = inspect.signature(model_to_inspect.forward)
+            self._signature_columns = list(signature.parameters.keys())
+            # Labels may be named label or label_ids, the default data collator handles that.
+            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
+
+    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+        if not self.args.remove_unused_columns:
+            return dataset
+        self._set_signature_columns_if_needed()
+        signature_columns = self._signature_columns
+
+        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
+        if len(ignored_columns) > 0:
+            dset_description = "" if description is None else f"in the {description} set"
+            logger.info(
+                f"The following columns {dset_description} don't have a corresponding argument in "
+                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
+                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
+                " you can safely ignore this message."
+            )
+
+        columns = [k for k in signature_columns if k in dataset.column_names]
+        if len(columns) == 0:
+            raise ValueError(
+                "No columns in the dataset match the model's forward method signature. "
+                f"The following columns have been ignored: [{', '.join(ignored_columns)}]. "
+                "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`."
+            )
+
+        if version.parse(datasets.__version__) < version.parse("1.4.0"):
+            dataset.set_format(
+                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
+            )
+            return dataset
+        else:
+            return dataset.remove_columns(ignored_columns)
+
+    def _get_collator_with_removed_columns(
+        self, data_collator: Callable, description: Optional[str] = None
+    ) -> Callable:
+        """Wrap the data collator in a callable removing unused columns."""
+        if not self.args.remove_unused_columns:
+            return data_collator
+        self._set_signature_columns_if_needed()
+        signature_columns = self._signature_columns
+
+        remove_columns_collator = RemoveColumnsCollator(
+            data_collator=data_collator,
+            signature_columns=signature_columns,
+            logger=logger,
+            description=description,
+            model_name=self.model.__class__.__name__,
+        )
+        return remove_columns_collator
+
+    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+        if self.train_dataset is None or not has_length(self.train_dataset):
+            return None
+
+        # Build the sampler.
+        if self.args.group_by_length:
+            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
+                lengths = (
+                    self.train_dataset[self.args.length_column_name]
+                    if self.args.length_column_name in self.train_dataset.column_names
+                    else None
+                )
+            else:
+                lengths = None
+            model_input_name = (
+                self.processing_class.model_input_names[0] if self.processing_class is not None else None
+            )
+            return LengthGroupedSampler(
+                self.args.train_batch_size * self.args.gradient_accumulation_steps,
+                dataset=self.train_dataset,
+                lengths=lengths,
+                model_input_name=model_input_name,
+            )
+
+        else:
+            return RandomSampler(self.train_dataset)
+
+    def get_train_dataloader(self) -> DataLoader:
+        """
+        Returns the training [`~torch.utils.data.DataLoader`].
+
+        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+        training if necessary) otherwise.
+
+        Subclass and override this method if you want to inject some custom behavior.
+        """
+        if self.train_dataset is None:
+            raise ValueError("Trainer: training requires a train_dataset.")
+
+        train_dataset = self.train_dataset
+        data_collator = self.data_collator
+        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+            train_dataset = self._remove_unused_columns(train_dataset, description="training")
+        else:
+            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
+
+        dataloader_params = {
+            "batch_size": self._train_batch_size,
+            "collate_fn": data_collator,
+            "num_workers": self.args.dataloader_num_workers,
+            "pin_memory": self.args.dataloader_pin_memory,
+            "persistent_workers": self.args.dataloader_persistent_workers,
+        }
+
+        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
+            dataloader_params["sampler"] = self._get_train_sampler()
+            dataloader_params["drop_last"] = self.args.dataloader_drop_last
+            dataloader_params["worker_init_fn"] = seed_worker
+            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
+
+        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
+
+    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
+        if eval_dataset is None or not has_length(eval_dataset):
+            return None
+        # Build the sampler.
+
+        # Deprecated code
+        if self.args.use_legacy_prediction_loop:
+            if is_torch_xla_available():
+                return SequentialDistributedSampler(
+                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
+                )
+            elif is_sagemaker_mp_enabled():
+                return SequentialDistributedSampler(
+                    eval_dataset,
+                    num_replicas=smp.dp_size(),
+                    rank=smp.dp_rank(),
+                    batch_size=self.args.per_device_eval_batch_size,
+                )
+            else:
+                return SequentialSampler(eval_dataset)
+
+        if self.args.group_by_length:
+            if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
+                lengths = (
+                    eval_dataset[self.args.length_column_name]
+                    if self.args.length_column_name in eval_dataset.column_names
+                    else None
+                )
+            else:
+                lengths = None
+            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
+            return LengthGroupedSampler(
+                self.args.eval_batch_size,
+                dataset=eval_dataset,
+                lengths=lengths,
+                model_input_name=model_input_name,
+            )
+
+        if self.args.world_size <= 1:
+            return SequentialSampler(eval_dataset)
+        else:
+            return None
+
+    def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
+        """
+        Returns the evaluation [`~torch.utils.data.DataLoader`].
+
+        Subclass and override this method if you want to inject some custom behavior.
+
+        Args:
+            eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
+                If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
+        """
+        if eval_dataset is None and self.eval_dataset is None:
+            raise ValueError("Trainer: evaluation requires an eval_dataset.")
+
+        # If we have persistent workers, don't do a fork bomb especially as eval datasets
+        # don't change during training
+        dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
+        if (
+            hasattr(self, "_eval_dataloaders")
+            and dataloader_key in self._eval_dataloaders
+            and self.args.dataloader_persistent_workers
+        ):
+            return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
+
+        eval_dataset = (
+            self.eval_dataset[eval_dataset]
+            if isinstance(eval_dataset, str)
+            else eval_dataset
+            if eval_dataset is not None
+            else self.eval_dataset
+        )
+        data_collator = self.data_collator
+
+        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
+            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
+        else:
+            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
+
+        dataloader_params = {
+            "batch_size": self.args.eval_batch_size,
+            "collate_fn": data_collator,
+            "num_workers": self.args.dataloader_num_workers,
+            "pin_memory": self.args.dataloader_pin_memory,
+            "persistent_workers": self.args.dataloader_persistent_workers,
+        }
+
+        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
+            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
+            dataloader_params["drop_last"] = self.args.dataloader_drop_last
+            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
+
+        # accelerator.free_memory() will destroy the references, so
+        # we need to store the non-prepared version
+        eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
+        if self.args.dataloader_persistent_workers:
+            if hasattr(self, "_eval_dataloaders"):
+                self._eval_dataloaders[dataloader_key] = eval_dataloader
+            else:
+                self._eval_dataloaders = {dataloader_key: eval_dataloader}
+
+        return self.accelerator.prepare(eval_dataloader)
+
+    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
+        """
+        Returns the test [`~torch.utils.data.DataLoader`].
+
+        Subclass and override this method if you want to inject some custom behavior.
+
+        Args:
+            test_dataset (`torch.utils.data.Dataset`, *optional*):
+                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
+                `model.forward()` method are automatically removed. It must implement `__len__`.
+        """
+        data_collator = self.data_collator
+
+        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
+            test_dataset = self._remove_unused_columns(test_dataset, description="test")
+        else:
+            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
+
+        dataloader_params = {
+            "batch_size": self.args.eval_batch_size,
+            "collate_fn": data_collator,
+            "num_workers": self.args.dataloader_num_workers,
+            "pin_memory": self.args.dataloader_pin_memory,
+            "persistent_workers": self.args.dataloader_persistent_workers,
+        }
+
+        if not isinstance(test_dataset, torch.utils.data.IterableDataset):
+            dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
+            dataloader_params["drop_last"] = self.args.dataloader_drop_last
+            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
+
+        # We use the same batch_size as for eval.
+        return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
+
+    def create_optimizer_and_scheduler(self, num_training_steps: int):
+        """
+        Setup the optimizer and the learning rate scheduler.
+
+        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
+        `create_scheduler`) in a subclass.
+        """
+        self.create_optimizer()
+        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
+            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
+            optimizer = self.optimizer.optimizer
+        else:
+            optimizer = self.optimizer
+        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
+
+    def get_decay_parameter_names(self, model) -> List[str]:
+        """
+        Get all parameter names that weight decay will be applied to
+
+        Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
+        apply to those modules since this function only filter out instance of nn.LayerNorm
+        """
+        decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
+        decay_parameters = [name for name in decay_parameters if "bias" not in name]
+        return decay_parameters
+
+    def create_optimizer(self):
+        """
+        Setup the optimizer.
+
+        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+        """
+        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+
+        if self.optimizer is None:
+            decay_parameters = self.get_decay_parameter_names(opt_model)
+            optimizer_grouped_parameters = [
+                {
+                    "params": [
+                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+                    ],
+                    "weight_decay": self.args.weight_decay,
+                },
+                {
+                    "params": [
+                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+                    ],
+                    "weight_decay": 0.0,
+                },
+            ]
+
+            if self.optimizer_cls_and_kwargs is not None:
+                optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
+            else:
+                optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
+
+            # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
+            # e.g. for GaLore optimizer.
+            if "params" in optimizer_kwargs:
+                optimizer_grouped_parameters = optimizer_kwargs.pop("params")
+
+            # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
+            # e.g. for LOMO optimizer.
+            if "model" in optimizer_kwargs:
+                optimizer_grouped_parameters = optimizer_kwargs.pop("model")
+
+            # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
+            # to avoid arguments conflicts.
+            if "optimizer_dict" in optimizer_kwargs:
+                optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
+
+            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+
+            if optimizer_cls.__name__ == "Adam8bit":
+                import bitsandbytes
+
+                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+                skipped = 0
+                for module in opt_model.modules():
+                    if isinstance(module, nn.Embedding):
+                        skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+                        logger.info(f"skipped {module}: {skipped/2**20}M params")
+                        manager.register_module_override(module, "weight", {"optim_bits": 32})
+                        logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+                logger.info(f"skipped: {skipped/2**20}M params")
+
+        if is_sagemaker_mp_enabled():
+            self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+        return self.optimizer
+
+    def get_num_trainable_parameters(self):
+        """
+        Get the number of trainable parameters.
+        """
+        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+
+    def get_learning_rates(self):
+        """
+        Returns the learning rate of each parameter from self.optimizer.
+        """
+        if self.optimizer is None:
+            raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
+        return [group["lr"] for group in self.optimizer.param_groups]
+
+    def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None):
+        """
+        Returns optimizer group for a parameter if given, else returns all optimizer groups for params.
+
+        Args:
+            param (`str` or `torch.nn.parameter.Parameter`, *optional*):
+                The parameter for which optimizer group needs to be returned.
+        """
+        if self.optimizer is None:
+            raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
+        if param is not None:
+            for group in self.optimizer.param_groups:
+                if param in group["params"]:
+                    return group
+        return [group["params"] for group in self.optimizer.param_groups]
+
+    @staticmethod
+    def get_optimizer_cls_and_kwargs(
+        args: TrainingArguments, model: Optional[PreTrainedModel] = None
+    ) -> Tuple[Any, Any]:
+        """
+        Returns the optimizer class and optimizer parameters based on the training arguments.
+
+        Args:
+            args (`transformers.training_args.TrainingArguments`):
+                The training arguments for the training session.
+
+        """
+
+        # parse args.optim_args
+        optim_args = {}
+        if args.optim_args:
+            for mapping in args.optim_args.replace(" ", "").split(","):
+                key, value = mapping.split("=")
+                optim_args[key] = value
+
+        optimizer_kwargs = {"lr": args.learning_rate}
+
+        adam_kwargs = {
+            "betas": (args.adam_beta1, args.adam_beta2),
+            "eps": args.adam_epsilon,
+        }
+        if args.optim == OptimizerNames.ADAFACTOR:
+            optimizer_cls = Adafactor
+            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+        elif args.optim == OptimizerNames.ADAMW_HF:
+            from .optimization import AdamW
+
+            optimizer_cls = AdamW
+            optimizer_kwargs.update(adam_kwargs)
+        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
+            from torch.optim import AdamW
+
+            optimizer_cls = AdamW
+            optimizer_kwargs.update(adam_kwargs)
+            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
+                optimizer_kwargs.update({"fused": True})
+        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
+            try:
+                from torch_xla.amp.syncfree import AdamW
+
+                optimizer_cls = AdamW
+                optimizer_kwargs.update(adam_kwargs)
+            except ImportError:
+                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
+        elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
+            try:
+                from torch_npu.optim import NpuFusedAdamW
+
+                optimizer_cls = NpuFusedAdamW
+                optimizer_kwargs.update(adam_kwargs)
+            except ImportError:
+                raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
+        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
+            try:
+                from apex.optimizers import FusedAdam
+
+                optimizer_cls = FusedAdam
+                optimizer_kwargs.update(adam_kwargs)
+            except ImportError:
+                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
+        elif args.optim in [
+            OptimizerNames.ADAMW_BNB,
+            OptimizerNames.ADAMW_8BIT,
+            OptimizerNames.PAGED_ADAMW,
+            OptimizerNames.PAGED_ADAMW_8BIT,
+            OptimizerNames.ADEMAMIX,
+            OptimizerNames.ADEMAMIX_8BIT,
+            OptimizerNames.PAGED_ADEMAMIX,
+            OptimizerNames.PAGED_ADEMAMIX_8BIT,
+            OptimizerNames.LION,
+            OptimizerNames.LION_8BIT,
+            OptimizerNames.PAGED_LION,
+            OptimizerNames.PAGED_LION_8BIT,
+            OptimizerNames.RMSPROP_BNB,
+            OptimizerNames.RMSPROP_8BIT,
+            OptimizerNames.RMSPROP_32BIT,
+        ]:
+            try:
+                from bitsandbytes.optim import AdamW, Lion, RMSprop
+
+                is_paged = False
+                optim_bits = 32
+                optimizer_cls = None
+                additional_optim_kwargs = adam_kwargs
+                if "paged" in args.optim:
+                    is_paged = True
+                if "8bit" in args.optim:
+                    optim_bits = 8
+                if "adam" in args.optim:
+                    optimizer_cls = AdamW
+                elif "lion" in args.optim:
+                    optimizer_cls = Lion
+                    additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
+                elif "rmsprop" in args.optim:
+                    optimizer_cls = RMSprop
+                    # Above we pass all `adam_kwargs` to the optimizer, here
+                    # we only pass `optim_args` which can be passed by the user.
+                    additional_optim_kwargs = optim_args
+                elif "ademamix" in args.optim:
+                    if is_bitsandbytes_available() and version.parse(
+                        importlib.metadata.version("bitsandbytes")
+                    ) < version.parse("0.44.0"):
+                        raise ValueError(
+                            "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
+                            "Please install `bitsandbytes` >= 0.44.0."
+                        )
+
+                    from bitsandbytes.optim import AdEMAMix
+
+                    optimizer_cls = AdEMAMix
+                    additional_optim_kwargs = {
+                        "betas": (
+                            float(optim_args.get("beta1", args.adam_beta1)),
+                            float(optim_args.get("beta2", args.adam_beta2)),
+                            float(optim_args.get("beta3", 0.9999)),
+                        ),
+                        "alpha": float(optim_args.get("alpha", 5.0)),
+                        "eps": float(optim_args.get("eps", args.adam_epsilon)),
+                    }
+
+                    if "t_alpha" in optim_args:
+                        additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
+
+                    if "t_beta3" in optim_args:
+                        additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
+
+                bnb_kwargs = {"optim_bits": optim_bits}
+                if "rmsprop" not in args.optim:
+                    bnb_kwargs["is_paged"] = is_paged
+
+                optimizer_kwargs.update(additional_optim_kwargs)
+                optimizer_kwargs.update(bnb_kwargs)
+            except ImportError:
+                raise ValueError("Trainer tried to instantiate bnb optimizer but `bitsandbytes` is not installed!")
+            if is_bitsandbytes_available() and version.parse(
+                importlib.metadata.version("bitsandbytes")
+            ) < version.parse("0.41.1"):
+                logger.warning(
+                    "You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. "
+                    "It is recommended to update your version as a major bug has been fixed in 8-bit optimizers."
+                )
+        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
+            try:
+                from torchdistx.optimizers import AnyPrecisionAdamW
+
+                optimizer_cls = AnyPrecisionAdamW
+                optimizer_kwargs.update(adam_kwargs)
+
+                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
+                optimizer_kwargs.update(
+                    {
+                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
+                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
+                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
+                        "compensation_buffer_dtype": getattr(
+                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
+                        ),
+                    }
+                )
+            except ImportError:
+                raise ValueError("Please install https://github.com/pytorch/torchdistx")
+        elif args.optim == OptimizerNames.SGD:
+            optimizer_cls = torch.optim.SGD
+        elif args.optim == OptimizerNames.ADAGRAD:
+            optimizer_cls = torch.optim.Adagrad
+        elif args.optim == OptimizerNames.RMSPROP:
+            optimizer_cls = torch.optim.RMSprop
+        elif args.optim in [
+            OptimizerNames.GALORE_ADAMW,
+            OptimizerNames.GALORE_ADAMW_8BIT,
+            OptimizerNames.GALORE_ADAFACTOR,
+            OptimizerNames.GALORE_ADAMW_LAYERWISE,
+            OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
+            OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
+        ]:
+            if not is_galore_torch_available():
+                raise ImportError(
+                    "You need to install `galore_torch` in order to use GaLore optimizers"
+                    " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
+                )
+            from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
+
+            is_layerwise = args.optim.lower().endswith("layerwise")
+            if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
+                raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
+
+            optimizer_mapping = {
+                OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
+                OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
+                OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
+                OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
+                OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
+                OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
+            }
+
+            optimizer_cls = optimizer_mapping[args.optim]
+
+            if args.optim_target_modules is None:
+                raise ValueError(
+                    "You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
+                )
+
+            if not isinstance(args.optim_target_modules, (list, str)):
+                raise ValueError(
+                    f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
+                )
+
+            if model is None:
+                raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
+
+            logger.warning(
+                "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
+            )
+
+            all_linear = (
+                isinstance(args.optim_target_modules, str)
+                and args.optim_target_modules.replace("_", "-") == "all-linear"
+            )
+
+            galore_params = []
+            galore_params_names = []
+            for module_name, module in model.named_modules():
+                target_module_exists, is_regex = check_target_module_exists(
+                    args.optim_target_modules, module_name, return_is_regex=True
+                )
+
+                if not isinstance(module, nn.Linear):
+                    # Warn in case we match but it's not a linear layer
+                    if target_module_exists and not is_regex:
+                        logger.warning(
+                            f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
+                        )
+
+                    continue
+
+                if not target_module_exists and not all_linear:
+                    continue
+
+                galore_params.append(module.weight)
+                galore_params_names.append(module_name + ".weight")
+
+            if len(galore_params) == 0:
+                raise ValueError(
+                    f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
+                )
+
+            non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
+
+            galore_optim_kwargs = {
+                "rank": int(optim_args.pop("rank", 128)),
+                "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
+                "scale": float(optim_args.pop("scale", 0.25)),
+                "proj_type": optim_args.pop("proj_type", "std"),
+            }
+
+            # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
+            param_groups = [
+                {"params": non_galore_params},
+                {"params": galore_params, **galore_optim_kwargs},
+            ]
+
+            if is_layerwise:
+                # For layer-wise optimizers, the optimization step is done through post accumulation
+                # gradient hooks. The trick is to first attach these hooks to the model parameters then
+                # create a dummy optimizer that will perform no-ops in the Trainer.
+                # See the original implementation or the nice implementation from @hiyouga
+                # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
+                if args.gradient_accumulation_steps != 1:
+                    raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
+
+                optimizer_dict = {}
+                for param in non_galore_params:
+                    param_groups = [{"params": [param]}]
+                    optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
+                for param in galore_params:
+                    param_groups = [{"params": [param], **galore_optim_kwargs}]
+                    optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
+
+                def optimizer_hook(param):
+                    if param.grad is not None:
+                        optimizer_dict[param].step()
+                        optimizer_dict[param].zero_grad()
+
+                for param in model.parameters():
+                    if param.requires_grad:
+                        param.register_post_accumulate_grad_hook(optimizer_hook)
+
+                optimizer_cls = LayerWiseDummyOptimizer
+                optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
+
+            optimizer_kwargs.update({"params": param_groups})
+
+            if args.optim == OptimizerNames.GALORE_ADAFACTOR:
+                optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+        elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+            if not is_lomo_available():
+                raise ImportError(
+                    "You need to install `lomo_optim` in order to use LOMO optimizers"
+                    " install it with `pip install lomo-optim`"
+                )
+            if not is_accelerate_available("0.30.0"):
+                raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")
+
+            if model is None:
+                raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")
+
+            from lomo_optim import AdaLomo, Lomo
+
+            if "ada" in args.optim:
+                optimizer_cls = AdaLomo
+            else:
+                optimizer_cls = Lomo
+
+            optimizer_kwargs.update({"model": model})
+        elif args.optim == OptimizerNames.GROKADAMW:
+            if not is_grokadamw_available():
+                raise ValueError("Please install grokadamw with `pip install grokadamw`")
+
+            from grokadamw import GrokAdamW
+
+            optimizer_cls = GrokAdamW
+            optimizer_kwargs.update(
+                {
+                    "alpha_init": float(optim_args.get("alpha_init", 0.98)),
+                    "lamb": float(optim_args.get("lamb", 2.0)),
+                    "gamma": float(optim_args.get("gamma", 0.1)),
+                    "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
+                    "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
+                }
+            )
+        elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
+            if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
+                "0.4.0"
+            ):
+                raise ImportError(
+                    "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
+                    "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
+                )
+            if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"):
+                raise ImportError(
+                    "You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
+                    "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
+                )
+            from torchao.prototype.low_bit_optim import AdamW4bit
+
+            optimizer_cls = AdamW4bit
+            optimizer_kwargs.update(adam_kwargs)
+        elif args.optim in [
+            OptimizerNames.SCHEDULE_FREE_ADAMW,
+            OptimizerNames.SCHEDULE_FREE_SGD,
+        ]:
+            if not is_schedulefree_available():
+                raise ImportError(
+                    "You need to install `schedulefree` in order to use schedulefree optimizers"
+                    " install it with `pip install schedulefree`"
+                )
+            if not is_accelerate_available("0.30.0"):
+                raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
+            from schedulefree import AdamWScheduleFree, SGDScheduleFree
+
+            additional_optim_kwargs = {}
+            if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
+                optimizer_cls = AdamWScheduleFree
+                additional_optim_kwargs = adam_kwargs
+            elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
+                optimizer_cls = SGDScheduleFree
+            else:
+                raise ValueError("Invalid schedulefree optimizer")
+            additional_optim_kwargs["weight_decay"] = args.weight_decay
+            additional_optim_kwargs["warmup_steps"] = args.warmup_steps
+            additional_optim_kwargs.update(
+                {
+                    "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
+                    "r": float(optim_args.get("r", 0.0)),
+                }
+            )
+            optimizer_kwargs.update(additional_optim_kwargs)
+        else:
+            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
+        return optimizer_cls, optimizer_kwargs
+
+    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
+        """
+        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
+        passed as an argument.
+
+        Args:
+            num_training_steps (int): The number of training steps to do.
+        """
+        if self.lr_scheduler is None:
+            self.lr_scheduler = get_scheduler(
+                self.args.lr_scheduler_type,
+                optimizer=self.optimizer if optimizer is None else optimizer,
+                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
+                num_training_steps=num_training_steps,
+                scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
+            )
+            self._created_lr_scheduler = True
+        return self.lr_scheduler
+
+    def num_examples(self, dataloader: DataLoader) -> int:
+        """
+        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
+        dataloader.dataset does not exist or has no length, estimates as best it can
+        """
+        try:
+            dataset = dataloader.dataset
+            # Special case for IterableDatasetShard, we need to dig deeper
+            if isinstance(dataset, IterableDatasetShard):
+                return len(dataloader.dataset.dataset)
+            return len(dataloader.dataset)
+        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
+            return len(dataloader) * self.args.per_device_train_batch_size
+
+    @staticmethod
+    def num_tokens(train_dl: DataLoader, max_steps: Optional[int] = None) -> int:
+        """
+        Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader.
+        """
+        train_tokens = 0
+        try:
+            for batch in train_dl:
+                tokens = batch["input_ids"].numel()
+                if max_steps is not None:
+                    return tokens * max_steps
+                train_tokens += tokens
+        except KeyError:
+            logger.warning("Cannot get num_tokens from dataloader")
+        return train_tokens
+
+    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
+        """HP search setup code"""
+        self._trial = trial
+
+        if self.hp_search_backend is None or trial is None:
+            return
+        if self.hp_search_backend == HPSearchBackend.OPTUNA:
+            params = self.hp_space(trial)
+        elif self.hp_search_backend == HPSearchBackend.RAY:
+            params = trial
+            params.pop("wandb", None)
+        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
+        elif self.hp_search_backend == HPSearchBackend.WANDB:
+            params = trial
+
+        for key, value in params.items():
+            if not hasattr(self.args, key):
+                logger.warning(
+                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
+                    " `TrainingArguments`."
+                )
+                continue
+            old_attr = getattr(self.args, key, None)
+            # Casting value to the proper type
+            if old_attr is not None:
+                value = type(old_attr)(value)
+
+            setattr(self.args, key, value)
+        if self.hp_search_backend == HPSearchBackend.OPTUNA:
+            logger.info(f"Trial: {trial.params}")
+        if self.hp_search_backend == HPSearchBackend.SIGOPT:
+            logger.info(f"SigOpt Assignments: {trial.assignments}")
+        if self.hp_search_backend == HPSearchBackend.WANDB:
+            logger.info(f"W&B Sweep parameters: {trial}")
+        if self.is_deepspeed_enabled:
+            if self.args.deepspeed is None:
+                raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
+
+            self.accelerator.free_memory()
+
+            # Rebuild the deepspeed config to reflect the updated training parameters
+            from accelerate.utils import DeepSpeedPlugin
+
+            from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
+            self.args.hf_deepspeed_config.trainer_config_process(self.args)
+            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
+
+            # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
+            # Simply calling `_reset_state` is enough and doesn't need a version pin.
+            AcceleratorState()._reset_state()
+
+        self.create_accelerator_and_postprocess()
+
+    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
+        if self.hp_search_backend is None or trial is None:
+            return
+        metrics = metrics.copy()
+        self.objective = self.compute_objective(metrics)
+        if self.hp_search_backend == HPSearchBackend.OPTUNA:
+            import optuna
+
+            if hasattr(trial, "study") and not trial.study._is_multi_objective():
+                trial.report(self.objective, step)
+                if trial.should_prune():
+                    self.callback_handler.on_train_end(self.args, self.state, self.control)
+                    raise optuna.TrialPruned()
+        elif self.hp_search_backend == HPSearchBackend.RAY:
+            import ray.train
+
+            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
+                checkpoint = None
+                if self.control.should_save:
+                    self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
+                    checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
+                metrics["objective"] = self.objective
+                ray.train.report(metrics, checkpoint=checkpoint)
+
+    def _tune_save_checkpoint(self, checkpoint_dir: str):
+        output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
+        self.save_model(output_dir, _internal_call=True)
+        if self.args.should_save:
+            # Update the `TrainerControl` state to where we are currently
+            self.state.stateful_callbacks["TrainerControl"] = self.control.state()
+            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+
+    def call_model_init(self, trial=None):
+        model_init_argcount = number_of_arguments(self.model_init)
+        if model_init_argcount == 0:
+            model = self.model_init()
+        elif model_init_argcount == 1:
+            model = self.model_init(trial)
+        else:
+            raise RuntimeError("model_init should have 0 or 1 argument.")
+
+        if model is None:
+            raise RuntimeError("model_init should not return None.")
+
+        return model
+
+    def torch_jit_model_eval(self, model, dataloader, training=False):
+        if not training:
+            if dataloader is None:
+                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
+                return model
+            example_batch = next(iter(dataloader))
+            example_batch = self._prepare_inputs(example_batch)
+            try:
+                jit_model = copy.copy(model)
+                jit_model.eval()
+                original_forward = jit_model.__dict__.pop("_original_forward", None)
+                # remove mixed precision hooks from the model
+                if original_forward:
+                    jit_model.forward = original_forward
+                with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
+                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
+                        if isinstance(example_batch, dict):
+                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
+                        else:
+                            jit_model = torch.jit.trace(
+                                jit_model,
+                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
+                                strict=False,
+                            )
+                    else:
+                        jit_inputs = []
+                        for key in example_batch:
+                            example_tensor = torch.ones_like(example_batch[key])
+                            jit_inputs.append(example_tensor)
+                        jit_inputs = tuple(jit_inputs)
+                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
+                jit_model = torch.jit.freeze(jit_model)
+                with torch.no_grad():
+                    jit_model(**example_batch)
+                    jit_model(**example_batch)
+                model = jit_model
+                self.use_cpu_amp = False
+            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
+                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
+
+        return model
+
+    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
+        if not is_ipex_available():
+            raise ImportError(
+                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
+                " to https://github.com/intel/intel-extension-for-pytorch."
+            )
+
+        import intel_extension_for_pytorch as ipex
+
+        if not training:
+            model.eval()
+            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
+            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
+            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
+        else:
+            if not model.training:
+                model.train()
+            model, self.optimizer = ipex.optimize(
+                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
+            )
+
+        return model
+
+    def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
+        attributes_map = {
+            "logging_steps": "logging_steps",
+            "eval_steps": "eval_steps",
+            "save_steps": "save_steps",
+        }
+
+        has_warning = False
+        warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
+        for arg_attr, state_attr in attributes_map.items():
+            arg_value = getattr(training_args, arg_attr, None)
+            state_value = getattr(trainer_state, state_attr, None)
+
+            if arg_value is not None and state_value is not None and arg_value != state_value:
+                warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
+                has_warning = True
+
+        # train bs is special as we need to account for multi-GPU
+        train_bs_args = training_args.per_device_train_batch_size
+        train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
+
+        if train_bs_args != train_bs_state:
+            warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
+            has_warning = True
+
+        if has_warning:
+            logger.warning_once(warning_str)
+
+    def _wrap_model(self, model, training=True, dataloader=None):
+        if self.args.use_ipex:
+            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
+            model = self.ipex_optimize_model(model, training, dtype=dtype)
+
+        if is_sagemaker_mp_enabled():
+            # Wrapping the base model twice in a DistributedModel will raise an error.
+            if isinstance(self.model_wrapped, smp.model.DistributedModel):
+                return self.model_wrapped
+            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
+
+        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
+        if self.accelerator.unwrap_model(model) is not model:
+            return model
+
+        # Mixed precision training with apex (torch < 1.6)
+        if self.use_apex and training:
+            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
+
+        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
+        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
+            model = nn.DataParallel(model)
+
+        if self.args.jit_mode_eval:
+            start_time = time.time()
+            model = self.torch_jit_model_eval(model, dataloader, training)
+            self.jit_compilation_time = round(time.time() - start_time, 4)
+
+        # Note: in torch.distributed mode, there's no point in wrapping the model
+        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
+        if not training:
+            return model
+
+        # Distributed training (should be after apex fp16 initialization)
+        # Distributed training using PyTorch FSDP
+        if self.is_fsdp_xla_enabled:
+            try:
+                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
+                from torch_xla.distributed.fsdp import checkpoint_module
+                from torch_xla.distributed.fsdp.wrap import (
+                    size_based_auto_wrap_policy,
+                    transformer_auto_wrap_policy,
+                )
+
+                if self.is_fsdp_xla_v2_enabled:
+                    from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
+                        SpmdFullyShardedDataParallel as FSDPv2,
+                    )
+            except ImportError:
+                raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
+            auto_wrap_policy = None
+            auto_wrapper_callable = None
+            default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
+            fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
+                "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
+            )
+
+            if self.args.fsdp_config["min_num_params"] > 0:
+                auto_wrap_policy = functools.partial(
+                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
+                )
+            elif fsdp_transformer_layer_cls_to_wrap is not None:
+                transformer_cls_to_wrap = set()
+                for layer_class in fsdp_transformer_layer_cls_to_wrap:
+                    transformer_cls = get_module_class_from_name(model, layer_class)
+                    if transformer_cls is None:
+                        raise Exception("Could not find the transformer layer class to wrap in the model.")
+                    else:
+                        transformer_cls_to_wrap.add(transformer_cls)
+
+                auto_wrap_policy = functools.partial(
+                    transformer_auto_wrap_policy,
+                    # Transformer layer class to wrap
+                    transformer_layer_cls=transformer_cls_to_wrap,
+                )
+            fsdp_kwargs = self.args.xla_fsdp_config
+            if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
+                if model.config.use_cache:
+                    logger.warning_once(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+                    )
+                    model.config.use_cache = False
+
+                # Apply gradient checkpointing to auto-wrapped sub-modules if specified
+                def auto_wrapper_callable(m, *args, **kwargs):
+                    target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
+                    return target_cls(checkpoint_module(m), *args, **kwargs)
+
+            # Wrap the base model with an outer FSDP wrapper
+            if self.is_fsdp_xla_v2_enabled:
+
+                def shard_output(output, mesh):
+                    from .modeling_outputs import CausalLMOutputWithPast
+
+                    real_output = None
+                    if isinstance(output, torch.Tensor):
+                        real_output = output
+                    elif isinstance(output, tuple):
+                        real_output = output[0]
+                    elif isinstance(output, CausalLMOutputWithPast):
+                        real_output = output.logits
+
+                    if real_output is None:
+                        raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
+                    xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
+
+                self.model = model = FSDPv2(
+                    model,
+                    shard_output=shard_output,
+                    auto_wrap_policy=auto_wrap_policy,
+                    auto_wrapper_callable=auto_wrapper_callable,
+                )
+            else:
+                self.model = model = FSDP(
+                    model,
+                    auto_wrap_policy=auto_wrap_policy,
+                    auto_wrapper_callable=auto_wrapper_callable,
+                    **fsdp_kwargs,
+                )
+
+            # Patch `xm.optimizer_step` should not reduce gradients in this case,
+            # as FSDP does not need gradient reduction over sharded parameters.
+            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
+                loss = optimizer.step(**optimizer_args)
+                if barrier:
+                    xm.mark_step()
+                return loss
+
+            xm.optimizer_step = patched_optimizer_step
+        elif is_sagemaker_dp_enabled():
+            model = nn.parallel.DistributedDataParallel(
+                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
+            )
+        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+            if is_torch_neuroncore_available():
+                return model
+            kwargs = {}
+            if self.args.ddp_find_unused_parameters is not None:
+                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
+            elif isinstance(model, PreTrainedModel):
+                # find_unused_parameters breaks checkpointing as per
+                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
+                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
+            else:
+                kwargs["find_unused_parameters"] = True
+
+            if self.args.ddp_bucket_cap_mb is not None:
+                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
+
+            if self.args.ddp_broadcast_buffers is not None:
+                kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
+
+            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
+
+        return model
+
+    def train(
+        self,
+        resume_from_checkpoint: Optional[Union[str, bool]] = None,
+        trial: Union["optuna.Trial", Dict[str, Any]] = None,
+        ignore_keys_for_eval: Optional[List[str]] = None,
+        **kwargs,
+    ):
+        """
+        Main training entry point.
+
+        Args:
+            resume_from_checkpoint (`str` or `bool`, *optional*):
+                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
+                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
+                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
+            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
+                The trial run or the hyperparameter dictionary for hyperparameter search.
+            ignore_keys_for_eval (`List[str]`, *optional*)
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions for evaluation during the training.
+            kwargs (`Dict[str, Any]`, *optional*):
+                Additional keyword arguments used to hide deprecated arguments
+        """
+        if resume_from_checkpoint is False:
+            resume_from_checkpoint = None
+
+        # memory metrics - must set up as early as possible
+        self._memory_tracker.start()
+
+        args = self.args
+
+        self.is_in_train = True
+
+        # Attach NEFTune hooks if necessary
+        if self.neftune_noise_alpha is not None:
+            self.model = self._activate_neftune(self.model)
+
+        # do_train is not a reliable argument, as it might not be set and .train() still called, so
+        # the following is a workaround:
+        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel:
+            self._move_model_to_device(self.model, args.device)
+
+        if "model_path" in kwargs:
+            resume_from_checkpoint = kwargs.pop("model_path")
+            warnings.warn(
+                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
+                "instead.",
+                FutureWarning,
+            )
+        if len(kwargs) > 0:
+            raise TypeError(f"train() got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
+        # This might change the seed so needs to run first.
+        self._hp_search_setup(trial)
+        self._train_batch_size = self.args.train_batch_size
+
+        # Model re-init
+        model_reloaded = False
+        if self.model_init is not None:
+            # Seed must be set before instantiating the model when using model_init.
+            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+            self.model = self.call_model_init(trial)
+            model_reloaded = True
+            # Reinitializes optimizer and scheduler
+            self.optimizer, self.lr_scheduler = None, None
+
+        # Load potential model checkpoint
+        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
+            if resume_from_checkpoint is None:
+                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
+
+        if resume_from_checkpoint is not None:
+            if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
+                self._load_from_checkpoint(resume_from_checkpoint)
+            # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
+            state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+            if state.train_batch_size is not None:
+                self._train_batch_size = state.train_batch_size
+
+        # If model was re-initialized, put it on the right device and update self.model_wrapped
+        if model_reloaded:
+            if self.place_model_on_device:
+                self._move_model_to_device(self.model, args.device)
+            self.model_wrapped = self.model
+
+        inner_training_loop = find_executable_batch_size(
+            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
+        )
+        if args.push_to_hub:
+            try:
+                # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
+                hf_hub_utils.disable_progress_bars()
+                return inner_training_loop(
+                    args=args,
+                    resume_from_checkpoint=resume_from_checkpoint,
+                    trial=trial,
+                    ignore_keys_for_eval=ignore_keys_for_eval,
+                )
+            finally:
+                hf_hub_utils.enable_progress_bars()
+        else:
+            return inner_training_loop(
+                args=args,
+                resume_from_checkpoint=resume_from_checkpoint,
+                trial=trial,
+                ignore_keys_for_eval=ignore_keys_for_eval,
+            )
+
+    def _inner_training_loop(
+        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+    ):
+        self.accelerator.free_memory()
+        self._train_batch_size = batch_size
+        if self.args.auto_find_batch_size:
+            if self.state.train_batch_size != self._train_batch_size:
+                from accelerate.utils import release_memory
+
+                (self.model_wrapped,) = release_memory(self.model_wrapped)
+                self.model_wrapped = self.model
+
+                # Check for DeepSpeed *after* the intial pass and modify the config
+                if self.is_deepspeed_enabled:
+                    # Temporarily unset `self.args.train_batch_size`
+                    original_bs = self.args.per_device_train_batch_size
+                    self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
+                    self.propagate_args_to_deepspeed(True)
+                    self.args.per_device_train_batch_size = original_bs
+            self.state.train_batch_size = self._train_batch_size
+        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
+        # Data loader and number of training steps
+        train_dataloader = self.get_train_dataloader()
+        if self.is_fsdp_xla_v2_enabled:
+            train_dataloader = tpu_spmd_dataloader(train_dataloader)
+
+        # Setting up training control variables:
+        # number of training epochs: num_train_epochs
+        # number of training steps per epoch: num_update_steps_per_epoch
+        # total number of training steps to execute: max_steps
+        total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
+
+        len_dataloader = None
+        num_train_tokens = None
+        if has_length(train_dataloader):
+            len_dataloader = len(train_dataloader)
+            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
+            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+            num_examples = self.num_examples(train_dataloader)
+            if args.max_steps > 0:
+                max_steps = args.max_steps
+                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+                    args.max_steps % num_update_steps_per_epoch > 0
+                )
+                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+                # the best we can do.
+                num_train_samples = args.max_steps * total_train_batch_size
+                if args.include_tokens_per_second:
+                    num_train_tokens = (
+                        self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+                    )
+            else:
+                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+                num_train_epochs = math.ceil(args.num_train_epochs)
+                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
+                if args.include_tokens_per_second:
+                    num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
+        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
+            max_steps = args.max_steps
+            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+            num_train_epochs = sys.maxsize
+            num_update_steps_per_epoch = max_steps
+            num_examples = total_train_batch_size * args.max_steps
+            num_train_samples = args.max_steps * total_train_batch_size
+            if args.include_tokens_per_second:
+                num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+        else:
+            raise ValueError(
+                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+                f" {args.max_steps}"
+            )
+
+        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+            if self.args.n_gpu > 1:
+                # nn.DataParallel(model) replicates the model, creating new variables and module
+                # references registered here no longer work on other gpus, breaking the module
+                raise ValueError(
+                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+                    " (torchrun or torch.distributed.launch (deprecated))."
+                )
+            else:
+                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
+
+        delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
+
+        # We need to reset the scheduler, as its parameters may be different on subsequent calls
+        if self._created_lr_scheduler:
+            self.lr_scheduler = None
+            self._created_lr_scheduler = False
+
+        if self.is_deepspeed_enabled:
+            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
+
+        if not delay_optimizer_creation:
+            self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+        self.state = TrainerState(
+            stateful_callbacks=[
+                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+            ]
+        )
+        self.state.is_hyper_param_search = trial is not None
+        self.state.train_batch_size = self._train_batch_size
+
+        # Compute absolute values for logging, eval, and save if given as ratio
+        if args.logging_steps is not None:
+            if args.logging_steps < 1:
+                self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
+            else:
+                self.state.logging_steps = args.logging_steps
+        if args.eval_steps is not None:
+            if args.eval_steps < 1:
+                self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
+            else:
+                self.state.eval_steps = args.eval_steps
+        if args.save_steps is not None:
+            if args.save_steps < 1:
+                self.state.save_steps = math.ceil(max_steps * args.save_steps)
+            else:
+                self.state.save_steps = args.save_steps
+
+        # Activate gradient checkpointing if needed
+        if args.gradient_checkpointing:
+            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
+
+        model = self._wrap_model(self.model_wrapped)
+
+        # as the model is wrapped, don't use `accelerator.prepare`
+        # this is for unhandled cases such as
+        # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
+        use_accelerator_prepare = True if model is self.model else False
+
+        if use_accelerator_prepare and self.is_fsdp_enabled:
+            # In case of auto_find_batch_size=True
+            # Remove FSDP wrapping from sub-models.
+            self.model = unwrap_model(self.model, recursive=True)
+
+        if delay_optimizer_creation:
+            if use_accelerator_prepare:
+                # configure fsdp plugin for qlora if any
+                self._fsdp_qlora_plugin_updates()
+                if self.accelerator.mixed_precision != "fp8":
+                    self.model = self.accelerator.prepare(self.model)
+            self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+        # prepare using `accelerator` prepare
+        if use_accelerator_prepare:
+            self.model.train()
+            if hasattr(self.lr_scheduler, "step"):
+                if self.use_apex:
+                    model = self.accelerator.prepare(self.model)
+                else:
+                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
+            else:
+                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
+                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
+                    self.model, self.optimizer, self.lr_scheduler
+                )
+        elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+            # In this case we are in DDP + LOMO, which should be supported
+            self.optimizer = self.accelerator.prepare(self.optimizer)
+
+        if self.is_fsdp_enabled:
+            self.model = self.model_wrapped = model
+
+        # for the rest of this function `model` is the outside model, whether it was wrapped or not
+        if model is not self.model:
+            self.model_wrapped = model
+
+        # backward compatibility
+        if self.is_deepspeed_enabled:
+            self.deepspeed = self.model_wrapped
+
+        # ckpt loading
+        if resume_from_checkpoint is not None:
+            if self.is_deepspeed_enabled:
+                deepspeed_load_checkpoint(
+                    self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
+                )
+            elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
+                self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
+
+        # Check if saved optimizer or scheduler states exist
+        self._load_optimizer_and_scheduler(resume_from_checkpoint)
+
+        # important: at this point:
+        # self.model         is the Transformers Model
+        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
+        # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
+
+        # Train!
+        logger.info("***** Running training *****")
+        logger.info(f"  Num examples = {num_examples:,}")
+        logger.info(f"  Num Epochs = {num_train_epochs:,}")
+        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
+        if self.args.per_device_train_batch_size != self._train_batch_size:
+            logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
+        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
+        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+        logger.info(f"  Total optimization steps = {max_steps:,}")
+        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
+
+        self.state.epoch = 0
+        start_time = time.time()
+        epochs_trained = 0
+        steps_trained_in_current_epoch = 0
+        steps_trained_progress_bar = None
+
+        # Check if continuing training from a checkpoint
+        if resume_from_checkpoint is not None and os.path.isfile(
+            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+        ):
+            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+            self.compare_trainer_and_checkpoint_args(self.args, self.state)
+            self._load_callback_state()
+            epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
+            if not args.ignore_data_skip:
+                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
+                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
+            else:
+                steps_trained_in_current_epoch = 0
+
+            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
+            logger.info(f"  Continuing training from epoch {epochs_trained}")
+            logger.info(f"  Continuing training from global step {self.state.global_step}")
+            if not args.ignore_data_skip:
+                logger.info(
+                    f"  Will skip the first {epochs_trained} epochs then the first"
+                    f" {steps_trained_in_current_epoch} batches in the first epoch."
+                )
+
+        # Update the references
+        self.callback_handler.model = self.model
+        self.callback_handler.optimizer = self.optimizer
+        self.callback_handler.lr_scheduler = self.lr_scheduler
+        self.callback_handler.train_dataloader = train_dataloader
+        if self.hp_name is not None and self._trial is not None:
+            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+            # parameter to Train when using DDP.
+            self.state.trial_name = self.hp_name(self._trial)
+        if trial is not None:
+            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
+            self.state.trial_params = hp_params(assignments)
+        else:
+            self.state.trial_params = None
+        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
+        # to set this after the load.
+        self.state.max_steps = max_steps
+        self.state.num_train_epochs = num_train_epochs
+        self.state.is_local_process_zero = self.is_local_process_zero()
+        self.state.is_world_process_zero = self.is_world_process_zero()
+
+        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+        tr_loss = torch.tensor(0.0).to(args.device)
+        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+        self._total_loss_scalar = 0.0
+        self._globalstep_last_logged = self.state.global_step
+        model.zero_grad()
+        grad_norm: Optional[float] = None
+        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+        if args.eval_on_start:
+            self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
+
+        for epoch in range(epochs_trained, num_train_epochs):
+            epoch_dataloader = train_dataloader
+            if hasattr(epoch_dataloader, "set_epoch"):
+                epoch_dataloader.set_epoch(epoch)
+
+            # Reset the past mems state at the beginning of each epoch if necessary.
+            if args.past_index >= 0:
+                self._past = None
+
+            steps_in_epoch = (
+                len(epoch_dataloader)
+                if len_dataloader is not None
+                else args.max_steps * args.gradient_accumulation_steps
+            )
+            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+                self._load_rng_state(resume_from_checkpoint)
+
+            rng_to_sync = False
+            steps_skipped = 0
+            if steps_trained_in_current_epoch > 0:
+                epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
+                steps_skipped = steps_trained_in_current_epoch
+                steps_trained_in_current_epoch = 0
+                rng_to_sync = True
+
+            step = -1
+            epoch_iterator = iter(epoch_dataloader)
+            # We chunkify the epoch iterator into gradient accumulation steps `n` batches
+            remainder = num_examples % args.gradient_accumulation_steps
+            if remainder == 0:
+                remainder = args.gradient_accumulation_steps
+            update_step = -1
+            total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
+            for _ in range(total_updates):
+                update_step += 1
+                num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
+                batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
+                for i, inputs in enumerate(batch_samples):
+                    step += 1
+                    do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
+                    # Since we perform prefetching, we need to manually set sync_gradients
+                    if not do_sync_step:
+                        self.accelerator.gradient_state._set_sync_gradients(False)
+                    else:
+                        self.accelerator.gradient_state._set_sync_gradients(True)
+
+                    if self.args.include_num_input_tokens_seen:
+                        main_input_name = getattr(self.model, "main_input_name", "input_ids")
+                        if main_input_name not in inputs:
+                            logger.warning(
+                                "Tried to track the number of tokens seen, however the current model is "
+                                "not configured properly to know what item is the input. To fix this, add "
+                                "a `main_input_name` attribute to the model class you are using."
+                            )
+                        else:
+                            input_tokens = inputs[main_input_name].numel()
+                            input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
+                            self.state.num_input_tokens_seen += (
+                                self.accelerator.gather(input_tokens).sum().cpu().item()
+                            )
+                    if rng_to_sync:
+                        self._load_rng_state(resume_from_checkpoint)
+                        rng_to_sync = False
+
+                    # Skip past any already trained steps if resuming training
+                    if steps_trained_in_current_epoch > 0:
+                        steps_trained_in_current_epoch -= 1
+                        if steps_trained_progress_bar is not None:
+                            steps_trained_progress_bar.update(1)
+                        if steps_trained_in_current_epoch == 0:
+                            self._load_rng_state(resume_from_checkpoint)
+                        continue
+                    elif steps_trained_progress_bar is not None:
+                        steps_trained_progress_bar.close()
+                        steps_trained_progress_bar = None
+
+                    if step % args.gradient_accumulation_steps == 0:
+                        self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+                    # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
+                    context = (
+                        functools.partial(self.accelerator.no_sync, model=model)
+                        if i != len(batch_samples) - 1
+                        and self.accelerator.distributed_type != DistributedType.DEEPSPEED
+                        else contextlib.nullcontext
+                    )
+                    with context():
+                        tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
+
+                    if (
+                        args.logging_nan_inf_filter
+                        and not is_torch_xla_available()
+                        and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+                    ):
+                        # if loss is nan or inf simply add the average of previous logged losses
+                        tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+                    else:
+                        if tr_loss.device != tr_loss_step.device:
+                            raise ValueError(
+                                f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
+                            )
+                        tr_loss = tr_loss + tr_loss_step
+
+                    self.current_flos += float(self.floating_point_ops(inputs))
+
+                    if do_sync_step:
+                        # Since we perform prefetching, we need to manually set sync_gradients to True
+                        self.accelerator.gradient_state._set_sync_gradients(True)
+
+                        # Gradient clipping
+                        if args.max_grad_norm is not None and args.max_grad_norm > 0:
+                            # deepspeed does its own clipping
+
+                            if is_sagemaker_mp_enabled() and args.fp16:
+                                _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
+                            elif self.use_apex:
+                                # Revert to normal clipping otherwise, handling Apex or full precision
+                                _grad_norm = nn.utils.clip_grad_norm_(
+                                    amp.master_params(self.optimizer),
+                                    args.max_grad_norm,
+                                )
+                            else:
+                                _grad_norm = self.accelerator.clip_grad_norm_(
+                                    model.parameters(),
+                                    args.max_grad_norm,
+                                )
+
+                            if (
+                                is_accelerate_available()
+                                and self.accelerator.distributed_type == DistributedType.DEEPSPEED
+                            ):
+                                grad_norm = model.get_global_grad_norm()
+                                # In some cases the grad norm may not return a float
+                                if hasattr(grad_norm, "item"):
+                                    grad_norm = grad_norm.item()
+                            else:
+                                grad_norm = _grad_norm
+
+                        self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
+
+                        self.optimizer.step()
+
+                        self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
+
+                        optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
+                        if optimizer_was_run:
+                            # Delay optimizer scheduling until metrics are generated
+                            if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+                                self.lr_scheduler.step()
+
+                        model.zero_grad()
+                        self.state.global_step += 1
+                        self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
+                        self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+                        self._maybe_log_save_evaluate(
+                            tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
+                        )
+                    else:
+                        self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+                    # PyTorch/XLA relies on the data loader to insert the mark_step for
+                    # each step. Since we are breaking the loop early, we need to manually
+                    # insert the mark_step here.
+                    if self.control.should_epoch_stop or self.control.should_training_stop:
+                        if is_torch_xla_available():
+                            xm.mark_step()
+                        break
+                # We also need to break out of the nested loop
+                if self.control.should_epoch_stop or self.control.should_training_stop:
+                    if is_torch_xla_available():
+                        xm.mark_step()
+                    break
+            if step < 0:
+                logger.warning(
+                    "There seems not to be a single sample in your epoch_iterator, stopping training at step"
+                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+                    f" num_steps ({max_steps}) higher than the number of available samples."
+                )
+                self.control.should_training_stop = True
+
+            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+            self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
+
+            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+                if is_torch_xla_available():
+                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+                    xm.master_print(met.metrics_report())
+                else:
+                    logger.warning(
+                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+                        "configured. Check your training configuration if this is unexpected."
+                    )
+            if self.control.should_training_stop:
+                break
+
+        if args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of training
+            delattr(self, "_past")
+
+        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+            # Wait for everyone to get here so we are sure the model has been saved by process 0.
+            if is_torch_xla_available():
+                xm.rendezvous("load_best_model_at_end")
+            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
+                dist.barrier()
+            elif is_sagemaker_mp_enabled():
+                smp.barrier()
+
+            self._load_best_model()
+
+        # add remaining tr_loss
+        self._total_loss_scalar += tr_loss.item()
+        effective_global_step = max(self.state.global_step, 0.001)  # Avoid ZeroDivisionError
+        train_loss = self._total_loss_scalar / effective_global_step
+
+        metrics = speed_metrics(
+            "train",
+            start_time,
+            num_samples=num_train_samples,
+            num_steps=self.state.max_steps,
+            num_tokens=num_train_tokens,
+        )
+        self.store_flos()
+        metrics["total_flos"] = self.state.total_flos
+        metrics["train_loss"] = train_loss
+
+        self.is_in_train = False
+
+        self._memory_tracker.stop_and_update_metrics(metrics)
+
+        self.log(metrics)
+
+        run_dir = self._get_output_dir(trial)
+        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+
+        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+            for checkpoint in checkpoints_sorted:
+                if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
+                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+                    shutil.rmtree(checkpoint, ignore_errors=True)
+
+        self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+        # Wait for the checkpoint to be uploaded.
+        self._finish_current_push()
+
+        # After training we make sure to retrieve back the original forward pass method
+        # for the embedding layer by removing the forward post hook.
+        if self.neftune_noise_alpha is not None:
+            self._deactivate_neftune(self.model)
+
+        return TrainOutput(self.state.global_step, train_loss, metrics)
+
+    def _get_output_dir(self, trial):
+        if self.hp_search_backend is not None and trial is not None:
+            if self.hp_search_backend == HPSearchBackend.OPTUNA:
+                run_id = trial.number
+            elif self.hp_search_backend == HPSearchBackend.RAY:
+                import ray.train
+
+                run_id = ray.train.get_context().get_trial_id()
+            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+                run_id = trial.id
+            elif self.hp_search_backend == HPSearchBackend.WANDB:
+                import wandb
+
+                run_id = wandb.run.id
+            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
+            run_dir = os.path.join(self.args.output_dir, run_name)
+        else:
+            run_dir = self.args.output_dir
+        return run_dir
+
+    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+        if model is None:
+            model = self.model
+
+        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
+        adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
+        adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
+        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
+        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
+        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
+        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
+        is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
+            # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
+            any(
+                FSDP_MODEL_NAME in folder_name
+                for folder_name in os.listdir(resume_from_checkpoint)
+                if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
+            )
+            # this checks the FSDP state dict when `FULL_STATE_DICT` is used
+            or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
+        )
+        # if multiple adapters exist, they get saved in sub directories
+        adapter_subdirs = (
+            [
+                folder_name
+                for folder_name in os.listdir(resume_from_checkpoint)
+                if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
+                and (
+                    os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
+                    or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
+                )
+            ]
+            if os.path.isdir(resume_from_checkpoint)
+            else []
+        )
+
+        if is_fsdp_ckpt and not self.is_fsdp_enabled:
+            raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
+
+        if not (
+            any(
+                os.path.isfile(f)
+                for f in [
+                    weights_file,
+                    safe_weights_file,
+                    weights_index_file,
+                    safe_weights_index_file,
+                    adapter_weights_file,
+                    adapter_safe_weights_file,
+                ]
+            )
+            or is_fsdp_ckpt
+            or adapter_subdirs
+        ):
+            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
+
+        logger.info(f"Loading model from {resume_from_checkpoint}.")
+
+        if os.path.isfile(config_file):
+            config = PretrainedConfig.from_json_file(config_file)
+            checkpoint_version = config.transformers_version
+            if checkpoint_version is not None and checkpoint_version != __version__:
+                logger.warning(
+                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
+                    f"Transformers but your current version is {__version__}. This is not recommended and could "
+                    "yield to errors or unwanted behaviors."
+                )
+
+        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
+            weights_only_kwarg = {"weights_only": True}
+            # If the model is on the GPU, it still works!
+            if is_sagemaker_mp_enabled():
+                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
+                    # If the 'user_content.pt' file exists, load with the new smp api.
+                    # Checkpoint must have been saved with the new smp api.
+                    smp.resume_from_checkpoint(
+                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
+                    )
+                else:
+                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+                    # Checkpoint must have been saved with the old smp api.
+                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
+                        logger.warning(
+                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
+                        )
+                    state_dict = torch.load(
+                        weights_file,
+                        map_location="cpu",
+                        **weights_only_kwarg,
+                    )
+                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
+                    state_dict["_smp_is_partial"] = False
+                    load_result = model.load_state_dict(state_dict, strict=True)
+                    # release memory
+                    del state_dict
+            elif self.is_fsdp_enabled:
+                load_fsdp_model(
+                    self.accelerator.state.fsdp_plugin,
+                    self.accelerator,
+                    model,
+                    resume_from_checkpoint,
+                    **_get_fsdp_ckpt_kwargs(),
+                )
+            else:
+                # We load the model state dict on the CPU to avoid an OOM error.
+                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
+                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
+                else:
+                    state_dict = torch.load(
+                        weights_file,
+                        map_location="cpu",
+                        **weights_only_kwarg,
+                    )
+
+                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+                # which takes *args instead of **kwargs
+                load_result = model.load_state_dict(state_dict, False)
+                # release memory
+                del state_dict
+                self._issue_warnings_after_load(load_result)
+
+        # Load adapters following PR # 24096
+        elif _is_peft_model(model):
+            # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
+            # TODO: in the future support only specific min PEFT versions
+            if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
+                model, "load_adapter"
+            ):
+                if os.path.exists(resume_from_checkpoint):
+                    # For BC for older PEFT versions
+                    if hasattr(model, "active_adapters"):
+                        active_adapters = model.active_adapters
+                        if len(active_adapters) > 1:
+                            logger.warning("Multiple active adapters detected will only consider the first adapter")
+                        active_adapter = active_adapters[0]
+                    else:
+                        active_adapter = model.active_adapter
+
+                    if adapter_subdirs:
+                        for subdir_name in adapter_subdirs:
+                            peft_id = os.path.join(resume_from_checkpoint, subdir_name)
+                            model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
+                        model.set_adapter(active_adapter)
+                    else:
+                        model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
+                else:
+                    logger.warning(
+                        "The intermediate checkpoints of PEFT may not be saved correctly, "
+                        f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
+                        "Check some examples here: https://github.com/huggingface/peft/issues/96"
+                    )
+            else:
+                logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
+        else:
+            # We load the sharded checkpoint
+            load_result = load_sharded_checkpoint(
+                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
+            )
+            if not is_sagemaker_mp_enabled():
+                self._issue_warnings_after_load(load_result)
+
+    def _load_best_model(self):
+        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
+        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
+        best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
+        best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
+
+        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+        if self.is_deepspeed_enabled:
+            deepspeed_load_checkpoint(
+                self.model_wrapped,
+                self.state.best_model_checkpoint,
+                load_module_strict=not _is_peft_model(self.model),
+            )
+        elif self.is_fsdp_enabled:
+            load_result = load_fsdp_model(
+                self.accelerator.state.fsdp_plugin,
+                self.accelerator,
+                model,
+                self.state.best_model_checkpoint,
+                **_get_fsdp_ckpt_kwargs(),
+            )
+        elif (
+            os.path.exists(best_model_path)
+            or os.path.exists(best_safe_model_path)
+            or os.path.exists(best_adapter_model_path)
+            or os.path.exists(best_safe_adapter_model_path)
+        ):
+            has_been_loaded = True
+            weights_only_kwarg = {"weights_only": True}
+            if is_sagemaker_mp_enabled():
+                if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
+                    # If the 'user_content.pt' file exists, load with the new smp api.
+                    # Checkpoint must have been saved with the new smp api.
+                    smp.resume_from_checkpoint(
+                        path=self.state.best_model_checkpoint,
+                        tag=WEIGHTS_NAME,
+                        partial=False,
+                        load_optimizer=False,
+                    )
+                else:
+                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+                    # Checkpoint must have been saved with the old smp api.
+                    if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+                        state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+                    else:
+                        state_dict = torch.load(
+                            best_model_path,
+                            map_location="cpu",
+                            **weights_only_kwarg,
+                        )
+
+                    state_dict["_smp_is_partial"] = False
+                    load_result = model.load_state_dict(state_dict, strict=True)
+            else:
+                if _is_peft_model(model):
+                    # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
+                    # TODO: in the future support only specific min PEFT versions
+                    if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
+                        model, "load_adapter"
+                    ):
+                        # For BC for older PEFT versions
+                        if hasattr(model, "active_adapters"):
+                            active_adapter = model.active_adapters[0]
+                            if len(model.active_adapters) > 1:
+                                logger.warning("Detected multiple active adapters, will only consider the first one")
+                        else:
+                            active_adapter = model.active_adapter
+
+                        if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
+                            try:
+                                model.load_adapter(self.state.best_model_checkpoint, active_adapter)
+                            except RuntimeError as exc:
+                                if model.peft_config[active_adapter].is_prompt_learning:
+                                    # for context: https://github.com/huggingface/peft/issues/2256
+                                    msg = (
+                                        "When using prompt learning PEFT methods such as "
+                                        f"{model.peft_config[active_adapter].peft_type.value}, setting "
+                                        "load_best_model_at_end=True can lead to errors, it is recommended "
+                                        "to set this to False and to load the model manually from the checkpoint "
+                                        "directory using PeftModel.from_pretrained(base_model, ) after training "
+                                        "has finished."
+                                    )
+                                    raise RuntimeError(msg) from exc
+                                else:
+                                    raise
+                            # Load_adapter has no return value present, modify it when appropriate.
+                            from torch.nn.modules.module import _IncompatibleKeys
+
+                            load_result = _IncompatibleKeys([], [])
+                        else:
+                            logger.warning(
+                                "The intermediate checkpoints of PEFT may not be saved correctly, "
+                                f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
+                                "Check some examples here: https://github.com/huggingface/peft/issues/96"
+                            )
+                            has_been_loaded = False
+                    else:
+                        logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
+                        has_been_loaded = False
+                else:
+                    # We load the model state dict on the CPU to avoid an OOM error.
+                    if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+                        state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+                    else:
+                        state_dict = torch.load(
+                            best_model_path,
+                            map_location="cpu",
+                            **weights_only_kwarg,
+                        )
+
+                    # If the model is on the GPU, it still works!
+                    # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+                    # which takes *args instead of **kwargs
+                    load_result = model.load_state_dict(state_dict, False)
+                if not is_sagemaker_mp_enabled() and has_been_loaded:
+                    self._issue_warnings_after_load(load_result)
+        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
+            os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
+        ):
+            load_result = load_sharded_checkpoint(
+                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
+            )
+            if not is_sagemaker_mp_enabled():
+                self._issue_warnings_after_load(load_result)
+        else:
+            logger.warning(
+                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
+                "on multiple nodes, you should activate `--save_on_each_node`."
+            )
+
+    def _issue_warnings_after_load(self, load_result):
+        if len(load_result.missing_keys) != 0:
+            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
+                self.model._keys_to_ignore_on_save
+            ):
+                self.model.tie_weights()
+            else:
+                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
+        if len(load_result.unexpected_keys) != 0:
+            logger.warning(
+                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
+            )
+
+    def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
+        metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
+        self._report_to_hp_search(trial, self.state.global_step, metrics)
+
+        # Run delayed LR scheduler now that metrics are populated
+        if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler:
+            metric_to_check = self.args.metric_for_best_model
+            if not metric_to_check.startswith("eval_"):
+                metric_to_check = f"eval_{metric_to_check}"
+            try:
+                self.lr_scheduler.step(metrics[metric_to_check])
+            except KeyError as exc:
+                raise KeyError(
+                    f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
+                    f"which is not found in the evaluation metrics. "
+                    f"The available evaluation metrics are: {list(metrics.keys())}. "
+                    f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
+                    f"consider changing the `metric_for_best_model` via the TrainingArguments."
+                ) from exc
+        return metrics
+
+    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
+        if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
+            if is_torch_xla_available():
+                xm.mark_step()
+
+            logs: Dict[str, float] = {}
+
+            # all_gather + mean() to get average loss over all processes
+            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+
+            # reset tr_loss to zero
+            tr_loss -= tr_loss
+
+            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+            if grad_norm is not None:
+                logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
+            logs["learning_rate"] = self._get_learning_rate()
+
+            self._total_loss_scalar += tr_loss_scalar
+            self._globalstep_last_logged = self.state.global_step
+            self.store_flos()
+
+            self.log(logs, start_time)
+
+        metrics = None
+        if self.control.should_evaluate:
+            metrics = self._evaluate(trial, ignore_keys_for_eval)
+            is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
+
+            if self.args.save_strategy == SaveStrategy.BEST:
+                self.control.should_save = is_new_best_metric
+
+        if self.control.should_save:
+            self._save_checkpoint(model, trial)
+            self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+    def _load_rng_state(self, checkpoint):
+        # Load RNG states from `checkpoint`
+        if checkpoint is None:
+            return
+
+        if self.args.world_size > 1:
+            process_index = self.args.process_index
+            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
+            if not os.path.isfile(rng_file):
+                logger.info(
+                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
+                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
+                )
+                return
+        else:
+            rng_file = os.path.join(checkpoint, "rng_state.pth")
+            if not os.path.isfile(rng_file):
+                logger.info(
+                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
+                    "fashion, reproducibility is not guaranteed."
+                )
+                return
+
+        with safe_globals():
+            checkpoint_rng_state = torch.load(rng_file)
+        random.setstate(checkpoint_rng_state["python"])
+        np.random.set_state(checkpoint_rng_state["numpy"])
+        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
+        if torch.cuda.is_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
+            else:
+                try:
+                    torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
+                except Exception as e:
+                    logger.info(
+                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
+                        "\nThis won't yield the same results as if the training had not been interrupted."
+                    )
+        if is_torch_xla_available():
+            xm.set_rng_state(checkpoint_rng_state["xla"])
+        if is_torch_npu_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
+            else:
+                try:
+                    torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
+                except Exception as e:
+                    logger.info(
+                        f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
+                        "\nThis won't yield the same results as if the training had not been interrupted."
+                    )
+        if is_torch_mlu_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"])
+            else:
+                try:
+                    torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"])
+                except Exception as e:
+                    logger.info(
+                        f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
+                        "\nThis won't yield the same results as if the training had not been interrupted."
+                    )
+        if is_torch_musa_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                torch.musa.set_rng_state_all(checkpoint_rng_state["musa"])
+            else:
+                try:
+                    torch.musa.set_rng_state(checkpoint_rng_state["musa"])
+                except Exception as e:
+                    logger.info(
+                        f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}"
+                        "\nThis won't yield the same results as if the training had not been interrupted."
+                    )
+
+    def _determine_best_metric(self, metrics, trial):
+        """
+        Determine if the model should be saved based on the evaluation metrics.
+
+        Returns:
+            bool: True if a new best metric was found, else False
+        """
+        is_new_best_metric = False
+
+        if self.args.metric_for_best_model is not None:
+            metric_to_check = self.args.metric_for_best_model
+
+            if not metric_to_check.startswith("eval_"):
+                metric_to_check = f"eval_{metric_to_check}"
+
+            try:
+                metric_value = metrics[metric_to_check]
+            except KeyError as exc:
+                raise KeyError(
+                    f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
+                    f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
+                ) from exc
+
+            operator = np.greater if self.args.greater_is_better else np.less
+
+            if self.state.best_metric is None:
+                self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
+
+            if operator(metric_value, self.state.best_metric):
+                run_dir = self._get_output_dir(trial=trial)
+                checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+                output_dir = os.path.join(run_dir, checkpoint_folder)
+
+                self.state.best_metric = metric_value
+                self.state.best_model_checkpoint = output_dir
+
+                is_new_best_metric = True
+
+        return is_new_best_metric
+
+    def _save_checkpoint(self, model, trial):
+        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
+        # want to save except FullyShardedDDP.
+        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
+
+        # Save model checkpoint
+        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+        if self.hp_search_backend is None and trial is None:
+            self.store_flos()
+
+        run_dir = self._get_output_dir(trial=trial)
+        output_dir = os.path.join(run_dir, checkpoint_folder)
+        self.save_model(output_dir, _internal_call=True)
+
+        if not self.args.save_only_model:
+            # Save optimizer and scheduler
+            self._save_optimizer_and_scheduler(output_dir)
+            # Save RNG state
+            self._save_rng_state(output_dir)
+
+        # Save the Trainer state
+        if self.args.should_save:
+            # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
+            for cb in [
+                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+            ]:
+                cb_name = cb.__class__.__name__
+                cb_state = cb.state()
+                if isinstance(self.state.stateful_callbacks[cb_name], list):
+                    self.state.stateful_callbacks[cb_name].append(cb_state)
+                else:
+                    self.state.stateful_callbacks[cb_name] = cb_state
+            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+
+        if self.args.push_to_hub:
+            self._push_from_checkpoint(output_dir)
+
+        # Maybe delete some older checkpoints.
+        if self.args.should_save:
+            # Solely rely on numerical checkpoint id for rotation.
+            # mtime is not reliable especially on some fuse fs in cloud environments.
+            self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
+
+    def _save_rng_state(self, output_dir):
+        # Save RNG state in non-distributed training
+        rng_states = {
+            "python": random.getstate(),
+            "numpy": np.random.get_state(),
+            "cpu": torch.random.get_rng_state(),
+        }
+        if torch.cuda.is_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
+                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
+            else:
+                rng_states["cuda"] = torch.cuda.random.get_rng_state()
+
+        if is_torch_xla_available():
+            rng_states["xla"] = xm.get_rng_state()
+
+        if is_torch_npu_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                rng_states["npu"] = torch.npu.random.get_rng_state_all()
+            else:
+                rng_states["npu"] = torch.npu.random.get_rng_state()
+
+        if is_torch_mlu_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
+            else:
+                rng_states["mlu"] = torch.mlu.random.get_rng_state()
+
+        if is_torch_musa_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                rng_states["musa"] = torch.musa.get_rng_state_all()
+            else:
+                rng_states["musa"] = torch.musa.get_rng_state()
+
+        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
+        # not yet exist.
+        os.makedirs(output_dir, exist_ok=True)
+
+        if self.args.world_size <= 1:
+            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
+        else:
+            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
+
+    def _save_optimizer_and_scheduler(self, output_dir):
+        if is_torch_xla_available():
+            xm.rendezvous("saving_optimizer_states")
+            if self.is_fsdp_xla_v1_enabled:
+                optm = {
+                    "optimizer": self.optimizer.state_dict(),
+                    "shard_metadata": self.model.get_shard_metadata(),
+                }
+                xm.save(
+                    optm,
+                    os.path.join(
+                        output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
+                    ),
+                    master_only=False,
+                )
+            else:
+                xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+            with warnings.catch_warnings(record=True) as caught_warnings:
+                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+                reissue_pt_warnings(caught_warnings)
+        elif is_sagemaker_mp_enabled():
+            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
+            smp.barrier()
+            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
+                smp.save(
+                    opt_state_dict,
+                    os.path.join(output_dir, OPTIMIZER_NAME),
+                    partial=True,
+                    v3=smp.state.cfg.shard_optimizer_state,
+                )
+        elif self.is_deepspeed_enabled:
+            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
+            # config `stage3_gather_16bit_weights_on_model_save` is True
+            accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
+                inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
+            )
+            if accept_exclude_frozen_parameters and _is_peft_model(self.model):
+                self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
+            else:
+                self.model_wrapped.save_checkpoint(output_dir)
+        elif self.is_fsdp_enabled:
+            # save fsdp specific ckpt for resuming from ckpt
+            save_fsdp_model(
+                self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
+            )
+            save_fsdp_optimizer(
+                self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
+            )
+        elif self.args.should_save:
+            # deepspeed.save_checkpoint above saves model/optim/sched
+            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+
+        # Save SCHEDULER & SCALER
+        is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
+            self.lr_scheduler, DeepSpeedSchedulerWrapper
+        )
+        if (
+            self.args.should_save
+            and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
+            and not is_torch_xla_available()
+        ):
+            with warnings.catch_warnings(record=True) as caught_warnings:
+                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+            reissue_pt_warnings(caught_warnings)
+
+    def _load_optimizer_and_scheduler(self, checkpoint):
+        """If optimizer and scheduler states exist, load them."""
+        if checkpoint is None:
+            return
+
+        if self.is_deepspeed_enabled:
+            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
+            if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
+                reissue_pt_warnings(caught_warnings)
+            return
+
+        checkpoint_file_exists = (
+            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
+            if is_sagemaker_mp_enabled()
+            else (
+                os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
+                or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
+                or (
+                    os.path.isdir(checkpoint)
+                    and any(
+                        OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
+                        for folder_name in os.listdir(checkpoint)
+                        if os.path.isdir(os.path.join(checkpoint, folder_name))
+                    )
+                )
+            )
+        )
+        checkpoint_file_exists = (
+            glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
+            if self.is_fsdp_xla_v1_enabled
+            else checkpoint_file_exists
+        )
+        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
+            # Load in optimizer and scheduler states
+            if is_torch_xla_available():
+                # On TPU we have to take some extra precautions to properly load the states on the right device.
+                if self.is_fsdp_xla_v1_enabled:
+                    optimizer_state = torch.load(
+                        os.path.join(
+                            checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
+                        ),
+                        map_location="cpu",
+                    )
+                    # We only need `optimizer` when resuming from checkpoint
+                    optimizer_state = optimizer_state["optimizer"]
+                else:
+                    optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
+                reissue_pt_warnings(caught_warnings)
+
+                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
+                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
+
+                self.optimizer.load_state_dict(optimizer_state)
+                self.lr_scheduler.load_state_dict(lr_scheduler_state)
+            else:
+                if is_sagemaker_mp_enabled():
+                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
+                        # Optimizer checkpoint was saved with smp >= 1.10
+                        def opt_load_hook(mod, opt):
+                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+                    else:
+                        # Optimizer checkpoint was saved with smp < 1.10
+                        def opt_load_hook(mod, opt):
+                            if IS_SAGEMAKER_MP_POST_1_10:
+                                opt.load_state_dict(
+                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
+                                )
+                            else:
+                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+                    self.model_wrapped.register_post_step_hook(opt_load_hook)
+                else:
+                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
+                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
+                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
+                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
+                    if self.is_fsdp_enabled:
+                        load_fsdp_optimizer(
+                            self.accelerator.state.fsdp_plugin,
+                            self.accelerator,
+                            self.optimizer,
+                            self.model,
+                            checkpoint,
+                            **_get_fsdp_ckpt_kwargs(),
+                        )
+                    else:
+                        self.optimizer.load_state_dict(
+                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
+                        )
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
+                reissue_pt_warnings(caught_warnings)
+
+    def _load_callback_state(self):
+        """If callback states exist and were passed in, restore their states if enabled"""
+        if not self.args.restore_callback_states_from_checkpoint:
+            return
+        # Callback states are stored in stateful_callbacks
+        not_found = []
+        new_callbacks = []
+        original_callbacks = self.callback_handler.callbacks + [self.control]
+        for stored_callback, data in self.state.stateful_callbacks.items():
+            if not isinstance(data, list):
+                data = [data]
+            if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
+                # We can load/restore from multiple callbacks of the same type.
+                duplicates = [
+                    callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
+                ]
+                for callback, callback_data in zip(duplicates, data):
+                    args = callback_data.get("args", {})
+                    attributes = callback_data.get("attributes", {})
+                    new_callback = type(callback)(**args)
+                    for attribute, value in attributes.items():
+                        setattr(new_callback, attribute, value)
+                    if isinstance(callback, TrainerControl):
+                        # Specifically for restoring the `control` state
+                        self.control = new_callback
+                    else:
+                        new_callbacks.append(new_callback)
+                    # We remove the existing callback and add it to the list of new callbacks
+                    self.callback_handler.remove_callback(type(new_callback))
+                logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
+            else:
+                not_found.append(stored_callback)
+        if len(not_found) > 0:
+            logger.warning(
+                f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
+            )
+        for callback in new_callbacks:
+            self.callback_handler.add_callback(callback)
+
+    def hyperparameter_search(
+        self,
+        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
+        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
+        n_trials: int = 20,
+        direction: Union[str, List[str]] = "minimize",
+        backend: Optional[Union["str", HPSearchBackend]] = None,
+        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
+        **kwargs,
+    ) -> Union[BestRun, List[BestRun]]:
+        """
+        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
+        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
+        the sum of all metrics otherwise.
+
+        
+
+        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
+        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
+        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
+        optimizer/scheduler.
+
+        
+
+        Args:
+            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
+                A function that defines the hyperparameter search space. Will default to
+                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
+                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
+            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
+                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
+                method. Will default to [`~trainer_utils.default_compute_objective`].
+            n_trials (`int`, *optional*, defaults to 100):
+                The number of trial runs to test.
+            direction (`str` or `List[str]`, *optional*, defaults to `"minimize"`):
+                If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you
+                should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or
+                several metrics. If it's multi objectives optimization, direction is `List[str]`, can be List of
+                `"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss,
+                `"maximize"` when optimizing one or several metrics.
+            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
+                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
+                on which one is installed. If all are installed, will default to optuna.
+            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
+                A function that defines the trial/run name. Will default to None.
+            kwargs (`Dict[str, Any]`, *optional*):
+                Additional keyword arguments for each backend:
+
+                - `optuna`: parameters from
+                  [optuna.study.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
+                  and also the parameters `timeout`, `n_jobs` and `gc_after_trial` from
+                  [optuna.study.Study.optimize](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize)
+                - `ray`: parameters from [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run).
+                  If `resources_per_trial` is not set in the `kwargs`, it defaults to 1 CPU core and 1 GPU (if available).
+                  If `progress_reporter` is not set in the `kwargs`,
+                  [ray.tune.CLIReporter](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html) is used.
+                - `sigopt`: the parameter `proxies` from
+                  [sigopt.Connection.set_proxies](https://docs.sigopt.com/support/faq#how-do-i-use-sigopt-with-a-proxy).
+
+        Returns:
+            [`trainer_utils.BestRun` or `List[trainer_utils.BestRun]`]: All the information about the best run or best
+            runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray
+            backend.
+        """
+        if backend is None:
+            backend = default_hp_search_backend()
+        backend = HPSearchBackend(backend)
+        backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
+        backend_obj.ensure_available()
+        self.hp_search_backend = backend
+        if self.model_init is None:
+            raise RuntimeError(
+                "To use hyperparameter search, you need to pass your model through a model_init function."
+            )
+
+        self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
+        self.hp_name = hp_name
+        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
+
+        best_run = backend_obj.run(self, n_trials, direction, **kwargs)
+
+        self.hp_search_backend = None
+        return best_run
+
+    def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
+        """
+        Log `logs` on the various objects watching training.
+
+        Subclass and override this method to inject custom behavior.
+
+        Args:
+            logs (`Dict[str, float]`):
+                The values to log.
+            start_time (`Optional[float]`):
+                The start of training.
+        """
+        if self.state.epoch is not None:
+            logs["epoch"] = self.state.epoch
+        if self.args.include_num_input_tokens_seen:
+            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
+            if start_time is not None:
+                speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen)
+
+        output = {**logs, **{"step": self.state.global_step}}
+        self.state.log_history.append(output)
+        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
+
+    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
+        """
+        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+        """
+        if isinstance(data, Mapping):
+            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
+        elif isinstance(data, (tuple, list)):
+            return type(data)(self._prepare_input(v) for v in data)
+        elif isinstance(data, torch.Tensor):
+            kwargs = {"device": self.args.device}
+            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
+                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
+                # embedding. Other models such as wav2vec2's inputs are already float and thus
+                # may need special handling to match the dtypes of the model
+                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
+            return data.to(**kwargs)
+        return data
+
+    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
+        """
+        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
+        handling potential state.
+        """
+        inputs = self._prepare_input(inputs)
+        if len(inputs) == 0:
+            raise ValueError(
+                "The batch received was empty, your model won't be able to train on it. Double-check that your "
+                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
+            )
+        if self.args.past_index >= 0 and self._past is not None:
+            inputs["mems"] = self._past
+
+        return inputs
+
+    def compute_loss_context_manager(self):
+        """
+        A helper wrapper to group together context managers.
+        """
+        return self.autocast_smart_context_manager()
+
+    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
+        """
+        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
+        arguments, depending on the situation.
+        """
+        if self.use_cpu_amp:
+            ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
+        else:
+            ctx_manager = contextlib.nullcontext()
+
+        return ctx_manager
+
+    def training_step(
+        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
+    ) -> torch.Tensor:
+        """
+        Perform a training step on a batch of inputs.
+
+        Subclass and override to inject custom behavior.
+
+        Args:
+            model (`nn.Module`):
+                The model to train.
+            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+
+                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+                argument `labels`. Check your model's documentation for all accepted arguments.
+
+        Return:
+            `torch.Tensor`: The tensor with training loss on this batch.
+        """
+        model.train()
+        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
+            self.optimizer.train()
+
+        inputs = self._prepare_inputs(inputs)
+        if is_sagemaker_mp_enabled():
+            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
+            return loss_mb.reduce_mean().detach().to(self.args.device)
+
+        with self.compute_loss_context_manager():
+            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
+
+        del inputs
+        if (
+            self.args.torch_empty_cache_steps is not None
+            and self.state.global_step % self.args.torch_empty_cache_steps == 0
+        ):
+            if is_torch_xpu_available():
+                torch.xpu.empty_cache()
+            elif is_torch_mlu_available():
+                torch.mlu.empty_cache()
+            elif is_torch_musa_available():
+                torch.musa.empty_cache()
+            elif is_torch_npu_available():
+                torch.npu.empty_cache()
+            elif is_torch_mps_available(min_version="2.0"):
+                torch.mps.empty_cache()
+            else:
+                torch.cuda.empty_cache()
+
+        kwargs = {}
+
+        # For LOMO optimizers you need to explicitly use the learnign rate
+        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+            kwargs["learning_rate"] = self._get_learning_rate()
+
+        if self.args.n_gpu > 1:
+            loss = loss.mean()  # mean() to average on multi-gpu parallel training
+
+        if self.use_apex:
+            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+                scaled_loss.backward()
+        else:
+            # Finally we need to normalize the loss for reporting
+            if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
+                loss = loss / self.args.gradient_accumulation_steps
+
+            self.accelerator.backward(loss, **kwargs)
+
+            return loss.detach()
+
+    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+        """
+        How the loss is computed by Trainer. By default, all models return the loss in the first element.
+
+        Subclass and override for custom behavior.
+        """
+        if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
+            labels = inputs.pop("labels")
+        else:
+            labels = None
+        if self.model_accepts_loss_kwargs:
+            loss_kwargs = {}
+            if num_items_in_batch is not None:
+                loss_kwargs["num_items_in_batch"] = num_items_in_batch
+            inputs = {**inputs, **loss_kwargs}
+        outputs = model(**inputs)
+        # Save past state if it exists
+        # TODO: this needs to be fixed and made cleaner later.
+        if self.args.past_index >= 0:
+            self._past = outputs[self.args.past_index]
+
+        if labels is not None:
+            unwrapped_model = self.accelerator.unwrap_model(model)
+            if _is_peft_model(unwrapped_model):
+                model_name = unwrapped_model.base_model.model._get_name()
+            else:
+                model_name = unwrapped_model._get_name()
+            # User-defined compute_loss function
+            if self.compute_loss_func is not None:
+                loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
+            elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+                loss = self.label_smoother(outputs, labels, shift_labels=True)
+            else:
+                loss = self.label_smoother(outputs, labels)
+        else:
+            if isinstance(outputs, dict) and "loss" not in outputs:
+                raise ValueError(
+                    "The model did not return a loss from the inputs, only the following keys: "
+                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
+                )
+            # We don't use .loss here since the model may return tuples instead of ModelOutput.
+            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+
+        if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
+            loss *= self.accelerator.num_processes
+
+        return (loss, outputs) if return_outputs else loss
+
+    def is_local_process_zero(self) -> bool:
+        """
+        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
+        machines) main process.
+        """
+        return self.args.local_process_index == 0
+
+    def is_world_process_zero(self) -> bool:
+        """
+        Whether or not this process is the global main process (when training in a distributed fashion on several
+        machines, this is only going to be `True` for one process).
+        """
+        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
+        # process index.
+        if is_sagemaker_mp_enabled():
+            return smp.rank() == 0
+        else:
+            return self.args.process_index == 0
+
+    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
+        """
+        Will save the model, so you can reload it using `from_pretrained()`.
+
+        Will only save from the main process.
+        """
+
+        if output_dir is None:
+            output_dir = self.args.output_dir
+
+        if is_torch_xla_available():
+            self._save_tpu(output_dir)
+        elif is_sagemaker_mp_enabled():
+            # Calling the state_dict needs to be done on the wrapped model and on all processes.
+            os.makedirs(output_dir, exist_ok=True)
+            state_dict = self.model_wrapped.state_dict()
+            if self.args.should_save:
+                self._save(output_dir, state_dict=state_dict)
+            if IS_SAGEMAKER_MP_POST_1_10:
+                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
+                Path(os.path.join(output_dir, "user_content.pt")).touch()
+        elif self.is_fsdp_enabled:
+            if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
+                version.parse(accelerate_version) > version.parse("0.24.1")
+            ):
+                state_dict = self.accelerator.get_state_dict(self.model)
+                if self.args.should_save:
+                    self._save(output_dir, state_dict=state_dict)
+        elif self.is_deepspeed_enabled:
+            try:
+                state_dict = self.accelerator.get_state_dict(self.deepspeed)
+                if self.args.should_save:
+                    self._save(output_dir, state_dict=state_dict)
+            except ValueError:
+                logger.warning(
+                    " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
+                    " zero_to_fp32.py to recover weights"
+                )
+                if self.args.should_save:
+                    self._save(output_dir, state_dict={})
+                # remove the dummy state_dict
+                remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
+                self.model_wrapped.save_checkpoint(output_dir)
+
+        elif self.args.should_save:
+            self._save(output_dir)
+
+        # Push to the Hub when `save_model` is called by the user.
+        if self.args.push_to_hub and not _internal_call:
+            self.push_to_hub(commit_message="Model save")
+
+    def _save_tpu(self, output_dir: Optional[str] = None):
+        output_dir = output_dir if output_dir is not None else self.args.output_dir
+
+        logger.info(f"Saving model checkpoint to {output_dir}")
+        model = self.model
+        xm.mark_step()
+
+        if xm.is_master_ordinal(local=False):
+            os.makedirs(output_dir, exist_ok=True)
+            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+        # Save a trained model and configuration using `save_pretrained()`.
+        # They can then be reloaded using `from_pretrained()`
+        supported_classes = (PushToHubMixin,)
+        xm.rendezvous("saving_checkpoint")
+        if self.is_fsdp_xla_v1_enabled:
+            ckpt = {
+                "model": model.state_dict(),
+                "shard_metadata": model.get_shard_metadata(),
+            }
+            ckpt_path = os.path.join(
+                output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}"
+            )
+            # All ranks save sharded checkpoint
+            xm.save(ckpt, ckpt_path, master_only=False)
+            # Make sure all ranks have saved checkpoints
+            xm.rendezvous("save_full_checkpoints")
+            # Master save full checkpoint
+            if self.args.should_save:
+                from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
+
+                full_state_dict, _ = consolidate_sharded_model_checkpoints(
+                    ckpt_prefix=os.path.join(output_dir, ""),
+                    ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
+                    save_model=False,
+                )
+                model = model.module.module
+                unwrapped_model = self.accelerator.unwrap_model(model)
+                if isinstance(unwrapped_model, supported_classes):
+                    unwrapped_model.save_pretrained(
+                        output_dir,
+                        state_dict=full_state_dict,
+                        save_function=xm.save,
+                        safe_serialization=self.args.save_safetensors,
+                    )
+                else:
+                    logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+                    xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+        elif not isinstance(model, supported_classes):
+            if isinstance(self.accelerator.unwrap_model(model), supported_classes):
+                self.accelerator.unwrap_model(model).save_pretrained(
+                    output_dir,
+                    is_main_process=self.args.should_save,
+                    state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
+                    save_function=xm.save,
+                    safe_serialization=self.args.save_safetensors,
+                )
+            else:
+                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+                state_dict = xm._maybe_convert_to_cpu(model.state_dict())
+                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+        else:
+            model.save_pretrained(
+                output_dir,
+                is_main_process=self.args.should_save,
+                save_function=xm.save,
+                safe_serialization=self.args.save_safetensors,
+                state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
+            )
+        if self.processing_class is not None and self.args.should_save:
+            self.processing_class.save_pretrained(output_dir)
+
+    def _save(self, output_dir: Optional[str] = None, state_dict=None):
+        # If we are executing this function, we are the process zero, so we don't check for that.
+        output_dir = output_dir if output_dir is not None else self.args.output_dir
+        os.makedirs(output_dir, exist_ok=True)
+        logger.info(f"Saving model checkpoint to {output_dir}")
+
+        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
+        # Save a trained model and configuration using `save_pretrained()`.
+        # They can then be reloaded using `from_pretrained()`
+        if not isinstance(self.model, supported_classes):
+            if state_dict is None:
+                state_dict = self.model.state_dict()
+
+            if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
+                self.accelerator.unwrap_model(self.model).save_pretrained(
+                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+                )
+            else:
+                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+                if self.args.save_safetensors:
+                    safetensors.torch.save_file(
+                        state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
+                    )
+                else:
+                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+        else:
+            self.model.save_pretrained(
+                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+            )
+
+        if self.processing_class is not None:
+            self.processing_class.save_pretrained(output_dir)
+
+        # Good practice: save your training arguments together with the trained model
+        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+    def store_flos(self):
+        # Storing the number of floating-point operations that went into the model
+        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+            self.state.total_flos += (
+                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
+            )
+            self.current_flos = 0
+        else:
+            self.state.total_flos += self.current_flos
+            self.current_flos = 0
+
+    def _sorted_checkpoints(
+        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
+    ) -> List[str]:
+        ordering_and_checkpoint_path = []
+
+        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
+
+        for path in glob_checkpoints:
+            if use_mtime:
+                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
+            else:
+                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
+                if regex_match is not None and regex_match.groups() is not None:
+                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
+
+        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
+        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
+        # Make sure we don't delete the best model.
+        if (
+            self.state.best_model_checkpoint is not None
+            and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted
+        ):
+            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
+            for i in range(best_model_index, len(checkpoints_sorted) - 2):
+                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
+        return checkpoints_sorted
+
+    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
+        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
+            return
+
+        # Check if we should delete older checkpoint(s)
+        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
+        if len(checkpoints_sorted) <= self.args.save_total_limit:
+            return
+
+        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
+        # we don't do to allow resuming.
+        save_total_limit = self.args.save_total_limit
+        if (
+            self.state.best_model_checkpoint is not None
+            and self.args.save_total_limit == 1
+            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
+        ):
+            save_total_limit = 2
+
+        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
+        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
+        for checkpoint in checkpoints_to_be_deleted:
+            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+            shutil.rmtree(checkpoint, ignore_errors=True)
+
+    def evaluate(
+        self,
+        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+        ignore_keys: Optional[List[str]] = None,
+        metric_key_prefix: str = "eval",
+    ) -> Dict[str, float]:
+        """
+        Run evaluation and returns metrics.
+
+        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+        (pass it to the init `compute_metrics` argument).
+
+        You can also subclass and override this method to inject custom behavior.
+
+        Args:
+            eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*):
+                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
+                not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
+                evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
+                `__len__` method.
+
+                
+
+                If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
+                separate evaluations on each dataset. This can be useful to monitor how training affects other
+                datasets or simply to get a more fine-grained evaluation.
+                When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
+                of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
+                `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
+                loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`.
+
+                
+
+            ignore_keys (`List[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "eval_bleu" if the prefix is "eval" (default)
+
+        Returns:
+            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+            dictionary also contains the epoch number which comes from the training state.
+        """
+        # handle multipe eval datasets
+        override = eval_dataset is not None
+        eval_dataset = eval_dataset if override else self.eval_dataset
+        if isinstance(eval_dataset, dict):
+            metrics = {}
+            for eval_dataset_name, _eval_dataset in eval_dataset.items():
+                dataset_metrics = self.evaluate(
+                    eval_dataset=_eval_dataset if override else eval_dataset_name,
+                    ignore_keys=ignore_keys,
+                    metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
+                )
+                metrics.update(dataset_metrics)
+            return metrics
+
+        # memory metrics - must set up as early as possible
+        self._memory_tracker.start()
+
+        eval_dataloader = self.get_eval_dataloader(eval_dataset)
+        if self.is_fsdp_xla_v2_enabled:
+            eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
+
+        start_time = time.time()
+
+        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+        output = eval_loop(
+            eval_dataloader,
+            description="Evaluation",
+            # No point gathering the predictions if there are no metrics, otherwise we defer to
+            # self.args.prediction_loss_only
+            prediction_loss_only=True if self.compute_metrics is None else None,
+            ignore_keys=ignore_keys,
+            metric_key_prefix=metric_key_prefix,
+        )
+
+        total_batch_size = self.args.eval_batch_size * self.args.world_size
+        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+        if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
+        output.metrics.update(
+            speed_metrics(
+                metric_key_prefix,
+                start_time,
+                num_samples=output.num_samples,
+                num_steps=math.ceil(output.num_samples / total_batch_size),
+            )
+        )
+
+        self.log(output.metrics)
+
+        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+            xm.master_print(met.metrics_report())
+
+        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
+
+        self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+        return output.metrics
+
+    def predict(
+        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
+    ) -> PredictionOutput:
+        """
+        Run prediction and returns predictions and potential metrics.
+
+        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+        will also return metrics, like in `evaluate()`.
+
+        Args:
+            test_dataset (`Dataset`):
+                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
+                `model.forward()` method are automatically removed. Has to implement the method `__len__`
+            ignore_keys (`List[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "test_bleu" if the prefix is "test" (default)
+
+        
+
+        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
+        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
+        one array. The padding index is -100.
+
+        
+
+        Returns: *NamedTuple* A namedtuple with the following keys:
+
+            - predictions (`np.ndarray`): The predictions on `test_dataset`.
+            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+              labels).
+        """
+        # memory metrics - must set up as early as possible
+        self._memory_tracker.start()
+
+        test_dataloader = self.get_test_dataloader(test_dataset)
+        start_time = time.time()
+
+        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+        output = eval_loop(
+            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
+        )
+        total_batch_size = self.args.eval_batch_size * self.args.world_size
+        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+        if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
+        output.metrics.update(
+            speed_metrics(
+                metric_key_prefix,
+                start_time,
+                num_samples=output.num_samples,
+                num_steps=math.ceil(output.num_samples / total_batch_size),
+            )
+        )
+
+        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
+        self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
+
+    def evaluation_loop(
+        self,
+        dataloader: DataLoader,
+        description: str,
+        prediction_loss_only: Optional[bool] = None,
+        ignore_keys: Optional[List[str]] = None,
+        metric_key_prefix: str = "eval",
+    ) -> EvalLoopOutput:
+        """
+        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+        Works both with or without labels.
+        """
+        args = self.args
+
+        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+        # if eval is called w/o train, handle model prep here
+        if self.is_deepspeed_enabled and self.deepspeed is None:
+            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+        if len(self.accelerator._models) == 0 and model is self.model:
+            start_time = time.time()
+            model = (
+                self.accelerator.prepare(model)
+                if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
+                else self.accelerator.prepare_model(model, evaluation_mode=True)
+            )
+            self.model_preparation_time = round(time.time() - start_time, 4)
+
+            if self.is_fsdp_enabled:
+                self.model = model
+
+            # for the rest of this function `model` is the outside model, whether it was wrapped or not
+            if model is not self.model:
+                self.model_wrapped = model
+
+            # backward compatibility
+            if self.is_deepspeed_enabled:
+                self.deepspeed = self.model_wrapped
+
+        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+        # while ``train`` is running, cast it to the right dtype first and then put on device
+        if not self.is_in_train:
+            if args.fp16_full_eval:
+                model = model.to(dtype=torch.float16, device=args.device)
+            elif args.bf16_full_eval:
+                model = model.to(dtype=torch.bfloat16, device=args.device)
+
+        batch_size = self.args.eval_batch_size
+
+        logger.info(f"\n***** Running {description} *****")
+        if has_length(dataloader):
+            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
+        else:
+            logger.info("  Num examples: Unknown")
+        logger.info(f"  Batch size = {batch_size}")
+
+        model.eval()
+        if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
+            self.optimizer.eval()
+
+        self.callback_handler.eval_dataloader = dataloader
+        # Do this before wrapping.
+        eval_dataset = getattr(dataloader, "dataset", None)
+
+        if args.past_index >= 0:
+            self._past = None
+
+        # Initialize containers
+        all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+        all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+        all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+        all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+
+        metrics = None
+        eval_set_kwargs = {}
+
+        # Will be useful when we have an iterable dataset so don't know its length.
+        observed_num_examples = 0
+
+        # Main evaluation loop
+        for step, inputs in enumerate(dataloader):
+            # Update the observed num examples
+            observed_batch_size = find_batch_size(inputs)
+            if observed_batch_size is not None:
+                observed_num_examples += observed_batch_size
+                # For batch samplers, batch_size is not known by the dataloader in advance.
+                if batch_size is None:
+                    batch_size = observed_batch_size
+
+            # Prediction step
+            losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+            main_input_name = getattr(self.model, "main_input_name", "input_ids")
+            inputs_decode = (
+                self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
+            )
+
+            if is_torch_xla_available():
+                xm.mark_step()
+
+            # Update containers
+            if losses is not None:
+                losses = self.gather_function((losses.repeat(batch_size)))
+                all_losses.add(losses)
+            if inputs_decode is not None:
+                inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
+                inputs_decode = self.gather_function((inputs_decode))
+                if not self.args.batch_eval_metrics or description == "Prediction":
+                    all_inputs.add(inputs_decode)
+            if labels is not None:
+                # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
+                labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
+            if logits is not None:
+                logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
+                if self.preprocess_logits_for_metrics is not None:
+                    logits = self.preprocess_logits_for_metrics(logits, labels)
+                logits = self.gather_function((logits))
+                if not self.args.batch_eval_metrics or description == "Prediction":
+                    all_preds.add(logits)
+            if labels is not None:
+                labels = self.gather_function((labels))
+                if not self.args.batch_eval_metrics or description == "Prediction":
+                    all_labels.add(labels)
+
+            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+            if self.args.batch_eval_metrics:
+                if self.compute_metrics is not None and logits is not None and labels is not None:
+                    is_last_step = self.accelerator.gradient_state.end_of_dataloader
+                    batch_kwargs = {}
+                    batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
+                    batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
+                    metrics = self.compute_metrics(
+                        EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
+                        compute_result=is_last_step,
+                    )
+
+                del losses, logits, labels, inputs
+                torch.cuda.empty_cache()
+
+            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+            elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+                all_losses.to_cpu_and_numpy()
+                all_preds.to_cpu_and_numpy()
+                all_labels.to_cpu_and_numpy()
+                all_inputs.to_cpu_and_numpy()
+
+                del losses, logits, labels, inputs
+                torch.cuda.empty_cache()
+
+        # After all calls to `.gather_function`, reset to `gather_for_metrics`:
+        self.gather_function = self.accelerator.gather_for_metrics
+        if args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of the evaluation loop
+            delattr(self, "_past")
+
+        # Gather all remaining tensors and put them back on the CPU
+        all_losses = all_losses.get_arrays()
+        all_preds = all_preds.get_arrays()
+        all_labels = all_labels.get_arrays()
+        all_inputs = all_inputs.get_arrays()
+
+        # Number of samples
+        if has_length(eval_dataset):
+            num_samples = len(eval_dataset)
+        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
+        # methods. Therefore we need to make sure it also has the attribute.
+        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
+            num_samples = eval_dataset.num_examples
+        else:
+            if has_length(dataloader):
+                num_samples = self.num_examples(dataloader)
+            else:  # both len(dataloader.dataset) and len(dataloader) fail
+                num_samples = observed_num_examples
+        if num_samples == 0 and observed_num_examples > 0:
+            num_samples = observed_num_examples
+
+        # Metrics!
+        if (
+            self.compute_metrics is not None
+            and all_preds is not None
+            and all_labels is not None
+            and not self.args.batch_eval_metrics
+        ):
+            eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
+            eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
+            metrics = self.compute_metrics(
+                EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
+            )
+        elif metrics is None:
+            metrics = {}
+
+        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+        metrics = denumpify_detensorize(metrics)
+
+        if isinstance(all_losses, list) and all_losses:
+            metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
+        elif isinstance(all_losses, np.ndarray):
+            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
+        if hasattr(self, "jit_compilation_time"):
+            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
+        if hasattr(self, "model_preparation_time"):
+            metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
+
+        # Prefix all keys with metric_key_prefix + '_'
+        for key in list(metrics.keys()):
+            if not key.startswith(f"{metric_key_prefix}_"):
+                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
+
+    def _nested_gather(self, tensors, name=None):
+        """
+        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+        concatenating them to `gathered`
+        """
+        if tensors is None:
+            return
+        if is_torch_xla_available():
+            if name is None:
+                name = "nested_gather"
+            tensors = nested_xla_mesh_reduce(tensors, name)
+        elif is_sagemaker_mp_enabled():
+            tensors = smp_gather(tensors)
+        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
+            self.args.distributed_state is None and self.args.local_rank != -1
+        ):
+            tensors = distributed_concat(tensors)
+        return tensors
+
+    def prediction_step(
+        self,
+        model: nn.Module,
+        inputs: Dict[str, Union[torch.Tensor, Any]],
+        prediction_loss_only: bool,
+        ignore_keys: Optional[List[str]] = None,
+    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+        """
+        Perform an evaluation step on `model` using `inputs`.
+
+        Subclass and override to inject custom behavior.
+
+        Args:
+            model (`nn.Module`):
+                The model to evaluate.
+            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+
+                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+                argument `labels`. Check your model's documentation for all accepted arguments.
+            prediction_loss_only (`bool`):
+                Whether or not to return the loss only.
+            ignore_keys (`List[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+
+        Return:
+            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
+            logits and labels (each being optional).
+        """
+        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
+        # For CLIP-like models capable of returning loss values.
+        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
+        # is `True` in `model.forward`.
+        return_loss = inputs.get("return_loss", None)
+        if return_loss is None:
+            return_loss = self.can_return_loss
+        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
+
+        inputs = self._prepare_inputs(inputs)
+        if ignore_keys is None:
+            if hasattr(self.model, "config"):
+                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+            else:
+                ignore_keys = []
+
+        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
+        if has_labels or loss_without_labels:
+            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
+            if len(labels) == 1:
+                labels = labels[0]
+        else:
+            labels = None
+
+        with torch.no_grad():
+            if is_sagemaker_mp_enabled():
+                raw_outputs = smp_forward_only(model, inputs)
+                if has_labels or loss_without_labels:
+                    if isinstance(raw_outputs, dict):
+                        loss_mb = raw_outputs["loss"]
+                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
+                    else:
+                        loss_mb = raw_outputs[0]
+                        logits_mb = raw_outputs[1:]
+
+                    loss = loss_mb.reduce_mean().detach().cpu()
+                    logits = smp_nested_concat(logits_mb)
+                else:
+                    loss = None
+                    if isinstance(raw_outputs, dict):
+                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
+                    else:
+                        logits_mb = raw_outputs
+                    logits = smp_nested_concat(logits_mb)
+            else:
+                if has_labels or loss_without_labels:
+                    with self.compute_loss_context_manager():
+                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
+                    loss = loss.mean().detach()
+
+                    if isinstance(outputs, dict):
+                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
+                    else:
+                        logits = outputs[1:]
+                else:
+                    loss = None
+                    with self.compute_loss_context_manager():
+                        outputs = model(**inputs)
+                    if isinstance(outputs, dict):
+                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
+                    else:
+                        logits = outputs
+                    # TODO: this needs to be fixed and made cleaner later.
+                    if self.args.past_index >= 0:
+                        self._past = outputs[self.args.past_index - 1]
+
+        if prediction_loss_only:
+            return (loss, None, None)
+
+        logits = nested_detach(logits)
+        if len(logits) == 1:
+            logits = logits[0]
+
+        return (loss, logits, labels)
+
+    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
+        """
+        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
+        operations for every backward + forward pass. If using another model, either implement such a method in the
+        model or subclass and override this method.
+
+        Args:
+            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+
+        Returns:
+            `int`: The number of floating-point operations.
+        """
+        if hasattr(self.model, "floating_point_ops"):
+            return self.model.floating_point_ops(inputs)
+        else:
+            return 0
+
+    def init_hf_repo(self, token: Optional[str] = None):
+        """
+        Initializes a git repo in `self.args.hub_model_id`.
+        """
+        # Only on process zero
+        if not self.is_world_process_zero():
+            return
+
+        if self.args.hub_model_id is None:
+            repo_name = Path(self.args.output_dir).absolute().name
+        else:
+            repo_name = self.args.hub_model_id
+
+        token = token if token is not None else self.args.hub_token
+        repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True)
+        self.hub_model_id = repo_url.repo_id
+        self.push_in_progress = None
+
+    def create_model_card(
+        self,
+        language: Optional[str] = None,
+        license: Optional[str] = None,
+        tags: Union[str, List[str], None] = None,
+        model_name: Optional[str] = None,
+        finetuned_from: Optional[str] = None,
+        tasks: Union[str, List[str], None] = None,
+        dataset_tags: Union[str, List[str], None] = None,
+        dataset: Union[str, List[str], None] = None,
+        dataset_args: Union[str, List[str], None] = None,
+    ):
+        """
+        Creates a draft of a model card using the information available to the `Trainer`.
+
+        Args:
+            language (`str`, *optional*):
+                The language of the model (if applicable)
+            license (`str`, *optional*):
+                The license of the model. Will default to the license of the pretrained model used, if the original
+                model given to the `Trainer` comes from a repo on the Hub.
+            tags (`str` or `List[str]`, *optional*):
+                Some tags to be included in the metadata of the model card.
+            model_name (`str`, *optional*):
+                The name of the model.
+            finetuned_from (`str`, *optional*):
+                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
+                of the original model given to the `Trainer` (if it comes from the Hub).
+            tasks (`str` or `List[str]`, *optional*):
+                One or several task identifiers, to be included in the metadata of the model card.
+            dataset_tags (`str` or `List[str]`, *optional*):
+                One or several dataset tags, to be included in the metadata of the model card.
+            dataset (`str` or `List[str]`, *optional*):
+                One or several dataset identifiers, to be included in the metadata of the model card.
+            dataset_args (`str` or `List[str]`, *optional*):
+               One or several dataset arguments, to be included in the metadata of the model card.
+        """
+        if not self.is_world_process_zero():
+            return
+
+        model_card_filepath = os.path.join(self.args.output_dir, "README.md")
+        is_peft_library = False
+        if os.path.exists(model_card_filepath):
+            library_name = ModelCard.load(model_card_filepath).data.get("library_name")
+            is_peft_library = library_name == "peft"
+
+            # Append existing tags in `tags`
+            existing_tags = ModelCard.load(model_card_filepath).data.tags
+            if tags is not None and existing_tags is not None:
+                if isinstance(tags, str):
+                    tags = [tags]
+                for tag in existing_tags:
+                    if tag not in tags:
+                        tags.append(tag)
+
+        training_summary = TrainingSummary.from_trainer(
+            self,
+            language=language,
+            license=license,
+            tags=tags,
+            model_name=model_name,
+            finetuned_from=finetuned_from,
+            tasks=tasks,
+            dataset_tags=dataset_tags,
+            dataset=dataset,
+            dataset_args=dataset_args,
+        )
+        model_card = training_summary.to_model_card()
+        with open(model_card_filepath, "w") as f:
+            f.write(model_card)
+
+        if is_peft_library:
+            self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
+
+    def _push_from_checkpoint(self, checkpoint_folder):
+        # Only push from one node.
+        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
+            return
+        # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True.
+        if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
+            return
+
+        output_dir = self.args.output_dir
+        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
+        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
+        #  Add sharded checkpoints if we have an index
+        for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
+            index_path = os.path.join(checkpoint_folder, index_file)
+            if os.path.isfile(index_path):
+                modeling_files.append(index_file)
+                with open(index_path) as f:
+                    index = json.loads(f.read())
+                shard_files = list(set(index["weight_map"].values()))
+                modeling_files.extend(shard_files)
+        if is_peft_available():
+            modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
+        for modeling_file in modeling_files:
+            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
+                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
+        # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure.
+        if self.processing_class is not None:
+            self.processing_class.save_pretrained(output_dir)
+        # Same for the training arguments
+        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+        if self.args.save_strategy == SaveStrategy.STEPS:
+            commit_message = f"Training in progress, step {self.state.global_step}"
+        else:
+            commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
+
+        model_push_job = upload_folder(
+            repo_id=self.hub_model_id,
+            folder_path=output_dir,
+            commit_message=commit_message,
+            token=self.args.hub_token,
+            run_as_future=True,
+            ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
+        )
+
+        push_jobs = [model_push_job]
+
+        if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
+            path_in_repo = (
+                "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
+            )
+            checkpoint_push = upload_folder(
+                repo_id=self.hub_model_id,
+                folder_path=checkpoint_folder,
+                path_in_repo=path_in_repo,
+                commit_message=commit_message + ", checkpoint",
+                token=self.args.hub_token,
+                run_as_future=True,
+            )
+            push_jobs.append(checkpoint_push)
+
+        if self.push_in_progress is None or self.push_in_progress.is_done():
+            self.push_in_progress = PushInProgress(push_jobs)
+        else:
+            self.push_in_progress.jobs.extend(push_jobs)
+
+    def _finish_current_push(self):
+        if not hasattr(self, "push_in_progress"):
+            return
+        if self.push_in_progress is not None and not self.push_in_progress.is_done():
+            logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
+            self.push_in_progress.wait_until_done()
+
+    def push_to_hub(
+        self,
+        commit_message: Optional[str] = "End of training",
+        blocking: bool = True,
+        token: Optional[str] = None,
+        revision: Optional[str] = None,
+        **kwargs,
+    ) -> str:
+        """
+        Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.
+
+        Parameters:
+            commit_message (`str`, *optional*, defaults to `"End of training"`):
+                Message to commit while pushing.
+            blocking (`bool`, *optional*, defaults to `True`):
+                Whether the function should return only when the `git push` has finished.
+            token (`str`, *optional*, defaults to `None`):
+                Token with write permission to overwrite Trainer's original args.
+            revision (`str`, *optional*):
+                The git revision to commit from. Defaults to the head of the "main" branch.
+            kwargs (`Dict[str, Any]`, *optional*):
+                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
+
+        Returns:
+            The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
+            progress of the commit if `blocking=True`.
+        """
+        model_name = kwargs.pop("model_name", None)
+        if model_name is None and self.args.should_save:
+            if self.args.hub_model_id is None:
+                model_name = Path(self.args.output_dir).name
+            else:
+                model_name = self.args.hub_model_id.split("/")[-1]
+        token = token if token is not None else self.args.hub_token
+
+        # In case the user calls this method with args.push_to_hub = False
+        if self.hub_model_id is None:
+            self.init_hf_repo(token=token)
+
+        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
+        # self.args.should_save.
+        self.save_model(_internal_call=True)
+
+        # Only push from one node.
+        if not self.is_world_process_zero():
+            return
+
+        # Add additional tags in the case the model has already some tags and users pass
+        # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
+        # from all models since Trainer does not call `model.push_to_hub`.
+        if getattr(self.model, "model_tags", None) is not None:
+            if "tags" not in kwargs:
+                kwargs["tags"] = []
+
+            # If it is a string, convert it to a list
+            if isinstance(kwargs["tags"], str):
+                kwargs["tags"] = [kwargs["tags"]]
+
+            for model_tag in self.model.model_tags:
+                if model_tag not in kwargs["tags"]:
+                    kwargs["tags"].append(model_tag)
+
+        self.create_model_card(model_name=model_name, **kwargs)
+
+        # Wait for the current upload to be finished.
+        self._finish_current_push()
+        return upload_folder(
+            repo_id=self.hub_model_id,
+            folder_path=self.args.output_dir,
+            commit_message=commit_message,
+            token=token,
+            run_as_future=not blocking,
+            ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
+            revision=revision,
+        )
+
+    #
+    # Deprecated code
+    #
+
+    def prediction_loop(
+        self,
+        dataloader: DataLoader,
+        description: str,
+        prediction_loss_only: Optional[bool] = None,
+        ignore_keys: Optional[List[str]] = None,
+        metric_key_prefix: str = "eval",
+    ) -> EvalLoopOutput:
+        """
+        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+        Works both with or without labels.
+        """
+        args = self.args
+
+        if not has_length(dataloader):
+            raise ValueError("dataloader must implement a working __len__")
+
+        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+        # if eval is called w/o train, handle model prep here
+        if self.is_deepspeed_enabled and self.deepspeed is None:
+            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+        if len(self.accelerator._models) == 0 and model is self.model:
+            model = (
+                self.accelerator.prepare(model)
+                if self.is_deepspeed_enabled or self.is_fsdp_enabled
+                else self.accelerator.prepare_model(model, evaluation_mode=True)
+            )
+
+            if self.is_fsdp_enabled:
+                self.model = model
+
+            # for the rest of this function `model` is the outside model, whether it was wrapped or not
+            if model is not self.model:
+                self.model_wrapped = model
+
+            # backward compatibility
+            if self.is_deepspeed_enabled:
+                self.deepspeed = self.model_wrapped
+
+        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+        # while ``train`` is running, cast it to the right dtype first and then put on device
+        if not self.is_in_train:
+            if args.fp16_full_eval:
+                model = model.to(dtype=torch.float16, device=args.device)
+            elif args.bf16_full_eval:
+                model = model.to(dtype=torch.bfloat16, device=args.device)
+
+        batch_size = (
+            dataloader.total_batch_size
+            if getattr(dataloader, "_is_accelerate_prepared", False)
+            else dataloader.batch_size
+        )
+
+        if batch_size is None:
+            raise ValueError(
+                "Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size."
+            )
+
+        num_examples = self.num_examples(dataloader)
+        logger.info(f"\n***** Running {description} *****")
+        logger.info(f"  Num examples = {num_examples}")
+        logger.info(f"  Batch size = {batch_size}")
+
+        losses_host: torch.Tensor = None
+        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
+        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
+        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
+        metrics: Optional[dict] = None
+        eval_set_kwargs: dict = {}
+
+        world_size = max(1, args.world_size)
+
+        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
+        if not prediction_loss_only:
+            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
+            # a batch size to the sampler)
+            make_multiple_of = None
+            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
+                make_multiple_of = dataloader.sampler.batch_size
+            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+
+        model.eval()
+        if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
+            self.optimizer.eval()
+
+        if args.past_index >= 0:
+            self._past = None
+
+        self.callback_handler.eval_dataloader = dataloader
+
+        for step, inputs in enumerate(dataloader):
+            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+            main_input_name = getattr(self.model, "main_input_name", "input_ids")
+            inputs_decode = (
+                self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
+            )
+
+            if loss is not None:
+                losses = loss.repeat(batch_size)
+                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+            if logits is not None:
+                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+            if labels is not None:
+                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+            if inputs_decode is not None:
+                inputs_host = (
+                    inputs_decode
+                    if inputs_host is None
+                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+                )
+            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+            if self.args.batch_eval_metrics:
+                if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
+                    is_last_step = self.accelerator.gradient_state.end_of_dataloader
+                    batch_kwargs = {}
+                    batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None
+                    batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None
+                    metrics = self.compute_metrics(
+                        EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs),
+                        compute_result=is_last_step,
+                    )
+
+            if self.args.batch_eval_metrics or (
+                args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
+            ):
+                # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+                if not prediction_loss_only:
+                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+                # Set back to None to begin a new accumulation
+                del losses_host, preds_host, labels_host, inputs_host
+                torch.cuda.empty_cache()
+                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
+
+        if args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of the evaluation loop
+            delattr(self, "_past")
+
+        # Gather all remaining tensors and put them back on the CPU
+        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+        if not prediction_loss_only:
+            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+        eval_loss = eval_losses_gatherer.finalize()
+        preds = preds_gatherer.finalize() if not prediction_loss_only else None
+        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
+        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
+
+        if (
+            self.compute_metrics is not None
+            and preds is not None
+            and label_ids is not None
+            and not self.args.batch_eval_metrics
+        ):
+            eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None
+            eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None
+            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs))
+        elif metrics is None:
+            metrics = {}
+
+        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+        metrics = denumpify_detensorize(metrics)
+
+        if eval_loss is not None:
+            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
+
+        # Prefix all keys with metric_key_prefix + '_'
+        for key in list(metrics.keys()):
+            if not key.startswith(f"{metric_key_prefix}_"):
+                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
+
+    def _gather_and_numpify(self, tensors, name):
+        """
+        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+        concatenating them to `gathered`
+        """
+        if tensors is None:
+            return
+        if is_torch_xla_available():
+            tensors = nested_xla_mesh_reduce(tensors, name)
+        elif is_sagemaker_mp_enabled():
+            tensors = smp_gather(tensors)
+        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+            tensors = distributed_concat(tensors)
+
+        return nested_numpify(tensors)
+
+    def _add_sm_patterns_to_gitignore(self) -> None:
+        """Add SageMaker Checkpointing patterns to .gitignore file."""
+        # Make sure we only do this on the main process
+        if not self.is_world_process_zero():
+            return
+
+        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
+
+        # Get current .gitignore content
+        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
+            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
+                current_content = f.read()
+        else:
+            current_content = ""
+
+        # Add the patterns to .gitignore
+        content = current_content
+        for pattern in patterns:
+            if pattern not in content:
+                if content.endswith("\n"):
+                    content += pattern
+                else:
+                    content += f"\n{pattern}"
+
+        # Write the .gitignore file if it has changed
+        if content != current_content:
+            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
+                logger.debug(f"Writing .gitignore file. Content: {content}")
+                f.write(content)
+
+        self.repo.git_add(".gitignore")
+
+        # avoid race condition with git status
+        time.sleep(0.5)
+
+        if not self.repo.is_repo_clean():
+            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
+            self.repo.git_push()
+
+    def create_accelerator_and_postprocess(self):
+        # We explicitly don't rely on the `Accelerator` to do gradient accumulation
+        grad_acc_kwargs = {}
+        if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
+            grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
+
+        # check if num_steps is attempted to be passed in gradient_accumulation_kwargs
+        if "num_steps" in grad_acc_kwargs:
+            if self.args.gradient_accumulation_steps > 1:
+                # raise because we do not know which setting is intended.
+                raise ValueError(
+                    "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
+                    "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
+                )
+            else:
+                self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
+
+        accelerator_config = self.args.accelerator_config.to_dict()
+
+        if is_accelerate_available("0.28.0"):
+            dataloader_config = DataLoaderConfiguration(
+                split_batches=accelerator_config.pop("split_batches"),
+                dispatch_batches=accelerator_config.pop("dispatch_batches"),
+                even_batches=accelerator_config.pop("even_batches"),
+                use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
+            )
+            if is_accelerate_available("1.1.0"):
+                dataloader_config.data_seed = self.args.data_seed
+
+        non_blocking = accelerator_config.pop("non_blocking")
+        if not is_accelerate_available("0.30.0"):
+            if non_blocking:
+                raise ImportError(
+                    "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature."
+                )
+        else:
+            if non_blocking and not self.args.dataloader_pin_memory:
+                logger.warning(
+                    "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
+                )
+            dataloader_config.non_blocking = non_blocking
+        # this would have been updated above, no need for it anymore
+        accelerator_config.pop("gradient_accumulation_kwargs")
+
+        args = {
+            "deepspeed_plugin": self.args.deepspeed_plugin,
+        }
+        if is_accelerate_available("0.28.0"):
+            args["dataloader_config"] = dataloader_config
+        else:
+            args.update(accelerator_config)
+
+        # create accelerator object
+        self.accelerator = Accelerator(**args)
+        # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
+        self.gather_function = self.accelerator.gather_for_metrics
+
+        if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
+            self.gather_function = functools.partial(
+                self.gather_function, use_gather_object=self.args.eval_use_gather_object
+            )
+
+        # deepspeed and accelerate flags covering both trainer args and accelerate launcher
+        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+
+        # post accelerator creation setup
+        if self.is_fsdp_enabled:
+            fsdp_plugin = self.accelerator.state.fsdp_plugin
+            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
+                "limit_all_gathers", fsdp_plugin.limit_all_gathers
+            )
+            fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
+                "activation_checkpointing", fsdp_plugin.activation_checkpointing
+            )
+            if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
+                raise ValueError(
+                    "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
+                    "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
+                    "when using FSDP."
+                )
+
+        if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
+            self.propagate_args_to_deepspeed()
+
+        # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
+        if (
+            self.args.save_only_model
+            and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
+            and self.args.load_best_model_at_end
+        ):
+            wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
+            raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
+
+        # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
+        if (
+            self.is_deepspeed_enabled
+            and self.accelerator.state.deepspeed_plugin.zero_stage == 3
+            and self.args.auto_find_batch_size
+        ):
+            raise ValueError(
+                "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
+            )
+
+    def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
+        """
+        Sets values in the deepspeed plugin based on the Trainer args
+        """
+        from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+        ds_plugin = self.accelerator.state.deepspeed_plugin
+
+        ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
+        ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
+        ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
+
+    def _fsdp_qlora_plugin_updates(self):
+        if self.is_fsdp_enabled and _is_peft_model(self.model):
+            from peft import LoraConfig
+            from peft.utils.other import fsdp_auto_wrap_policy
+
+            if isinstance(self.model.active_peft_config, LoraConfig):
+                fsdp_plugin = self.accelerator.state.fsdp_plugin
+                fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
+            if (
+                getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+                and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
+                and version.parse(accelerate_version) > version.parse("0.27.0")
+            ):
+                fsdp_plugin.set_mixed_precision(
+                    self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
+                )
+
+    def get_batch_samples(self, epoch_iterator, num_batches):
+        batch_samples = []
+        num_items_in_batch = None
+        for _ in range(num_batches):
+            try:
+                batch_samples += [next(epoch_iterator)]
+            except StopIteration:
+                break
+
+        if len(batch_samples) > 0 and "labels" in batch_samples[0]:
+            # For now we don't support object detection
+            try:
+                num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
+            except (TypeError, AttributeError):
+                pass
+
+        if self.args.average_tokens_across_devices and num_items_in_batch is not None:
+            num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
+
+        if torch.is_tensor(num_items_in_batch):
+            num_items_in_batch = num_items_in_batch.item()
+
+        return batch_samples, num_items_in_batch