diff --git a/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/INSTALLER b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/METADATA b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..559eb3ebd454aad3270983ebe8972c59a9a6d71b --- /dev/null +++ b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/METADATA @@ -0,0 +1,256 @@ +Metadata-Version: 2.4 +Name: fsspec +Version: 2025.9.0 +Summary: File-system specification +Project-URL: Changelog, https://filesystem-spec.readthedocs.io/en/latest/changelog.html +Project-URL: Documentation, https://filesystem-spec.readthedocs.io/en/latest/ +Project-URL: Homepage, https://github.com/fsspec/filesystem_spec +Maintainer-email: Martin Durant +License-Expression: BSD-3-Clause +License-File: LICENSE +Keywords: file +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Requires-Python: >=3.9 +Provides-Extra: abfs +Requires-Dist: adlfs; extra == 'abfs' +Provides-Extra: adl +Requires-Dist: adlfs; extra == 'adl' +Provides-Extra: arrow +Requires-Dist: pyarrow>=1; extra == 'arrow' +Provides-Extra: dask +Requires-Dist: dask; extra == 'dask' +Requires-Dist: distributed; extra == 'dask' +Provides-Extra: dev +Requires-Dist: pre-commit; extra == 'dev' +Requires-Dist: ruff>=0.5; extra == 'dev' +Provides-Extra: doc +Requires-Dist: numpydoc; extra == 'doc' +Requires-Dist: sphinx; extra == 'doc' +Requires-Dist: sphinx-design; extra == 'doc' +Requires-Dist: sphinx-rtd-theme; extra == 'doc' +Requires-Dist: yarl; extra == 'doc' +Provides-Extra: dropbox +Requires-Dist: dropbox; extra == 'dropbox' +Requires-Dist: dropboxdrivefs; extra == 'dropbox' +Requires-Dist: requests; extra == 'dropbox' +Provides-Extra: entrypoints +Provides-Extra: full +Requires-Dist: adlfs; extra == 'full' +Requires-Dist: aiohttp!=4.0.0a0,!=4.0.0a1; extra == 'full' +Requires-Dist: dask; extra == 'full' +Requires-Dist: distributed; extra == 'full' +Requires-Dist: dropbox; extra == 'full' +Requires-Dist: dropboxdrivefs; extra == 'full' +Requires-Dist: fusepy; extra == 'full' +Requires-Dist: gcsfs; extra == 'full' +Requires-Dist: libarchive-c; extra == 'full' +Requires-Dist: ocifs; extra == 'full' +Requires-Dist: panel; extra == 'full' +Requires-Dist: paramiko; extra == 'full' +Requires-Dist: pyarrow>=1; extra == 'full' +Requires-Dist: pygit2; extra == 'full' +Requires-Dist: requests; extra == 'full' +Requires-Dist: s3fs; extra == 'full' +Requires-Dist: smbprotocol; extra == 'full' +Requires-Dist: tqdm; extra == 'full' +Provides-Extra: fuse +Requires-Dist: fusepy; extra == 'fuse' +Provides-Extra: gcs +Requires-Dist: gcsfs; extra == 'gcs' +Provides-Extra: git +Requires-Dist: pygit2; extra == 'git' +Provides-Extra: github +Requires-Dist: requests; extra == 'github' +Provides-Extra: gs +Requires-Dist: gcsfs; extra == 'gs' +Provides-Extra: gui +Requires-Dist: panel; extra == 'gui' +Provides-Extra: hdfs +Requires-Dist: pyarrow>=1; extra == 'hdfs' +Provides-Extra: http +Requires-Dist: aiohttp!=4.0.0a0,!=4.0.0a1; extra == 'http' +Provides-Extra: libarchive +Requires-Dist: libarchive-c; extra == 'libarchive' +Provides-Extra: oci +Requires-Dist: ocifs; extra == 'oci' +Provides-Extra: s3 +Requires-Dist: s3fs; extra == 's3' +Provides-Extra: sftp +Requires-Dist: paramiko; extra == 'sftp' +Provides-Extra: smb +Requires-Dist: smbprotocol; extra == 'smb' +Provides-Extra: ssh +Requires-Dist: paramiko; extra == 'ssh' +Provides-Extra: test +Requires-Dist: aiohttp!=4.0.0a0,!=4.0.0a1; extra == 'test' +Requires-Dist: numpy; extra == 'test' +Requires-Dist: pytest; extra == 'test' +Requires-Dist: pytest-asyncio!=0.22.0; extra == 'test' +Requires-Dist: pytest-benchmark; extra == 'test' +Requires-Dist: pytest-cov; extra == 'test' +Requires-Dist: pytest-mock; extra == 'test' +Requires-Dist: pytest-recording; extra == 'test' +Requires-Dist: pytest-rerunfailures; extra == 'test' +Requires-Dist: requests; extra == 'test' +Provides-Extra: test-downstream +Requires-Dist: aiobotocore<3.0.0,>=2.5.4; extra == 'test-downstream' +Requires-Dist: dask[dataframe,test]; extra == 'test-downstream' +Requires-Dist: moto[server]<5,>4; extra == 'test-downstream' +Requires-Dist: pytest-timeout; extra == 'test-downstream' +Requires-Dist: xarray; extra == 'test-downstream' +Provides-Extra: test-full +Requires-Dist: adlfs; extra == 'test-full' +Requires-Dist: aiohttp!=4.0.0a0,!=4.0.0a1; extra == 'test-full' +Requires-Dist: cloudpickle; extra == 'test-full' +Requires-Dist: dask; extra == 'test-full' +Requires-Dist: distributed; extra == 'test-full' +Requires-Dist: dropbox; extra == 'test-full' +Requires-Dist: dropboxdrivefs; extra == 'test-full' +Requires-Dist: fastparquet; extra == 'test-full' +Requires-Dist: fusepy; extra == 'test-full' +Requires-Dist: gcsfs; extra == 'test-full' +Requires-Dist: jinja2; extra == 'test-full' +Requires-Dist: kerchunk; extra == 'test-full' +Requires-Dist: libarchive-c; extra == 'test-full' +Requires-Dist: lz4; extra == 'test-full' +Requires-Dist: notebook; extra == 'test-full' +Requires-Dist: numpy; extra == 'test-full' +Requires-Dist: ocifs; extra == 'test-full' +Requires-Dist: pandas; extra == 'test-full' +Requires-Dist: panel; extra == 'test-full' +Requires-Dist: paramiko; extra == 'test-full' +Requires-Dist: pyarrow; extra == 'test-full' +Requires-Dist: pyarrow>=1; extra == 'test-full' +Requires-Dist: pyftpdlib; extra == 'test-full' +Requires-Dist: pygit2; extra == 'test-full' +Requires-Dist: pytest; extra == 'test-full' +Requires-Dist: pytest-asyncio!=0.22.0; extra == 'test-full' +Requires-Dist: pytest-benchmark; extra == 'test-full' +Requires-Dist: pytest-cov; extra == 'test-full' +Requires-Dist: pytest-mock; extra == 'test-full' +Requires-Dist: pytest-recording; extra == 'test-full' +Requires-Dist: pytest-rerunfailures; extra == 'test-full' +Requires-Dist: python-snappy; extra == 'test-full' +Requires-Dist: requests; extra == 'test-full' +Requires-Dist: smbprotocol; extra == 'test-full' +Requires-Dist: tqdm; extra == 'test-full' +Requires-Dist: urllib3; extra == 'test-full' +Requires-Dist: zarr; extra == 'test-full' +Requires-Dist: zstandard; (python_version < '3.14') and extra == 'test-full' +Provides-Extra: tqdm +Requires-Dist: tqdm; extra == 'tqdm' +Description-Content-Type: text/markdown + +# filesystem_spec + +[![PyPI version](https://badge.fury.io/py/fsspec.svg)](https://pypi.python.org/pypi/fsspec/) +[![Anaconda-Server Badge](https://anaconda.org/conda-forge/fsspec/badges/version.svg)](https://anaconda.org/conda-forge/fsspec) +![Build](https://github.com/fsspec/filesystem_spec/workflows/CI/badge.svg) +[![Docs](https://readthedocs.org/projects/filesystem-spec/badge/?version=latest)](https://filesystem-spec.readthedocs.io/en/latest/?badge=latest) + +A specification for pythonic filesystems. + +## Install + +```bash +pip install fsspec +``` + +would install the base fsspec. Various optionally supported features might require specification of custom +extra require, e.g. `pip install fsspec[ssh]` will install dependencies for `ssh` backends support. +Use `pip install fsspec[full]` for installation of all known extra dependencies. + +Up-to-date package also provided through conda-forge distribution: + +```bash +conda install -c conda-forge fsspec +``` + + +## Purpose + +To produce a template or specification for a file-system interface, that specific implementations should follow, +so that applications making use of them can rely on a common behaviour and not have to worry about the specific +internal implementation decisions with any given backend. Many such implementations are included in this package, +or in sister projects such as `s3fs` and `gcsfs`. + +In addition, if this is well-designed, then additional functionality, such as a key-value store or FUSE +mounting of the file-system implementation may be available for all implementations "for free". + +## Documentation + +Please refer to [RTD](https://filesystem-spec.readthedocs.io/en/latest/?badge=latest) + +## Develop + +fsspec uses GitHub Actions for CI. Environment files can be found +in the "ci/" directory. Note that the main environment is called "py38", +but it is expected that the version of python installed be adjustable at +CI runtime. For local use, pick a version suitable for you. + +```bash +# For a new environment (mamba / conda). +mamba create -n fsspec -c conda-forge python=3.9 -y +conda activate fsspec + +# Standard dev install with docs and tests. +pip install -e ".[dev,doc,test]" + +# Full tests except for downstream +pip install s3fs +pip uninstall s3fs +pip install -e .[dev,doc,test_full] +pip install s3fs --no-deps +pytest -v + +# Downstream tests. +sh install_s3fs.sh +# Windows powershell. +install_s3fs.sh +``` + +### Testing + +Tests can be run in the dev environment, if activated, via ``pytest fsspec``. + +The full fsspec suite requires a system-level docker, docker-compose, and fuse +installation. If only making changes to one backend implementation, it is +not generally necessary to run all tests locally. + +It is expected that contributors ensure that any change to fsspec does not +cause issues or regressions for either other fsspec-related packages such +as gcsfs and s3fs, nor for downstream users of fsspec. The "downstream" CI +run and corresponding environment file run a set of tests from the dask +test suite, and very minimal tests against pandas and zarr from the +test_downstream.py module in this repo. + +### Code Formatting + +fsspec uses [Black](https://black.readthedocs.io/en/stable) to ensure +a consistent code format throughout the project. +Run ``black fsspec`` from the root of the filesystem_spec repository to +auto-format your code. Additionally, many editors have plugins that will apply +``black`` as you edit files. ``black`` is included in the ``tox`` environments. + +Optionally, you may wish to setup [pre-commit hooks](https://pre-commit.com) to +automatically run ``black`` when you make a git commit. +Run ``pre-commit install --install-hooks`` from the root of the +filesystem_spec repository to setup pre-commit hooks. ``black`` will now be run +before you commit, reformatting any changed files. You can format without +committing via ``pre-commit run`` or skip these checks with ``git commit +--no-verify``. + +## Support + +Work on this repository is supported in part by: + +"Anaconda, Inc. - Advancing AI through open source." + +anaconda logo diff --git a/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/RECORD b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..220c76c0922de71a4228eaf28e511826b6a2d53d --- /dev/null +++ b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/RECORD @@ -0,0 +1,117 @@ +fsspec-2025.9.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +fsspec-2025.9.0.dist-info/METADATA,sha256=VmMoGluoRhQXlQigYs9kzwlXfPIg1KBkRL7V2F5O2B0,10397 +fsspec-2025.9.0.dist-info/RECORD,, +fsspec-2025.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87 +fsspec-2025.9.0.dist-info/licenses/LICENSE,sha256=LcNUls5TpzB5FcAIqESq1T53K0mzTN0ARFBnaRQH7JQ,1513 +fsspec/__init__.py,sha256=L7qwNBU1iMNQd8Of87HYSNFT9gWlNMSESaJC8fY0AaQ,2053 +fsspec/__pycache__/__init__.cpython-39.pyc,, +fsspec/__pycache__/_version.cpython-39.pyc,, +fsspec/__pycache__/archive.cpython-39.pyc,, +fsspec/__pycache__/asyn.cpython-39.pyc,, +fsspec/__pycache__/caching.cpython-39.pyc,, +fsspec/__pycache__/callbacks.cpython-39.pyc,, +fsspec/__pycache__/compression.cpython-39.pyc,, +fsspec/__pycache__/config.cpython-39.pyc,, +fsspec/__pycache__/conftest.cpython-39.pyc,, +fsspec/__pycache__/core.cpython-39.pyc,, +fsspec/__pycache__/dircache.cpython-39.pyc,, +fsspec/__pycache__/exceptions.cpython-39.pyc,, +fsspec/__pycache__/fuse.cpython-39.pyc,, +fsspec/__pycache__/generic.cpython-39.pyc,, +fsspec/__pycache__/gui.cpython-39.pyc,, +fsspec/__pycache__/json.cpython-39.pyc,, +fsspec/__pycache__/mapping.cpython-39.pyc,, +fsspec/__pycache__/parquet.cpython-39.pyc,, +fsspec/__pycache__/registry.cpython-39.pyc,, +fsspec/__pycache__/spec.cpython-39.pyc,, +fsspec/__pycache__/transaction.cpython-39.pyc,, +fsspec/__pycache__/utils.cpython-39.pyc,, +fsspec/_version.py,sha256=LkyV4dUHpfGx8N-SvSE0eARRvdCyg8wWXOl3qM2ZAZ4,710 +fsspec/archive.py,sha256=vM6t_lgV6lBWbBYwpm3S4ofBQFQxUPr5KkDQrrQcQro,2411 +fsspec/asyn.py,sha256=mE55tO_MmGcxD14cUuaiS3veAqo0h6ZqANfnUuCN3sk,36365 +fsspec/caching.py,sha256=86uSgPa5E55b28XEhuC-dMcKAxJtZZnpQqnHTwaF3hI,34294 +fsspec/callbacks.py,sha256=BDIwLzK6rr_0V5ch557fSzsivCElpdqhXr5dZ9Te-EE,9210 +fsspec/compression.py,sha256=gBK2MV_oTFVW2XDq8bZVbYQKYrl6JDUou6_-kyvmxuk,5086 +fsspec/config.py,sha256=LF4Zmu1vhJW7Je9Q-cwkRc3xP7Rhyy7Xnwj26Z6sv2g,4279 +fsspec/conftest.py,sha256=fVfx-NLrH_OZS1TIpYNoPzM7efEcMoL62reHOdYeFCA,1245 +fsspec/core.py,sha256=1tLctwr7sF1VO3djc_UkjhJ8IAEy0TUMH_bb07Sw17E,23828 +fsspec/dircache.py,sha256=YzogWJrhEastHU7vWz-cJiJ7sdtLXFXhEpInGKd4EcM,2717 +fsspec/exceptions.py,sha256=pauSLDMxzTJMOjvX1WEUK0cMyFkrFxpWJsyFywav7A8,331 +fsspec/fuse.py,sha256=Q-3NOOyLqBfYa4Db5E19z_ZY36zzYHtIs1mOUasItBQ,10177 +fsspec/generic.py,sha256=K-b03ifKidHUo99r8nz2pB6oGyf88RtTKahCuBF9ZVU,13409 +fsspec/gui.py,sha256=CQ7QsrTpaDlWSLNOpwNoJc7khOcYXIZxmrAJN9bHWQU,14002 +fsspec/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +fsspec/implementations/__pycache__/__init__.cpython-39.pyc,, +fsspec/implementations/__pycache__/arrow.cpython-39.pyc,, +fsspec/implementations/__pycache__/asyn_wrapper.cpython-39.pyc,, +fsspec/implementations/__pycache__/cache_mapper.cpython-39.pyc,, +fsspec/implementations/__pycache__/cache_metadata.cpython-39.pyc,, +fsspec/implementations/__pycache__/cached.cpython-39.pyc,, +fsspec/implementations/__pycache__/dask.cpython-39.pyc,, +fsspec/implementations/__pycache__/data.cpython-39.pyc,, +fsspec/implementations/__pycache__/dbfs.cpython-39.pyc,, +fsspec/implementations/__pycache__/dirfs.cpython-39.pyc,, +fsspec/implementations/__pycache__/ftp.cpython-39.pyc,, +fsspec/implementations/__pycache__/gist.cpython-39.pyc,, +fsspec/implementations/__pycache__/git.cpython-39.pyc,, +fsspec/implementations/__pycache__/github.cpython-39.pyc,, +fsspec/implementations/__pycache__/http.cpython-39.pyc,, +fsspec/implementations/__pycache__/http_sync.cpython-39.pyc,, +fsspec/implementations/__pycache__/jupyter.cpython-39.pyc,, +fsspec/implementations/__pycache__/libarchive.cpython-39.pyc,, +fsspec/implementations/__pycache__/local.cpython-39.pyc,, +fsspec/implementations/__pycache__/memory.cpython-39.pyc,, +fsspec/implementations/__pycache__/reference.cpython-39.pyc,, +fsspec/implementations/__pycache__/sftp.cpython-39.pyc,, +fsspec/implementations/__pycache__/smb.cpython-39.pyc,, +fsspec/implementations/__pycache__/tar.cpython-39.pyc,, +fsspec/implementations/__pycache__/webhdfs.cpython-39.pyc,, +fsspec/implementations/__pycache__/zip.cpython-39.pyc,, +fsspec/implementations/arrow.py,sha256=721Dikne_lV_0tlgk9jyKmHL6W-5MT0h2LKGvOYQTPI,8623 +fsspec/implementations/asyn_wrapper.py,sha256=fox9yjsEu7NCgzdAZJYfNALtUnFkIc_QmeKzaSllZho,3679 +fsspec/implementations/cache_mapper.py,sha256=W4wlxyPxZbSp9ItJ0pYRVBMh6bw9eFypgP6kUYuuiI4,2421 +fsspec/implementations/cache_metadata.py,sha256=rddh5-0SXIeyWCPpBpOFcaAyWoPyeYmFfeubEWt-nRM,8536 +fsspec/implementations/cached.py,sha256=TETvCyf0x-Ak8Y4uiuvIKx2IFYOzvcV0LMUIt4AoJzM,35168 +fsspec/implementations/dask.py,sha256=CXZbJzIVOhKV8ILcxuy3bTvcacCueAbyQxmvAkbPkrk,4466 +fsspec/implementations/data.py,sha256=LDLczxRh8h7x39Zjrd-GgzdQHr78yYxDlrv2C9Uxb5E,1658 +fsspec/implementations/dbfs.py,sha256=1cvvC6KBWOb8pBVpc01xavVbEPXO1xsgZvPD7H73M9k,16217 +fsspec/implementations/dirfs.py,sha256=f1sGnQ9Vf0xTxrXo4jDeBy4Qfq3RTqAEemqBSeb0hwY,12108 +fsspec/implementations/ftp.py,sha256=bzL_TgH77nMMtTMewRGkbq4iObSHGu7YoMRCXBH4nrc,11639 +fsspec/implementations/gist.py,sha256=Ost985hmFr50KsA-QD0shY3hP4KX5qJ9rb5C-X4ehK8,8341 +fsspec/implementations/git.py,sha256=qBDWMz5LNllPqVjr5jf_1FuNha4P5lyQI3IlhYg-wUE,3731 +fsspec/implementations/github.py,sha256=aCsZL8UvXZgdkcB1RUs3DdLeNrjLKcFsFYeQFDWbBFo,11653 +fsspec/implementations/http.py,sha256=f_542L40574onk0tiOsSBABeNq1TXHYUyEBpCeO6vhA,30435 +fsspec/implementations/http_sync.py,sha256=UydDqSdUBdhiJ1KufzV8rKGrTftFR4QmNV0safILb8g,30133 +fsspec/implementations/jupyter.py,sha256=B2uj7OEm7yIk-vRSsO37_ND0t0EBvn4B-Su43ibN4Pg,3811 +fsspec/implementations/libarchive.py,sha256=5_I2DiLXwQ1JC8x-K7jXu-tBwhO9dj7tFLnb0bTnVMQ,7102 +fsspec/implementations/local.py,sha256=DQeK7jRGv4_mJAweLKALO5WzIIkjXxZ_jRvwQ_xadSA,16936 +fsspec/implementations/memory.py,sha256=Kc6TZSbZ4tdi-6cE5ttEPIgMyq9aAt6cDdVLFRTJvf8,10488 +fsspec/implementations/reference.py,sha256=npYj49AmR8rmON9t_BLpfEXqhgsardUeynamqyraOXo,48704 +fsspec/implementations/sftp.py,sha256=fMY9XZcmpjszQ2tCqO_TPaJesaeD_Dv7ptYzgUPGoO0,5631 +fsspec/implementations/smb.py,sha256=5fhu8h06nOLBPh2c48aT7WBRqh9cEcbIwtyu06wTjec,15236 +fsspec/implementations/tar.py,sha256=dam78Tp_CozybNqCY2JYgGBS3Uc9FuJUAT9oB0lolOs,4111 +fsspec/implementations/webhdfs.py,sha256=G9wGywj7BkZk4Mu9zXu6HaDlEqX4F8Gw1i4k46CP_-o,16769 +fsspec/implementations/zip.py,sha256=9LBMHPft2OutJl2Ft-r9u_z3GptLkc2n91ur2A3bCbg,6072 +fsspec/json.py,sha256=3BfNSQ96MB4Xao_ocjheINeqZM2ev7oljUzR5XmNXrE,3814 +fsspec/mapping.py,sha256=m2ndB_gtRBXYmNJg0Ie1-BVR75TFleHmIQBzC-yWhjU,8343 +fsspec/parquet.py,sha256=6ibAmG527L5JNFS0VO8BDNlxHdA3bVYqdByeiFgpUVM,19448 +fsspec/registry.py,sha256=epoYryFFzDWjbkQJfh6xkF3nEu8RTiOzV3-voi8Pshs,12048 +fsspec/spec.py,sha256=7cOUe5PC5Uyf56HtGBUHEoym8ktPj-BI8G4HR8Xd_C8,77298 +fsspec/tests/abstract/__init__.py,sha256=4xUJrv7gDgc85xAOz1p-V_K1hrsdMWTSa0rviALlJk8,10181 +fsspec/tests/abstract/__pycache__/__init__.cpython-39.pyc,, +fsspec/tests/abstract/__pycache__/common.cpython-39.pyc,, +fsspec/tests/abstract/__pycache__/copy.cpython-39.pyc,, +fsspec/tests/abstract/__pycache__/get.cpython-39.pyc,, +fsspec/tests/abstract/__pycache__/mv.cpython-39.pyc,, +fsspec/tests/abstract/__pycache__/open.cpython-39.pyc,, +fsspec/tests/abstract/__pycache__/pipe.cpython-39.pyc,, +fsspec/tests/abstract/__pycache__/put.cpython-39.pyc,, +fsspec/tests/abstract/common.py,sha256=1GQwNo5AONzAnzZj0fWgn8NJPLXALehbsuGxS3FzWVU,4973 +fsspec/tests/abstract/copy.py,sha256=gU5-d97U3RSde35Vp4RxPY4rWwL744HiSrJ8IBOp9-8,19967 +fsspec/tests/abstract/get.py,sha256=vNR4HztvTR7Cj56AMo7_tx7TeYz1Jgr_2Wb8Lv-UiBY,20755 +fsspec/tests/abstract/mv.py,sha256=k8eUEBIrRrGMsBY5OOaDXdGnQUKGwDIfQyduB6YD3Ns,1982 +fsspec/tests/abstract/open.py,sha256=Fi2PBPYLbRqysF8cFm0rwnB41kMdQVYjq8cGyDXp3BU,329 +fsspec/tests/abstract/pipe.py,sha256=LFzIrLCB5GLXf9rzFKJmE8AdG7LQ_h4bJo70r8FLPqM,402 +fsspec/tests/abstract/put.py,sha256=7aih17OKB_IZZh1Mkq1eBDIjobhtMQmI8x-Pw-S_aZk,21201 +fsspec/transaction.py,sha256=xliRG6U2Zf3khG4xcw9WiB-yAoqJSHEGK_VjHOdtgo0,2398 +fsspec/utils.py,sha256=HC8RFbb7KpEDedsYxExvWvsTObEuUcuuWxd0B_MyGpo,22995 diff --git a/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/WHEEL b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..12228d414b6cfed7c39d3781c85c63256a1d7fb5 --- /dev/null +++ b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.27.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/licenses/LICENSE b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..67590a5e5be5a5a2dde3fe53a7512e404a896c22 --- /dev/null +++ b/phivenv/Lib/site-packages/fsspec-2025.9.0.dist-info/licenses/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2018, Martin Durant +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a99949e89d9d8e6f60288095a529a41a44f501a Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/_version.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/_version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7cec164bf710171f7ffae9a7963a541e6c41750 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/_version.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/archive.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/archive.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0c09d3b444a46f880da4da7a71c71b7fc53543f Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/archive.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/asyn.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/asyn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af8dd31919aa95693b9aa61cc78f74a2bb306072 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/asyn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/caching.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/caching.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..826973df6e094343eee84f687ec633fd2bf62edf Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/caching.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/callbacks.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/callbacks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4666816b7acc9433fed92a08d5a9ce63143ca562 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/callbacks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/compression.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/compression.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d111636603a489496d15616a17beabe09b617f1 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/compression.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4692e7cb729c822cc1e98e8cbd492ce066dd4855 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/conftest.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/conftest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..192d643352d8fbccc846665b2678c5e466cf9c75 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/conftest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f43634688f0418cce5990f3577cb6c1de3b52781 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/dircache.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/dircache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6443241f2708b07053d4d123ecd52add8eafda8 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/dircache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/exceptions.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/exceptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a777732ef4af3ab07226614cf8424672f698535 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/exceptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/fuse.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5488a8a4a647f9ad202deee5583ea4cb3a86baa2 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/fuse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/generic.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/generic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f78948d58360dce69231b51f706c3ffa6f42e065 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/generic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/gui.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/gui.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7411b7ddf6d3ffed2ede47331d3c96272eee826 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/gui.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/json.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/json.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51b38a9f2dd6a58e6bebc41c9bc8218f796d5a43 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/json.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/mapping.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/mapping.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b7d1a06ee271565ae1b2d11cb0a0075a8dd4cd0 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/mapping.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/parquet.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/parquet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbdc8693e82748bb1cdf241084071366fe7b2616 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/parquet.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/registry.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5c12c0c3dd95440040fd10fb2aa717536863e5e Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/spec.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/spec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3497a63d7059429530238e3b717820ac8b51feab Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/spec.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/transaction.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/transaction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..474038fa04bb5a312c5b2c38a18aa15d7ddcf882 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/transaction.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96677d280b7eb139371226fd1c6212bccaa71f30 Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/tests/abstract/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/tests/abstract/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fe2e8fb924637fd7f3c9f033f24e7d1493147bd Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/tests/abstract/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/fsspec/tests/abstract/__pycache__/put.cpython-39.pyc b/phivenv/Lib/site-packages/fsspec/tests/abstract/__pycache__/put.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79461048fe50255d673c48e8a86a9239a5fec1ec Binary files /dev/null and b/phivenv/Lib/site-packages/fsspec/tests/abstract/__pycache__/put.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/__init__.py b/phivenv/Lib/site-packages/functorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2790da74333710de2bb6e2b99528fe1b85e2e2 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torch._functorch.deprecated import ( + combine_state_for_ensemble, + functionalize, + grad, + grad_and_value, + hessian, + jacfwd, + jacrev, + jvp, + make_functional, + make_functional_with_buffers, + vjp, + vmap, +) + +# utilities. Maybe these should go in their own namespace in the future? +from torch._functorch.make_functional import ( + FunctionalModule, + FunctionalModuleWithBuffers, +) + +# Was never documented +from torch._functorch.python_key import make_fx + + +# Top-level APIs. Please think carefully before adding something to the +# top-level namespace: +# - private helper functions should go into torch._functorch +# - very experimental things should go into functorch.experimental +# - compilation related things should go into functorch.compile + + +__version__ = torch.__version__ diff --git a/phivenv/Lib/site-packages/functorch/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fd0691821d88065279736c7888bf30cd6dff282 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/_src/__init__.py b/phivenv/Lib/site-packages/functorch/_src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/functorch/_src/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/_src/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27a6099c363e0af4298f968123a16acf9893d3b7 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/_src/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/_src/aot_autograd/__init__.py b/phivenv/Lib/site-packages/functorch/_src/aot_autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bef6245f7a0f900c71b5350dcec65c19c260cd78 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/_src/aot_autograd/__init__.py @@ -0,0 +1,8 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.aot_autograd import ( + aot_autograd_decompositions, + KNOWN_TYPES, + PytreeThunk, +) diff --git a/phivenv/Lib/site-packages/functorch/_src/aot_autograd/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/_src/aot_autograd/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..893d8594e6f1371446bf6bd0afc411a9b8369bdd Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/_src/aot_autograd/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/_src/eager_transforms/__init__.py b/phivenv/Lib/site-packages/functorch/_src/eager_transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37df7424c4056760388c6b0d17288e4571fde359 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/_src/eager_transforms/__init__.py @@ -0,0 +1,7 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.eager_transforms import ( + _assert_wrapped_functional, + _unwrap_functional_tensor, +) diff --git a/phivenv/Lib/site-packages/functorch/_src/eager_transforms/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/_src/eager_transforms/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dc63311e4f0d03bf5732755656e9f60b947a43b Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/_src/eager_transforms/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/_src/make_functional/__init__.py b/phivenv/Lib/site-packages/functorch/_src/make_functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c96ce28510b5172edbc3941b232c2221e75397de --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/_src/make_functional/__init__.py @@ -0,0 +1,4 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.make_functional import _swap_state diff --git a/phivenv/Lib/site-packages/functorch/_src/make_functional/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/_src/make_functional/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c5d72e4fb8e95de509d74e786414cda8bd356ac Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/_src/make_functional/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/_src/vmap/__init__.py b/phivenv/Lib/site-packages/functorch/_src/vmap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec274ba24ad6b3f28798c42bfe92c8b14670a51e --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/_src/vmap/__init__.py @@ -0,0 +1,16 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.vmap import ( + _add_batch_dim, + _broadcast_to_and_flatten, + _create_batched_inputs, + _get_name, + _process_batched_inputs, + _remove_batch_dim, + _unwrap_batched, + _validate_and_get_batch_size, + Tensor, + tree_flatten, + tree_unflatten, +) diff --git a/phivenv/Lib/site-packages/functorch/_src/vmap/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/_src/vmap/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8e8c1fa3985c2192151aa40a8723d3a39da5466 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/_src/vmap/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/compile/__init__.py b/phivenv/Lib/site-packages/functorch/compile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cff66e3aeef73f96273b6c4c77c74e3f1153f4b5 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/compile/__init__.py @@ -0,0 +1,30 @@ +from torch._functorch import config +from torch._functorch.aot_autograd import ( + aot_function, + aot_module, + aot_module_simplified, + compiled_function, + compiled_module, + get_aot_compilation_context, + get_aot_graph_name, + get_graph_being_compiled, + make_boxed_compiler, + make_boxed_func, +) +from torch._functorch.compilers import ( + debug_compile, + default_decompositions, + draw_graph_compile, + memory_efficient_fusion, + nnc_jit, + nop, + print_compile, + ts_compile, +) +from torch._functorch.fx_minifier import minifier +from torch._functorch.partitioners import ( + default_partition, + draw_graph, + min_cut_rematerialization_partition, +) +from torch._functorch.python_key import pythonkey_decompose diff --git a/phivenv/Lib/site-packages/functorch/compile/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/compile/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5efffb096110fb5171e964a5ae776288275400dd Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/compile/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__init__.py b/phivenv/Lib/site-packages/functorch/dim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bed08f3e1b38e425716b7961e272bc06d7782a --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/__init__.py @@ -0,0 +1,177 @@ +import functorch._C +import torch +from functorch._C import dim as _C + +from .tree_map import tree_flatten, tree_map +from .wrap_type import wrap_type + + +_C._patch_tensor_class() +dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists + + +class DimensionMismatchError(Exception): + pass + + +class DimensionBindError(Exception): + pass + + +from . import op_properties + + +# use dict to avoid writing C++ bindings for set +pointwise = dict.fromkeys(op_properties.pointwise, True) + +use_c = True +if not use_c: + from . import reference + + +class _Tensor: + # fast path around slow wrapping/unwrapping logic for simply queries used + # by the implementation... + + @property + def dims(self): + return tuple(d for d in self._levels if isinstance(d, Dim)) + + def dim(self): + return self.ndim + + if use_c: + __torch_function__ = classmethod(_C.__torch_function__) + expand = _C._instancemethod(_C.expand) + else: + __torch_function__ = reference.__torch_function__ + expand = reference.expand + + index = _C._instancemethod(_C.index) + + def __repr__(self): + tensor, levels, ndim = self._tensor, self._levels, self.ndim + return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}" + + +TensorLike = (_Tensor, torch.Tensor) + + +class Dim(_C.Dim, _Tensor): + # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precedence. + # Tensor defines format, but we want to print Dims with special formatting + __format__ = object.__format__ + + +class Tensor(_Tensor, _C.Tensor): + if not use_c: + from_batched = staticmethod(_C.Tensor_from_batched) + from_positional = staticmethod(_C.Tensor_from_positional) + sum = _C._instancemethod(_C.Tensor_sum) + + +def cat(tensors, dim, new_dim): + n = dims() + return stack(tensors, n, dim).index([n, dim], new_dim) + + +if use_c: + _wrap = _C._wrap + + def _def(name, *args, **kwargs): + orig = getattr(torch.Tensor, name) + setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) + + t__getitem__ = _C._instancemethod(_C.__getitem__) + stack = _C.stack + split = _C._instancemethod(_C.split) +else: + _wrap, _def = reference._wrap, reference._def + t__getitem__ = reference.t__getitem__ + stack = reference.stack + split = reference.split + +# note: there is no python reference +t__setitem__ = _C._instancemethod(_C.__setitem__) +# this is patched in the C API because otherwise torch.Tensor will +# no longer be considered a sequence and things will break +# torch.Tensor.__getitem__ = t__getitem__ + +_Tensor.__getitem__ = t__getitem__ +# torch.Tensor.__setitem__ = t__setitem__ +_Tensor.__setitem__ = t__setitem__ + +torch.Tensor.split = split +_Tensor.split = split +torch.Tensor.expand = _C._instancemethod(_C.expand) +torch.Tensor.index = _C._instancemethod(_C.index) +wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__) +del _Tensor.ndim + +if use_c: + _Tensor.order = _C._instancemethod(_C.order) +else: + _Tensor.order = reference.positional + +_def("mean") +_def("sum") +_def("all") +_def("amax") +_def("amin") +_def("aminmax") +_def("any") +_def("count_nonzero") +_def("logsumexp") +_def("nanmean") +_def("nansum") +_def("prod") +_def("std", keepdim_offset=2) +_def("var", keepdim_offset=2) +_def("max", single_dim=True) +_def("min", single_dim=True) +_def("argmax", single_dim=True) +_def("argmin", single_dim=True) +_def("kthvalue", single_dim=True) +_def("median", single_dim=True) +_def("nanmedian", single_dim=True) +_def("mode", single_dim=True) +_def("sort", reduce=False) +_def("argsort", reduce=False) +_def("unbind", single_dim=True) +_def("chunk", dim_offset=1, reduce=False) +_def("cummax", single_dim=True, reduce=False) +_def("cummin", single_dim=True, reduce=False) +_def("cumprod", single_dim=True, reduce=False) +_def("cumprod_", single_dim=True, reduce=False) +_def("cumsum", single_dim=True, reduce=False) +_def("cumsum_", single_dim=True, reduce=False) +_def("logcumsumexp", single_dim=True, reduce=False) +_def("renorm", dim_offset=1, single_dim=True, reduce=False) +_def("softmax", single_dim=True, reduce=False) +softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False) + +# stuff to handle in the future, because they require special +# binding logic for dims +# cross +# diag_embed +# diagonal +# diagonal_scatter +# diff +# nanquantile +# quantile +# roll +# rot90 +# topk (new dimes on output) +# should these all be subsumed by inplace indexing? +# index_add_ +# index_add +# index_copy +# index_copy_ +# index_fill +# index_fill_ +# index_select +# scatter +# scatter_ +# scatter_add +# scatter_add_ +# scatter_reduce diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93b6ce4d1edfaf2774535a296e4d07e3bae6e9bb Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/batch_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/batch_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..131aa192bfa761d8751e8e411c0b5e73e36abb78 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/batch_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/delayed_mul_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/delayed_mul_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4a8de5d0ca8fcf162a20cc54d9583f2ae224dde Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/delayed_mul_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/dim.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/dim.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60ecffbcb7dc59e8c3ecf82c818ca840f7594211 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/dim.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/magic_trace.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/magic_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52f04feb28225e3c984c40938e4e5b119245ef9c Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/magic_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/op_properties.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/op_properties.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0281782de5141a7ac7d497d7a10dc3f242672b01 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/op_properties.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/reference.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/reference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ac29c55a19be0887b96262cd80481c0bb2d58de Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/reference.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/tree_map.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/tree_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f7c7e6c009fd6536f047beef5e9a68e4345acb0 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/tree_map.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/__pycache__/wrap_type.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/dim/__pycache__/wrap_type.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e369d49bb64efe7f4f1de7362317e5c64321ee7 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/dim/__pycache__/wrap_type.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/dim/batch_tensor.py b/phivenv/Lib/site-packages/functorch/dim/batch_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..234e731e825bd0a55a80a1a55d90ea43b0586080 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/batch_tensor.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from contextlib import contextmanager + +from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers + + +_enabled = False + + +@contextmanager +def _enable_layers(dims): + global _enabled + assert not _enabled + input = sorted((d._level, d.size) for d in dims if not isinstance(d, int)) + n = len(input) + try: + _vmap_add_layers(input) + _enabled = True + yield + finally: + _enabled = False + _vmap_remove_layers(n) diff --git a/phivenv/Lib/site-packages/functorch/dim/delayed_mul_tensor.py b/phivenv/Lib/site-packages/functorch/dim/delayed_mul_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..2aefbe0bf0010c0e6fd410290163100e0d843c21 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/delayed_mul_tensor.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from . import _Tensor, Tensor +from .reference import _dims, _enable_layers, llist, ltuple + + +class DelayedMulTensor(_Tensor): + def __init__(self, lhs, rhs): + self._lhs, self._rhs = lhs, rhs + self._data = None + self._levels_data = None + self._has_device = lhs._has_device or rhs._has_device + self._batchtensor_data = None + self._tensor_data = None + + @property + def _levels(self): + if self._levels_data is None: + levels = llist(self._lhs._levels) + for l in self._rhs._levels: + if l not in levels: + levels.append(l) + self._levels_data = ltuple(levels) + return self._levels_data + + @property + def _batchtensor(self): + if self._batchtensor_data is None: + with _enable_layers(self._levels): + print("bt multiply fallback") + self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor + return self._batchtensor_data + + @property + def _tensor(self): + if self._tensor_data is None: + self._tensor_data = Tensor.from_batched( + self._batchtensor, self._has_device + )._tensor + return self._tensor_data + + @property + def ndim(self): + return self._batchtensor.ndim + + @property + def dims(self): + return ltuple(super().dims) + + def sum(self, dim): + dims = _dims(dim, 0, False, False) + n = ord("a") + all_levels = self._levels + + def to_char(d): + return chr(n + all_levels.index(d)) + + plhs, levelslhs = self._lhs._tensor, self._lhs._levels + prhs, levelsrhs = self._rhs._tensor, self._rhs._levels + new_levels = [l for l in self._levels if l not in dims] + fmt = "".join( + [ + *(to_char(d) for d in levelslhs), + ",", + *(to_char(d) for d in levelsrhs), + "->", + *(to_char(d) for d in new_levels), + ] + ) + result_data = torch.einsum(fmt, (plhs, prhs)) + return Tensor.from_positional(result_data, new_levels, True) diff --git a/phivenv/Lib/site-packages/functorch/dim/dim.py b/phivenv/Lib/site-packages/functorch/dim/dim.py new file mode 100644 index 0000000000000000000000000000000000000000..809a78f5731ccfabab3b337107188f278188ae80 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/dim.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import dis +import inspect +from dataclasses import dataclass +from typing import Union + +from . import DimList + + +_vmap_levels = [] + + +@dataclass +class LevelInfo: + level: int + alive: bool = True + + +class Dim: + def __init__(self, name: str, size: Union[None, int] = None): + self.name = name + self._size = None + self._vmap_level = None + if size is not None: + self.size = size + + def __del__(self): + if self._vmap_level is not None: + _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 + while ( + not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821 + ): + _vmap_decrement_nesting() # noqa: F821 + _vmap_levels.pop() + + @property + def size(self): + assert self.is_bound + return self._size + + @size.setter + def size(self, size: int): + from . import DimensionBindError + + if self._size is None: + self._size = size + self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821 + self._vmap_stack = len(_vmap_levels) + _vmap_levels.append(LevelInfo(self._vmap_level)) + + elif self._size != size: + raise DimensionBindError( + f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" + ) + + @property + def is_bound(self): + return self._size is not None + + def __repr__(self): + return self.name + + +def extract_name(inst): + assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" + return inst.argval + + +_cache = {} + + +def dims(lists=0): + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + code, lasti = calling_frame.f_code, calling_frame.f_lasti + key = (code, lasti) + if key not in _cache: + first = lasti // 2 + 1 + instructions = list(dis.get_instructions(calling_frame.f_code)) + unpack = instructions[first] + + if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": + # just a single dim, not a list + name = unpack.argval + ctor = Dim if lists == 0 else DimList + _cache[key] = lambda: ctor(name=name) + else: + assert unpack.opname == "UNPACK_SEQUENCE" + ndims = unpack.argval + names = tuple( + extract_name(instructions[first + 1 + i]) for i in range(ndims) + ) + first_list = len(names) - lists + _cache[key] = lambda: tuple( + Dim(n) if i < first_list else DimList(name=n) + for i, n in enumerate(names) + ) + return _cache[key]() + + +def _dim_set(positional, arg): + def convert(a): + if isinstance(a, Dim): + return a + else: + assert isinstance(a, int) + return positional[a] + + if arg is None: + return positional + elif not isinstance(arg, (Dim, int)): + return tuple(convert(a) for a in arg) + else: + return (convert(arg),) diff --git a/phivenv/Lib/site-packages/functorch/dim/magic_trace.py b/phivenv/Lib/site-packages/functorch/dim/magic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad92632a2ef08b300493d9e63c8a7d3ab8e9749 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/magic_trace.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import signal +import subprocess +from contextlib import contextmanager + + +@contextmanager +def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"): + pid = os.getpid() + if not os.path.exists(magic_trace_cache): + print(f"Downloading magic_trace to: {magic_trace_cache}") + subprocess.run( + [ + "wget", + "-O", + magic_trace_cache, + "-q", + "https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace", + ] + ) + subprocess.run(["chmod", "+x", magic_trace_cache]) + args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output] + p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8") + while True: + x = p.stderr.readline() + print(x) + if "Attached" in x: + break + try: + yield + finally: + p.send_signal(signal.SIGINT) + r = p.wait() + print(p.stderr.read()) + p.stderr.close() + if r != 0: + raise ValueError(f"magic_trace exited abnormally: {r}") diff --git a/phivenv/Lib/site-packages/functorch/dim/op_properties.py b/phivenv/Lib/site-packages/functorch/dim/op_properties.py new file mode 100644 index 0000000000000000000000000000000000000000..3fda08afb60f75d635b5590735e7320700ed8594 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/op_properties.py @@ -0,0 +1,312 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + + +# pointwise operators can go through a faster pathway + +tensor_magic_methods = ["add", ""] +pointwise_magic_methods_with_reverse = ( + "add", + "sub", + "mul", + "floordiv", + "div", + "truediv", + "mod", + "pow", + "lshift", + "rshift", + "and", + "or", + "xor", +) +pointwise_magic_methods = ( + *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)), + "eq", + "gt", + "le", + "lt", + "ge", + "gt", + "ne", + "neg", + "pos", + "abs", + "invert", + "iadd", + "isub", + "imul", + "ifloordiv", + "idiv", + "itruediv", + "imod", + "ipow", + "ilshift", + "irshift", + "iand", + "ior", + "ixor", + "int", + "long", + "float", + "complex", +) + +pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),) + +pointwise = ( + *(getattr(torch.Tensor, m) for m in pointwise_methods), + torch.nn.functional.dropout, + torch.where, + torch.Tensor.abs, + torch.abs, + torch.Tensor.acos, + torch.acos, + torch.Tensor.acosh, + torch.acosh, + torch.Tensor.add, + torch.add, + torch.Tensor.addcdiv, + torch.addcdiv, + torch.Tensor.addcmul, + torch.addcmul, + torch.Tensor.addr, + torch.addr, + torch.Tensor.angle, + torch.angle, + torch.Tensor.asin, + torch.asin, + torch.Tensor.asinh, + torch.asinh, + torch.Tensor.atan, + torch.atan, + torch.Tensor.atan2, + torch.atan2, + torch.Tensor.atanh, + torch.atanh, + torch.Tensor.bitwise_and, + torch.bitwise_and, + torch.Tensor.bitwise_left_shift, + torch.bitwise_left_shift, + torch.Tensor.bitwise_not, + torch.bitwise_not, + torch.Tensor.bitwise_or, + torch.bitwise_or, + torch.Tensor.bitwise_right_shift, + torch.bitwise_right_shift, + torch.Tensor.bitwise_xor, + torch.bitwise_xor, + torch.Tensor.ceil, + torch.ceil, + torch.celu, + torch.nn.functional.celu, + torch.Tensor.clamp, + torch.clamp, + torch.Tensor.clamp_max, + torch.clamp_max, + torch.Tensor.clamp_min, + torch.clamp_min, + torch.Tensor.copysign, + torch.copysign, + torch.Tensor.cos, + torch.cos, + torch.Tensor.cosh, + torch.cosh, + torch.Tensor.deg2rad, + torch.deg2rad, + torch.Tensor.digamma, + torch.digamma, + torch.Tensor.div, + torch.div, + torch.dropout, + torch.nn.functional.dropout, + torch.nn.functional.elu, + torch.Tensor.eq, + torch.eq, + torch.Tensor.erf, + torch.erf, + torch.Tensor.erfc, + torch.erfc, + torch.Tensor.erfinv, + torch.erfinv, + torch.Tensor.exp, + torch.exp, + torch.Tensor.exp2, + torch.exp2, + torch.Tensor.expm1, + torch.expm1, + torch.feature_dropout, + torch.Tensor.float_power, + torch.float_power, + torch.Tensor.floor, + torch.floor, + torch.Tensor.floor_divide, + torch.floor_divide, + torch.Tensor.fmod, + torch.fmod, + torch.Tensor.frac, + torch.frac, + torch.Tensor.frexp, + torch.frexp, + torch.Tensor.gcd, + torch.gcd, + torch.Tensor.ge, + torch.ge, + torch.nn.functional.gelu, + torch.nn.functional.glu, + torch.Tensor.gt, + torch.gt, + torch.Tensor.hardshrink, + torch.hardshrink, + torch.nn.functional.hardshrink, + torch.nn.functional.hardsigmoid, + torch.nn.functional.hardswish, + torch.nn.functional.hardtanh, + torch.Tensor.heaviside, + torch.heaviside, + torch.Tensor.hypot, + torch.hypot, + torch.Tensor.i0, + torch.i0, + torch.Tensor.igamma, + torch.igamma, + torch.Tensor.igammac, + torch.igammac, + torch.Tensor.isclose, + torch.isclose, + torch.Tensor.isfinite, + torch.isfinite, + torch.Tensor.isinf, + torch.isinf, + torch.Tensor.isnan, + torch.isnan, + torch.Tensor.isneginf, + torch.isneginf, + torch.Tensor.isposinf, + torch.isposinf, + torch.Tensor.isreal, + torch.isreal, + torch.Tensor.kron, + torch.kron, + torch.Tensor.lcm, + torch.lcm, + torch.Tensor.ldexp, + torch.ldexp, + torch.Tensor.le, + torch.le, + torch.nn.functional.leaky_relu, + torch.Tensor.lerp, + torch.lerp, + torch.Tensor.lgamma, + torch.lgamma, + torch.Tensor.log, + torch.log, + torch.Tensor.log10, + torch.log10, + torch.Tensor.log1p, + torch.log1p, + torch.Tensor.log2, + torch.log2, + torch.nn.functional.logsigmoid, + torch.Tensor.logical_and, + torch.logical_and, + torch.Tensor.logical_not, + torch.logical_not, + torch.Tensor.logical_or, + torch.logical_or, + torch.Tensor.logical_xor, + torch.logical_xor, + torch.Tensor.logit, + torch.logit, + torch.Tensor.lt, + torch.lt, + torch.Tensor.maximum, + torch.maximum, + torch.Tensor.minimum, + torch.minimum, + torch.nn.functional.mish, + torch.Tensor.mvlgamma, + torch.mvlgamma, + torch.Tensor.nan_to_num, + torch.nan_to_num, + torch.Tensor.ne, + torch.ne, + torch.Tensor.neg, + torch.neg, + torch.Tensor.nextafter, + torch.nextafter, + torch.Tensor.outer, + torch.outer, + torch.polar, + torch.Tensor.polygamma, + torch.polygamma, + torch.Tensor.positive, + torch.positive, + torch.Tensor.pow, + torch.pow, + torch.Tensor.prelu, + torch.prelu, + torch.nn.functional.prelu, + torch.Tensor.rad2deg, + torch.rad2deg, + torch.Tensor.reciprocal, + torch.reciprocal, + torch.Tensor.relu, + torch.relu, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.Tensor.remainder, + torch.remainder, + torch.Tensor.round, + torch.round, + torch.rrelu, + torch.nn.functional.rrelu, + torch.Tensor.rsqrt, + torch.rsqrt, + torch.rsub, + torch.selu, + torch.nn.functional.selu, + torch.Tensor.sgn, + torch.sgn, + torch.Tensor.sigmoid, + torch.sigmoid, + torch.nn.functional.sigmoid, + torch.Tensor.sign, + torch.sign, + torch.Tensor.signbit, + torch.signbit, + torch.nn.functional.silu, + torch.Tensor.sin, + torch.sin, + torch.Tensor.sinc, + torch.sinc, + torch.Tensor.sinh, + torch.sinh, + torch.nn.functional.softplus, + torch.nn.functional.softshrink, + torch.Tensor.sqrt, + torch.sqrt, + torch.Tensor.square, + torch.square, + torch.Tensor.sub, + torch.sub, + torch.Tensor.tan, + torch.tan, + torch.Tensor.tanh, + torch.tanh, + torch.nn.functional.tanh, + torch.threshold, + torch.nn.functional.threshold, + torch.trapz, + torch.Tensor.true_divide, + torch.true_divide, + torch.Tensor.trunc, + torch.trunc, + torch.Tensor.xlogy, + torch.xlogy, + torch.rand_like, +) diff --git a/phivenv/Lib/site-packages/functorch/dim/reference.py b/phivenv/Lib/site-packages/functorch/dim/reference.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3306323224cf2df30e4e6f6410fff821a1ff41 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/reference.py @@ -0,0 +1,645 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# reference python implementations for C ops +import torch +from functorch._C import dim as _C + +from . import op_properties +from .batch_tensor import _enable_layers +from .tree_map import tree_flatten, tree_map + + +DimList = _C.DimList +import operator +from functools import reduce + + +# use dict to avoid writing C++ bindings for set +pointwise = set(op_properties.pointwise) + + +def prod(x): + return reduce(operator.mul, x, 1) + + +def _wrap_dim(d, N, keepdim): + from . import Dim + + if isinstance(d, Dim): + assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" + return d + elif d >= 0: + return d - N + else: + return d + + +def _dims(d, N, keepdim, single_dim): + from . import Dim + + if isinstance(d, (Dim, int)): + return ltuple((_wrap_dim(d, N, keepdim),)) + assert not single_dim, f"expected a single dimension or int but found: {d}" + return ltuple(_wrap_dim(x, N, keepdim) for x in d) + + +def _bind_dims_to_size(lhs_size, rhs, lhs_debug): + from . import DimensionMismatchError + + not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) + if len(not_bound) == 1: + idx, d = not_bound[0] + rhs_so_far = prod(r.size for r in rhs if r.is_bound) + if lhs_size % rhs_so_far != 0: + rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) + raise DimensionMismatchError( + f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" + ) + new_size = lhs_size // rhs_so_far + d.size = new_size + elif len(not_bound) > 1: + rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) + raise DimensionMismatchError( + f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" + ) + else: + rhs_size = prod(r.size for r in rhs) + if lhs_size != rhs_size: + raise DimensionMismatchError( + f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" + ) + + +def _tensor_levels(inp): + from . import _Tensor + + if isinstance(inp, _Tensor): + return inp._tensor, llist(inp._levels), inp._has_device + else: + return inp, llist(range(-inp.ndim, 0)), True + + +def _match_levels(v, from_levels, to_levels): + view = [] + permute = [] + requires_view = False + size = v.size() + for t in to_levels: + try: + idx = from_levels.index(t) + permute.append(idx) + view.append(size[idx]) + except ValueError: + view.append(1) + requires_view = True + if permute != list(range(len(permute))): + v = v.permute(*permute) + if requires_view: + v = v.view(*view) + return v + + +# make a single dimension positional but do not permute it, +# used to do multi-tensor operators where the dim being acted on +# should not physically move if possible +def _positional_no_permute(self, dim, expand_dim=False): + from . import Tensor + + ptensor, levels = self._tensor, llist(self._levels) + try: + idx = levels.index(dim) + except ValueError: + if not expand_dim: + raise + idx = 0 + ptensor = ptensor.expand(dim.size, *ptensor.size()) + levels.insert(0, 0) + idx_batched = 0 + for i in range(idx): + if isinstance(levels[i], int): + levels[i] -= 1 + idx_batched += 1 + levels[idx] = -idx_batched - 1 + return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched + + +def seq(a, b): + from . import Dim + + if isinstance(a, Dim) != isinstance(b, Dim): + return False + if isinstance(a, Dim): + return a is b + else: + return a == b + + +class isin: + __slots__ = () + + def __contains__(self, item): + for x in self: + if seq(item, x): + return True + return False + + def index(self, item): + for i, x in enumerate(self): + if seq(item, x): + return i + raise ValueError + + +class llist(isin, list): + __slots__ = () + + +class ltuple(isin, tuple): + __slots__ = () + + +empty_dict = {} + + +@classmethod +def __torch_function__(self, orig, cls, args, kwargs=empty_dict): + from . import _Tensor, Tensor, TensorLike + from .delayed_mul_tensor import DelayedMulTensor + + if orig is torch.Tensor.__mul__: + lhs, rhs = args + if ( + isinstance(lhs, _Tensor) + and isinstance(rhs, _Tensor) + and lhs.ndim == 0 + and rhs.ndim == 0 + ): + return DelayedMulTensor(lhs, rhs) + all_dims = llist() + flat_args, unflatten = tree_flatten((args, kwargs)) + device_holding_tensor = None + for f in flat_args: + if isinstance(f, _Tensor): + if f._has_device: + device_holding_tensor = f._batchtensor + for d in f.dims: + if d not in all_dims: + all_dims.append(d) + + def unwrap(t): + if isinstance(t, _Tensor): + r = t._batchtensor + if device_holding_tensor is not None and not t._has_device: + r = r.to(device=device_holding_tensor.device) + return r + return t + + if orig in pointwise: + result_levels = llist() + to_expand = [] + for i, f in enumerate(flat_args): + if isinstance(f, TensorLike): + ptensor, levels, _ = _tensor_levels(f) + if ( + isinstance(f, _Tensor) + and not f._has_device + and device_holding_tensor is not None + ): + ptensor = ptensor.to(device=device_holding_tensor.device) + flat_args[i] = ptensor + for l in levels: + if l not in result_levels: + result_levels.append(l) + to_expand.append((i, levels)) + + for i, levels in to_expand: + flat_args[i] = _match_levels(flat_args[i], levels, result_levels) + args, kwargs = unflatten(flat_args) + result = orig(*args, **kwargs) + + def wrap(t): + if isinstance(t, TensorLike): + return Tensor.from_positional( + t, result_levels, device_holding_tensor is not None + ) + return t + + return tree_map(wrap, result) + else: + + def wrap(t): + if isinstance(t, TensorLike): + return Tensor.from_batched(t, device_holding_tensor is not None) + return t + + with _enable_layers(all_dims): + print(f"batch_tensor for {orig}") + args, kwargs = unflatten(unwrap(f) for f in flat_args) + result = orig(*args, **kwargs) + # print("END", orig) + return tree_map(wrap, result) + + +def positional(self, *dims): + from . import Dim, DimensionBindError, Tensor + + ptensor, levels = self._tensor, llist(self._levels) + flat_dims = llist() + view = [] + needs_view = False + ndim = self.ndim + for d in dims: + if isinstance(d, DimList): + flat_dims.extend(d) + view.extend(e.size for e in d) + elif isinstance(d, Dim): + flat_dims.append(d) + view.append(d.size) + elif isinstance(d, int): + d = _wrap_dim(d, ndim, False) + flat_dims.append(d) + view.append(ptensor.size(d)) + else: + flat_dims.extend(d) + view.append(prod(e.size for e in d)) + needs_view = True + + permute = list(range(len(levels))) + for i, d in enumerate(flat_dims): + try: + idx = levels.index(d) + except ValueError as e: + raise DimensionBindError( + f"tensor of dimensions {self.dims} does not contain dim {d}" + ) from e + p = permute[idx] + del levels[idx] + del permute[idx] + levels.insert(i, 0) + permute.insert(i, p) + ptensor = ptensor.permute(*permute) + seen = 0 + for i in range(len(levels) - 1, -1, -1): + if isinstance(levels[i], int): + seen += 1 + levels[i] = -seen + result = Tensor.from_positional(ptensor, levels, self._has_device) + if needs_view: + result = result.reshape(*view, *result.size()[len(flat_dims) :]) + return result + + +def _contains_dim(input): + from . import Dim + + for i in input: + if isinstance(i, Dim): + return True + + +def expand(self, *sizes): + if not _contains_dim(sizes): + return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) + dims = sizes + sizes = [d.size for d in dims] + [-1] * self.ndim + self = self.expand(*sizes) + return self[dims] + + +_not_present = object() + + +def _getarg(name, offset, args, kwargs, default): + if len(args) > offset: + return args[offset] + return kwargs.get(name, default) + + +def _patcharg(name, offset, args, kwargs, value): + if len(args) > offset: + args[offset] = value + else: + kwargs[name] = value + + +def _wrap( + orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True +): + from . import Dim, Tensor, TensorLike + + def fn(self, *args, **kwargs): + dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) + if dim is _not_present or (single_dim and not isinstance(dim, Dim)): + with _enable_layers(self.dims): + print(f"dim fallback batch_tensor for {orig}") + return Tensor.from_batched( + orig(self._batchtensor, *args, **kwargs), self._has_device + ) + keepdim = ( + _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False + ) + t, levels = self._tensor, llist(self._levels) + dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) + dim_indices = tuple(levels.index(d) for d in dims) + if reduce and not keepdim: + new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] + else: + new_levels = levels + + if len(dim_indices) == 1: + dim_indices = dim_indices[ + 0 + ] # so that dims that really only take a single argument work... + args = list(args) + _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) + + def wrap(t): + if isinstance(t, TensorLike): + return Tensor.from_positional(t, new_levels, self._has_device) + return t + + with _enable_layers(new_levels): + print(f"dim used batch_tensor for {orig}") + r = orig(t, *args, **kwargs) + return tree_map(wrap, r) + + return fn + + +def _def(name, *args, **kwargs): + from . import _Tensor + + orig = getattr(torch.Tensor, name) + setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) + + +no_slice = slice(None) + +_orig_getitem = torch.Tensor.__getitem__ + + +class dim_tracker: + def __init__(self) -> None: + self.dims = llist() + self.count = [] + + def record(self, d): + if d not in self.dims: + self.dims.append(d) + self.count.append(1) + + def __getitem__(self, d): + return self.count[self.dims.index(d)] + + +def t__getitem__(self, input): + from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike + + # * bail to original example if we have a single non-Dim tensor, or a non-tensor + # * locate ... or an unbound tensor list, and determine its size, bind dim list + # (remember that None does not count to the total dim count) + # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, + # produce the re-view if needed + # * for each single-use dim index, replace with no_slice and mark that it will be added + # (keep track of whether we have to call super) + # * call super if needed + # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) + # this handles bool indexing handling, as well as some other simple cases. + + is_simple = ( + not isinstance(input, Dim) + and not isinstance(input, (tuple, list)) + and + # WAR for functorch bug where zero time tensors in getitem are not handled correctly. + not (isinstance(input, TensorLike) and input.ndim == 0) + ) + + if is_simple: + if isinstance(self, _Tensor): + return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) + else: + return _orig_getitem(self, input) + + # can further optimize this case + if not isinstance(input, tuple): + input = [input] + else: + input = list(input) + + dims_indexed = 0 + expanding_object = None + dimlists = [] + for i, s in enumerate(input): + if s is ... or isinstance(s, DimList) and not s.is_bound: + if expanding_object is not None: + msg = ( + "at most one ... or unbound dimension list can exist in indexing list but" + f" found 2 at offsets {i} and {expanding_object}" + ) + raise DimensionBindError(msg) + expanding_object = i + + if isinstance(s, DimList): + dims_indexed += len(s) if s.is_bound else 0 + dimlists.append(i) + elif s is not None and s is not ...: + dims_indexed += 1 + + ndim = self.ndim + if dims_indexed > ndim: + raise IndexError( + f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." + ) + if expanding_object is not None: + expanding_ndims = ndim - dims_indexed + obj = input[expanding_object] + if obj is ...: + input[expanding_object : expanding_object + 1] = [ + no_slice + ] * expanding_ndims + else: + obj.bind_len(expanding_ndims) + # flatten the dimslists into the indexing + for i in reversed(dimlists): + input[i : i + 1] = input[i] + dims_indexed = 0 + requires_view = False + size = self.size() + view_sizes = [] + dims_seen = dim_tracker() + + def add_dims(t): + if not isinstance(t, _Tensor): + return + for d in t.dims: + dims_seen.record(d) + + add_dims(self) + dim_packs = [] + for i, idx in enumerate(input): + if idx is None: + input[i] = no_slice + view_sizes.append(1) + requires_view = True + else: + sz = size[dims_indexed] + if isinstance(idx, Dim): + idx.size = sz + dims_seen.record(idx) + view_sizes.append(sz) + elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): + for d in idx: + dims_seen.record(idx) + _bind_dims_to_size(sz, idx, f"offset {i}") + view_sizes.extend(d.size for d in idx) + requires_view = True + dim_packs.append(i) + else: + add_dims(idx) + view_sizes.append(sz) + dims_indexed += 1 + if requires_view: + self = self.view(*view_sizes) + for i in reversed(dim_packs): + input[i : i + 1] = input[i] + + # currently: + # input is flat, containing either Dim, or Tensor, or something valid for standard indexing + # self may have first-class dims as well. + + # to index: + # drop the first class dims from self, they just become direct indices of their positions + + # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. + # these dimensions will appear and need to be bound at the first place tensor occurs + + if isinstance(self, _Tensor): + ptensor_self, levels = self._tensor, list(self._levels) + # indices to ptensor rather than self which has first-class dimensions + input_it = iter(input) + flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] + has_device = self._has_device + to_pad = 0 + else: + ptensor_self, flat_inputs = self, input + to_pad = ptensor_self.ndim - len(flat_inputs) + has_device = True + + result_levels = [] + index_levels = [] + tensor_insert_point = None + to_expand = {} + requires_getindex = False + for i, inp in enumerate(flat_inputs): + if isinstance(inp, Dim) and dims_seen[inp] == 1: + flat_inputs[i] = no_slice + result_levels.append(inp) + elif isinstance(inp, TensorLike): + requires_getindex = True + if tensor_insert_point is None: + tensor_insert_point = len(result_levels) + ptensor, levels, _ = _tensor_levels(inp) + to_expand[i] = levels + flat_inputs[i] = ptensor + for l in levels: + if l not in index_levels: + index_levels.append(l) + else: + requires_getindex = True + result_levels.append(0) + + if tensor_insert_point is not None: + result_levels[tensor_insert_point:tensor_insert_point] = index_levels + + for i, levels in to_expand.items(): + flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) + + if requires_getindex: + result = _orig_getitem(ptensor_self, flat_inputs) + else: + result = ptensor_self + + next_positional = -1 + if to_pad > 0: + result_levels.extend([0] * to_pad) + for i, r in enumerate(reversed(result_levels)): + if isinstance(r, int): + result_levels[-1 - i] = next_positional + next_positional -= 1 + + return Tensor.from_positional(result, result_levels, has_device) + + +# XXX - dim is optional and can be the outer-most dimension... +def stack(tensors, new_dim, dim=0, out=None): + if isinstance(dim, int): + return torch.stack(tensors, dim, out).index(dim, new_dim) + index = None + if out is not None: + out, index = _positional_no_permute(out, dim, expand_dim=True) + ptensors = [] + for t in tensors: + pt, pi = _positional_no_permute(t, dim, expand_dim=True) + if index is not None and pi != index: + pt = pt.move_dim(pi, index) + else: + index = pi + ptensors.append(pt) + pr = torch.stack(ptensors, index, out=out) + return pr.index((index, index + 1), (new_dim, dim)) + + +_orig_split = torch.Tensor.split + + +def split(self, split_size_or_sections, dim=0): + from . import _Tensor, Dim + + if isinstance(split_size_or_sections, int) or any( + isinstance(t, int) for t in split_size_or_sections + ): + if isinstance(dim, Dim): + raise ValueError( + "when dim is specified as a Dim object, split sizes must also be dimensions." + ) + return _orig_split(self, split_size_or_sections, dim=dim) + + if isinstance(dim, Dim): + assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" + self, dim = _positional_no_permute(self, dim) + + size = self.size(dim) + total_bound_size = 0 + unbound = [] + sizes = [] + for i, d in enumerate(split_size_or_sections): + if d.is_bound: + sizes.append(d.size) + total_bound_size += d.size + else: + sizes.append(0) + unbound.append(i) + + if unbound: + assert total_bound_size <= size, ( + f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" + ) + remaining_size = size - total_bound_size + chunk_size = -(-remaining_size // len(unbound)) + for u in unbound: + sz = min(chunk_size, remaining_size) + split_size_or_sections[u].size = sz + sizes[u] = sz + remaining_size -= sz + else: + assert total_bound_size == size, ( + f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" + ) + return tuple( + t.index(dim, d) + for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) + ) diff --git a/phivenv/Lib/site-packages/functorch/dim/tree_map.py b/phivenv/Lib/site-packages/functorch/dim/tree_map.py new file mode 100644 index 0000000000000000000000000000000000000000..014a44b95d1890bc8456bc64be6f093e2b433cb6 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/tree_map.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functorch._C import dim + + +tree_flatten = dim.tree_flatten + + +def tree_map(fn, tree): + vs, unflatten = tree_flatten(tree) + return unflatten(fn(v) for v in vs) diff --git a/phivenv/Lib/site-packages/functorch/dim/wrap_type.py b/phivenv/Lib/site-packages/functorch/dim/wrap_type.py new file mode 100644 index 0000000000000000000000000000000000000000..797ce6c34c5cbbe519e07b3c68bceb9350b98b94 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/dim/wrap_type.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from types import ( + BuiltinMethodType, + FunctionType, + GetSetDescriptorType, + MethodDescriptorType, + WrapperDescriptorType, +) + +from functorch._C import dim as _C + + +_wrap_method = _C._wrap_method + +FUNC_TYPES = ( + FunctionType, + MethodDescriptorType, + BuiltinMethodType, + WrapperDescriptorType, +) +PROPERTY_TYPES = (GetSetDescriptorType, property) + + +def _py_wrap_method(orig, __torch_function__): + def impl(*args, **kwargs): + return __torch_function__(orig, None, args, kwargs) + + return impl + + +def wrap_type(use_c, to_patch, pattern, __torch_function__): + if use_c: + wrap_method = _wrap_method + else: + wrap_method = _py_wrap_method + + all = {} + for t in reversed(pattern.mro()[:-1]): # skip object + all.update(t.__dict__) + + def wrap_attr(orig): + return property(wrap_method(orig.__get__, __torch_function__)) + + for name, obj in all.items(): + if name in ( + "__dict__", + "__new__", + "__init__", + "__repr__", + "__weakref__", + "__doc__", + "__module__", + "__dir__", + ): + continue + + # skip things that have been overloaded + # things that come from object like `__eq__` still need to be patched, however. + if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr( + object, name, None + ): + continue + + if isinstance(obj, FUNC_TYPES): + setattr(to_patch, name, wrap_method(obj, __torch_function__)) + elif isinstance(obj, PROPERTY_TYPES): + setattr(to_patch, name, wrap_attr(obj)) diff --git a/phivenv/Lib/site-packages/functorch/einops/__init__.py b/phivenv/Lib/site-packages/functorch/einops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef17a03c09e65c20bff38b7ed2d2b1e05f18f780 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/einops/__init__.py @@ -0,0 +1,4 @@ +from .rearrange import rearrange + + +__all__ = ["rearrange"] diff --git a/phivenv/Lib/site-packages/functorch/einops/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/einops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0defbec65e16db2dc621e3596a592831ac7f3986 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/einops/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/einops/__pycache__/_parsing.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/einops/__pycache__/_parsing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4c904e1192d4b5f0b221ff515431964a5b4e930 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/einops/__pycache__/_parsing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/einops/__pycache__/rearrange.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/einops/__pycache__/rearrange.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..777add58b75859bd1a09717febe719e736d0a39e Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/einops/__pycache__/rearrange.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/einops/_parsing.py b/phivenv/Lib/site-packages/functorch/einops/_parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4a42275326c97d5796c201c48b3c322b52cfe3 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/einops/_parsing.py @@ -0,0 +1,308 @@ +"""Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py. + +MIT License + +Copyright (c) 2018 Alex Rogozhnikov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from __future__ import annotations + +import keyword +import warnings +from typing import Optional, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from collections.abc import Collection, Mapping + + +_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated + + +class AnonymousAxis: + """Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier. + + Note: Different instances of this class are not equal to each other, even if they have the same value. + """ + + def __init__(self, value: str) -> None: + self.value = int(value) + if self.value < 1: + raise ValueError( + f"Anonymous axis should have positive length, not {self.value}" + ) + + def __repr__(self) -> str: + return f"{self.value}-axis" + + +class ParsedExpression: + """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)').""" + + def __init__( + self, + expression: str, + *, + allow_underscore: bool = False, + allow_duplicates: bool = False, + ) -> None: + """Parse the expression and store relevant metadata. + + Args: + expression (str): the `einops`-pattern to parse + allow_underscore (bool): whether to allow axis identifier names to begin with an underscore + allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression + """ + self.has_ellipsis: bool = False + self.has_ellipsis_parenthesized: Optional[bool] = None + self.identifiers: set[Union[str, AnonymousAxis]] = set() + # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition + self.has_non_unitary_anonymous_axes: bool = False + # composition keeps structure of composite axes, see how different corner cases are handled in tests + self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = [] + if "." in expression: + if "..." not in expression: + raise ValueError( + "Expression may contain dots only inside ellipsis (...)" + ) + if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: + raise ValueError( + "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " + ) + expression = expression.replace("...", _ellipsis) + self.has_ellipsis = True + + bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None + + def add_axis_name(x: str) -> None: + if x in self.identifiers: + if not (allow_underscore and x == "_") and not allow_duplicates: + raise ValueError( + f"Indexing expression contains duplicate dimension '{x}'" + ) + if x == _ellipsis: + self.identifiers.add(_ellipsis) + if bracket_group is None: + self.composition.append(_ellipsis) + self.has_ellipsis_parenthesized = False + else: + bracket_group.append(_ellipsis) + self.has_ellipsis_parenthesized = True + else: + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 + if bracket_group is None: + self.composition.append([]) + else: + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name_return_reason( + x, allow_underscore=allow_underscore + ) + if not (is_number or is_axis_name): + raise ValueError(f"Invalid axis identifier: {x}\n{reason}") + axis_name: Union[str, AnonymousAxis] = ( + AnonymousAxis(x) if is_number else x + ) + self.identifiers.add(axis_name) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([axis_name]) + else: + bracket_group.append(axis_name) + + current_identifier = None + for char in expression: + if char in "() ": + if current_identifier is not None: + add_axis_name(current_identifier) + current_identifier = None + if char == "(": + if bracket_group is not None: + raise ValueError( + "Axis composition is one-level (brackets inside brackets not allowed)" + ) + bracket_group = [] + elif char == ")": + if bracket_group is None: + raise ValueError("Brackets are not balanced") + self.composition.append(bracket_group) + bracket_group = None + elif str.isalnum(char) or char in ["_", _ellipsis]: + if current_identifier is None: + current_identifier = char + else: + current_identifier += char + else: + raise ValueError(f"Unknown character '{char}'") + + if bracket_group is not None: + raise ValueError(f"Imbalanced parentheses in expression: '{expression}'") + if current_identifier is not None: + add_axis_name(current_identifier) + + @staticmethod + def check_axis_name_return_reason( + name: str, allow_underscore: bool = False + ) -> tuple[bool, str]: + """Check if the given axis name is valid, and a message explaining why if not. + + Valid axes names are python identifiers except keywords, and should not start or end with an underscore. + + Args: + name (str): the axis name to check + allow_underscore (bool): whether axis names are allowed to start with an underscore + + Returns: + tuple[bool, str]: whether the axis name is valid, a message explaining why if not + """ + if not str.isidentifier(name): + return False, "not a valid python identifier" + elif name[0] == "_" or name[-1] == "_": + if name == "_" and allow_underscore: + return True, "" + return False, "axis name should should not start or end with underscore" + else: + if keyword.iskeyword(name): + warnings.warn( + f"It is discouraged to use axes names that are keywords: {name}", + RuntimeWarning, + ) + if name in ["axis"]: + warnings.warn( + "It is discouraged to use 'axis' as an axis name and will raise an error in future", + FutureWarning, + ) + return True, "" + + @staticmethod + def check_axis_name(name: str) -> bool: + """Check if the name is a valid axis name. + + Args: + name (str): the axis name to check + + Returns: + bool: whether the axis name is valid + """ + is_valid, _ = ParsedExpression.check_axis_name_return_reason(name) + return is_valid + + +def parse_pattern( + pattern: str, axes_lengths: Mapping[str, int] +) -> tuple[ParsedExpression, ParsedExpression]: + """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object. + + Args: + pattern (str): the `einops`-style rearrangement pattern + axes_lengths (Mapping[str, int]): any additional length specifications for dimensions + + Returns: + tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions + """ + # adapted from einops.einops._prepare_transformation_recipe + # https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py + try: + left_str, right_str = pattern.split("->") + except ValueError: + raise ValueError("Pattern must contain a single '->' separator") from None + + if _ellipsis in axes_lengths: + raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier") + + left = ParsedExpression(left_str) + right = ParsedExpression(right_str) + + if not left.has_ellipsis and right.has_ellipsis: + raise ValueError( + f"Ellipsis found in right side, but not left side of a pattern {pattern}" + ) + if left.has_ellipsis and left.has_ellipsis_parenthesized: + raise ValueError( + f"Ellipsis is parenthesis in the left side is not allowed: {pattern}" + ) + + return left, right + + +def validate_rearrange_expressions( + left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int] +) -> None: + """Perform expression validations that are specific to the `rearrange` operation. + + Args: + left (ParsedExpression): left-hand side expression + right (ParsedExpression): right-hand side expression + axes_lengths (Mapping[str, int]): any additional length specifications for dimensions + """ + for length in axes_lengths.values(): + if (length_type := type(length)) is not int: + raise TypeError( + f"rearrange axis lengths must be integers, got: {length_type}" + ) + + if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes: + raise ValueError("rearrange only supports unnamed axes of size 1") + + difference = set.symmetric_difference(left.identifiers, right.identifiers) + if len(difference) > 0: + raise ValueError( + f"Identifiers only on one side of rearrange expression (should be on both): {difference}" + ) + + unmatched_axes = axes_lengths.keys() - left.identifiers + if len(unmatched_axes) > 0: + raise ValueError( + f"Identifiers not found in rearrange expression: {unmatched_axes}" + ) + + +def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str: + """Convert a collection of strings representing first class dims into a comma-separated string. + + Args: + collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert + + Returns: + str: the comma-separated string + + Examples: + >>> comma_separate(("d0",)) + 'd0' + + >>> comma_separate(("d0", "d1", "d2", "d3")) + 'd0, d1, d2, d3' + + >>> comma_separate([("d1", "d4")]) + '(d1, d4)' + + >>> comma_separate([("d0",), (), ("d1",), ("d2",), ("d3", "d4")]) + '(d0,), (), (d1,), (d2,), (d3, d4)' + """ + return ", ".join( + item + if isinstance(item, str) + else f"({comma_separate(item)}{',' if len(item) == 1 else ''})" + for item in collection + ) diff --git a/phivenv/Lib/site-packages/functorch/einops/rearrange.py b/phivenv/Lib/site-packages/functorch/einops/rearrange.py new file mode 100644 index 0000000000000000000000000000000000000000..8133fd214ede6967c968c5755cfc09db106169dc --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/einops/rearrange.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import functools +from typing import Callable, TYPE_CHECKING, Union + +import torch +from functorch._C import dim as _C + +from ._parsing import ( + _ellipsis, + AnonymousAxis, + comma_separate, + parse_pattern, + validate_rearrange_expressions, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + +__all__ = ["rearrange"] + +dims = _C.dims + + +@functools.lru_cache(256) +def _create_rearrange_callable( + tensor_ndim: int, pattern: str, **axes_lengths: int +) -> Callable[[torch.Tensor], torch.Tensor]: + r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions. + + Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and + specified axes lengths, this function can be memoized. + + Args: + tensor_ndim (int): the number of dimensions in the tensor to rearrange + pattern (str): the `einops`-style rearrangement pattern + axes_lengths (int): any additional length specifications for dimensions + + Returns: + Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement + """ + left, right = parse_pattern(pattern, axes_lengths) + validate_rearrange_expressions(left, right, axes_lengths) + + n_anon_dims = sum(not dim for dim in left.composition) + if left.has_ellipsis: + n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1) + n_named_dims = len(left.identifiers) - 1 + + if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim: + raise ValueError( + f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of " + f"dimensions in the tensor ({tensor_ndim})" + ) + else: + n_ellipsis_dims = 0 + n_named_dims = len(left.identifiers) + + if (pattern_ndim := len(left.composition)) != tensor_ndim: + raise ValueError( + f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in " + f"the tensor ({tensor_ndim})" + ) + n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims + + if n_dims == 0: + # an identity rearrangement on a 0-dimension tensor + return lambda tensor: tensor + + first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims)) + identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {} + anon_axes: list[AnonymousAxis] = [] + + # map the left-hand side identifiers to strings representing first class dims + dims_i = 0 + for dimension in left.composition: + if isinstance(dimension, list): + for identifier in dimension: + # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists + assert isinstance(identifier, str) + identifier_dim_map[identifier] = (first_class_dims[dims_i],) + dims_i += 1 + if not dimension: + # unitary anonymous axis + anon_axis = AnonymousAxis("1") + identifier_dim_map[anon_axis] = (first_class_dims[dims_i],) + anon_axes.append(anon_axis) + dimension.append(anon_axis) + dims_i += 1 + elif dimension == _ellipsis: + identifier = _ellipsis + identifier_dim_map[identifier] = tuple( + first_class_dims[dims_i + j] for j in range(n_ellipsis_dims) + ) + dims_i += n_ellipsis_dims + else: + raise ValueError(f"Unexpected dimension: {dimension}") + + def composition_to_dims( + composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]], + ) -> list[Union[str, tuple[str, ...]]]: + """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first + class dims.""" + dim_composition: list[Union[str, tuple[str, ...]]] = [] + for dimension in composition: + if isinstance(dimension, list): + dim_composition.append( + tuple( + dim + for identifier in dimension + for dim in identifier_dim_map[identifier] + ) + ) + elif dimension == _ellipsis: + dim_composition.extend(identifier_dim_map[_ellipsis]) + else: + raise ValueError(f"Unexpected dimension: {dimension}") + return dim_composition + + left_dims = composition_to_dims(left.composition) + right_dims = composition_to_dims(right.composition) + anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes) + specified_lengths = tuple( + (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items() + ) + + custom_rearrange_callable_name = "do_rearrange" + custom_rearrange_callable_code = ( + ( + f"def {custom_rearrange_callable_name}(tensor):\n" + f" {comma_separate(first_class_dims)} = dims({n_dims})\n" + ) + + ( + "".join( + f" {dim}.size = {length}\n" for (dim, length) in specified_lengths + ) + if specified_lengths + else "" + ) + + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n" + + ( + f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n" + if anon_dims + else " return tensor\n" + ) + ) + + exec(custom_rearrange_callable_code) + return locals()[custom_rearrange_callable_name] + + +def rearrange( + tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]], + pattern: str, + **axes_lengths: int, +) -> torch.Tensor: + r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional + tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, + stack, concatenate and other operations. + + See: https://einops.rocks/api/rearrange/ + + Args: + tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange + pattern (str): the rearrangement pattern + axes_lengths (int): any additional length specifications for dimensions + + Returns: + Tensor: the rearranged tensor + + Examples: + >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel) + >>> images = torch.randn((32, 30, 40, 3)) + + >>> # stack along first (batch) axis, output is a single array + >>> rearrange(images, "b h w c -> b h w c").shape + torch.Size([32, 30, 40, 3]) + + >>> # concatenate images along height (vertical axis), 960 = 32 * 30 + >>> rearrange(images, "b h w c -> (b h) w c").shape + torch.Size([960, 40, 3]) + + >>> # concatenated images along horizontal axis, 1280 = 32 * 40 + >>> rearrange(images, "b h w c -> h (b w) c").shape + torch.Size([30, 1280, 3]) + + >>> # reordered axes to "b c h w" format for deep learning + >>> rearrange(images, "b h w c -> b c h w").shape + torch.Size([32, 3, 30, 40]) + + >>> # flattened each image into a vector, 3600 = 30 * 40 * 3 + >>> rearrange(images, "b h w c -> b (c h w)").shape + torch.Size([32, 3600]) + + >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 + >>> rearrange(images, "b (h1 h) (w1 w) c -> (b h1 w1) h w c", h1=2, w1=2).shape + torch.Size([128, 15, 20, 3]) + + >>> # space-to-depth operation + >>> rearrange(images, "b (h h1) (w w1) c -> b h w (c h1 w1)", h1=2, w1=2).shape + torch.Size([32, 15, 20, 12]) + """ + if not isinstance(tensor, torch.Tensor): + tensor = torch.stack(tensor) + + rearrange_callable = _create_rearrange_callable( + tensor.ndim, pattern, **axes_lengths + ) + + return rearrange_callable(tensor) diff --git a/phivenv/Lib/site-packages/functorch/experimental/__init__.py b/phivenv/Lib/site-packages/functorch/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00af1abcc55fd733e473cfa413da38ef5898bbcb --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/experimental/__init__.py @@ -0,0 +1,5 @@ +# PyTorch forward-mode is not mature yet +from functorch import functionalize +from torch._functorch.apis import chunk_vmap +from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ +from torch._functorch.eager_transforms import hessian, jacfwd, jvp diff --git a/phivenv/Lib/site-packages/functorch/experimental/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/experimental/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df7702b8f17eee0546cf7d1ce92722665882312 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/experimental/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/experimental/__pycache__/control_flow.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/experimental/__pycache__/control_flow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d6b2532511773d95aa10c3aa392ddc56e75db4 Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/experimental/__pycache__/control_flow.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/experimental/__pycache__/ops.cpython-39.pyc b/phivenv/Lib/site-packages/functorch/experimental/__pycache__/ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8312ce49b3f6eeffdd8df58a8b3051fce38dd68a Binary files /dev/null and b/phivenv/Lib/site-packages/functorch/experimental/__pycache__/ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/functorch/experimental/control_flow.py b/phivenv/Lib/site-packages/functorch/experimental/control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..46ed7a86a1a11cb0aff6057c7cc013638a1d5d6f --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/experimental/control_flow.py @@ -0,0 +1,6 @@ +from torch import cond # noqa: F401 +from torch._higher_order_ops.map import ( # noqa: F401 + _stack_pytree, + _unstack_pytree, + map, +) diff --git a/phivenv/Lib/site-packages/functorch/experimental/ops.py b/phivenv/Lib/site-packages/functorch/experimental/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..144515478f0c537ee28ba46f72e3a70f08858f38 --- /dev/null +++ b/phivenv/Lib/site-packages/functorch/experimental/ops.py @@ -0,0 +1 @@ +from torch._ops import HigherOrderOperator # noqa: F401 diff --git a/phivenv/Lib/site-packages/huggingface_hub/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d02d069c38db65f93ae68e1adafd9995c25a6a8 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/__init__.py @@ -0,0 +1,1530 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# *********** +# `huggingface_hub` init has 2 modes: +# - Normal usage: +# If imported to use it, all modules and functions are lazy-loaded. This means +# they exist at top level in module but are imported only the first time they are +# used. This way, `from huggingface_hub import something` will import `something` +# quickly without the hassle of importing all the features from `huggingface_hub`. +# - Static check: +# If statically analyzed, all modules and functions are loaded normally. This way +# static typing check works properly as well as autocomplete in text editors and +# IDEs. +# +# The static model imports are done inside the `if TYPE_CHECKING:` statement at +# the bottom of this file. Since module/functions imports are duplicated, it is +# mandatory to make sure to add them twice when adding one. This is checked in the +# `make quality` command. +# +# To update the static imports, please run the following command and commit the changes. +# ``` +# # Use script +# python utils/check_static_imports.py --update-file +# +# # Or run style on codebase +# make style +# ``` +# +# *********** +# Lazy loader vendored from https://github.com/scientific-python/lazy_loader +import importlib +import os +import sys +from typing import TYPE_CHECKING + + +__version__ = "0.34.4" + +# Alphabetical order of definitions is ensured in tests +# WARNING: any comment added in this dictionary definition will be lost when +# re-generating the file ! +_SUBMOD_ATTRS = { + "_commit_scheduler": [ + "CommitScheduler", + ], + "_inference_endpoints": [ + "InferenceEndpoint", + "InferenceEndpointError", + "InferenceEndpointStatus", + "InferenceEndpointTimeoutError", + "InferenceEndpointType", + ], + "_jobs_api": [ + "JobInfo", + "JobOwner", + "JobStage", + "JobStatus", + ], + "_login": [ + "auth_list", + "auth_switch", + "interpreter_login", + "login", + "logout", + "notebook_login", + ], + "_oauth": [ + "OAuthInfo", + "OAuthOrgInfo", + "OAuthUserInfo", + "attach_huggingface_oauth", + "parse_huggingface_oauth", + ], + "_snapshot_download": [ + "snapshot_download", + ], + "_space_api": [ + "SpaceHardware", + "SpaceRuntime", + "SpaceStage", + "SpaceStorage", + "SpaceVariable", + ], + "_tensorboard_logger": [ + "HFSummaryWriter", + ], + "_webhooks_payload": [ + "WebhookPayload", + "WebhookPayloadComment", + "WebhookPayloadDiscussion", + "WebhookPayloadDiscussionChanges", + "WebhookPayloadEvent", + "WebhookPayloadMovedTo", + "WebhookPayloadRepo", + "WebhookPayloadUrl", + "WebhookPayloadWebhook", + ], + "_webhooks_server": [ + "WebhooksServer", + "webhook_endpoint", + ], + "community": [ + "Discussion", + "DiscussionComment", + "DiscussionCommit", + "DiscussionEvent", + "DiscussionStatusChange", + "DiscussionTitleChange", + "DiscussionWithDetails", + ], + "constants": [ + "CONFIG_NAME", + "FLAX_WEIGHTS_NAME", + "HUGGINGFACE_CO_URL_HOME", + "HUGGINGFACE_CO_URL_TEMPLATE", + "PYTORCH_WEIGHTS_NAME", + "REPO_TYPE_DATASET", + "REPO_TYPE_MODEL", + "REPO_TYPE_SPACE", + "TF2_WEIGHTS_NAME", + "TF_WEIGHTS_NAME", + ], + "fastai_utils": [ + "_save_pretrained_fastai", + "from_pretrained_fastai", + "push_to_hub_fastai", + ], + "file_download": [ + "HfFileMetadata", + "_CACHED_NO_EXIST", + "get_hf_file_metadata", + "hf_hub_download", + "hf_hub_url", + "try_to_load_from_cache", + ], + "hf_api": [ + "Collection", + "CollectionItem", + "CommitInfo", + "CommitOperation", + "CommitOperationAdd", + "CommitOperationCopy", + "CommitOperationDelete", + "DatasetInfo", + "GitCommitInfo", + "GitRefInfo", + "GitRefs", + "HfApi", + "ModelInfo", + "RepoUrl", + "SpaceInfo", + "User", + "UserLikes", + "WebhookInfo", + "WebhookWatchedItem", + "accept_access_request", + "add_collection_item", + "add_space_secret", + "add_space_variable", + "auth_check", + "cancel_access_request", + "cancel_job", + "change_discussion_status", + "comment_discussion", + "create_branch", + "create_collection", + "create_commit", + "create_discussion", + "create_inference_endpoint", + "create_inference_endpoint_from_catalog", + "create_pull_request", + "create_repo", + "create_tag", + "create_webhook", + "dataset_info", + "delete_branch", + "delete_collection", + "delete_collection_item", + "delete_file", + "delete_folder", + "delete_inference_endpoint", + "delete_repo", + "delete_space_secret", + "delete_space_storage", + "delete_space_variable", + "delete_tag", + "delete_webhook", + "disable_webhook", + "duplicate_space", + "edit_discussion_comment", + "enable_webhook", + "fetch_job_logs", + "file_exists", + "get_collection", + "get_dataset_tags", + "get_discussion_details", + "get_full_repo_name", + "get_inference_endpoint", + "get_model_tags", + "get_paths_info", + "get_repo_discussions", + "get_safetensors_metadata", + "get_space_runtime", + "get_space_variables", + "get_token_permission", + "get_user_overview", + "get_webhook", + "grant_access", + "inspect_job", + "list_accepted_access_requests", + "list_collections", + "list_datasets", + "list_inference_catalog", + "list_inference_endpoints", + "list_jobs", + "list_lfs_files", + "list_liked_repos", + "list_models", + "list_organization_members", + "list_papers", + "list_pending_access_requests", + "list_rejected_access_requests", + "list_repo_commits", + "list_repo_files", + "list_repo_likers", + "list_repo_refs", + "list_repo_tree", + "list_spaces", + "list_user_followers", + "list_user_following", + "list_webhooks", + "merge_pull_request", + "model_info", + "move_repo", + "paper_info", + "parse_safetensors_file_metadata", + "pause_inference_endpoint", + "pause_space", + "permanently_delete_lfs_files", + "preupload_lfs_files", + "reject_access_request", + "rename_discussion", + "repo_exists", + "repo_info", + "repo_type_and_id_from_hf_id", + "request_space_hardware", + "request_space_storage", + "restart_space", + "resume_inference_endpoint", + "revision_exists", + "run_as_future", + "run_job", + "run_uv_job", + "scale_to_zero_inference_endpoint", + "set_space_sleep_time", + "space_info", + "super_squash_history", + "unlike", + "update_collection_item", + "update_collection_metadata", + "update_inference_endpoint", + "update_repo_settings", + "update_repo_visibility", + "update_webhook", + "upload_file", + "upload_folder", + "upload_large_folder", + "whoami", + ], + "hf_file_system": [ + "HfFileSystem", + "HfFileSystemFile", + "HfFileSystemResolvedPath", + "HfFileSystemStreamFile", + ], + "hub_mixin": [ + "ModelHubMixin", + "PyTorchModelHubMixin", + ], + "inference._client": [ + "InferenceClient", + "InferenceTimeoutError", + ], + "inference._generated._async_client": [ + "AsyncInferenceClient", + ], + "inference._generated.types": [ + "AudioClassificationInput", + "AudioClassificationOutputElement", + "AudioClassificationOutputTransform", + "AudioClassificationParameters", + "AudioToAudioInput", + "AudioToAudioOutputElement", + "AutomaticSpeechRecognitionEarlyStoppingEnum", + "AutomaticSpeechRecognitionGenerationParameters", + "AutomaticSpeechRecognitionInput", + "AutomaticSpeechRecognitionOutput", + "AutomaticSpeechRecognitionOutputChunk", + "AutomaticSpeechRecognitionParameters", + "ChatCompletionInput", + "ChatCompletionInputFunctionDefinition", + "ChatCompletionInputFunctionName", + "ChatCompletionInputGrammarType", + "ChatCompletionInputJSONSchema", + "ChatCompletionInputMessage", + "ChatCompletionInputMessageChunk", + "ChatCompletionInputMessageChunkType", + "ChatCompletionInputResponseFormatJSONObject", + "ChatCompletionInputResponseFormatJSONSchema", + "ChatCompletionInputResponseFormatText", + "ChatCompletionInputStreamOptions", + "ChatCompletionInputTool", + "ChatCompletionInputToolCall", + "ChatCompletionInputToolChoiceClass", + "ChatCompletionInputToolChoiceEnum", + "ChatCompletionInputURL", + "ChatCompletionOutput", + "ChatCompletionOutputComplete", + "ChatCompletionOutputFunctionDefinition", + "ChatCompletionOutputLogprob", + "ChatCompletionOutputLogprobs", + "ChatCompletionOutputMessage", + "ChatCompletionOutputToolCall", + "ChatCompletionOutputTopLogprob", + "ChatCompletionOutputUsage", + "ChatCompletionStreamOutput", + "ChatCompletionStreamOutputChoice", + "ChatCompletionStreamOutputDelta", + "ChatCompletionStreamOutputDeltaToolCall", + "ChatCompletionStreamOutputFunction", + "ChatCompletionStreamOutputLogprob", + "ChatCompletionStreamOutputLogprobs", + "ChatCompletionStreamOutputTopLogprob", + "ChatCompletionStreamOutputUsage", + "DepthEstimationInput", + "DepthEstimationOutput", + "DocumentQuestionAnsweringInput", + "DocumentQuestionAnsweringInputData", + "DocumentQuestionAnsweringOutputElement", + "DocumentQuestionAnsweringParameters", + "FeatureExtractionInput", + "FeatureExtractionInputTruncationDirection", + "FillMaskInput", + "FillMaskOutputElement", + "FillMaskParameters", + "ImageClassificationInput", + "ImageClassificationOutputElement", + "ImageClassificationOutputTransform", + "ImageClassificationParameters", + "ImageSegmentationInput", + "ImageSegmentationOutputElement", + "ImageSegmentationParameters", + "ImageSegmentationSubtask", + "ImageToImageInput", + "ImageToImageOutput", + "ImageToImageParameters", + "ImageToImageTargetSize", + "ImageToTextEarlyStoppingEnum", + "ImageToTextGenerationParameters", + "ImageToTextInput", + "ImageToTextOutput", + "ImageToTextParameters", + "ImageToVideoInput", + "ImageToVideoOutput", + "ImageToVideoParameters", + "ImageToVideoTargetSize", + "ObjectDetectionBoundingBox", + "ObjectDetectionInput", + "ObjectDetectionOutputElement", + "ObjectDetectionParameters", + "Padding", + "QuestionAnsweringInput", + "QuestionAnsweringInputData", + "QuestionAnsweringOutputElement", + "QuestionAnsweringParameters", + "SentenceSimilarityInput", + "SentenceSimilarityInputData", + "SummarizationInput", + "SummarizationOutput", + "SummarizationParameters", + "SummarizationTruncationStrategy", + "TableQuestionAnsweringInput", + "TableQuestionAnsweringInputData", + "TableQuestionAnsweringOutputElement", + "TableQuestionAnsweringParameters", + "Text2TextGenerationInput", + "Text2TextGenerationOutput", + "Text2TextGenerationParameters", + "Text2TextGenerationTruncationStrategy", + "TextClassificationInput", + "TextClassificationOutputElement", + "TextClassificationOutputTransform", + "TextClassificationParameters", + "TextGenerationInput", + "TextGenerationInputGenerateParameters", + "TextGenerationInputGrammarType", + "TextGenerationOutput", + "TextGenerationOutputBestOfSequence", + "TextGenerationOutputDetails", + "TextGenerationOutputFinishReason", + "TextGenerationOutputPrefillToken", + "TextGenerationOutputToken", + "TextGenerationStreamOutput", + "TextGenerationStreamOutputStreamDetails", + "TextGenerationStreamOutputToken", + "TextToAudioEarlyStoppingEnum", + "TextToAudioGenerationParameters", + "TextToAudioInput", + "TextToAudioOutput", + "TextToAudioParameters", + "TextToImageInput", + "TextToImageOutput", + "TextToImageParameters", + "TextToSpeechEarlyStoppingEnum", + "TextToSpeechGenerationParameters", + "TextToSpeechInput", + "TextToSpeechOutput", + "TextToSpeechParameters", + "TextToVideoInput", + "TextToVideoOutput", + "TextToVideoParameters", + "TokenClassificationAggregationStrategy", + "TokenClassificationInput", + "TokenClassificationOutputElement", + "TokenClassificationParameters", + "TranslationInput", + "TranslationOutput", + "TranslationParameters", + "TranslationTruncationStrategy", + "TypeEnum", + "VideoClassificationInput", + "VideoClassificationOutputElement", + "VideoClassificationOutputTransform", + "VideoClassificationParameters", + "VisualQuestionAnsweringInput", + "VisualQuestionAnsweringInputData", + "VisualQuestionAnsweringOutputElement", + "VisualQuestionAnsweringParameters", + "ZeroShotClassificationInput", + "ZeroShotClassificationOutputElement", + "ZeroShotClassificationParameters", + "ZeroShotImageClassificationInput", + "ZeroShotImageClassificationOutputElement", + "ZeroShotImageClassificationParameters", + "ZeroShotObjectDetectionBoundingBox", + "ZeroShotObjectDetectionInput", + "ZeroShotObjectDetectionOutputElement", + "ZeroShotObjectDetectionParameters", + ], + "inference._mcp.agent": [ + "Agent", + ], + "inference._mcp.mcp_client": [ + "MCPClient", + ], + "inference_api": [ + "InferenceApi", + ], + "keras_mixin": [ + "KerasModelHubMixin", + "from_pretrained_keras", + "push_to_hub_keras", + "save_pretrained_keras", + ], + "repocard": [ + "DatasetCard", + "ModelCard", + "RepoCard", + "SpaceCard", + "metadata_eval_result", + "metadata_load", + "metadata_save", + "metadata_update", + ], + "repocard_data": [ + "CardData", + "DatasetCardData", + "EvalResult", + "ModelCardData", + "SpaceCardData", + ], + "repository": [ + "Repository", + ], + "serialization": [ + "StateDictSplit", + "get_tf_storage_size", + "get_torch_storage_id", + "get_torch_storage_size", + "load_state_dict_from_file", + "load_torch_model", + "save_torch_model", + "save_torch_state_dict", + "split_state_dict_into_shards_factory", + "split_tf_state_dict_into_shards", + "split_torch_state_dict_into_shards", + ], + "serialization._dduf": [ + "DDUFEntry", + "export_entries_as_dduf", + "export_folder_as_dduf", + "read_dduf_file", + ], + "utils": [ + "CacheNotFound", + "CachedFileInfo", + "CachedRepoInfo", + "CachedRevisionInfo", + "CorruptedCacheException", + "DeleteCacheStrategy", + "HFCacheInfo", + "HfFolder", + "cached_assets_path", + "configure_http_backend", + "dump_environment_info", + "get_session", + "get_token", + "logging", + "scan_cache_dir", + ], +} + +# WARNING: __all__ is generated automatically, Any manual edit will be lost when re-generating this file ! +# +# To update the static imports, please run the following command and commit the changes. +# ``` +# # Use script +# python utils/check_all_variable.py --update +# +# # Or run style on codebase +# make style +# ``` + +__all__ = [ + "Agent", + "AsyncInferenceClient", + "AudioClassificationInput", + "AudioClassificationOutputElement", + "AudioClassificationOutputTransform", + "AudioClassificationParameters", + "AudioToAudioInput", + "AudioToAudioOutputElement", + "AutomaticSpeechRecognitionEarlyStoppingEnum", + "AutomaticSpeechRecognitionGenerationParameters", + "AutomaticSpeechRecognitionInput", + "AutomaticSpeechRecognitionOutput", + "AutomaticSpeechRecognitionOutputChunk", + "AutomaticSpeechRecognitionParameters", + "CONFIG_NAME", + "CacheNotFound", + "CachedFileInfo", + "CachedRepoInfo", + "CachedRevisionInfo", + "CardData", + "ChatCompletionInput", + "ChatCompletionInputFunctionDefinition", + "ChatCompletionInputFunctionName", + "ChatCompletionInputGrammarType", + "ChatCompletionInputJSONSchema", + "ChatCompletionInputMessage", + "ChatCompletionInputMessageChunk", + "ChatCompletionInputMessageChunkType", + "ChatCompletionInputResponseFormatJSONObject", + "ChatCompletionInputResponseFormatJSONSchema", + "ChatCompletionInputResponseFormatText", + "ChatCompletionInputStreamOptions", + "ChatCompletionInputTool", + "ChatCompletionInputToolCall", + "ChatCompletionInputToolChoiceClass", + "ChatCompletionInputToolChoiceEnum", + "ChatCompletionInputURL", + "ChatCompletionOutput", + "ChatCompletionOutputComplete", + "ChatCompletionOutputFunctionDefinition", + "ChatCompletionOutputLogprob", + "ChatCompletionOutputLogprobs", + "ChatCompletionOutputMessage", + "ChatCompletionOutputToolCall", + "ChatCompletionOutputTopLogprob", + "ChatCompletionOutputUsage", + "ChatCompletionStreamOutput", + "ChatCompletionStreamOutputChoice", + "ChatCompletionStreamOutputDelta", + "ChatCompletionStreamOutputDeltaToolCall", + "ChatCompletionStreamOutputFunction", + "ChatCompletionStreamOutputLogprob", + "ChatCompletionStreamOutputLogprobs", + "ChatCompletionStreamOutputTopLogprob", + "ChatCompletionStreamOutputUsage", + "Collection", + "CollectionItem", + "CommitInfo", + "CommitOperation", + "CommitOperationAdd", + "CommitOperationCopy", + "CommitOperationDelete", + "CommitScheduler", + "CorruptedCacheException", + "DDUFEntry", + "DatasetCard", + "DatasetCardData", + "DatasetInfo", + "DeleteCacheStrategy", + "DepthEstimationInput", + "DepthEstimationOutput", + "Discussion", + "DiscussionComment", + "DiscussionCommit", + "DiscussionEvent", + "DiscussionStatusChange", + "DiscussionTitleChange", + "DiscussionWithDetails", + "DocumentQuestionAnsweringInput", + "DocumentQuestionAnsweringInputData", + "DocumentQuestionAnsweringOutputElement", + "DocumentQuestionAnsweringParameters", + "EvalResult", + "FLAX_WEIGHTS_NAME", + "FeatureExtractionInput", + "FeatureExtractionInputTruncationDirection", + "FillMaskInput", + "FillMaskOutputElement", + "FillMaskParameters", + "GitCommitInfo", + "GitRefInfo", + "GitRefs", + "HFCacheInfo", + "HFSummaryWriter", + "HUGGINGFACE_CO_URL_HOME", + "HUGGINGFACE_CO_URL_TEMPLATE", + "HfApi", + "HfFileMetadata", + "HfFileSystem", + "HfFileSystemFile", + "HfFileSystemResolvedPath", + "HfFileSystemStreamFile", + "HfFolder", + "ImageClassificationInput", + "ImageClassificationOutputElement", + "ImageClassificationOutputTransform", + "ImageClassificationParameters", + "ImageSegmentationInput", + "ImageSegmentationOutputElement", + "ImageSegmentationParameters", + "ImageSegmentationSubtask", + "ImageToImageInput", + "ImageToImageOutput", + "ImageToImageParameters", + "ImageToImageTargetSize", + "ImageToTextEarlyStoppingEnum", + "ImageToTextGenerationParameters", + "ImageToTextInput", + "ImageToTextOutput", + "ImageToTextParameters", + "ImageToVideoInput", + "ImageToVideoOutput", + "ImageToVideoParameters", + "ImageToVideoTargetSize", + "InferenceApi", + "InferenceClient", + "InferenceEndpoint", + "InferenceEndpointError", + "InferenceEndpointStatus", + "InferenceEndpointTimeoutError", + "InferenceEndpointType", + "InferenceTimeoutError", + "JobInfo", + "JobOwner", + "JobStage", + "JobStatus", + "KerasModelHubMixin", + "MCPClient", + "ModelCard", + "ModelCardData", + "ModelHubMixin", + "ModelInfo", + "OAuthInfo", + "OAuthOrgInfo", + "OAuthUserInfo", + "ObjectDetectionBoundingBox", + "ObjectDetectionInput", + "ObjectDetectionOutputElement", + "ObjectDetectionParameters", + "PYTORCH_WEIGHTS_NAME", + "Padding", + "PyTorchModelHubMixin", + "QuestionAnsweringInput", + "QuestionAnsweringInputData", + "QuestionAnsweringOutputElement", + "QuestionAnsweringParameters", + "REPO_TYPE_DATASET", + "REPO_TYPE_MODEL", + "REPO_TYPE_SPACE", + "RepoCard", + "RepoUrl", + "Repository", + "SentenceSimilarityInput", + "SentenceSimilarityInputData", + "SpaceCard", + "SpaceCardData", + "SpaceHardware", + "SpaceInfo", + "SpaceRuntime", + "SpaceStage", + "SpaceStorage", + "SpaceVariable", + "StateDictSplit", + "SummarizationInput", + "SummarizationOutput", + "SummarizationParameters", + "SummarizationTruncationStrategy", + "TF2_WEIGHTS_NAME", + "TF_WEIGHTS_NAME", + "TableQuestionAnsweringInput", + "TableQuestionAnsweringInputData", + "TableQuestionAnsweringOutputElement", + "TableQuestionAnsweringParameters", + "Text2TextGenerationInput", + "Text2TextGenerationOutput", + "Text2TextGenerationParameters", + "Text2TextGenerationTruncationStrategy", + "TextClassificationInput", + "TextClassificationOutputElement", + "TextClassificationOutputTransform", + "TextClassificationParameters", + "TextGenerationInput", + "TextGenerationInputGenerateParameters", + "TextGenerationInputGrammarType", + "TextGenerationOutput", + "TextGenerationOutputBestOfSequence", + "TextGenerationOutputDetails", + "TextGenerationOutputFinishReason", + "TextGenerationOutputPrefillToken", + "TextGenerationOutputToken", + "TextGenerationStreamOutput", + "TextGenerationStreamOutputStreamDetails", + "TextGenerationStreamOutputToken", + "TextToAudioEarlyStoppingEnum", + "TextToAudioGenerationParameters", + "TextToAudioInput", + "TextToAudioOutput", + "TextToAudioParameters", + "TextToImageInput", + "TextToImageOutput", + "TextToImageParameters", + "TextToSpeechEarlyStoppingEnum", + "TextToSpeechGenerationParameters", + "TextToSpeechInput", + "TextToSpeechOutput", + "TextToSpeechParameters", + "TextToVideoInput", + "TextToVideoOutput", + "TextToVideoParameters", + "TokenClassificationAggregationStrategy", + "TokenClassificationInput", + "TokenClassificationOutputElement", + "TokenClassificationParameters", + "TranslationInput", + "TranslationOutput", + "TranslationParameters", + "TranslationTruncationStrategy", + "TypeEnum", + "User", + "UserLikes", + "VideoClassificationInput", + "VideoClassificationOutputElement", + "VideoClassificationOutputTransform", + "VideoClassificationParameters", + "VisualQuestionAnsweringInput", + "VisualQuestionAnsweringInputData", + "VisualQuestionAnsweringOutputElement", + "VisualQuestionAnsweringParameters", + "WebhookInfo", + "WebhookPayload", + "WebhookPayloadComment", + "WebhookPayloadDiscussion", + "WebhookPayloadDiscussionChanges", + "WebhookPayloadEvent", + "WebhookPayloadMovedTo", + "WebhookPayloadRepo", + "WebhookPayloadUrl", + "WebhookPayloadWebhook", + "WebhookWatchedItem", + "WebhooksServer", + "ZeroShotClassificationInput", + "ZeroShotClassificationOutputElement", + "ZeroShotClassificationParameters", + "ZeroShotImageClassificationInput", + "ZeroShotImageClassificationOutputElement", + "ZeroShotImageClassificationParameters", + "ZeroShotObjectDetectionBoundingBox", + "ZeroShotObjectDetectionInput", + "ZeroShotObjectDetectionOutputElement", + "ZeroShotObjectDetectionParameters", + "_CACHED_NO_EXIST", + "_save_pretrained_fastai", + "accept_access_request", + "add_collection_item", + "add_space_secret", + "add_space_variable", + "attach_huggingface_oauth", + "auth_check", + "auth_list", + "auth_switch", + "cached_assets_path", + "cancel_access_request", + "cancel_job", + "change_discussion_status", + "comment_discussion", + "configure_http_backend", + "create_branch", + "create_collection", + "create_commit", + "create_discussion", + "create_inference_endpoint", + "create_inference_endpoint_from_catalog", + "create_pull_request", + "create_repo", + "create_tag", + "create_webhook", + "dataset_info", + "delete_branch", + "delete_collection", + "delete_collection_item", + "delete_file", + "delete_folder", + "delete_inference_endpoint", + "delete_repo", + "delete_space_secret", + "delete_space_storage", + "delete_space_variable", + "delete_tag", + "delete_webhook", + "disable_webhook", + "dump_environment_info", + "duplicate_space", + "edit_discussion_comment", + "enable_webhook", + "export_entries_as_dduf", + "export_folder_as_dduf", + "fetch_job_logs", + "file_exists", + "from_pretrained_fastai", + "from_pretrained_keras", + "get_collection", + "get_dataset_tags", + "get_discussion_details", + "get_full_repo_name", + "get_hf_file_metadata", + "get_inference_endpoint", + "get_model_tags", + "get_paths_info", + "get_repo_discussions", + "get_safetensors_metadata", + "get_session", + "get_space_runtime", + "get_space_variables", + "get_tf_storage_size", + "get_token", + "get_token_permission", + "get_torch_storage_id", + "get_torch_storage_size", + "get_user_overview", + "get_webhook", + "grant_access", + "hf_hub_download", + "hf_hub_url", + "inspect_job", + "interpreter_login", + "list_accepted_access_requests", + "list_collections", + "list_datasets", + "list_inference_catalog", + "list_inference_endpoints", + "list_jobs", + "list_lfs_files", + "list_liked_repos", + "list_models", + "list_organization_members", + "list_papers", + "list_pending_access_requests", + "list_rejected_access_requests", + "list_repo_commits", + "list_repo_files", + "list_repo_likers", + "list_repo_refs", + "list_repo_tree", + "list_spaces", + "list_user_followers", + "list_user_following", + "list_webhooks", + "load_state_dict_from_file", + "load_torch_model", + "logging", + "login", + "logout", + "merge_pull_request", + "metadata_eval_result", + "metadata_load", + "metadata_save", + "metadata_update", + "model_info", + "move_repo", + "notebook_login", + "paper_info", + "parse_huggingface_oauth", + "parse_safetensors_file_metadata", + "pause_inference_endpoint", + "pause_space", + "permanently_delete_lfs_files", + "preupload_lfs_files", + "push_to_hub_fastai", + "push_to_hub_keras", + "read_dduf_file", + "reject_access_request", + "rename_discussion", + "repo_exists", + "repo_info", + "repo_type_and_id_from_hf_id", + "request_space_hardware", + "request_space_storage", + "restart_space", + "resume_inference_endpoint", + "revision_exists", + "run_as_future", + "run_job", + "run_uv_job", + "save_pretrained_keras", + "save_torch_model", + "save_torch_state_dict", + "scale_to_zero_inference_endpoint", + "scan_cache_dir", + "set_space_sleep_time", + "snapshot_download", + "space_info", + "split_state_dict_into_shards_factory", + "split_tf_state_dict_into_shards", + "split_torch_state_dict_into_shards", + "super_squash_history", + "try_to_load_from_cache", + "unlike", + "update_collection_item", + "update_collection_metadata", + "update_inference_endpoint", + "update_repo_settings", + "update_repo_visibility", + "update_webhook", + "upload_file", + "upload_folder", + "upload_large_folder", + "webhook_endpoint", + "whoami", +] + + +def _attach(package_name, submodules=None, submod_attrs=None): + """Attach lazily loaded submodules, functions, or other attributes. + + Typically, modules import submodules and attributes as follows: + + ```py + import mysubmodule + import anothersubmodule + + from .foo import someattr + ``` + + The idea is to replace a package's `__getattr__`, `__dir__`, such that all imports + work exactly the way they would with normal imports, except that the import occurs + upon first use. + + The typical way to call this function, replacing the above imports, is: + + ```python + __getattr__, __dir__ = lazy.attach( + __name__, + ['mysubmodule', 'anothersubmodule'], + {'foo': ['someattr']} + ) + ``` + This functionality requires Python 3.7 or higher. + + Args: + package_name (`str`): + Typically use `__name__`. + submodules (`set`): + List of submodules to attach. + submod_attrs (`dict`): + Dictionary of submodule -> list of attributes / functions. + These attributes are imported as they are used. + + Returns: + __getattr__, __dir__, __all__ + + """ + if submod_attrs is None: + submod_attrs = {} + + if submodules is None: + submodules = set() + else: + submodules = set(submodules) + + attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs} + + def __getattr__(name): + if name in submodules: + try: + return importlib.import_module(f"{package_name}.{name}") + except Exception as e: + print(f"Error importing {package_name}.{name}: {e}") + raise + elif name in attr_to_modules: + submod_path = f"{package_name}.{attr_to_modules[name]}" + try: + submod = importlib.import_module(submod_path) + except Exception as e: + print(f"Error importing {submod_path}: {e}") + raise + attr = getattr(submod, name) + + # If the attribute lives in a file (module) with the same + # name as the attribute, ensure that the attribute and *not* + # the module is accessible on the package. + if name == attr_to_modules[name]: + pkg = sys.modules[package_name] + pkg.__dict__[name] = attr + + return attr + else: + raise AttributeError(f"No {package_name} attribute {name}") + + def __dir__(): + return __all__ + + return __getattr__, __dir__ + + +__getattr__, __dir__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS) + +if os.environ.get("EAGER_IMPORT", ""): + for attr in __all__: + __getattr__(attr) + +# WARNING: any content below this statement is generated automatically. Any manual edit +# will be lost when re-generating this file ! +# +# To update the static imports, please run the following command and commit the changes. +# ``` +# # Use script +# python utils/check_static_imports.py --update +# +# # Or run style on codebase +# make style +# ``` +if TYPE_CHECKING: # pragma: no cover + from ._commit_scheduler import CommitScheduler # noqa: F401 + from ._inference_endpoints import ( + InferenceEndpoint, # noqa: F401 + InferenceEndpointError, # noqa: F401 + InferenceEndpointStatus, # noqa: F401 + InferenceEndpointTimeoutError, # noqa: F401 + InferenceEndpointType, # noqa: F401 + ) + from ._jobs_api import ( + JobInfo, # noqa: F401 + JobOwner, # noqa: F401 + JobStage, # noqa: F401 + JobStatus, # noqa: F401 + ) + from ._login import ( + auth_list, # noqa: F401 + auth_switch, # noqa: F401 + interpreter_login, # noqa: F401 + login, # noqa: F401 + logout, # noqa: F401 + notebook_login, # noqa: F401 + ) + from ._oauth import ( + OAuthInfo, # noqa: F401 + OAuthOrgInfo, # noqa: F401 + OAuthUserInfo, # noqa: F401 + attach_huggingface_oauth, # noqa: F401 + parse_huggingface_oauth, # noqa: F401 + ) + from ._snapshot_download import snapshot_download # noqa: F401 + from ._space_api import ( + SpaceHardware, # noqa: F401 + SpaceRuntime, # noqa: F401 + SpaceStage, # noqa: F401 + SpaceStorage, # noqa: F401 + SpaceVariable, # noqa: F401 + ) + from ._tensorboard_logger import HFSummaryWriter # noqa: F401 + from ._webhooks_payload import ( + WebhookPayload, # noqa: F401 + WebhookPayloadComment, # noqa: F401 + WebhookPayloadDiscussion, # noqa: F401 + WebhookPayloadDiscussionChanges, # noqa: F401 + WebhookPayloadEvent, # noqa: F401 + WebhookPayloadMovedTo, # noqa: F401 + WebhookPayloadRepo, # noqa: F401 + WebhookPayloadUrl, # noqa: F401 + WebhookPayloadWebhook, # noqa: F401 + ) + from ._webhooks_server import ( + WebhooksServer, # noqa: F401 + webhook_endpoint, # noqa: F401 + ) + from .community import ( + Discussion, # noqa: F401 + DiscussionComment, # noqa: F401 + DiscussionCommit, # noqa: F401 + DiscussionEvent, # noqa: F401 + DiscussionStatusChange, # noqa: F401 + DiscussionTitleChange, # noqa: F401 + DiscussionWithDetails, # noqa: F401 + ) + from .constants import ( + CONFIG_NAME, # noqa: F401 + FLAX_WEIGHTS_NAME, # noqa: F401 + HUGGINGFACE_CO_URL_HOME, # noqa: F401 + HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 + PYTORCH_WEIGHTS_NAME, # noqa: F401 + REPO_TYPE_DATASET, # noqa: F401 + REPO_TYPE_MODEL, # noqa: F401 + REPO_TYPE_SPACE, # noqa: F401 + TF2_WEIGHTS_NAME, # noqa: F401 + TF_WEIGHTS_NAME, # noqa: F401 + ) + from .fastai_utils import ( + _save_pretrained_fastai, # noqa: F401 + from_pretrained_fastai, # noqa: F401 + push_to_hub_fastai, # noqa: F401 + ) + from .file_download import ( + _CACHED_NO_EXIST, # noqa: F401 + HfFileMetadata, # noqa: F401 + get_hf_file_metadata, # noqa: F401 + hf_hub_download, # noqa: F401 + hf_hub_url, # noqa: F401 + try_to_load_from_cache, # noqa: F401 + ) + from .hf_api import ( + Collection, # noqa: F401 + CollectionItem, # noqa: F401 + CommitInfo, # noqa: F401 + CommitOperation, # noqa: F401 + CommitOperationAdd, # noqa: F401 + CommitOperationCopy, # noqa: F401 + CommitOperationDelete, # noqa: F401 + DatasetInfo, # noqa: F401 + GitCommitInfo, # noqa: F401 + GitRefInfo, # noqa: F401 + GitRefs, # noqa: F401 + HfApi, # noqa: F401 + ModelInfo, # noqa: F401 + RepoUrl, # noqa: F401 + SpaceInfo, # noqa: F401 + User, # noqa: F401 + UserLikes, # noqa: F401 + WebhookInfo, # noqa: F401 + WebhookWatchedItem, # noqa: F401 + accept_access_request, # noqa: F401 + add_collection_item, # noqa: F401 + add_space_secret, # noqa: F401 + add_space_variable, # noqa: F401 + auth_check, # noqa: F401 + cancel_access_request, # noqa: F401 + cancel_job, # noqa: F401 + change_discussion_status, # noqa: F401 + comment_discussion, # noqa: F401 + create_branch, # noqa: F401 + create_collection, # noqa: F401 + create_commit, # noqa: F401 + create_discussion, # noqa: F401 + create_inference_endpoint, # noqa: F401 + create_inference_endpoint_from_catalog, # noqa: F401 + create_pull_request, # noqa: F401 + create_repo, # noqa: F401 + create_tag, # noqa: F401 + create_webhook, # noqa: F401 + dataset_info, # noqa: F401 + delete_branch, # noqa: F401 + delete_collection, # noqa: F401 + delete_collection_item, # noqa: F401 + delete_file, # noqa: F401 + delete_folder, # noqa: F401 + delete_inference_endpoint, # noqa: F401 + delete_repo, # noqa: F401 + delete_space_secret, # noqa: F401 + delete_space_storage, # noqa: F401 + delete_space_variable, # noqa: F401 + delete_tag, # noqa: F401 + delete_webhook, # noqa: F401 + disable_webhook, # noqa: F401 + duplicate_space, # noqa: F401 + edit_discussion_comment, # noqa: F401 + enable_webhook, # noqa: F401 + fetch_job_logs, # noqa: F401 + file_exists, # noqa: F401 + get_collection, # noqa: F401 + get_dataset_tags, # noqa: F401 + get_discussion_details, # noqa: F401 + get_full_repo_name, # noqa: F401 + get_inference_endpoint, # noqa: F401 + get_model_tags, # noqa: F401 + get_paths_info, # noqa: F401 + get_repo_discussions, # noqa: F401 + get_safetensors_metadata, # noqa: F401 + get_space_runtime, # noqa: F401 + get_space_variables, # noqa: F401 + get_token_permission, # noqa: F401 + get_user_overview, # noqa: F401 + get_webhook, # noqa: F401 + grant_access, # noqa: F401 + inspect_job, # noqa: F401 + list_accepted_access_requests, # noqa: F401 + list_collections, # noqa: F401 + list_datasets, # noqa: F401 + list_inference_catalog, # noqa: F401 + list_inference_endpoints, # noqa: F401 + list_jobs, # noqa: F401 + list_lfs_files, # noqa: F401 + list_liked_repos, # noqa: F401 + list_models, # noqa: F401 + list_organization_members, # noqa: F401 + list_papers, # noqa: F401 + list_pending_access_requests, # noqa: F401 + list_rejected_access_requests, # noqa: F401 + list_repo_commits, # noqa: F401 + list_repo_files, # noqa: F401 + list_repo_likers, # noqa: F401 + list_repo_refs, # noqa: F401 + list_repo_tree, # noqa: F401 + list_spaces, # noqa: F401 + list_user_followers, # noqa: F401 + list_user_following, # noqa: F401 + list_webhooks, # noqa: F401 + merge_pull_request, # noqa: F401 + model_info, # noqa: F401 + move_repo, # noqa: F401 + paper_info, # noqa: F401 + parse_safetensors_file_metadata, # noqa: F401 + pause_inference_endpoint, # noqa: F401 + pause_space, # noqa: F401 + permanently_delete_lfs_files, # noqa: F401 + preupload_lfs_files, # noqa: F401 + reject_access_request, # noqa: F401 + rename_discussion, # noqa: F401 + repo_exists, # noqa: F401 + repo_info, # noqa: F401 + repo_type_and_id_from_hf_id, # noqa: F401 + request_space_hardware, # noqa: F401 + request_space_storage, # noqa: F401 + restart_space, # noqa: F401 + resume_inference_endpoint, # noqa: F401 + revision_exists, # noqa: F401 + run_as_future, # noqa: F401 + run_job, # noqa: F401 + run_uv_job, # noqa: F401 + scale_to_zero_inference_endpoint, # noqa: F401 + set_space_sleep_time, # noqa: F401 + space_info, # noqa: F401 + super_squash_history, # noqa: F401 + unlike, # noqa: F401 + update_collection_item, # noqa: F401 + update_collection_metadata, # noqa: F401 + update_inference_endpoint, # noqa: F401 + update_repo_settings, # noqa: F401 + update_repo_visibility, # noqa: F401 + update_webhook, # noqa: F401 + upload_file, # noqa: F401 + upload_folder, # noqa: F401 + upload_large_folder, # noqa: F401 + whoami, # noqa: F401 + ) + from .hf_file_system import ( + HfFileSystem, # noqa: F401 + HfFileSystemFile, # noqa: F401 + HfFileSystemResolvedPath, # noqa: F401 + HfFileSystemStreamFile, # noqa: F401 + ) + from .hub_mixin import ( + ModelHubMixin, # noqa: F401 + PyTorchModelHubMixin, # noqa: F401 + ) + from .inference._client import ( + InferenceClient, # noqa: F401 + InferenceTimeoutError, # noqa: F401 + ) + from .inference._generated._async_client import AsyncInferenceClient # noqa: F401 + from .inference._generated.types import ( + AudioClassificationInput, # noqa: F401 + AudioClassificationOutputElement, # noqa: F401 + AudioClassificationOutputTransform, # noqa: F401 + AudioClassificationParameters, # noqa: F401 + AudioToAudioInput, # noqa: F401 + AudioToAudioOutputElement, # noqa: F401 + AutomaticSpeechRecognitionEarlyStoppingEnum, # noqa: F401 + AutomaticSpeechRecognitionGenerationParameters, # noqa: F401 + AutomaticSpeechRecognitionInput, # noqa: F401 + AutomaticSpeechRecognitionOutput, # noqa: F401 + AutomaticSpeechRecognitionOutputChunk, # noqa: F401 + AutomaticSpeechRecognitionParameters, # noqa: F401 + ChatCompletionInput, # noqa: F401 + ChatCompletionInputFunctionDefinition, # noqa: F401 + ChatCompletionInputFunctionName, # noqa: F401 + ChatCompletionInputGrammarType, # noqa: F401 + ChatCompletionInputJSONSchema, # noqa: F401 + ChatCompletionInputMessage, # noqa: F401 + ChatCompletionInputMessageChunk, # noqa: F401 + ChatCompletionInputMessageChunkType, # noqa: F401 + ChatCompletionInputResponseFormatJSONObject, # noqa: F401 + ChatCompletionInputResponseFormatJSONSchema, # noqa: F401 + ChatCompletionInputResponseFormatText, # noqa: F401 + ChatCompletionInputStreamOptions, # noqa: F401 + ChatCompletionInputTool, # noqa: F401 + ChatCompletionInputToolCall, # noqa: F401 + ChatCompletionInputToolChoiceClass, # noqa: F401 + ChatCompletionInputToolChoiceEnum, # noqa: F401 + ChatCompletionInputURL, # noqa: F401 + ChatCompletionOutput, # noqa: F401 + ChatCompletionOutputComplete, # noqa: F401 + ChatCompletionOutputFunctionDefinition, # noqa: F401 + ChatCompletionOutputLogprob, # noqa: F401 + ChatCompletionOutputLogprobs, # noqa: F401 + ChatCompletionOutputMessage, # noqa: F401 + ChatCompletionOutputToolCall, # noqa: F401 + ChatCompletionOutputTopLogprob, # noqa: F401 + ChatCompletionOutputUsage, # noqa: F401 + ChatCompletionStreamOutput, # noqa: F401 + ChatCompletionStreamOutputChoice, # noqa: F401 + ChatCompletionStreamOutputDelta, # noqa: F401 + ChatCompletionStreamOutputDeltaToolCall, # noqa: F401 + ChatCompletionStreamOutputFunction, # noqa: F401 + ChatCompletionStreamOutputLogprob, # noqa: F401 + ChatCompletionStreamOutputLogprobs, # noqa: F401 + ChatCompletionStreamOutputTopLogprob, # noqa: F401 + ChatCompletionStreamOutputUsage, # noqa: F401 + DepthEstimationInput, # noqa: F401 + DepthEstimationOutput, # noqa: F401 + DocumentQuestionAnsweringInput, # noqa: F401 + DocumentQuestionAnsweringInputData, # noqa: F401 + DocumentQuestionAnsweringOutputElement, # noqa: F401 + DocumentQuestionAnsweringParameters, # noqa: F401 + FeatureExtractionInput, # noqa: F401 + FeatureExtractionInputTruncationDirection, # noqa: F401 + FillMaskInput, # noqa: F401 + FillMaskOutputElement, # noqa: F401 + FillMaskParameters, # noqa: F401 + ImageClassificationInput, # noqa: F401 + ImageClassificationOutputElement, # noqa: F401 + ImageClassificationOutputTransform, # noqa: F401 + ImageClassificationParameters, # noqa: F401 + ImageSegmentationInput, # noqa: F401 + ImageSegmentationOutputElement, # noqa: F401 + ImageSegmentationParameters, # noqa: F401 + ImageSegmentationSubtask, # noqa: F401 + ImageToImageInput, # noqa: F401 + ImageToImageOutput, # noqa: F401 + ImageToImageParameters, # noqa: F401 + ImageToImageTargetSize, # noqa: F401 + ImageToTextEarlyStoppingEnum, # noqa: F401 + ImageToTextGenerationParameters, # noqa: F401 + ImageToTextInput, # noqa: F401 + ImageToTextOutput, # noqa: F401 + ImageToTextParameters, # noqa: F401 + ImageToVideoInput, # noqa: F401 + ImageToVideoOutput, # noqa: F401 + ImageToVideoParameters, # noqa: F401 + ImageToVideoTargetSize, # noqa: F401 + ObjectDetectionBoundingBox, # noqa: F401 + ObjectDetectionInput, # noqa: F401 + ObjectDetectionOutputElement, # noqa: F401 + ObjectDetectionParameters, # noqa: F401 + Padding, # noqa: F401 + QuestionAnsweringInput, # noqa: F401 + QuestionAnsweringInputData, # noqa: F401 + QuestionAnsweringOutputElement, # noqa: F401 + QuestionAnsweringParameters, # noqa: F401 + SentenceSimilarityInput, # noqa: F401 + SentenceSimilarityInputData, # noqa: F401 + SummarizationInput, # noqa: F401 + SummarizationOutput, # noqa: F401 + SummarizationParameters, # noqa: F401 + SummarizationTruncationStrategy, # noqa: F401 + TableQuestionAnsweringInput, # noqa: F401 + TableQuestionAnsweringInputData, # noqa: F401 + TableQuestionAnsweringOutputElement, # noqa: F401 + TableQuestionAnsweringParameters, # noqa: F401 + Text2TextGenerationInput, # noqa: F401 + Text2TextGenerationOutput, # noqa: F401 + Text2TextGenerationParameters, # noqa: F401 + Text2TextGenerationTruncationStrategy, # noqa: F401 + TextClassificationInput, # noqa: F401 + TextClassificationOutputElement, # noqa: F401 + TextClassificationOutputTransform, # noqa: F401 + TextClassificationParameters, # noqa: F401 + TextGenerationInput, # noqa: F401 + TextGenerationInputGenerateParameters, # noqa: F401 + TextGenerationInputGrammarType, # noqa: F401 + TextGenerationOutput, # noqa: F401 + TextGenerationOutputBestOfSequence, # noqa: F401 + TextGenerationOutputDetails, # noqa: F401 + TextGenerationOutputFinishReason, # noqa: F401 + TextGenerationOutputPrefillToken, # noqa: F401 + TextGenerationOutputToken, # noqa: F401 + TextGenerationStreamOutput, # noqa: F401 + TextGenerationStreamOutputStreamDetails, # noqa: F401 + TextGenerationStreamOutputToken, # noqa: F401 + TextToAudioEarlyStoppingEnum, # noqa: F401 + TextToAudioGenerationParameters, # noqa: F401 + TextToAudioInput, # noqa: F401 + TextToAudioOutput, # noqa: F401 + TextToAudioParameters, # noqa: F401 + TextToImageInput, # noqa: F401 + TextToImageOutput, # noqa: F401 + TextToImageParameters, # noqa: F401 + TextToSpeechEarlyStoppingEnum, # noqa: F401 + TextToSpeechGenerationParameters, # noqa: F401 + TextToSpeechInput, # noqa: F401 + TextToSpeechOutput, # noqa: F401 + TextToSpeechParameters, # noqa: F401 + TextToVideoInput, # noqa: F401 + TextToVideoOutput, # noqa: F401 + TextToVideoParameters, # noqa: F401 + TokenClassificationAggregationStrategy, # noqa: F401 + TokenClassificationInput, # noqa: F401 + TokenClassificationOutputElement, # noqa: F401 + TokenClassificationParameters, # noqa: F401 + TranslationInput, # noqa: F401 + TranslationOutput, # noqa: F401 + TranslationParameters, # noqa: F401 + TranslationTruncationStrategy, # noqa: F401 + TypeEnum, # noqa: F401 + VideoClassificationInput, # noqa: F401 + VideoClassificationOutputElement, # noqa: F401 + VideoClassificationOutputTransform, # noqa: F401 + VideoClassificationParameters, # noqa: F401 + VisualQuestionAnsweringInput, # noqa: F401 + VisualQuestionAnsweringInputData, # noqa: F401 + VisualQuestionAnsweringOutputElement, # noqa: F401 + VisualQuestionAnsweringParameters, # noqa: F401 + ZeroShotClassificationInput, # noqa: F401 + ZeroShotClassificationOutputElement, # noqa: F401 + ZeroShotClassificationParameters, # noqa: F401 + ZeroShotImageClassificationInput, # noqa: F401 + ZeroShotImageClassificationOutputElement, # noqa: F401 + ZeroShotImageClassificationParameters, # noqa: F401 + ZeroShotObjectDetectionBoundingBox, # noqa: F401 + ZeroShotObjectDetectionInput, # noqa: F401 + ZeroShotObjectDetectionOutputElement, # noqa: F401 + ZeroShotObjectDetectionParameters, # noqa: F401 + ) + from .inference._mcp.agent import Agent # noqa: F401 + from .inference._mcp.mcp_client import MCPClient # noqa: F401 + from .inference_api import InferenceApi # noqa: F401 + from .keras_mixin import ( + KerasModelHubMixin, # noqa: F401 + from_pretrained_keras, # noqa: F401 + push_to_hub_keras, # noqa: F401 + save_pretrained_keras, # noqa: F401 + ) + from .repocard import ( + DatasetCard, # noqa: F401 + ModelCard, # noqa: F401 + RepoCard, # noqa: F401 + SpaceCard, # noqa: F401 + metadata_eval_result, # noqa: F401 + metadata_load, # noqa: F401 + metadata_save, # noqa: F401 + metadata_update, # noqa: F401 + ) + from .repocard_data import ( + CardData, # noqa: F401 + DatasetCardData, # noqa: F401 + EvalResult, # noqa: F401 + ModelCardData, # noqa: F401 + SpaceCardData, # noqa: F401 + ) + from .repository import Repository # noqa: F401 + from .serialization import ( + StateDictSplit, # noqa: F401 + get_tf_storage_size, # noqa: F401 + get_torch_storage_id, # noqa: F401 + get_torch_storage_size, # noqa: F401 + load_state_dict_from_file, # noqa: F401 + load_torch_model, # noqa: F401 + save_torch_model, # noqa: F401 + save_torch_state_dict, # noqa: F401 + split_state_dict_into_shards_factory, # noqa: F401 + split_tf_state_dict_into_shards, # noqa: F401 + split_torch_state_dict_into_shards, # noqa: F401 + ) + from .serialization._dduf import ( + DDUFEntry, # noqa: F401 + export_entries_as_dduf, # noqa: F401 + export_folder_as_dduf, # noqa: F401 + read_dduf_file, # noqa: F401 + ) + from .utils import ( + CachedFileInfo, # noqa: F401 + CachedRepoInfo, # noqa: F401 + CachedRevisionInfo, # noqa: F401 + CacheNotFound, # noqa: F401 + CorruptedCacheException, # noqa: F401 + DeleteCacheStrategy, # noqa: F401 + HFCacheInfo, # noqa: F401 + HfFolder, # noqa: F401 + cached_assets_path, # noqa: F401 + configure_http_backend, # noqa: F401 + dump_environment_info, # noqa: F401 + get_session, # noqa: F401 + get_token, # noqa: F401 + logging, # noqa: F401 + scan_cache_dir, # noqa: F401 + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/_commit_api.py b/phivenv/Lib/site-packages/huggingface_hub/_commit_api.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8fa86e6caf9e2db6ff5ce90f147267a2ab6e7d --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_commit_api.py @@ -0,0 +1,908 @@ +""" +Type definitions and utilities for the `create_commit` API +""" + +import base64 +import io +import os +import warnings +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from itertools import groupby +from pathlib import Path, PurePosixPath +from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union + +from tqdm.contrib.concurrent import thread_map + +from . import constants +from .errors import EntryNotFoundError, HfHubHTTPError, XetAuthorizationError, XetRefreshTokenError +from .file_download import hf_hub_url +from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info +from .utils import ( + FORBIDDEN_FOLDERS, + XetTokenType, + are_progress_bars_disabled, + chunk_iterable, + fetch_xet_connection_info_from_repo_info, + get_session, + hf_raise_for_status, + logging, + sha, + tqdm_stream_file, + validate_hf_hub_args, +) +from .utils import tqdm as hf_tqdm + + +if TYPE_CHECKING: + from .hf_api import RepoFile + + +logger = logging.get_logger(__name__) + + +UploadMode = Literal["lfs", "regular"] + +# Max is 1,000 per request on the Hub for HfApi.get_paths_info +# Otherwise we get: +# HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters +# See https://github.com/huggingface/huggingface_hub/issues/1503 +FETCH_LFS_BATCH_SIZE = 500 + +UPLOAD_BATCH_MAX_NUM_FILES = 256 + + +@dataclass +class CommitOperationDelete: + """ + Data structure holding necessary info to delete a file or a folder from a repository + on the Hub. + + Args: + path_in_repo (`str`): + Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` + for a file or `"checkpoints/1fec34a/"` for a folder. + is_folder (`bool` or `Literal["auto"]`, *optional*) + Whether the Delete Operation applies to a folder or not. If "auto", the path + type (file or folder) is guessed automatically by looking if path ends with + a "/" (folder) or not (file). To explicitly set the path type, you can set + `is_folder=True` or `is_folder=False`. + """ + + path_in_repo: str + is_folder: Union[bool, Literal["auto"]] = "auto" + + def __post_init__(self): + self.path_in_repo = _validate_path_in_repo(self.path_in_repo) + + if self.is_folder == "auto": + self.is_folder = self.path_in_repo.endswith("/") + if not isinstance(self.is_folder, bool): + raise ValueError( + f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'." + ) + + +@dataclass +class CommitOperationCopy: + """ + Data structure holding necessary info to copy a file in a repository on the Hub. + + Limitations: + - Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it + - Cross-repository copies are not supported. + + Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub. + + Args: + src_path_in_repo (`str`): + Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`. + path_in_repo (`str`): + Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`. + src_revision (`str`, *optional*): + The git revision of the file to be copied. Can be any valid git revision. + Default to the target commit revision. + """ + + src_path_in_repo: str + path_in_repo: str + src_revision: Optional[str] = None + # set to the OID of the file to be copied if it has already been uploaded + # useful to determine if a commit will be empty or not. + _src_oid: Optional[str] = None + # set to the OID of the file to copy to if it has already been uploaded + # useful to determine if a commit will be empty or not. + _dest_oid: Optional[str] = None + + def __post_init__(self): + self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo) + self.path_in_repo = _validate_path_in_repo(self.path_in_repo) + + +@dataclass +class CommitOperationAdd: + """ + Data structure holding necessary info to upload a file to a repository on the Hub. + + Args: + path_in_repo (`str`): + Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` + path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`): + Either: + - a path to a local file (as `str` or `pathlib.Path`) to upload + - a buffer of bytes (`bytes`) holding the content of the file to upload + - a "file object" (subclass of `io.BufferedIOBase`), typically obtained + with `open(path, "rb")`. It must support `seek()` and `tell()` methods. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both + `seek()` and `tell()`. + """ + + path_in_repo: str + path_or_fileobj: Union[str, Path, bytes, BinaryIO] + upload_info: UploadInfo = field(init=False, repr=False) + + # Internal attributes + + # set to "lfs" or "regular" once known + _upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None) + + # set to True if .gitignore rules prevent the file from being uploaded as LFS + # (server-side check) + _should_ignore: Optional[bool] = field(init=False, repr=False, default=None) + + # set to the remote OID of the file if it has already been uploaded + # useful to determine if a commit will be empty or not + _remote_oid: Optional[str] = field(init=False, repr=False, default=None) + + # set to True once the file has been uploaded as LFS + _is_uploaded: bool = field(init=False, repr=False, default=False) + + # set to True once the file has been committed + _is_committed: bool = field(init=False, repr=False, default=False) + + def __post_init__(self) -> None: + """Validates `path_or_fileobj` and compute `upload_info`.""" + self.path_in_repo = _validate_path_in_repo(self.path_in_repo) + + # Validate `path_or_fileobj` value + if isinstance(self.path_or_fileobj, Path): + self.path_or_fileobj = str(self.path_or_fileobj) + if isinstance(self.path_or_fileobj, str): + path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj)) + if not os.path.isfile(path_or_fileobj): + raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system") + elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)): + # ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode + raise ValueError( + "path_or_fileobj must be either an instance of str, bytes or" + " io.BufferedIOBase. If you passed a file-like object, make sure it is" + " in binary mode." + ) + if isinstance(self.path_or_fileobj, io.BufferedIOBase): + try: + self.path_or_fileobj.tell() + self.path_or_fileobj.seek(0, os.SEEK_CUR) + except (OSError, AttributeError) as exc: + raise ValueError( + "path_or_fileobj is a file-like object but does not implement seek() and tell()" + ) from exc + + # Compute "upload_info" attribute + if isinstance(self.path_or_fileobj, str): + self.upload_info = UploadInfo.from_path(self.path_or_fileobj) + elif isinstance(self.path_or_fileobj, bytes): + self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj) + else: + self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj) + + @contextmanager + def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]: + """ + A context manager that yields a file-like object allowing to read the underlying + data behind `path_or_fileobj`. + + Args: + with_tqdm (`bool`, *optional*, defaults to `False`): + If True, iterating over the file object will display a progress bar. Only + works if the file-like object is a path to a file. Pure bytes and buffers + are not supported. + + Example: + + ```python + >>> operation = CommitOperationAdd( + ... path_in_repo="remote/dir/weights.h5", + ... path_or_fileobj="./local/weights.h5", + ... ) + CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5') + + >>> with operation.as_file() as file: + ... content = file.read() + + >>> with operation.as_file(with_tqdm=True) as file: + ... while True: + ... data = file.read(1024) + ... if not data: + ... break + config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] + + >>> with operation.as_file(with_tqdm=True) as file: + ... requests.put(..., data=file) + config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] + ``` + """ + if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path): + if with_tqdm: + with tqdm_stream_file(self.path_or_fileobj) as file: + yield file + else: + with open(self.path_or_fileobj, "rb") as file: + yield file + elif isinstance(self.path_or_fileobj, bytes): + yield io.BytesIO(self.path_or_fileobj) + elif isinstance(self.path_or_fileobj, io.BufferedIOBase): + prev_pos = self.path_or_fileobj.tell() + yield self.path_or_fileobj + self.path_or_fileobj.seek(prev_pos, io.SEEK_SET) + + def b64content(self) -> bytes: + """ + The base64-encoded content of `path_or_fileobj` + + Returns: `bytes` + """ + with self.as_file() as file: + return base64.b64encode(file.read()) + + @property + def _local_oid(self) -> Optional[str]: + """Return the OID of the local file. + + This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one. + If the file did not change, we won't upload it again to prevent empty commits. + + For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref). + For regular files, the OID corresponds to the SHA1 of the file content. + Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the + pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes + and more convenient client-side. + """ + if self._upload_mode is None: + return None + elif self._upload_mode == "lfs": + return self.upload_info.sha256.hex() + else: + # Regular file => compute sha1 + # => no need to read by chunk since the file is guaranteed to be <=5MB. + with self.as_file() as file: + return sha.git_hash(file.read()) + + +def _validate_path_in_repo(path_in_repo: str) -> str: + # Validate `path_in_repo` value to prevent a server-side issue + if path_in_repo.startswith("/"): + path_in_repo = path_in_repo[1:] + if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"): + raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'") + if path_in_repo.startswith("./"): + path_in_repo = path_in_repo[2:] + for forbidden in FORBIDDEN_FOLDERS: + if any(part == forbidden for part in path_in_repo.split("/")): + raise ValueError( + f"Invalid `path_in_repo` in CommitOperation: cannot update files under a '{forbidden}/' folder (path:" + f" '{path_in_repo}')." + ) + return path_in_repo + + +CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete] + + +def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None: + """ + Warn user when a list of operations is expected to overwrite itself in a single + commit. + + Rules: + - If a filepath is updated by multiple `CommitOperationAdd` operations, a warning + message is triggered. + - If a filepath is updated at least once by a `CommitOperationAdd` and then deleted + by a `CommitOperationDelete`, a warning is triggered. + - If a `CommitOperationDelete` deletes a filepath that is then updated by a + `CommitOperationAdd`, no warning is triggered. This is usually useless (no need to + delete before upload) but can happen if a user deletes an entire folder and then + add new files to it. + """ + nb_additions_per_path: Dict[str, int] = defaultdict(int) + for operation in operations: + path_in_repo = operation.path_in_repo + if isinstance(operation, CommitOperationAdd): + if nb_additions_per_path[path_in_repo] > 0: + warnings.warn( + "About to update multiple times the same file in the same commit:" + f" '{path_in_repo}'. This can cause undesired inconsistencies in" + " your repo." + ) + nb_additions_per_path[path_in_repo] += 1 + for parent in PurePosixPath(path_in_repo).parents: + # Also keep track of number of updated files per folder + # => warns if deleting a folder overwrite some contained files + nb_additions_per_path[str(parent)] += 1 + if isinstance(operation, CommitOperationDelete): + if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0: + if operation.is_folder: + warnings.warn( + "About to delete a folder containing files that have just been" + f" updated within the same commit: '{path_in_repo}'. This can" + " cause undesired inconsistencies in your repo." + ) + else: + warnings.warn( + "About to delete a file that have just been updated within the" + f" same commit: '{path_in_repo}'. This can cause undesired" + " inconsistencies in your repo." + ) + + +@validate_hf_hub_args +def _upload_lfs_files( + *, + additions: List[CommitOperationAdd], + repo_type: str, + repo_id: str, + headers: Dict[str, str], + endpoint: Optional[str] = None, + num_threads: int = 5, + revision: Optional[str] = None, +): + """ + Uploads the content of `additions` to the Hub using the large file storage protocol. + + Relevant external documentation: + - LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md + + Args: + additions (`List` of `CommitOperationAdd`): + The files to be uploaded + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`Dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + num_threads (`int`, *optional*): + The number of concurrent threads to use when uploading. Defaults to 5. + revision (`str`, *optional*): + The git revision to upload to. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If an upload failed for any reason + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the server returns malformed responses + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the LFS batch endpoint returned an HTTP error. + """ + # Step 1: retrieve upload instructions from the LFS batch endpoint. + # Upload instructions are retrieved by chunk of 256 files to avoid reaching + # the payload limit. + batch_actions: List[Dict] = [] + for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES): + batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info( + upload_infos=[op.upload_info for op in chunk], + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + headers=headers, + token=None, # already passed in 'headers' + ) + + # If at least 1 error, we do not retrieve information for other chunks + if batch_errors_chunk: + message = "\n".join( + [ + f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}" + for err in batch_errors_chunk + ] + ) + raise ValueError(f"LFS batch endpoint returned errors:\n{message}") + + batch_actions += batch_actions_chunk + oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions} + + # Step 2: ignore files that have already been uploaded + filtered_actions = [] + for action in batch_actions: + if action.get("actions") is None: + logger.debug( + f"Content of file {oid2addop[action['oid']].path_in_repo} is already" + " present upstream - skipping upload." + ) + else: + filtered_actions.append(action) + + if len(filtered_actions) == 0: + logger.debug("No LFS files to upload.") + return + + # Step 3: upload files concurrently according to these instructions + def _wrapped_lfs_upload(batch_action) -> None: + try: + operation = oid2addop[batch_action["oid"]] + lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers, endpoint=endpoint) + except Exception as exc: + raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc + + if constants.HF_HUB_ENABLE_HF_TRANSFER: + logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.") + for action in hf_tqdm(filtered_actions, name="huggingface_hub.lfs_upload"): + _wrapped_lfs_upload(action) + elif len(filtered_actions) == 1: + logger.debug("Uploading 1 LFS file to the Hub") + _wrapped_lfs_upload(filtered_actions[0]) + else: + logger.debug( + f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently" + ) + thread_map( + _wrapped_lfs_upload, + filtered_actions, + desc=f"Upload {len(filtered_actions)} LFS files", + max_workers=num_threads, + tqdm_class=hf_tqdm, + ) + + +@validate_hf_hub_args +def _upload_xet_files( + *, + additions: List[CommitOperationAdd], + repo_type: str, + repo_id: str, + headers: Dict[str, str], + endpoint: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, +): + """ + Uploads the content of `additions` to the Hub using the xet storage protocol. + This chunks the files and deduplicates the chunks before uploading them to xetcas storage. + + Args: + additions (`List` of `CommitOperationAdd`): + The files to be uploaded. + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`Dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + endpoint: (`str`, *optional*): + The endpoint to use for the xetcas service. Defaults to `constants.ENDPOINT`. + revision (`str`, *optional*): + The git revision to upload to. + create_pr (`bool`, *optional*): + Whether or not to create a Pull Request with that commit. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If an upload failed for any reason. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the server returns malformed responses or if the user is unauthorized to upload to xet storage. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the LFS batch endpoint returned an HTTP error. + + **How it works:** + The file download system uses Xet storage, which is a content-addressable storage system that breaks files into chunks + for efficient storage and transfer. + + `hf_xet.upload_files` manages uploading files by: + - Taking a list of file paths to upload + - Breaking files into smaller chunks for efficient storage + - Avoiding duplicate storage by recognizing identical chunks across files + - Connecting to a storage server (CAS server) that manages these chunks + + The upload process works like this: + 1. Create a local folder at ~/.cache/huggingface/xet/chunk-cache to store file chunks for reuse. + 2. Process files in parallel (up to 8 files at once): + 2.1. Read the file content. + 2.2. Split the file content into smaller chunks based on content patterns: each chunk gets a unique ID based on what's in it. + 2.3. For each chunk: + - Check if it already exists in storage. + - Skip uploading chunks that already exist. + 2.4. Group chunks into larger blocks for efficient transfer. + 2.5. Upload these blocks to the storage server. + 2.6. Create and upload information about how the file is structured. + 3. Return reference files that contain information about the uploaded files, which can be used later to download them. + """ + if len(additions) == 0: + return + + # at this point, we know that hf_xet is installed + from hf_xet import upload_bytes, upload_files + + from .utils._xet_progress_reporting import XetProgressReporter + + try: + xet_connection_info = fetch_xet_connection_info_from_repo_info( + token_type=XetTokenType.WRITE, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + headers=headers, + endpoint=endpoint, + params={"create_pr": "1"} if create_pr else None, + ) + except HfHubHTTPError as e: + if e.response.status_code == 401: + raise XetAuthorizationError( + f"You are unauthorized to upload to xet storage for {repo_type}/{repo_id}. " + f"Please check that you have configured your access token with write access to the repo." + ) from e + raise + + xet_endpoint = xet_connection_info.endpoint + access_token_info = (xet_connection_info.access_token, xet_connection_info.expiration_unix_epoch) + + def token_refresher() -> Tuple[str, int]: + new_xet_connection = fetch_xet_connection_info_from_repo_info( + token_type=XetTokenType.WRITE, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + headers=headers, + endpoint=endpoint, + params={"create_pr": "1"} if create_pr else None, + ) + if new_xet_connection is None: + raise XetRefreshTokenError("Failed to refresh xet token") + return new_xet_connection.access_token, new_xet_connection.expiration_unix_epoch + + if not are_progress_bars_disabled(): + progress = XetProgressReporter() + progress_callback = progress.update_progress + else: + progress, progress_callback = None, None + + try: + for i, chunk in enumerate(chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES)): + _chunk = [op for op in chunk] + + bytes_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, bytes)] + paths_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, (str, Path))] + + if len(paths_ops) > 0: + upload_files( + [str(op.path_or_fileobj) for op in paths_ops], + xet_endpoint, + access_token_info, + token_refresher, + progress_callback, + repo_type, + ) + if len(bytes_ops) > 0: + upload_bytes( + [op.path_or_fileobj for op in bytes_ops], + xet_endpoint, + access_token_info, + token_refresher, + progress_callback, + repo_type, + ) + + finally: + if progress is not None: + progress.close(False) + + return + + +def _validate_preupload_info(preupload_info: dict): + files = preupload_info.get("files") + if not isinstance(files, list): + raise ValueError("preupload_info is improperly formatted") + for file_info in files: + if not ( + isinstance(file_info, dict) + and isinstance(file_info.get("path"), str) + and isinstance(file_info.get("uploadMode"), str) + and (file_info["uploadMode"] in ("lfs", "regular")) + ): + raise ValueError("preupload_info is improperly formatted:") + return preupload_info + + +@validate_hf_hub_args +def _fetch_upload_modes( + additions: Iterable[CommitOperationAdd], + repo_type: str, + repo_id: str, + headers: Dict[str, str], + revision: str, + endpoint: Optional[str] = None, + create_pr: bool = False, + gitignore_content: Optional[str] = None, +) -> None: + """ + Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob, + as a git LFS blob, or as a XET file. Input `additions` are mutated in-place with the upload mode. + + Args: + additions (`Iterable` of :class:`CommitOperationAdd`): + Iterable of :class:`CommitOperationAdd` describing the files to + upload to the Hub. + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`Dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + revision (`str`): + The git revision to upload the files to. Can be any valid git revision. + gitignore_content (`str`, *optional*): + The content of the `.gitignore` file to know which files should be ignored. The order of priority + is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present + in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub + (if any). + Raises: + [`~utils.HfHubHTTPError`] + If the Hub API returned an error. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the Hub API response is improperly formatted. + """ + endpoint = endpoint if endpoint is not None else constants.ENDPOINT + + # Fetch upload mode (LFS or regular) chunk by chunk. + upload_modes: Dict[str, UploadMode] = {} + should_ignore_info: Dict[str, bool] = {} + oid_info: Dict[str, Optional[str]] = {} + + for chunk in chunk_iterable(additions, 256): + payload: Dict = { + "files": [ + { + "path": op.path_in_repo, + "sample": base64.b64encode(op.upload_info.sample).decode("ascii"), + "size": op.upload_info.size, + } + for op in chunk + ] + } + if gitignore_content is not None: + payload["gitIgnore"] = gitignore_content + + resp = get_session().post( + f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}", + json=payload, + headers=headers, + params={"create_pr": "1"} if create_pr else None, + ) + hf_raise_for_status(resp) + preupload_info = _validate_preupload_info(resp.json()) + upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]}) + should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]}) + oid_info.update(**{file["path"]: file.get("oid") for file in preupload_info["files"]}) + + # Set upload mode for each addition operation + for addition in additions: + addition._upload_mode = upload_modes[addition.path_in_repo] + addition._should_ignore = should_ignore_info[addition.path_in_repo] + addition._remote_oid = oid_info[addition.path_in_repo] + + # Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented) + # => empty files are uploaded as "regular" to still allow users to commit them. + for addition in additions: + if addition.upload_info.size == 0: + addition._upload_mode = "regular" + + +@validate_hf_hub_args +def _fetch_files_to_copy( + copies: Iterable[CommitOperationCopy], + repo_type: str, + repo_id: str, + headers: Dict[str, str], + revision: str, + endpoint: Optional[str] = None, +) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]: + """ + Fetch information about the files to copy. + + For LFS files, we only need their metadata (file size and sha256) while for regular files + we need to download the raw content from the Hub. + + Args: + copies (`Iterable` of :class:`CommitOperationCopy`): + Iterable of :class:`CommitOperationCopy` describing the files to + copy on the Hub. + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`Dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + revision (`str`): + The git revision to upload the files to. Can be any valid git revision. + + Returns: `Dict[Tuple[str, Optional[str]], Union[RepoFile, bytes]]]` + Key is the file path and revision of the file to copy. + Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files). + + Raises: + [`~utils.HfHubHTTPError`] + If the Hub API returned an error. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the Hub API response is improperly formatted. + """ + from .hf_api import HfApi, RepoFolder + + hf_api = HfApi(endpoint=endpoint, headers=headers) + files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {} + # Store (path, revision) -> oid mapping + oid_info: Dict[Tuple[str, Optional[str]], Optional[str]] = {} + # 1. Fetch OIDs for destination paths in batches. + dest_paths = [op.path_in_repo for op in copies] + for offset in range(0, len(dest_paths), FETCH_LFS_BATCH_SIZE): + dest_repo_files = hf_api.get_paths_info( + repo_id=repo_id, + paths=dest_paths[offset : offset + FETCH_LFS_BATCH_SIZE], + revision=revision, + repo_type=repo_type, + ) + for file in dest_repo_files: + if not isinstance(file, RepoFolder): + oid_info[(file.path, revision)] = file.blob_id + + # 2. Group by source revision and fetch source file info in batches. + for src_revision, operations in groupby(copies, key=lambda op: op.src_revision): + operations = list(operations) # type: ignore + src_paths = [op.src_path_in_repo for op in operations] + for offset in range(0, len(src_paths), FETCH_LFS_BATCH_SIZE): + src_repo_files = hf_api.get_paths_info( + repo_id=repo_id, + paths=src_paths[offset : offset + FETCH_LFS_BATCH_SIZE], + revision=src_revision or revision, + repo_type=repo_type, + ) + + for src_repo_file in src_repo_files: + if isinstance(src_repo_file, RepoFolder): + raise NotImplementedError("Copying a folder is not implemented.") + oid_info[(src_repo_file.path, src_revision)] = src_repo_file.blob_id + # If it's an LFS file, store the RepoFile object. Otherwise, download raw bytes. + if src_repo_file.lfs: + files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file + else: + # TODO: (optimization) download regular files to copy concurrently + url = hf_hub_url( + endpoint=endpoint, + repo_type=repo_type, + repo_id=repo_id, + revision=src_revision or revision, + filename=src_repo_file.path, + ) + response = get_session().get(url, headers=headers) + hf_raise_for_status(response) + files_to_copy[(src_repo_file.path, src_revision)] = response.content + # 3. Ensure all operations found a corresponding file in the Hub + # and track src/dest OIDs for each operation. + for operation in operations: + if (operation.src_path_in_repo, src_revision) not in files_to_copy: + raise EntryNotFoundError( + f"Cannot copy {operation.src_path_in_repo} at revision " + f"{src_revision or revision}: file is missing on repo." + ) + operation._src_oid = oid_info.get((operation.src_path_in_repo, operation.src_revision)) + operation._dest_oid = oid_info.get((operation.path_in_repo, revision)) + return files_to_copy + + +def _prepare_commit_payload( + operations: Iterable[CommitOperation], + files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]], + commit_message: str, + commit_description: Optional[str] = None, + parent_commit: Optional[str] = None, +) -> Iterable[Dict[str, Any]]: + """ + Builds the payload to POST to the `/commit` API of the Hub. + + Payload is returned as an iterator so that it can be streamed as a ndjson in the + POST request. + + For more information, see: + - https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073 + - http://ndjson.org/ + """ + commit_description = commit_description if commit_description is not None else "" + + # 1. Send a header item with the commit metadata + header_value = {"summary": commit_message, "description": commit_description} + if parent_commit is not None: + header_value["parentCommit"] = parent_commit + yield {"key": "header", "value": header_value} + + nb_ignored_files = 0 + + # 2. Send operations, one per line + for operation in operations: + # Skip ignored files + if isinstance(operation, CommitOperationAdd) and operation._should_ignore: + logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).") + nb_ignored_files += 1 + continue + + # 2.a. Case adding a regular file + if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular": + yield { + "key": "file", + "value": { + "content": operation.b64content().decode(), + "path": operation.path_in_repo, + "encoding": "base64", + }, + } + # 2.b. Case adding an LFS file + elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs": + yield { + "key": "lfsFile", + "value": { + "path": operation.path_in_repo, + "algo": "sha256", + "oid": operation.upload_info.sha256.hex(), + "size": operation.upload_info.size, + }, + } + # 2.c. Case deleting a file or folder + elif isinstance(operation, CommitOperationDelete): + yield { + "key": "deletedFolder" if operation.is_folder else "deletedFile", + "value": {"path": operation.path_in_repo}, + } + # 2.d. Case copying a file or folder + elif isinstance(operation, CommitOperationCopy): + file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)] + if isinstance(file_to_copy, bytes): + yield { + "key": "file", + "value": { + "content": base64.b64encode(file_to_copy).decode(), + "path": operation.path_in_repo, + "encoding": "base64", + }, + } + elif file_to_copy.lfs: + yield { + "key": "lfsFile", + "value": { + "path": operation.path_in_repo, + "algo": "sha256", + "oid": file_to_copy.lfs.sha256, + }, + } + else: + raise ValueError( + "Malformed files_to_copy (should be raw file content as bytes or RepoFile objects with LFS info." + ) + # 2.e. Never expected to happen + else: + raise ValueError( + f"Unknown operation to commit. Operation: {operation}. Upload mode:" + f" {getattr(operation, '_upload_mode', None)}" + ) + + if nb_ignored_files > 0: + logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).") diff --git a/phivenv/Lib/site-packages/huggingface_hub/_commit_scheduler.py b/phivenv/Lib/site-packages/huggingface_hub/_commit_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f1f20339e7df2d17588623dc13bb3c6be6a46b53 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_commit_scheduler.py @@ -0,0 +1,353 @@ +import atexit +import logging +import os +import time +from concurrent.futures import Future +from dataclasses import dataclass +from io import SEEK_END, SEEK_SET, BytesIO +from pathlib import Path +from threading import Lock, Thread +from typing import Dict, List, Optional, Union + +from .hf_api import DEFAULT_IGNORE_PATTERNS, CommitInfo, CommitOperationAdd, HfApi +from .utils import filter_repo_objects + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class _FileToUpload: + """Temporary dataclass to store info about files to upload. Not meant to be used directly.""" + + local_path: Path + path_in_repo: str + size_limit: int + last_modified: float + + +class CommitScheduler: + """ + Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes). + + The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is + properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually + with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads) + to learn more about how to use it. + + Args: + repo_id (`str`): + The id of the repo to commit to. + folder_path (`str` or `Path`): + Path to the local folder to upload regularly. + every (`int` or `float`, *optional*): + The number of minutes between each commit. Defaults to 5 minutes. + path_in_repo (`str`, *optional*): + Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder + of the repository. + repo_type (`str`, *optional*): + The type of the repo to commit to. Defaults to `model`. + revision (`str`, *optional*): + The revision of the repo to commit to. Defaults to `main`. + private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + token (`str`, *optional*): + The token to use to commit to the repo. Defaults to the token saved on the machine. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + squash_history (`bool`, *optional*): + Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is + useful to avoid degraded performances on the repo when it grows too large. + hf_api (`HfApi`, *optional*): + The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...). + + Example: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import CommitScheduler + + # Scheduler uploads every 10 minutes + >>> csv_path = Path("watched_folder/data.csv") + >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10) + + >>> with csv_path.open("a") as f: + ... f.write("first line") + + # Some time later (...) + >>> with csv_path.open("a") as f: + ... f.write("second line") + ``` + + Example using a context manager: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import CommitScheduler + + >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler: + ... csv_path = Path("watched_folder/data.csv") + ... with csv_path.open("a") as f: + ... f.write("first line") + ... (...) + ... with csv_path.open("a") as f: + ... f.write("second line") + + # Scheduler is now stopped and last commit have been triggered + ``` + """ + + def __init__( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + every: Union[int, float] = 5, + path_in_repo: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + squash_history: bool = False, + hf_api: Optional["HfApi"] = None, + ) -> None: + self.api = hf_api or HfApi(token=token) + + # Folder + self.folder_path = Path(folder_path).expanduser().resolve() + self.path_in_repo = path_in_repo or "" + self.allow_patterns = allow_patterns + + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS + + if self.folder_path.is_file(): + raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.") + self.folder_path.mkdir(parents=True, exist_ok=True) + + # Repository + repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True) + self.repo_id = repo_url.repo_id + self.repo_type = repo_type + self.revision = revision + self.token = token + + # Keep track of already uploaded files + self.last_uploaded: Dict[Path, float] = {} # key is local path, value is timestamp + + # Scheduler + if not every > 0: + raise ValueError(f"'every' must be a positive integer, not '{every}'.") + self.lock = Lock() + self.every = every + self.squash_history = squash_history + + logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.") + self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True) + self._scheduler_thread.start() + atexit.register(self._push_to_hub) + + self.__stopped = False + + def stop(self) -> None: + """Stop the scheduler. + + A stopped scheduler cannot be restarted. Mostly for tests purposes. + """ + self.__stopped = True + + def __enter__(self) -> "CommitScheduler": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + # Upload last changes before exiting + self.trigger().result() + self.stop() + return + + def _run_scheduler(self) -> None: + """Dumb thread waiting between each scheduled push to Hub.""" + while True: + self.last_future = self.trigger() + time.sleep(self.every * 60) + if self.__stopped: + break + + def trigger(self) -> Future: + """Trigger a `push_to_hub` and return a future. + + This method is automatically called every `every` minutes. You can also call it manually to trigger a commit + immediately, without waiting for the next scheduled commit. + """ + return self.api.run_as_future(self._push_to_hub) + + def _push_to_hub(self) -> Optional[CommitInfo]: + if self.__stopped: # If stopped, already scheduled commits are ignored + return None + + logger.info("(Background) scheduled commit triggered.") + try: + value = self.push_to_hub() + if self.squash_history: + logger.info("(Background) squashing repo history.") + self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision) + return value + except Exception as e: + logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced + raise + + def push_to_hub(self) -> Optional[CommitInfo]: + """ + Push folder to the Hub and return the commit info. + + + + This method is not meant to be called directly. It is run in the background by the scheduler, respecting a + queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency + issues. + + + + The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and + uploads only changed files. If no changes are found, the method returns without committing anything. If you want + to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful + for example to compress data together in a single file before committing. For more details and examples, check + out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads). + """ + # Check files to upload (with lock) + with self.lock: + logger.debug("Listing files to upload for scheduled commit.") + + # List files from folder (taken from `_prepare_upload_folder_additions`) + relpath_to_abspath = { + path.relative_to(self.folder_path).as_posix(): path + for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic + if path.is_file() + } + prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else "" + + # Filter with pattern + filter out unchanged files + retrieve current file size + files_to_upload: List[_FileToUpload] = [] + for relpath in filter_repo_objects( + relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns + ): + local_path = relpath_to_abspath[relpath] + stat = local_path.stat() + if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime: + files_to_upload.append( + _FileToUpload( + local_path=local_path, + path_in_repo=prefix + relpath, + size_limit=stat.st_size, + last_modified=stat.st_mtime, + ) + ) + + # Return if nothing to upload + if len(files_to_upload) == 0: + logger.debug("Dropping schedule commit: no changed file to upload.") + return None + + # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size) + logger.debug("Removing unchanged files since previous scheduled commit.") + add_operations = [ + CommitOperationAdd( + # Cap the file to its current size, even if the user append data to it while a scheduled commit is happening + path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit), + path_in_repo=file_to_upload.path_in_repo, + ) + for file_to_upload in files_to_upload + ] + + # Upload files (append mode expected - no need for lock) + logger.debug("Uploading files for scheduled commit.") + commit_info = self.api.create_commit( + repo_id=self.repo_id, + repo_type=self.repo_type, + operations=add_operations, + commit_message="Scheduled Commit", + revision=self.revision, + ) + + # Successful commit: keep track of the latest "last_modified" for each file + for file in files_to_upload: + self.last_uploaded[file.local_path] = file.last_modified + return commit_info + + +class PartialFileIO(BytesIO): + """A file-like object that reads only the first part of a file. + + Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the + file is uploaded (i.e. the part that was available when the filesystem was first scanned). + + In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal + disturbance for the user. The object is passed to `CommitOperationAdd`. + + Only supports `read`, `tell` and `seek` methods. + + Args: + file_path (`str` or `Path`): + Path to the file to read. + size_limit (`int`): + The maximum number of bytes to read from the file. If the file is larger than this, only the first part + will be read (and uploaded). + """ + + def __init__(self, file_path: Union[str, Path], size_limit: int) -> None: + self._file_path = Path(file_path) + self._file = self._file_path.open("rb") + self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size) + + def __del__(self) -> None: + self._file.close() + return super().__del__() + + def __repr__(self) -> str: + return f"" + + def __len__(self) -> int: + return self._size_limit + + def __getattribute__(self, name: str): + if name.startswith("_") or name in ("read", "tell", "seek"): # only 3 public methods supported + return super().__getattribute__(name) + raise NotImplementedError(f"PartialFileIO does not support '{name}'.") + + def tell(self) -> int: + """Return the current file position.""" + return self._file.tell() + + def seek(self, __offset: int, __whence: int = SEEK_SET) -> int: + """Change the stream position to the given offset. + + Behavior is the same as a regular file, except that the position is capped to the size limit. + """ + if __whence == SEEK_END: + # SEEK_END => set from the truncated end + __offset = len(self) + __offset + __whence = SEEK_SET + + pos = self._file.seek(__offset, __whence) + if pos > self._size_limit: + return self._file.seek(self._size_limit) + return pos + + def read(self, __size: Optional[int] = -1) -> bytes: + """Read at most `__size` bytes from the file. + + Behavior is the same as a regular file, except that it is capped to the size limit. + """ + current = self._file.tell() + if __size is None or __size < 0: + # Read until file limit + truncated_size = self._size_limit - current + else: + # Read until file limit or __size + truncated_size = min(__size, self._size_limit - current) + return self._file.read(truncated_size) diff --git a/phivenv/Lib/site-packages/huggingface_hub/_inference_endpoints.py b/phivenv/Lib/site-packages/huggingface_hub/_inference_endpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..37f772bfbe28013ff5329d0a19a438706d50a19c --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_inference_endpoints.py @@ -0,0 +1,413 @@ +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Dict, Optional, Union + +from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError + +from .utils import get_session, logging, parse_datetime + + +if TYPE_CHECKING: + from .hf_api import HfApi + from .inference._client import InferenceClient + from .inference._generated._async_client import AsyncInferenceClient + +logger = logging.get_logger(__name__) + + +class InferenceEndpointStatus(str, Enum): + PENDING = "pending" + INITIALIZING = "initializing" + UPDATING = "updating" + UPDATE_FAILED = "updateFailed" + RUNNING = "running" + PAUSED = "paused" + FAILED = "failed" + SCALED_TO_ZERO = "scaledToZero" + + +class InferenceEndpointType(str, Enum): + PUBlIC = "public" + PROTECTED = "protected" + PRIVATE = "private" + + +@dataclass +class InferenceEndpoint: + """ + Contains information about a deployed Inference Endpoint. + + Args: + name (`str`): + The unique name of the Inference Endpoint. + namespace (`str`): + The namespace where the Inference Endpoint is located. + repository (`str`): + The name of the model repository deployed on this Inference Endpoint. + status ([`InferenceEndpointStatus`]): + The current status of the Inference Endpoint. + url (`str`, *optional*): + The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL. + framework (`str`): + The machine learning framework used for the model. + revision (`str`): + The specific model revision deployed on the Inference Endpoint. + task (`str`): + The task associated with the deployed model. + created_at (`datetime.datetime`): + The timestamp when the Inference Endpoint was created. + updated_at (`datetime.datetime`): + The timestamp of the last update of the Inference Endpoint. + type ([`InferenceEndpointType`]): + The type of the Inference Endpoint (public, protected, private). + raw (`Dict`): + The raw dictionary data returned from the API. + token (`str` or `bool`, *optional*): + Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the + locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server. + + Example: + ```python + >>> from huggingface_hub import get_inference_endpoint + >>> endpoint = get_inference_endpoint("my-text-to-image") + >>> endpoint + InferenceEndpoint(name='my-text-to-image', ...) + + # Get status + >>> endpoint.status + 'running' + >>> endpoint.url + 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' + + # Run inference + >>> endpoint.client.text_to_image(...) + + # Pause endpoint to save $$$ + >>> endpoint.pause() + + # ... + # Resume and wait for deployment + >>> endpoint.resume() + >>> endpoint.wait() + >>> endpoint.client.text_to_image(...) + ``` + """ + + # Field in __repr__ + name: str = field(init=False) + namespace: str + repository: str = field(init=False) + status: InferenceEndpointStatus = field(init=False) + health_route: str = field(init=False) + url: Optional[str] = field(init=False) + + # Other fields + framework: str = field(repr=False, init=False) + revision: str = field(repr=False, init=False) + task: str = field(repr=False, init=False) + created_at: datetime = field(repr=False, init=False) + updated_at: datetime = field(repr=False, init=False) + type: InferenceEndpointType = field(repr=False, init=False) + + # Raw dict from the API + raw: Dict = field(repr=False) + + # Internal fields + _token: Union[str, bool, None] = field(repr=False, compare=False) + _api: "HfApi" = field(repr=False, compare=False) + + @classmethod + def from_raw( + cls, raw: Dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None + ) -> "InferenceEndpoint": + """Initialize object from raw dictionary.""" + if api is None: + from .hf_api import HfApi + + api = HfApi() + if token is None: + token = api.token + + # All other fields are populated in __post_init__ + return cls(raw=raw, namespace=namespace, _token=token, _api=api) + + def __post_init__(self) -> None: + """Populate fields from raw dictionary.""" + self._populate_from_raw() + + @property + def client(self) -> "InferenceClient": + """Returns a client to make predictions on this Inference Endpoint. + + Returns: + [`InferenceClient`]: an inference client pointing to the deployed endpoint. + + Raises: + [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. + """ + if self.url is None: + raise InferenceEndpointError( + "Cannot create a client for this Inference Endpoint as it is not yet deployed. " + "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." + ) + from .inference._client import InferenceClient + + return InferenceClient( + model=self.url, + token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. + ) + + @property + def async_client(self) -> "AsyncInferenceClient": + """Returns a client to make predictions on this Inference Endpoint. + + Returns: + [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint. + + Raises: + [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. + """ + if self.url is None: + raise InferenceEndpointError( + "Cannot create a client for this Inference Endpoint as it is not yet deployed. " + "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." + ) + from .inference._generated._async_client import AsyncInferenceClient + + return AsyncInferenceClient( + model=self.url, + token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. + ) + + def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint": + """Wait for the Inference Endpoint to be deployed. + + Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout` + seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest + data. + + Args: + timeout (`int`, *optional*): + The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait + indefinitely. + refresh_every (`int`, *optional*): + The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + + Raises: + [`InferenceEndpointError`] + If the Inference Endpoint ended up in a failed state. + [`InferenceEndpointTimeoutError`] + If the Inference Endpoint is not deployed after `timeout` seconds. + """ + if timeout is not None and timeout < 0: + raise ValueError("`timeout` cannot be negative.") + if refresh_every <= 0: + raise ValueError("`refresh_every` must be positive.") + + start = time.time() + while True: + if self.status == InferenceEndpointStatus.FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information." + ) + if self.status == InferenceEndpointStatus.UPDATE_FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to update. Please check the logs for more information." + ) + if self.status == InferenceEndpointStatus.RUNNING and self.url is not None: + # Verify the endpoint is actually reachable + _health_url = f"{self.url.rstrip('/')}/{self.health_route.lstrip('/')}" + response = get_session().get(_health_url, headers=self._api._build_hf_headers(token=self._token)) + if response.status_code == 200: + logger.info("Inference Endpoint is ready to be used.") + return self + + if timeout is not None: + if time.time() - start > timeout: + raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") + logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...") + time.sleep(refresh_every) + self.fetch() + + def fetch(self) -> "InferenceEndpoint": + """Fetch latest information about the Inference Endpoint. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def update( + self, + *, + # Compute update + accelerator: Optional[str] = None, + instance_size: Optional[str] = None, + instance_type: Optional[str] = None, + min_replica: Optional[int] = None, + max_replica: Optional[int] = None, + scale_to_zero_timeout: Optional[int] = None, + # Model update + repository: Optional[str] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[Dict] = None, + secrets: Optional[Dict[str, str]] = None, + ) -> "InferenceEndpoint": + """Update the Inference Endpoint. + + This method allows the update of either the compute configuration, the deployed model, or both. All arguments are + optional but at least one must be provided. + + This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Args: + accelerator (`str`, *optional*): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`, *optional*): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`, *optional*): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero. + + repository (`str`, *optional*): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`, *optional*): + The machine learning framework used for the model (e.g. `"custom"`). + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`Dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + secrets (`Dict[str, str]`, *optional*): + Secret values to inject in the container environment. + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + # Make API call + obj = self._api.update_inference_endpoint( + name=self.name, + namespace=self.namespace, + accelerator=accelerator, + instance_size=instance_size, + instance_type=instance_type, + min_replica=min_replica, + max_replica=max_replica, + scale_to_zero_timeout=scale_to_zero_timeout, + repository=repository, + framework=framework, + revision=revision, + task=task, + custom_image=custom_image, + secrets=secrets, + token=self._token, # type: ignore [arg-type] + ) + + # Mutate current object + self.raw = obj.raw + self._populate_from_raw() + return self + + def pause(self) -> "InferenceEndpoint": + """Pause the Inference Endpoint. + + A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`]. + This is different than scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which + would be automatically restarted when a request is made to it. + + This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def resume(self, running_ok: bool = True) -> "InferenceEndpoint": + """Resume the Inference Endpoint. + + This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Args: + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.resume_inference_endpoint( + name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token + ) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def scale_to_zero(self) -> "InferenceEndpoint": + """Scale Inference Endpoint to zero. + + An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a + cold start delay. This is different than pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which + would require a manual resume with [`InferenceEndpoint.resume`]. + + This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def delete(self) -> None: + """Delete the Inference Endpoint. + + This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable + to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`]. + + This is an alias for [`HfApi.delete_inference_endpoint`]. + """ + self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + + def _populate_from_raw(self) -> None: + """Populate fields from raw dictionary. + + Called in __post_init__ + each time the Inference Endpoint is updated. + """ + # Repr fields + self.name = self.raw["name"] + self.repository = self.raw["model"]["repository"] + self.status = self.raw["status"]["state"] + self.url = self.raw["status"].get("url") + self.health_route = self.raw["healthRoute"] + + # Other fields + self.framework = self.raw["model"]["framework"] + self.revision = self.raw["model"]["revision"] + self.task = self.raw["model"]["task"] + self.created_at = parse_datetime(self.raw["status"]["createdAt"]) + self.updated_at = parse_datetime(self.raw["status"]["updatedAt"]) + self.type = self.raw["type"] diff --git a/phivenv/Lib/site-packages/huggingface_hub/_jobs_api.py b/phivenv/Lib/site-packages/huggingface_hub/_jobs_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9d76d645d0df56a37f8ff75ed0fd53348b4abd --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_jobs_api.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2025-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from huggingface_hub import constants +from huggingface_hub._space_api import SpaceHardware +from huggingface_hub.utils._datetime import parse_datetime + + +class JobStage(str, Enum): + """ + Enumeration of possible stage of a Job on the Hub. + + Value can be compared to a string: + ```py + assert JobStage.COMPLETED == "COMPLETED" + ``` + + Taken from https://github.com/huggingface/moon-landing/blob/main/server/job_types/JobInfo.ts#L61 (private url). + """ + + # Copied from moon-landing > server > lib > Job.ts + COMPLETED = "COMPLETED" + CANCELED = "CANCELED" + ERROR = "ERROR" + DELETED = "DELETED" + RUNNING = "RUNNING" + + +@dataclass +class JobStatus: + stage: JobStage + message: Optional[str] + + +@dataclass +class JobOwner: + id: str + name: str + type: str + + +@dataclass +class JobInfo: + """ + Contains information about a Job. + + Args: + id (`str`): + Job ID. + created_at (`datetime` or `None`): + When the Job was created. + docker_image (`str` or `None`): + The Docker image from Docker Hub used for the Job. + Can be None if space_id is present instead. + space_id (`str` or `None`): + The Docker image from Hugging Face Spaces used for the Job. + Can be None if docker_image is present instead. + command (`List[str]` or `None`): + Command of the Job, e.g. `["python", "-c", "print('hello world')"]` + arguments (`List[str]` or `None`): + Arguments passed to the command + environment (`Dict[str]` or `None`): + Environment variables of the Job as a dictionary. + secrets (`Dict[str]` or `None`): + Secret environment variables of the Job (encrypted). + flavor (`str` or `None`): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + E.g. `"cpu-basic"`. + status: (`JobStatus` or `None`): + Status of the Job, e.g. `JobStatus(stage="RUNNING", message=None)` + See [`JobStage`] for possible stage values. + status: (`JobOwner` or `None`): + Owner of the Job, e.g. `JobOwner(id="5e9ecfc04957053f60648a3e", name="lhoestq", type="user")` + + Example: + + ```python + >>> from huggingface_hub import run_job + >>> job = run_job( + ... image="python:3.12", + ... command=["python", "-c", "print('Hello from the cloud!')"] + ... ) + >>> job + JobInfo(id='687fb701029421ae5549d998', created_at=datetime.datetime(2025, 7, 22, 16, 6, 25, 79000, tzinfo=datetime.timezone.utc), docker_image='python:3.12', space_id=None, command=['python', '-c', "print('Hello from the cloud!')"], arguments=[], environment={}, secrets={}, flavor='cpu-basic', status=JobStatus(stage='RUNNING', message=None), owner=JobOwner(id='5e9ecfc04957053f60648a3e', name='lhoestq', type='user'), endpoint='https://huggingface.co', url='https://huggingface.co/jobs/lhoestq/687fb701029421ae5549d998') + >>> job.id + '687fb701029421ae5549d998' + >>> job.url + 'https://huggingface.co/jobs/lhoestq/687fb701029421ae5549d998' + >>> job.status.stage + 'RUNNING' + ``` + """ + + id: str + created_at: Optional[datetime] + docker_image: Optional[str] + space_id: Optional[str] + command: Optional[List[str]] + arguments: Optional[List[str]] + environment: Optional[Dict[str, Any]] + secrets: Optional[Dict[str, Any]] + flavor: Optional[SpaceHardware] + status: JobStatus + owner: JobOwner + + # Inferred fields + endpoint: str + url: str + + def __init__(self, **kwargs) -> None: + self.id = kwargs["id"] + created_at = kwargs.get("createdAt") or kwargs.get("created_at") + self.created_at = parse_datetime(created_at) if created_at else None + self.docker_image = kwargs.get("dockerImage") or kwargs.get("docker_image") + self.space_id = kwargs.get("spaceId") or kwargs.get("space_id") + owner = kwargs.get("owner", {}) + self.owner = JobOwner(id=owner["id"], name=owner["name"], type=owner["type"]) + self.command = kwargs.get("command") + self.arguments = kwargs.get("arguments") + self.environment = kwargs.get("environment") + self.secrets = kwargs.get("secrets") + self.flavor = kwargs.get("flavor") + status = kwargs.get("status", {}) + self.status = JobStatus(stage=status["stage"], message=status.get("message")) + + # Inferred fields + self.endpoint = kwargs.get("endpoint", constants.ENDPOINT) + self.url = f"{self.endpoint}/jobs/{self.owner.name}/{self.id}" diff --git a/phivenv/Lib/site-packages/huggingface_hub/_local_folder.py b/phivenv/Lib/site-packages/huggingface_hub/_local_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..37f6c32a760ecf03794c129735fe2e15516952d1 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_local_folder.py @@ -0,0 +1,447 @@ +# coding=utf-8 +# Copyright 2024-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle the `../.cache/huggingface` folder in local directories. + +First discussed in https://github.com/huggingface/huggingface_hub/issues/1738 to store +download metadata when downloading files from the hub to a local directory (without +using the cache). + +./.cache/huggingface folder structure: +[4.0K] data +├── [4.0K] .cache +│ └── [4.0K] huggingface +│ └── [4.0K] download +│ ├── [ 16] file.parquet.metadata +│ ├── [ 16] file.txt.metadata +│ └── [4.0K] folder +│ └── [ 16] file.parquet.metadata +│ +├── [6.5G] file.parquet +├── [1.5K] file.txt +└── [4.0K] folder + └── [ 16] file.parquet + + +Download metadata file structure: +``` +# file.txt.metadata +11c5a3d5811f50298f278a704980280950aedb10 +a16a55fda99d2f2e7b69cce5cf93ff4ad3049930 +1712656091.123 + +# file.parquet.metadata +11c5a3d5811f50298f278a704980280950aedb10 +7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421 +1712656091.123 +} +``` +""" + +import base64 +import hashlib +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from .utils import WeakFileLock + + +logger = logging.getLogger(__name__) + + +@dataclass +class LocalDownloadFilePaths: + """ + Paths to the files related to a download process in a local dir. + + Returned by [`get_local_download_paths`]. + + Attributes: + file_path (`Path`): + Path where the file will be saved. + lock_path (`Path`): + Path to the lock file used to ensure atomicity when reading/writing metadata. + metadata_path (`Path`): + Path to the metadata file. + """ + + file_path: Path + lock_path: Path + metadata_path: Path + + def incomplete_path(self, etag: str) -> Path: + """Return the path where a file will be temporarily downloaded before being moved to `file_path`.""" + path = self.metadata_path.parent / f"{_short_hash(self.metadata_path.name)}.{etag}.incomplete" + resolved_path = str(path.resolve()) + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix. + if os.name == "nt" and len(resolved_path) > 255 and not resolved_path.startswith("\\\\?\\"): + path = Path("\\\\?\\" + resolved_path) + return path + + +@dataclass(frozen=True) +class LocalUploadFilePaths: + """ + Paths to the files related to an upload process in a local dir. + + Returned by [`get_local_upload_paths`]. + + Attributes: + path_in_repo (`str`): + Path of the file in the repo. + file_path (`Path`): + Path where the file will be saved. + lock_path (`Path`): + Path to the lock file used to ensure atomicity when reading/writing metadata. + metadata_path (`Path`): + Path to the metadata file. + """ + + path_in_repo: str + file_path: Path + lock_path: Path + metadata_path: Path + + +@dataclass +class LocalDownloadFileMetadata: + """ + Metadata about a file in the local directory related to a download process. + + Attributes: + filename (`str`): + Path of the file in the repo. + commit_hash (`str`): + Commit hash of the file in the repo. + etag (`str`): + ETag of the file in the repo. Used to check if the file has changed. + For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. + timestamp (`int`): + Unix timestamp of when the metadata was saved i.e. when the metadata was accurate. + """ + + filename: str + commit_hash: str + etag: str + timestamp: float + + +@dataclass +class LocalUploadFileMetadata: + """ + Metadata about a file in the local directory related to an upload process. + """ + + size: int + + # Default values correspond to "we don't know yet" + timestamp: Optional[float] = None + should_ignore: Optional[bool] = None + sha256: Optional[str] = None + upload_mode: Optional[str] = None + remote_oid: Optional[str] = None + is_uploaded: bool = False + is_committed: bool = False + + def save(self, paths: LocalUploadFilePaths) -> None: + """Save the metadata to disk.""" + with WeakFileLock(paths.lock_path): + with paths.metadata_path.open("w") as f: + new_timestamp = time.time() + f.write(str(new_timestamp) + "\n") + + f.write(str(self.size)) # never None + f.write("\n") + + if self.should_ignore is not None: + f.write(str(int(self.should_ignore))) + f.write("\n") + + if self.sha256 is not None: + f.write(self.sha256) + f.write("\n") + + if self.upload_mode is not None: + f.write(self.upload_mode) + f.write("\n") + + if self.remote_oid is not None: + f.write(self.remote_oid) + f.write("\n") + + f.write(str(int(self.is_uploaded)) + "\n") + f.write(str(int(self.is_committed)) + "\n") + + self.timestamp = new_timestamp + + +def get_local_download_paths(local_dir: Path, filename: str) -> LocalDownloadFilePaths: + """Compute paths to the files related to a download process. + + Folders containing the paths are all guaranteed to exist. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + filename (`str`): + Path of the file in the repo. + + Return: + [`LocalDownloadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path, incomplete_path). + """ + # filename is the path in the Hub repository (separated by '/') + # make sure to have a cross platform transcription + sanitized_filename = os.path.join(*filename.split("/")) + if os.name == "nt": + if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename: + raise ValueError( + f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository" + " owner to rename this file." + ) + file_path = local_dir / sanitized_filename + metadata_path = _huggingface_dir(local_dir) / "download" / f"{sanitized_filename}.metadata" + lock_path = metadata_path.with_suffix(".lock") + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix + if os.name == "nt": + if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255: + file_path = Path("\\\\?\\" + os.path.abspath(file_path)) + lock_path = Path("\\\\?\\" + os.path.abspath(lock_path)) + metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path)) + + file_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.parent.mkdir(parents=True, exist_ok=True) + return LocalDownloadFilePaths(file_path=file_path, lock_path=lock_path, metadata_path=metadata_path) + + +def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePaths: + """Compute paths to the files related to an upload process. + + Folders containing the paths are all guaranteed to exist. + + Args: + local_dir (`Path`): + Path to the local directory that is uploaded. + filename (`str`): + Path of the file in the repo. + + Return: + [`LocalUploadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path). + """ + # filename is the path in the Hub repository (separated by '/') + # make sure to have a cross platform transcription + sanitized_filename = os.path.join(*filename.split("/")) + if os.name == "nt": + if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename: + raise ValueError( + f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository" + " owner to rename this file." + ) + file_path = local_dir / sanitized_filename + metadata_path = _huggingface_dir(local_dir) / "upload" / f"{sanitized_filename}.metadata" + lock_path = metadata_path.with_suffix(".lock") + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix + if os.name == "nt": + if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255: + file_path = Path("\\\\?\\" + os.path.abspath(file_path)) + lock_path = Path("\\\\?\\" + os.path.abspath(lock_path)) + metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path)) + + file_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.parent.mkdir(parents=True, exist_ok=True) + return LocalUploadFilePaths( + path_in_repo=filename, file_path=file_path, lock_path=lock_path, metadata_path=metadata_path + ) + + +def read_download_metadata(local_dir: Path, filename: str) -> Optional[LocalDownloadFileMetadata]: + """Read metadata about a file in the local directory related to a download process. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + filename (`str`): + Path of the file in the repo. + + Return: + `[LocalDownloadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise. + """ + paths = get_local_download_paths(local_dir, filename) + with WeakFileLock(paths.lock_path): + if paths.metadata_path.exists(): + try: + with paths.metadata_path.open() as f: + commit_hash = f.readline().strip() + etag = f.readline().strip() + timestamp = float(f.readline().strip()) + metadata = LocalDownloadFileMetadata( + filename=filename, + commit_hash=commit_hash, + etag=etag, + timestamp=timestamp, + ) + except Exception as e: + # remove the metadata file if it is corrupted / not the right format + logger.warning( + f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue." + ) + try: + paths.metadata_path.unlink() + except Exception as e: + logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}") + + try: + # check if the file exists and hasn't been modified since the metadata was saved + stat = paths.file_path.stat() + if ( + stat.st_mtime - 1 <= metadata.timestamp + ): # allow 1s difference as stat.st_mtime might not be precise + return metadata + logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.") + except FileNotFoundError: + # file does not exist => metadata is outdated + return None + return None + + +def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetadata: + """Read metadata about a file in the local directory related to an upload process. + + TODO: factorize logic with `read_download_metadata`. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + filename (`str`): + Path of the file in the repo. + + Return: + `[LocalUploadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise. + """ + paths = get_local_upload_paths(local_dir, filename) + with WeakFileLock(paths.lock_path): + if paths.metadata_path.exists(): + try: + with paths.metadata_path.open() as f: + timestamp = float(f.readline().strip()) + + size = int(f.readline().strip()) # never None + + _should_ignore = f.readline().strip() + should_ignore = None if _should_ignore == "" else bool(int(_should_ignore)) + + _sha256 = f.readline().strip() + sha256 = None if _sha256 == "" else _sha256 + + _upload_mode = f.readline().strip() + upload_mode = None if _upload_mode == "" else _upload_mode + if upload_mode not in (None, "regular", "lfs"): + raise ValueError(f"Invalid upload mode in metadata {paths.path_in_repo}: {upload_mode}") + + _remote_oid = f.readline().strip() + remote_oid = None if _remote_oid == "" else _remote_oid + + is_uploaded = bool(int(f.readline().strip())) + is_committed = bool(int(f.readline().strip())) + + metadata = LocalUploadFileMetadata( + timestamp=timestamp, + size=size, + should_ignore=should_ignore, + sha256=sha256, + upload_mode=upload_mode, + remote_oid=remote_oid, + is_uploaded=is_uploaded, + is_committed=is_committed, + ) + except Exception as e: + # remove the metadata file if it is corrupted / not the right format + logger.warning( + f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue." + ) + try: + paths.metadata_path.unlink() + except Exception as e: + logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}") + + # TODO: can we do better? + if ( + metadata.timestamp is not None + and metadata.is_uploaded # file was uploaded + and not metadata.is_committed # but not committed + and time.time() - metadata.timestamp > 20 * 3600 # and it's been more than 20 hours + ): # => we consider it as garbage-collected by S3 + metadata.is_uploaded = False + + # check if the file exists and hasn't been modified since the metadata was saved + try: + if metadata.timestamp is not None and paths.file_path.stat().st_mtime <= metadata.timestamp: + return metadata + logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.") + except FileNotFoundError: + # file does not exist => metadata is outdated + pass + + # empty metadata => we don't know anything expect its size + return LocalUploadFileMetadata(size=paths.file_path.stat().st_size) + + +def write_download_metadata(local_dir: Path, filename: str, commit_hash: str, etag: str) -> None: + """Write metadata about a file in the local directory related to a download process. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + """ + paths = get_local_download_paths(local_dir, filename) + with WeakFileLock(paths.lock_path): + with paths.metadata_path.open("w") as f: + f.write(f"{commit_hash}\n{etag}\n{time.time()}\n") + + +def _huggingface_dir(local_dir: Path) -> Path: + """Return the path to the `.cache/huggingface` directory in a local directory.""" + # Wrap in lru_cache to avoid overwriting the .gitignore file if called multiple times + path = local_dir / ".cache" / "huggingface" + path.mkdir(exist_ok=True, parents=True) + + # Create a .gitignore file in the .cache/huggingface directory if it doesn't exist + # Should be thread-safe enough like this. + gitignore = path / ".gitignore" + gitignore_lock = path / ".gitignore.lock" + if not gitignore.exists(): + try: + with WeakFileLock(gitignore_lock, timeout=0.1): + gitignore.write_text("*") + except IndexError: + pass + except OSError: # TimeoutError, FileNotFoundError, PermissionError, etc. + pass + try: + gitignore_lock.unlink() + except OSError: + pass + return path + + +def _short_hash(filename: str) -> str: + return base64.urlsafe_b64encode(hashlib.sha1(filename.encode()).digest()).decode() diff --git a/phivenv/Lib/site-packages/huggingface_hub/_login.py b/phivenv/Lib/site-packages/huggingface_hub/_login.py new file mode 100644 index 0000000000000000000000000000000000000000..303cd2b35d834cecfaeee646644e49ffcbcd89a2 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_login.py @@ -0,0 +1,520 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains methods to log in to the Hub.""" + +import os +import subprocess +from getpass import getpass +from pathlib import Path +from typing import Optional + +from . import constants +from .commands._cli_utils import ANSI +from .utils import ( + capture_output, + get_token, + is_google_colab, + is_notebook, + list_credential_helpers, + logging, + run_subprocess, + set_git_credential, + unset_git_credential, +) +from .utils._auth import ( + _get_token_by_name, + _get_token_from_environment, + _get_token_from_file, + _get_token_from_google_colab, + _save_stored_tokens, + _save_token, + get_stored_tokens, +) +from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args + + +logger = logging.get_logger(__name__) + +_HF_LOGO_ASCII = """ + _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| + _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| + _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| +""" + + +@_deprecate_arguments( + version="1.0", + deprecated_args="write_permission", + custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", +) +@_deprecate_positional_args(version="1.0") +def login( + token: Optional[str] = None, + *, + add_to_git_credential: bool = False, + new_session: bool = True, + write_permission: bool = False, +) -> None: + """Login the machine to access the Hub. + + The `token` is persisted in cache and set as a git credential. Once done, the machine + is logged in and the access token will be available across all `huggingface_hub` + components. If `token` is not provided, it will be prompted to the user either with + a widget (in a notebook) or via the terminal. + + To log in from outside of a script, one can also use `hf auth login` which is + a cli command that wraps [`login`]. + + + + [`login`] is a drop-in replacement method for [`notebook_login`] as it wraps and + extends its capabilities. + + + + + + When the token is not passed, [`login`] will automatically detect if the script runs + in a notebook or not. However, this detection might not be accurate due to the + variety of notebooks that exists nowadays. If that is the case, you can always force + the UI by using [`notebook_login`] or [`interpreter_login`]. + + + + Args: + token (`str`, *optional*): + User access token to generate from https://huggingface.co/settings/token. + add_to_git_credential (`bool`, defaults to `False`): + If `True`, token will be set as git credential. If no git credential helper + is configured, a warning will be displayed to the user. If `token` is `None`, + the value of `add_to_git_credential` is ignored and will be prompted again + to the end user. + new_session (`bool`, defaults to `True`): + If `True`, will request a token even if one is already saved on the machine. + write_permission (`bool`): + Ignored and deprecated argument. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If an organization token is passed. Only personal account tokens are valid + to log in. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If token is invalid. + [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + If running in a notebook but `ipywidgets` is not installed. + """ + if token is not None: + if not add_to_git_credential: + logger.info( + "The token has not been saved to the git credentials helper. Pass " + "`add_to_git_credential=True` in this function directly or " + "`--add-to-git-credential` if using via `hf`CLI if " + "you want to set the git credential as well." + ) + _login(token, add_to_git_credential=add_to_git_credential) + elif is_notebook(): + notebook_login(new_session=new_session) + else: + interpreter_login(new_session=new_session) + + +def logout(token_name: Optional[str] = None) -> None: + """Logout the machine from the Hub. + + Token is deleted from the machine and removed from git credential. + + Args: + token_name (`str`, *optional*): + Name of the access token to logout from. If `None`, will logout from all saved access tokens. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): + If the access token name is not found. + """ + if get_token() is None and not get_stored_tokens(): # No active token and no saved access tokens + logger.warning("Not logged in!") + return + if not token_name: + # Delete all saved access tokens and token + for file_path in (constants.HF_TOKEN_PATH, constants.HF_STORED_TOKENS_PATH): + try: + Path(file_path).unlink() + except FileNotFoundError: + pass + logger.info("Successfully logged out from all access tokens.") + else: + _logout_from_token(token_name) + logger.info(f"Successfully logged out from access token: {token_name}.") + + unset_git_credential() + + # Check if still logged in + if _get_token_from_google_colab() is not None: + raise EnvironmentError( + "You are automatically logged in using a Google Colab secret.\n" + "To log out, you must unset the `HF_TOKEN` secret in your Colab settings." + ) + if _get_token_from_environment() is not None: + raise EnvironmentError( + "Token has been deleted from your machine but you are still logged in.\n" + "To log out, you must clear out both `HF_TOKEN` and `HUGGING_FACE_HUB_TOKEN` environment variables." + ) + + +def auth_switch(token_name: str, add_to_git_credential: bool = False) -> None: + """Switch to a different access token. + + Args: + token_name (`str`): + Name of the access token to switch to. + add_to_git_credential (`bool`, defaults to `False`): + If `True`, token will be set as git credential. If no git credential helper + is configured, a warning will be displayed to the user. If `token` is `None`, + the value of `add_to_git_credential` is ignored and will be prompted again + to the end user. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): + If the access token name is not found. + """ + token = _get_token_by_name(token_name) + if not token: + raise ValueError(f"Access token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}") + # Write token to HF_TOKEN_PATH + _set_active_token(token_name, add_to_git_credential) + logger.info(f"The current active token is: {token_name}") + token_from_environment = _get_token_from_environment() + if token_from_environment is not None and token_from_environment != token: + logger.warning( + "The environment variable `HF_TOKEN` is set and will override the access token you've just switched to." + ) + + +def auth_list() -> None: + """List all stored access tokens.""" + tokens = get_stored_tokens() + + if not tokens: + logger.info("No access tokens found.") + return + # Find current token + current_token = get_token() + current_token_name = None + for token_name in tokens: + if tokens.get(token_name) == current_token: + current_token_name = token_name + # Print header + max_offset = max(len("token"), max(len(token) for token in tokens)) + 2 + print(f" {{:<{max_offset}}}| {{:<15}}".format("name", "token")) + print("-" * (max_offset + 2) + "|" + "-" * 15) + + # Print saved access tokens + for token_name in tokens: + token = tokens.get(token_name, "") + masked_token = f"{token[:3]}****{token[-4:]}" if token != "" else token + is_current = "*" if token == current_token else " " + + print(f"{is_current} {{:<{max_offset}}}| {{:<15}}".format(token_name, masked_token)) + + if _get_token_from_environment(): + logger.warning( + "\nNote: Environment variable `HF_TOKEN` is set and is the current active token independently from the stored tokens listed above." + ) + elif current_token_name is None: + logger.warning( + "\nNote: No active token is set and no environment variable `HF_TOKEN` is found. Use `hf auth login` to log in." + ) + + +### +# Interpreter-based login (text) +### + + +@_deprecate_arguments( + version="1.0", + deprecated_args="write_permission", + custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", +) +@_deprecate_positional_args(version="1.0") +def interpreter_login(*, new_session: bool = True, write_permission: bool = False) -> None: + """ + Displays a prompt to log in to the HF website and store the token. + + This is equivalent to [`login`] without passing a token when not run in a notebook. + [`interpreter_login`] is useful if you want to force the use of the terminal prompt + instead of a notebook widget. + + For more details, see [`login`]. + + Args: + new_session (`bool`, defaults to `True`): + If `True`, will request a token even if one is already saved on the machine. + write_permission (`bool`): + Ignored and deprecated argument. + """ + if not new_session and get_token() is not None: + logger.info("User is already logged in.") + return + + from .commands.delete_cache import _ask_for_confirmation_no_tui + + print(_HF_LOGO_ASCII) + if get_token() is not None: + logger.info( + " A token is already saved on your machine. Run `hf auth whoami`" + " to get more information or `hf auth logout` if you want" + " to log out." + ) + logger.info(" Setting a new token will erase the existing one.") + + logger.info( + " To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens ." + ) + if os.name == "nt": + logger.info("Token can be pasted using 'Right-Click'.") + token = getpass("Enter your token (input will not be visible): ") + add_to_git_credential = _ask_for_confirmation_no_tui("Add token as git credential?") + + _login(token=token, add_to_git_credential=add_to_git_credential) + + +### +# Notebook-based login (widget) +### + +NOTEBOOK_LOGIN_PASSWORD_HTML = """

Immediately click login after typing your password or +it might be stored in plain text in this notebook file.
""" + + +NOTEBOOK_LOGIN_TOKEN_HTML_START = """

Copy a token from your Hugging Face +tokens page and paste it below.
Immediately click login after copying +your token or it might be stored in plain text in this notebook file.
""" + + +NOTEBOOK_LOGIN_TOKEN_HTML_END = """ +Pro Tip: If you don't already have one, you can create a dedicated +'notebooks' token with 'write' access, that you can then easily reuse for all +notebooks. """ + + +@_deprecate_arguments( + version="1.0", + deprecated_args="write_permission", + custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", +) +@_deprecate_positional_args(version="1.0") +def notebook_login(*, new_session: bool = True, write_permission: bool = False) -> None: + """ + Displays a widget to log in to the HF website and store the token. + + This is equivalent to [`login`] without passing a token when run in a notebook. + [`notebook_login`] is useful if you want to force the use of the notebook widget + instead of a prompt in the terminal. + + For more details, see [`login`]. + + Args: + new_session (`bool`, defaults to `True`): + If `True`, will request a token even if one is already saved on the machine. + write_permission (`bool`): + Ignored and deprecated argument. + """ + try: + import ipywidgets.widgets as widgets # type: ignore + from IPython.display import display # type: ignore + except ImportError: + raise ImportError( + "The `notebook_login` function can only be used in a notebook (Jupyter or" + " Colab) and you need the `ipywidgets` module: `pip install ipywidgets`." + ) + if not new_session and get_token() is not None: + logger.info("User is already logged in.") + return + + box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%") + + token_widget = widgets.Password(description="Token:") + git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?") + token_finish_button = widgets.Button(description="Login") + + login_token_widget = widgets.VBox( + [ + widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_START), + token_widget, + git_checkbox_widget, + token_finish_button, + widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_END), + ], + layout=box_layout, + ) + display(login_token_widget) + + # On click events + def login_token_event(t): + """Event handler for the login button.""" + token = token_widget.value + add_to_git_credential = git_checkbox_widget.value + # Erase token and clear value to make sure it's not saved in the notebook. + token_widget.value = "" + # Hide inputs + login_token_widget.children = [widgets.Label("Connecting...")] + try: + with capture_output() as captured: + _login(token, add_to_git_credential=add_to_git_credential) + message = captured.getvalue() + except Exception as error: + message = str(error) + # Print result (success message or error) + login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()] + + token_finish_button.on_click(login_token_event) + + +### +# Login private helpers +### + + +def _login( + token: str, + add_to_git_credential: bool, +) -> None: + from .hf_api import whoami # avoid circular import + + if token.startswith("api_org"): + raise ValueError("You must use your personal account token, not an organization token.") + + token_info = whoami(token) + permission = token_info["auth"]["accessToken"]["role"] + logger.info(f"Token is valid (permission: {permission}).") + + token_name = token_info["auth"]["accessToken"]["displayName"] + # Store token locally + _save_token(token=token, token_name=token_name) + # Set active token + _set_active_token(token_name=token_name, add_to_git_credential=add_to_git_credential) + logger.info("Login successful.") + if _get_token_from_environment(): + logger.warning( + "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured." + ) + else: + logger.info(f"The current active token is: `{token_name}`") + + +def _logout_from_token(token_name: str) -> None: + """Logout from a specific access token. + + Args: + token_name (`str`): + The name of the access token to logout from. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): + If the access token name is not found. + """ + stored_tokens = get_stored_tokens() + # If there is no access tokens saved or the access token name is not found, do nothing + if not stored_tokens or token_name not in stored_tokens: + return + + token = stored_tokens.pop(token_name) + _save_stored_tokens(stored_tokens) + + if token == _get_token_from_file(): + logger.warning(f"Active token '{token_name}' has been deleted.") + Path(constants.HF_TOKEN_PATH).unlink(missing_ok=True) + + +def _set_active_token( + token_name: str, + add_to_git_credential: bool, +) -> None: + """Set the active access token. + + Args: + token_name (`str`): + The name of the token to set as active. + """ + token = _get_token_by_name(token_name) + if not token: + raise ValueError(f"Token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}") + if add_to_git_credential: + if _is_git_credential_helper_configured(): + set_git_credential(token) + logger.info( + "Your token has been saved in your configured git credential helpers" + + f" ({','.join(list_credential_helpers())})." + ) + else: + logger.warning("Token has not been saved to git credential helper.") + # Write token to HF_TOKEN_PATH + path = Path(constants.HF_TOKEN_PATH) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(token) + logger.info(f"Your token has been saved to {constants.HF_TOKEN_PATH}") + + +def _is_git_credential_helper_configured() -> bool: + """Check if a git credential helper is configured. + + Warns user if not the case (except for Google Colab where "store" is set by default + by `huggingface_hub`). + """ + helpers = list_credential_helpers() + if len(helpers) > 0: + return True # Do not warn: at least 1 helper is set + + # Only in Google Colab to avoid the warning message + # See https://github.com/huggingface/huggingface_hub/issues/1043#issuecomment-1247010710 + if is_google_colab(): + _set_store_as_git_credential_helper_globally() + return True # Do not warn: "store" is used by default in Google Colab + + # Otherwise, warn user + print( + ANSI.red( + "Cannot authenticate through git-credential as no helper is defined on your" + " machine.\nYou might have to re-authenticate when pushing to the Hugging" + " Face Hub.\nRun the following command in your terminal in case you want to" + " set the 'store' credential helper as default.\n\ngit config --global" + " credential.helper store\n\nRead" + " https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more" + " details." + ) + ) + return False + + +def _set_store_as_git_credential_helper_globally() -> None: + """Set globally the credential.helper to `store`. + + To be used only in Google Colab as we assume the user doesn't care about the git + credential config. It is the only particular case where we don't want to display the + warning message in [`notebook_login()`]. + + Related: + - https://github.com/huggingface/huggingface_hub/issues/1043 + - https://github.com/huggingface/huggingface_hub/issues/1051 + - https://git-scm.com/docs/git-credential-store + """ + try: + run_subprocess("git config --global credential.helper store") + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) diff --git a/phivenv/Lib/site-packages/huggingface_hub/_oauth.py b/phivenv/Lib/site-packages/huggingface_hub/_oauth.py new file mode 100644 index 0000000000000000000000000000000000000000..9f8eb607962bc18fec348fed18ce269524983e23 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_oauth.py @@ -0,0 +1,460 @@ +import datetime +import hashlib +import logging +import os +import time +import urllib.parse +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union + +from . import constants +from .hf_api import whoami +from .utils import experimental, get_token + + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import fastapi + + +@dataclass +class OAuthOrgInfo: + """ + Information about an organization linked to a user logged in with OAuth. + + Attributes: + sub (`str`): + Unique identifier for the org. OpenID Connect field. + name (`str`): + The org's full name. OpenID Connect field. + preferred_username (`str`): + The org's username. OpenID Connect field. + picture (`str`): + The org's profile picture URL. OpenID Connect field. + is_enterprise (`bool`): + Whether the org is an enterprise org. Hugging Face field. + can_pay (`Optional[bool]`, *optional*): + Whether the org has a payment method set up. Hugging Face field. + role_in_org (`Optional[str]`, *optional*): + The user's role in the org. Hugging Face field. + security_restrictions (`Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]]`, *optional*): + Array of security restrictions that the user hasn't completed for this org. Possible values: "ip", "token-policy", "mfa", "sso". Hugging Face field. + """ + + sub: str + name: str + preferred_username: str + picture: str + is_enterprise: bool + can_pay: Optional[bool] = None + role_in_org: Optional[str] = None + security_restrictions: Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]] = None + + +@dataclass +class OAuthUserInfo: + """ + Information about a user logged in with OAuth. + + Attributes: + sub (`str`): + Unique identifier for the user, even in case of rename. OpenID Connect field. + name (`str`): + The user's full name. OpenID Connect field. + preferred_username (`str`): + The user's username. OpenID Connect field. + email_verified (`Optional[bool]`, *optional*): + Indicates if the user's email is verified. OpenID Connect field. + email (`Optional[str]`, *optional*): + The user's email address. OpenID Connect field. + picture (`str`): + The user's profile picture URL. OpenID Connect field. + profile (`str`): + The user's profile URL. OpenID Connect field. + website (`Optional[str]`, *optional*): + The user's website URL. OpenID Connect field. + is_pro (`bool`): + Whether the user is a pro user. Hugging Face field. + can_pay (`Optional[bool]`, *optional*): + Whether the user has a payment method set up. Hugging Face field. + orgs (`Optional[List[OrgInfo]]`, *optional*): + List of organizations the user is part of. Hugging Face field. + """ + + sub: str + name: str + preferred_username: str + email_verified: Optional[bool] + email: Optional[str] + picture: str + profile: str + website: Optional[str] + is_pro: bool + can_pay: Optional[bool] + orgs: Optional[List[OAuthOrgInfo]] + + +@dataclass +class OAuthInfo: + """ + Information about the OAuth login. + + Attributes: + access_token (`str`): + The access token. + access_token_expires_at (`datetime.datetime`): + The expiration date of the access token. + user_info ([`OAuthUserInfo`]): + The user information. + state (`str`, *optional*): + State passed to the OAuth provider in the original request to the OAuth provider. + scope (`str`): + Granted scope. + """ + + access_token: str + access_token_expires_at: datetime.datetime + user_info: OAuthUserInfo + state: Optional[str] + scope: str + + +@experimental +def attach_huggingface_oauth(app: "fastapi.FastAPI", route_prefix: str = "/"): + """ + Add OAuth endpoints to a FastAPI app to enable OAuth login with Hugging Face. + + How to use: + - Call this method on your FastAPI app to add the OAuth endpoints. + - Inside your route handlers, call `parse_huggingface_oauth(request)` to retrieve the OAuth info. + - If user is logged in, an [`OAuthInfo`] object is returned with the user's info. If not, `None` is returned. + - In your app, make sure to add links to `/oauth/huggingface/login` and `/oauth/huggingface/logout` for the user to log in and out. + + Example: + ```py + from huggingface_hub import attach_huggingface_oauth, parse_huggingface_oauth + + # Create a FastAPI app + app = FastAPI() + + # Add OAuth endpoints to the FastAPI app + attach_huggingface_oauth(app) + + # Add a route that greets the user if they are logged in + @app.get("/") + def greet_json(request: Request): + # Retrieve the OAuth info from the request + oauth_info = parse_huggingface_oauth(request) # e.g. OAuthInfo dataclass + if oauth_info is None: + return {"msg": "Not logged in!"} + return {"msg": f"Hello, {oauth_info.user_info.preferred_username}!"} + ``` + """ + # TODO: handle generic case (handling OAuth in a non-Space environment with custom dev values) (low priority) + + # Add SessionMiddleware to the FastAPI app to store the OAuth info in the session. + # Session Middleware requires a secret key to sign the cookies. Let's use a hash + # of the OAuth secret key to make it unique to the Space + updated in case OAuth + # config gets updated. When ran locally, we use an empty string as a secret key. + try: + from starlette.middleware.sessions import SessionMiddleware + except ImportError as e: + raise ImportError( + "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " + "`huggingface_hub[oauth]` to your requirements.txt file in order to install the required dependencies." + ) from e + session_secret = (constants.OAUTH_CLIENT_SECRET or "") + "-v1" + app.add_middleware( + SessionMiddleware, # type: ignore[arg-type] + secret_key=hashlib.sha256(session_secret.encode()).hexdigest(), + same_site="none", + https_only=True, + ) # type: ignore + + # Add OAuth endpoints to the FastAPI app: + # - {route_prefix}/oauth/huggingface/login + # - {route_prefix}/oauth/huggingface/callback + # - {route_prefix}/oauth/huggingface/logout + # If the app is running in a Space, OAuth is enabled normally. + # Otherwise, we mock the endpoints to make the user log in with a fake user profile - without any calls to hf.co. + route_prefix = route_prefix.strip("/") + if os.getenv("SPACE_ID") is not None: + logger.info("OAuth is enabled in the Space. Adding OAuth routes.") + _add_oauth_routes(app, route_prefix=route_prefix) + else: + logger.info("App is not running in a Space. Adding mocked OAuth routes.") + _add_mocked_oauth_routes(app, route_prefix=route_prefix) + + +def parse_huggingface_oauth(request: "fastapi.Request") -> Optional[OAuthInfo]: + """ + Returns the information from a logged in user as a [`OAuthInfo`] object. + + For flexibility and future-proofing, this method is very lax in its parsing and does not raise errors. + Missing fields are set to `None` without a warning. + + Return `None`, if the user is not logged in (no info in session cookie). + + See [`attach_huggingface_oauth`] for an example on how to use this method. + """ + if "oauth_info" not in request.session: + logger.debug("No OAuth info in session.") + return None + + logger.debug("Parsing OAuth info from session.") + oauth_data = request.session["oauth_info"] + user_data = oauth_data.get("userinfo", {}) + orgs_data = user_data.get("orgs", []) + + orgs = ( + [ + OAuthOrgInfo( + sub=org.get("sub"), + name=org.get("name"), + preferred_username=org.get("preferred_username"), + picture=org.get("picture"), + is_enterprise=org.get("isEnterprise"), + can_pay=org.get("canPay"), + role_in_org=org.get("roleInOrg"), + security_restrictions=org.get("securityRestrictions"), + ) + for org in orgs_data + ] + if orgs_data + else None + ) + + user_info = OAuthUserInfo( + sub=user_data.get("sub"), + name=user_data.get("name"), + preferred_username=user_data.get("preferred_username"), + email_verified=user_data.get("email_verified"), + email=user_data.get("email"), + picture=user_data.get("picture"), + profile=user_data.get("profile"), + website=user_data.get("website"), + is_pro=user_data.get("isPro"), + can_pay=user_data.get("canPay"), + orgs=orgs, + ) + + return OAuthInfo( + access_token=oauth_data.get("access_token"), + access_token_expires_at=datetime.datetime.fromtimestamp(oauth_data.get("expires_at")), + user_info=user_info, + state=oauth_data.get("state"), + scope=oauth_data.get("scope"), + ) + + +def _add_oauth_routes(app: "fastapi.FastAPI", route_prefix: str) -> None: + """Add OAuth routes to the FastAPI app (login, callback handler and logout).""" + try: + import fastapi + from authlib.integrations.base_client.errors import MismatchingStateError + from authlib.integrations.starlette_client import OAuth + from fastapi.responses import RedirectResponse + except ImportError as e: + raise ImportError( + "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " + "`huggingface_hub[oauth]` to your requirements.txt file." + ) from e + + # Check environment variables + msg = ( + "OAuth is required but '{}' environment variable is not set. Make sure you've enabled OAuth in your Space by" + " setting `hf_oauth: true` in the Space metadata." + ) + if constants.OAUTH_CLIENT_ID is None: + raise ValueError(msg.format("OAUTH_CLIENT_ID")) + if constants.OAUTH_CLIENT_SECRET is None: + raise ValueError(msg.format("OAUTH_CLIENT_SECRET")) + if constants.OAUTH_SCOPES is None: + raise ValueError(msg.format("OAUTH_SCOPES")) + if constants.OPENID_PROVIDER_URL is None: + raise ValueError(msg.format("OPENID_PROVIDER_URL")) + + # Register OAuth server + oauth = OAuth() + oauth.register( + name="huggingface", + client_id=constants.OAUTH_CLIENT_ID, + client_secret=constants.OAUTH_CLIENT_SECRET, + client_kwargs={"scope": constants.OAUTH_SCOPES}, + server_metadata_url=constants.OPENID_PROVIDER_URL + "/.well-known/openid-configuration", + ) + + login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) + + # Register OAuth endpoints + @app.get(login_uri) + async def oauth_login(request: fastapi.Request) -> RedirectResponse: + """Endpoint that redirects to HF OAuth page.""" + redirect_uri = _generate_redirect_uri(request) + return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore + + @app.get(callback_uri) + async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: + """Endpoint that handles the OAuth callback.""" + try: + oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore + except MismatchingStateError: + # Parse query params + nb_redirects = int(request.query_params.get("_nb_redirects", 0)) + target_url = request.query_params.get("_target_url") + + # Build redirect URI with the same query params as before and bump nb_redirects count + query_params: Dict[str, Union[int, str]] = {"_nb_redirects": nb_redirects + 1} + if target_url: + query_params["_target_url"] = target_url + + redirect_uri = f"{login_uri}?{urllib.parse.urlencode(query_params)}" + + # If the user is redirected more than 3 times, it is very likely that the cookie is not working properly. + # (e.g. browser is blocking third-party cookies in iframe). In this case, redirect the user in the + # non-iframe view. + if nb_redirects > constants.OAUTH_MAX_REDIRECTS: + host = os.environ.get("SPACE_HOST") + if host is None: # cannot happen in a Space + raise RuntimeError( + "App is not running in a Space (SPACE_HOST environment variable is not set). Cannot redirect to non-iframe view." + ) from None + host_url = "https://" + host.rstrip("/") + return RedirectResponse(host_url + redirect_uri) + + # Redirect the user to the login page again + return RedirectResponse(redirect_uri) + + # OAuth login worked => store the user info in the session and redirect + logger.debug("Successfully logged in with OAuth. Storing user info in session.") + request.session["oauth_info"] = oauth_info + return RedirectResponse(_get_redirect_target(request)) + + @app.get(logout_uri) + async def oauth_logout(request: fastapi.Request) -> RedirectResponse: + """Endpoint that logs out the user (e.g. delete info from cookie session).""" + logger.debug("Logged out with OAuth. Removing user info from session.") + request.session.pop("oauth_info", None) + return RedirectResponse(_get_redirect_target(request)) + + +def _add_mocked_oauth_routes(app: "fastapi.FastAPI", route_prefix: str = "/") -> None: + """Add fake oauth routes if app is run locally and OAuth is enabled. + + Using OAuth will have the same behavior as in a Space but instead of authenticating with HF, a mocked user profile + is added to the session. + """ + try: + import fastapi + from fastapi.responses import RedirectResponse + from starlette.datastructures import URL + except ImportError as e: + raise ImportError( + "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " + "`huggingface_hub[oauth]` to your requirements.txt file." + ) from e + + warnings.warn( + "OAuth is not supported outside of a Space environment. To help you debug your app locally, the oauth endpoints" + " are mocked to return your profile and token. To make it work, your machine must be logged in to Huggingface." + ) + mocked_oauth_info = _get_mocked_oauth_info() + + login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) + + # Define OAuth routes + @app.get(login_uri) + async def oauth_login(request: fastapi.Request) -> RedirectResponse: + """Fake endpoint that redirects to HF OAuth page.""" + # Define target (where to redirect after login) + redirect_uri = _generate_redirect_uri(request) + return RedirectResponse(callback_uri + "?" + urllib.parse.urlencode({"_target_url": redirect_uri})) + + @app.get(callback_uri) + async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: + """Endpoint that handles the OAuth callback.""" + request.session["oauth_info"] = mocked_oauth_info + return RedirectResponse(_get_redirect_target(request)) + + @app.get(logout_uri) + async def oauth_logout(request: fastapi.Request) -> RedirectResponse: + """Endpoint that logs out the user (e.g. delete cookie session).""" + request.session.pop("oauth_info", None) + logout_url = URL("/").include_query_params(**request.query_params) + return RedirectResponse(url=logout_url, status_code=302) # see https://github.com/gradio-app/gradio/pull/9659 + + +def _generate_redirect_uri(request: "fastapi.Request") -> str: + if "_target_url" in request.query_params: + # if `_target_url` already in query params => respect it + target = request.query_params["_target_url"] + else: + # otherwise => keep query params + target = "/?" + urllib.parse.urlencode(request.query_params) + + redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target) + redirect_uri_as_str = str(redirect_uri) + if redirect_uri.netloc.endswith(".hf.space"): + # In Space, FastAPI redirect as http but we want https + redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://") + return redirect_uri_as_str + + +def _get_redirect_target(request: "fastapi.Request", default_target: str = "/") -> str: + return request.query_params.get("_target_url", default_target) + + +def _get_mocked_oauth_info() -> Dict: + token = get_token() + if token is None: + raise ValueError( + "Your machine must be logged in to HF to debug an OAuth app locally. Please" + " run `hf auth login` or set `HF_TOKEN` as environment variable " + "with one of your access token. You can generate a new token in your " + "settings page (https://huggingface.co/settings/tokens)." + ) + + user = whoami() + if user["type"] != "user": + raise ValueError( + "Your machine is not logged in with a personal account. Please use a " + "personal access token. You can generate a new token in your settings page" + " (https://huggingface.co/settings/tokens)." + ) + + return { + "access_token": token, + "token_type": "bearer", + "expires_in": 8 * 60 * 60, # 8 hours + "id_token": "FOOBAR", + "scope": "openid profile", + "refresh_token": "hf_oauth__refresh_token", + "expires_at": int(time.time()) + 8 * 60 * 60, # 8 hours + "userinfo": { + "sub": "0123456789", + "name": user["fullname"], + "preferred_username": user["name"], + "profile": f"https://huggingface.co/{user['name']}", + "picture": user["avatarUrl"], + "website": "", + "aud": "00000000-0000-0000-0000-000000000000", + "auth_time": 1691672844, + "nonce": "aaaaaaaaaaaaaaaaaaa", + "iat": 1691672844, + "exp": 1691676444, + "iss": "https://huggingface.co", + }, + } + + +def _get_oauth_uris(route_prefix: str = "/") -> Tuple[str, str, str]: + route_prefix = route_prefix.strip("/") + if route_prefix: + route_prefix = f"/{route_prefix}" + return ( + f"{route_prefix}/oauth/huggingface/login", + f"{route_prefix}/oauth/huggingface/callback", + f"{route_prefix}/oauth/huggingface/logout", + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/_snapshot_download.py b/phivenv/Lib/site-packages/huggingface_hub/_snapshot_download.py new file mode 100644 index 0000000000000000000000000000000000000000..0db8a29f7e65a4841590d033f6b7b51d46647bf0 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_snapshot_download.py @@ -0,0 +1,343 @@ +import os +from pathlib import Path +from typing import Dict, Iterable, List, Literal, Optional, Type, Union + +import requests +from tqdm.auto import tqdm as base_tqdm +from tqdm.contrib.concurrent import thread_map + +from . import constants +from .errors import ( + GatedRepoError, + HfHubHTTPError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name +from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo +from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args +from .utils import tqdm as hf_tqdm + + +logger = logging.get_logger(__name__) + +VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough + + +@validate_hf_hub_args +def snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + proxies: Optional[Dict] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[Type[base_tqdm]] = None, + headers: Optional[Dict[str, str]] = None, + endpoint: Optional[str] = None, + # Deprecated args + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + resume_download: Optional[bool] = None, +) -> str: + """Download repo files. + + Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from + a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. You can also filter which files to download using + `allow_patterns` and `ignore_patterns`. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files. While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly + configured. It is also not possible to filter which files to download when cloning a repository using git. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded files will be placed under this directory. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`str`, `dict`, *optional*): + The user-agent info in the form of a dictionary or a string. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in the local cache. + token (`str`, `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If a string, it's used as the authentication token. + headers (`dict`, *optional*): + Additional headers to include in the request. Those headers take precedence over the others. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are downloaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not downloaded. + max_workers (`int`, *optional*): + Number of concurrent threads to download files (1 thread = 1 file download). + Defaults to 8. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Note that the `tqdm_class` is not passed to each individual download. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + + Returns: + `str`: folder path of the repo snapshot. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` and the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid. + """ + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + if revision is None: + revision = constants.DEFAULT_REVISION + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if repo_type is None: + repo_type = "model" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") + + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) + + api = HfApi( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + endpoint=endpoint, + headers=headers, + token=token, + ) + + repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None + api_call_error: Optional[Exception] = None + if not local_files_only: + # try/except logic to handle different errors => taken from `hf_hub_download` + try: + # if we have internet connection we want to list files to download + repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision) + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + OfflineModeIsEnabled, + ) as error: + # Internet connection is down + # => will try to use local files only + api_call_error = error + pass + except RevisionNotFoundError: + # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) + raise + except requests.HTTPError as error: + # Multiple reasons for an http error: + # - Repository is private and invalid/missing token sent + # - Repository is gated and invalid/missing token sent + # - Hub is down (error 500 or 504) + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + api_call_error = error + pass + + # At this stage, if `repo_info` is None it means either: + # - internet connection is down + # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True) + # - repo is private/gated and invalid/missing token sent + # - Hub is down + # => let's look if we can find the appropriate folder in the cache: + # - if the specified revision is a commit hash, look inside "snapshots". + # - f the specified revision is a branch or tag, look inside "refs". + # => if local_dir is not None, we will return the path to the local folder if it exists. + if repo_info is None: + # Try to get which commit hash corresponds to the specified revision + commit_hash = None + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + ref_path = os.path.join(storage_folder, "refs", revision) + if os.path.exists(ref_path): + # retrieve commit_hash from refs file + with open(ref_path) as f: + commit_hash = f.read() + + # Try to locate snapshot folder for this commit hash + if commit_hash is not None and local_dir is None: + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(snapshot_folder): + # Snapshot folder exists => let's return it + # (but we can't check if all the files are actually there) + return snapshot_folder + + # If local_dir is not None, return it if it exists and is not empty + if local_dir is not None: + local_dir = Path(local_dir) + if local_dir.is_dir() and any(local_dir.iterdir()): + logger.warning( + f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})." + ) + return str(local_dir.resolve()) + # If we couldn't find the appropriate folder on disk, raise an error. + if local_files_only: + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass " + "'local_files_only=False' as input." + ) + elif isinstance(api_call_error, OfflineModeIsEnabled): + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set " + "'HF_HUB_OFFLINE=0' as environment variable." + ) from api_call_error + elif isinstance(api_call_error, (RepositoryNotFoundError, GatedRepoError)) or ( + isinstance(api_call_error, HfHubHTTPError) and api_call_error.response.status_code == 401 + ): + # Repo not found, gated, or specific authentication error => let's raise the actual error + raise api_call_error + else: + # Otherwise: most likely a connection issue or Hub downtime => let's warn the user + raise LocalEntryNotFoundError( + "An error happened while trying to locate the files on the Hub and we cannot find the appropriate" + " snapshot folder for the specified revision on the local disk. Please check your internet connection" + " and try again." + ) from api_call_error + + # At this stage, internet connection is up and running + # => let's download the files! + assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." + + # Corner case: on very large repos, the siblings list in `repo_info` might not contain all files. + # In that case, we need to use the `list_repo_tree` method to prevent caching issues. + repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] if repo_info.siblings is not None else [] + unreliable_nb_files = ( + repo_info.siblings is None + or len(repo_info.siblings) == 0 + or len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD + ) + if unreliable_nb_files: + logger.info( + "Number of files in the repo is unreliable. Using `list_repo_tree` to ensure all files are listed." + ) + repo_files = ( + f.rfilename + for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type) + if isinstance(f, RepoFile) + ) + + filtered_repo_files: Iterable[str] = filter_repo_objects( + items=repo_files, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + if not unreliable_nb_files: + filtered_repo_files = list(filtered_repo_files) + tqdm_desc = f"Fetching {len(filtered_repo_files)} files" + else: + tqdm_desc = "Fetching ... files" + + commit_hash = repo_info.sha + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + if revision != commit_hash: + ref_path = os.path.join(storage_folder, "refs", revision) + try: + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with open(ref_path, "w") as f: + f.write(commit_hash) + except OSError as e: + logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.") + + # we pass the commit_hash to hf_hub_download + # so no network call happens if we already + # have the file locally. + def _inner_hf_hub_download(repo_file: str): + return hf_hub_download( + repo_id, + filename=repo_file, + repo_type=repo_type, + revision=commit_hash, + endpoint=endpoint, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + force_download=force_download, + token=token, + headers=headers, + ) + + if constants.HF_HUB_ENABLE_HF_TRANSFER: + # when using hf_transfer we don't want extra parallelism + # from the one hf_transfer provides + for file in filtered_repo_files: + _inner_hf_hub_download(file) + else: + thread_map( + _inner_hf_hub_download, + filtered_repo_files, + desc=tqdm_desc, + max_workers=max_workers, + # User can use its own tqdm class or the default one from `huggingface_hub.utils` + tqdm_class=tqdm_class or hf_tqdm, + ) + + if local_dir is not None: + return str(os.path.realpath(local_dir)) + return snapshot_folder diff --git a/phivenv/Lib/site-packages/huggingface_hub/_space_api.py b/phivenv/Lib/site-packages/huggingface_hub/_space_api.py new file mode 100644 index 0000000000000000000000000000000000000000..05fccfbc1ebdfc14840a88751914b8fc0d1a498d --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_space_api.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Dict, Optional + +from huggingface_hub.utils import parse_datetime + + +class SpaceStage(str, Enum): + """ + Enumeration of possible stage of a Space on the Hub. + + Value can be compared to a string: + ```py + assert SpaceStage.BUILDING == "BUILDING" + ``` + + Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L61 (private url). + """ + + # Copied from moon-landing > server > repo_types > SpaceInfo.ts (private repo) + NO_APP_FILE = "NO_APP_FILE" + CONFIG_ERROR = "CONFIG_ERROR" + BUILDING = "BUILDING" + BUILD_ERROR = "BUILD_ERROR" + RUNNING = "RUNNING" + RUNNING_BUILDING = "RUNNING_BUILDING" + RUNTIME_ERROR = "RUNTIME_ERROR" + DELETING = "DELETING" + STOPPED = "STOPPED" + PAUSED = "PAUSED" + + +class SpaceHardware(str, Enum): + """ + Enumeration of hardwares available to run your Space on the Hub. + + Value can be compared to a string: + ```py + assert SpaceHardware.CPU_BASIC == "cpu-basic" + ``` + + Taken from https://github.com/huggingface-internal/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts (private url). + """ + + # CPU + CPU_BASIC = "cpu-basic" + CPU_UPGRADE = "cpu-upgrade" + CPU_XL = "cpu-xl" + + # ZeroGPU + ZERO_A10G = "zero-a10g" + + # GPU + T4_SMALL = "t4-small" + T4_MEDIUM = "t4-medium" + L4X1 = "l4x1" + L4X4 = "l4x4" + L40SX1 = "l40sx1" + L40SX4 = "l40sx4" + L40SX8 = "l40sx8" + A10G_SMALL = "a10g-small" + A10G_LARGE = "a10g-large" + A10G_LARGEX2 = "a10g-largex2" + A10G_LARGEX4 = "a10g-largex4" + A100_LARGE = "a100-large" + H100 = "h100" + H100X8 = "h100x8" + + +class SpaceStorage(str, Enum): + """ + Enumeration of persistent storage available for your Space on the Hub. + + Value can be compared to a string: + ```py + assert SpaceStorage.SMALL == "small" + ``` + + Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts#L24 (private url). + """ + + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + + +@dataclass +class SpaceRuntime: + """ + Contains information about the current runtime of a Space. + + Args: + stage (`str`): + Current stage of the space. Example: RUNNING. + hardware (`str` or `None`): + Current hardware of the space. Example: "cpu-basic". Can be `None` if Space + is `BUILDING` for the first time. + requested_hardware (`str` or `None`): + Requested hardware. Can be different than `hardware` especially if the request + has just been made. Example: "t4-medium". Can be `None` if no hardware has + been requested yet. + sleep_time (`int` or `None`): + Number of seconds the Space will be kept alive after the last request. By default (if value is `None`), the + Space will never go to sleep if it's running on an upgraded hardware, while it will go to sleep after 48 + hours on a free 'cpu-basic' hardware. For more details, see https://huggingface.co/docs/hub/spaces-gpus#sleep-time. + raw (`dict`): + Raw response from the server. Contains more information about the Space + runtime like number of replicas, number of cpu, memory size,... + """ + + stage: SpaceStage + hardware: Optional[SpaceHardware] + requested_hardware: Optional[SpaceHardware] + sleep_time: Optional[int] + storage: Optional[SpaceStorage] + raw: Dict + + def __init__(self, data: Dict) -> None: + self.stage = data["stage"] + self.hardware = data.get("hardware", {}).get("current") + self.requested_hardware = data.get("hardware", {}).get("requested") + self.sleep_time = data.get("gcTimeout") + self.storage = data.get("storage") + self.raw = data + + +@dataclass +class SpaceVariable: + """ + Contains information about the current variables of a Space. + + Args: + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + value (`str`): + Variable value. Example: `"the_model_repo_id"`. + description (`str` or None): + Description of the variable. Example: `"Model Repo ID of the implemented model"`. + updatedAt (`datetime` or None): + datetime of the last update of the variable (if the variable has been updated at least once). + """ + + key: str + value: str + description: Optional[str] + updated_at: Optional[datetime] + + def __init__(self, key: str, values: Dict) -> None: + self.key = key + self.value = values["value"] + self.description = values.get("description") + updated_at = values.get("updatedAt") + self.updated_at = parse_datetime(updated_at) if updated_at is not None else None diff --git a/phivenv/Lib/site-packages/huggingface_hub/_tensorboard_logger.py b/phivenv/Lib/site-packages/huggingface_hub/_tensorboard_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..5e910972463d3e6bc8b8796c95fde5696d999952 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_tensorboard_logger.py @@ -0,0 +1,194 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a logger to push training logs to the Hub, using Tensorboard.""" + +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Union + +from ._commit_scheduler import CommitScheduler +from .errors import EntryNotFoundError +from .repocard import ModelCard +from .utils import experimental + + +# Depending on user's setup, SummaryWriter can come either from 'tensorboardX' +# or from 'torch.utils.tensorboard'. Both are compatible so let's try to load +# from either of them. +try: + from tensorboardX import SummaryWriter + + is_summary_writer_available = True + +except ImportError: + try: + from torch.utils.tensorboard import SummaryWriter + + is_summary_writer_available = False + except ImportError: + # Dummy class to avoid failing at import. Will raise on instance creation. + SummaryWriter = object + is_summary_writer_available = False + +if TYPE_CHECKING: + from tensorboardX import SummaryWriter + + +class HFSummaryWriter(SummaryWriter): + """ + Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub. + + Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate + thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection + issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every` + minutes (default to every 5 minutes). + + + + `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice. + + + + Args: + repo_id (`str`): + The id of the repo to which the logs will be pushed. + logdir (`str`, *optional*): + The directory where the logs will be written. If not specified, a local directory will be created by the + underlying `SummaryWriter` object. + commit_every (`int` or `float`, *optional*): + The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes. + squash_history (`bool`, *optional*): + Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is + useful to avoid degraded performances on the repo when it grows too large. + repo_type (`str`, *optional*): + The type of the repo to which the logs will be pushed. Defaults to "model". + repo_revision (`str`, *optional*): + The revision of the repo to which the logs will be pushed. Defaults to "main". + repo_private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + path_in_repo (`str`, *optional*): + The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/". + repo_allow_patterns (`List[str]` or `str`, *optional*): + A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the + [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. + repo_ignore_patterns (`List[str]` or `str`, *optional*): + A list of patterns to exclude in the upload. Check out the + [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. + token (`str`, *optional*): + Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more + details + kwargs: + Additional keyword arguments passed to `SummaryWriter`. + + Examples: + ```diff + # Taken from https://pytorch.org/docs/stable/tensorboard.html + - from torch.utils.tensorboard import SummaryWriter + + from huggingface_hub import HFSummaryWriter + + import numpy as np + + - writer = SummaryWriter() + + writer = HFSummaryWriter(repo_id="username/my-trained-model") + + for n_iter in range(100): + writer.add_scalar('Loss/train', np.random.random(), n_iter) + writer.add_scalar('Loss/test', np.random.random(), n_iter) + writer.add_scalar('Accuracy/train', np.random.random(), n_iter) + writer.add_scalar('Accuracy/test', np.random.random(), n_iter) + ``` + + ```py + >>> from huggingface_hub import HFSummaryWriter + + # Logs are automatically pushed every 15 minutes (5 by default) + when exiting the context manager + >>> with HFSummaryWriter(repo_id="test_hf_logger", commit_every=15) as logger: + ... logger.add_scalar("a", 1) + ... logger.add_scalar("b", 2) + ``` + """ + + @experimental + def __new__(cls, *args, **kwargs) -> "HFSummaryWriter": + if not is_summary_writer_available: + raise ImportError( + "You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade" + " tensorboardX` first." + ) + return super().__new__(cls) + + def __init__( + self, + repo_id: str, + *, + logdir: Optional[str] = None, + commit_every: Union[int, float] = 5, + squash_history: bool = False, + repo_type: Optional[str] = None, + repo_revision: Optional[str] = None, + repo_private: Optional[bool] = None, + path_in_repo: Optional[str] = "tensorboard", + repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*", + repo_ignore_patterns: Optional[Union[List[str], str]] = None, + token: Optional[str] = None, + **kwargs, + ): + # Initialize SummaryWriter + super().__init__(logdir=logdir, **kwargs) + + # Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it. + if not isinstance(self.logdir, str): + raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.") + + # Append logdir name to `path_in_repo` + if path_in_repo is None or path_in_repo == "": + path_in_repo = Path(self.logdir).name + else: + path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name + + # Initialize scheduler + self.scheduler = CommitScheduler( + folder_path=self.logdir, + path_in_repo=path_in_repo, + repo_id=repo_id, + repo_type=repo_type, + revision=repo_revision, + private=repo_private, + token=token, + allow_patterns=repo_allow_patterns, + ignore_patterns=repo_ignore_patterns, + every=commit_every, + squash_history=squash_history, + ) + + # Exposing some high-level info at root level + self.repo_id = self.scheduler.repo_id + self.repo_type = self.scheduler.repo_type + self.repo_revision = self.scheduler.revision + + # Add `hf-summary-writer` tag to the model card metadata + try: + card = ModelCard.load(repo_id_or_path=self.repo_id, repo_type=self.repo_type) + except EntryNotFoundError: + card = ModelCard("") + tags = card.data.get("tags", []) + if "hf-summary-writer" not in tags: + tags.append("hf-summary-writer") + card.data["tags"] = tags + card.push_to_hub(repo_id=self.repo_id, repo_type=self.repo_type) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Push to hub in a non-blocking way when exiting the logger's context manager.""" + super().__exit__(exc_type, exc_val, exc_tb) + future = self.scheduler.trigger() + future.result() diff --git a/phivenv/Lib/site-packages/huggingface_hub/_upload_large_folder.py b/phivenv/Lib/site-packages/huggingface_hub/_upload_large_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..1ccbc07d39d3d03e9bb8c39f1bb16aa2ca4ab41f --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_upload_large_folder.py @@ -0,0 +1,755 @@ +# coding=utf-8 +# Copyright 2024-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +import logging +import os +import queue +import shutil +import sys +import threading +import time +import traceback +from datetime import datetime +from pathlib import Path +from threading import Lock +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from urllib.parse import quote + +from . import constants +from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes +from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata +from .constants import DEFAULT_REVISION, REPO_TYPES +from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm +from .utils._cache_manager import _format_size +from .utils._runtime import is_xet_available +from .utils.sha import sha_fileobj + + +if TYPE_CHECKING: + from .hf_api import HfApi + +logger = logging.getLogger(__name__) + +WAITING_TIME_IF_NO_TASKS = 10 # seconds +MAX_NB_FILES_FETCH_UPLOAD_MODE = 100 +COMMIT_SIZE_SCALE: List[int] = [20, 50, 75, 100, 125, 200, 250, 400, 600, 1000] + +UPLOAD_BATCH_SIZE_XET = 256 # Max 256 files per upload batch for XET-enabled repos +UPLOAD_BATCH_SIZE_LFS = 1 # Otherwise, batches of 1 for regular LFS upload + +# Repository limits (from https://huggingface.co/docs/hub/repositories-recommendations) +MAX_FILES_PER_REPO = 100_000 # Recommended maximum number of files per repository +MAX_FILES_PER_FOLDER = 10_000 # Recommended maximum number of files per folder +MAX_FILE_SIZE_GB = 50 # Hard limit for individual file size +RECOMMENDED_FILE_SIZE_GB = 20 # Recommended maximum for individual file size + + +def _validate_upload_limits(paths_list: List[LocalUploadFilePaths]) -> None: + """ + Validate upload against repository limits and warn about potential issues. + + Args: + paths_list: List of file paths to be uploaded + + Warns about: + - Too many files in the repository (>100k) + - Too many entries (files or subdirectories) in a single folder (>10k) + - Files exceeding size limits (>20GB recommended, >50GB hard limit) + """ + logger.info("Running validation checks on files to upload...") + + # Check 1: Total file count + if len(paths_list) > MAX_FILES_PER_REPO: + logger.warning( + f"You are about to upload {len(paths_list):,} files. " + f"This exceeds the recommended limit of {MAX_FILES_PER_REPO:,} files per repository.\n" + f"Consider:\n" + f" - Splitting your data into multiple repositories\n" + f" - Using fewer, larger files (e.g., parquet files)\n" + f" - See: https://huggingface.co/docs/hub/repositories-recommendations" + ) + + # Check 2: Files and subdirectories per folder + # Track immediate children (files and subdirs) for each folder + from collections import defaultdict + + entries_per_folder: Dict[str, Any] = defaultdict(lambda: {"files": 0, "subdirs": set()}) + + for paths in paths_list: + path = Path(paths.path_in_repo) + parts = path.parts + + # Count this file in its immediate parent directory + parent = str(path.parent) if str(path.parent) != "." else "." + entries_per_folder[parent]["files"] += 1 + + # Track immediate subdirectories for each parent folder + # Walk through the path components to track parent-child relationships + for i, child in enumerate(parts[:-1]): + parent = "." if i == 0 else "/".join(parts[:i]) + entries_per_folder[parent]["subdirs"].add(child) + + # Check limits for each folder + for folder, data in entries_per_folder.items(): + file_count = data["files"] + subdir_count = len(data["subdirs"]) + total_entries = file_count + subdir_count + + if total_entries > MAX_FILES_PER_FOLDER: + folder_display = "root" if folder == "." else folder + logger.warning( + f"Folder '{folder_display}' contains {total_entries:,} entries " + f"({file_count:,} files and {subdir_count:,} subdirectories). " + f"This exceeds the recommended {MAX_FILES_PER_FOLDER:,} entries per folder.\n" + "Consider reorganising into sub-folders." + ) + + # Check 3: File sizes + large_files = [] + very_large_files = [] + + for paths in paths_list: + size = paths.file_path.stat().st_size + size_gb = size / 1_000_000_000 # Use decimal GB as per Hub limits + + if size_gb > MAX_FILE_SIZE_GB: + very_large_files.append((paths.path_in_repo, size_gb)) + elif size_gb > RECOMMENDED_FILE_SIZE_GB: + large_files.append((paths.path_in_repo, size_gb)) + + # Warn about very large files (>50GB) + if very_large_files: + files_str = "\n - ".join(f"{path}: {size:.1f}GB" for path, size in very_large_files[:5]) + more_str = f"\n ... and {len(very_large_files) - 5} more files" if len(very_large_files) > 5 else "" + logger.warning( + f"Found {len(very_large_files)} files exceeding the {MAX_FILE_SIZE_GB}GB hard limit:\n" + f" - {files_str}{more_str}\n" + f"These files may fail to upload. Consider splitting them into smaller chunks." + ) + + # Warn about large files (>20GB) + if large_files: + files_str = "\n - ".join(f"{path}: {size:.1f}GB" for path, size in large_files[:5]) + more_str = f"\n ... and {len(large_files) - 5} more files" if len(large_files) > 5 else "" + logger.warning( + f"Found {len(large_files)} files larger than {RECOMMENDED_FILE_SIZE_GB}GB (recommended limit):\n" + f" - {files_str}{more_str}\n" + f"Large files may slow down loading and processing." + ) + + logger.info("Validation checks complete.") + + +def upload_large_folder_internal( + api: "HfApi", + repo_id: str, + folder_path: Union[str, Path], + *, + repo_type: str, # Repo type is required! + revision: Optional[str] = None, + private: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + num_workers: Optional[int] = None, + print_report: bool = True, + print_report_every: int = 60, +): + """Upload a large folder to the Hub in the most resilient way possible. + + See [`HfApi.upload_large_folder`] for the full documentation. + """ + # 1. Check args and setup + if repo_type is None: + raise ValueError( + "For large uploads, `repo_type` is explicitly required. Please set it to `model`, `dataset` or `space`." + " If you are using the CLI, pass it as `--repo-type=model`." + ) + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if revision is None: + revision = DEFAULT_REVISION + + folder_path = Path(folder_path).expanduser().resolve() + if not folder_path.is_dir(): + raise ValueError(f"Provided path: '{folder_path}' is not a directory") + + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + ignore_patterns += DEFAULT_IGNORE_PATTERNS + + if num_workers is None: + nb_cores = os.cpu_count() or 1 + num_workers = max(nb_cores - 2, 2) # Use all but 2 cores, or at least 2 cores + + # 2. Create repo if missing + repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True) + logger.info(f"Repo created: {repo_url}") + repo_id = repo_url.repo_id + # 2.1 Check if xet is enabled to set batch file upload size + is_xet_enabled = ( + is_xet_available() + and api.repo_info( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + expand="xetEnabled", + ).xet_enabled + ) + upload_batch_size = UPLOAD_BATCH_SIZE_XET if is_xet_enabled else UPLOAD_BATCH_SIZE_LFS + + # 3. List files to upload + filtered_paths_list = filter_repo_objects( + (path.relative_to(folder_path).as_posix() for path in folder_path.glob("**/*") if path.is_file()), + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + paths_list = [get_local_upload_paths(folder_path, relpath) for relpath in filtered_paths_list] + logger.info(f"Found {len(paths_list)} candidate files to upload") + + # Validate upload against repository limits + _validate_upload_limits(paths_list) + + logger.info("Starting upload...") + + # Read metadata for each file + items = [ + (paths, read_upload_metadata(folder_path, paths.path_in_repo)) + for paths in tqdm(paths_list, desc="Recovering from metadata files") + ] + + # 4. Start workers + status = LargeUploadStatus(items, upload_batch_size) + threads = [ + threading.Thread( + target=_worker_job, + kwargs={ + "status": status, + "api": api, + "repo_id": repo_id, + "repo_type": repo_type, + "revision": revision, + }, + ) + for _ in range(num_workers) + ] + + for thread in threads: + thread.start() + + # 5. Print regular reports + if print_report: + print("\n\n" + status.current_report()) + last_report_ts = time.time() + while True: + time.sleep(1) + if time.time() - last_report_ts >= print_report_every: + if print_report: + _print_overwrite(status.current_report()) + last_report_ts = time.time() + if status.is_done(): + logging.info("Is done: exiting main loop") + break + + for thread in threads: + thread.join() + + logger.info(status.current_report()) + logging.info("Upload is complete!") + + +#################### +# Logic to manage workers and synchronize tasks +#################### + + +class WorkerJob(enum.Enum): + SHA256 = enum.auto() + GET_UPLOAD_MODE = enum.auto() + PREUPLOAD_LFS = enum.auto() + COMMIT = enum.auto() + WAIT = enum.auto() # if no tasks are available but we don't want to exit + + +JOB_ITEM_T = Tuple[LocalUploadFilePaths, LocalUploadFileMetadata] + + +class LargeUploadStatus: + """Contains information, queues and tasks for a large upload process.""" + + def __init__(self, items: List[JOB_ITEM_T], upload_batch_size: int = 1): + self.items = items + self.queue_sha256: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.queue_get_upload_mode: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.queue_preupload_lfs: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.queue_commit: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.lock = Lock() + + self.nb_workers_sha256: int = 0 + self.nb_workers_get_upload_mode: int = 0 + self.nb_workers_preupload_lfs: int = 0 + self.upload_batch_size: int = upload_batch_size + self.nb_workers_commit: int = 0 + self.nb_workers_waiting: int = 0 + self.last_commit_attempt: Optional[float] = None + + self._started_at = datetime.now() + self._chunk_idx: int = 1 + self._chunk_lock: Lock = Lock() + + # Setup queues + for item in self.items: + paths, metadata = item + if metadata.sha256 is None: + self.queue_sha256.put(item) + elif metadata.upload_mode is None: + self.queue_get_upload_mode.put(item) + elif metadata.upload_mode == "lfs" and not metadata.is_uploaded: + self.queue_preupload_lfs.put(item) + elif not metadata.is_committed: + self.queue_commit.put(item) + else: + logger.debug(f"Skipping file {paths.path_in_repo} (already uploaded and committed)") + + def target_chunk(self) -> int: + with self._chunk_lock: + return COMMIT_SIZE_SCALE[self._chunk_idx] + + def update_chunk(self, success: bool, nb_items: int, duration: float) -> None: + with self._chunk_lock: + if not success: + logger.warning(f"Failed to commit {nb_items} files at once. Will retry with less files in next batch.") + self._chunk_idx -= 1 + elif nb_items >= COMMIT_SIZE_SCALE[self._chunk_idx] and duration < 40: + logger.info(f"Successfully committed {nb_items} at once. Increasing the limit for next batch.") + self._chunk_idx += 1 + + self._chunk_idx = max(0, min(self._chunk_idx, len(COMMIT_SIZE_SCALE) - 1)) + + def current_report(self) -> str: + """Generate a report of the current status of the large upload.""" + nb_hashed = 0 + size_hashed = 0 + nb_preuploaded = 0 + nb_lfs = 0 + nb_lfs_unsure = 0 + size_preuploaded = 0 + nb_committed = 0 + size_committed = 0 + total_size = 0 + ignored_files = 0 + total_files = 0 + + with self.lock: + for _, metadata in self.items: + if metadata.should_ignore: + ignored_files += 1 + continue + total_size += metadata.size + total_files += 1 + if metadata.sha256 is not None: + nb_hashed += 1 + size_hashed += metadata.size + if metadata.upload_mode == "lfs": + nb_lfs += 1 + if metadata.upload_mode is None: + nb_lfs_unsure += 1 + if metadata.is_uploaded: + nb_preuploaded += 1 + size_preuploaded += metadata.size + if metadata.is_committed: + nb_committed += 1 + size_committed += metadata.size + total_size_str = _format_size(total_size) + + now = datetime.now() + now_str = now.strftime("%Y-%m-%d %H:%M:%S") + elapsed = now - self._started_at + elapsed_str = str(elapsed).split(".")[0] # remove milliseconds + + message = "\n" + "-" * 10 + message += f" {now_str} ({elapsed_str}) " + message += "-" * 10 + "\n" + + message += "Files: " + message += f"hashed {nb_hashed}/{total_files} ({_format_size(size_hashed)}/{total_size_str}) | " + message += f"pre-uploaded: {nb_preuploaded}/{nb_lfs} ({_format_size(size_preuploaded)}/{total_size_str})" + if nb_lfs_unsure > 0: + message += f" (+{nb_lfs_unsure} unsure)" + message += f" | committed: {nb_committed}/{total_files} ({_format_size(size_committed)}/{total_size_str})" + message += f" | ignored: {ignored_files}\n" + + message += "Workers: " + message += f"hashing: {self.nb_workers_sha256} | " + message += f"get upload mode: {self.nb_workers_get_upload_mode} | " + message += f"pre-uploading: {self.nb_workers_preupload_lfs} | " + message += f"committing: {self.nb_workers_commit} | " + message += f"waiting: {self.nb_workers_waiting}\n" + message += "-" * 51 + + return message + + def is_done(self) -> bool: + with self.lock: + return all(metadata.is_committed or metadata.should_ignore for _, metadata in self.items) + + +def _worker_job( + status: LargeUploadStatus, + api: "HfApi", + repo_id: str, + repo_type: str, + revision: str, +): + """ + Main process for a worker. The worker will perform tasks based on the priority list until all files are uploaded + and committed. If no tasks are available, the worker will wait for 10 seconds before checking again. + + If a task fails for any reason, the item(s) are put back in the queue for another worker to pick up. + + Read `upload_large_folder` docstring for more information on how tasks are prioritized. + """ + while True: + next_job: Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]] = None + + # Determine next task + next_job = _determine_next_job(status) + if next_job is None: + return + job, items = next_job + + # Perform task + if job == WorkerJob.SHA256: + item = items[0] # single item + try: + _compute_sha256(item) + status.queue_get_upload_mode.put(item) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to compute sha256: {e}") + traceback.format_exc() + status.queue_sha256.put(item) + + with status.lock: + status.nb_workers_sha256 -= 1 + + elif job == WorkerJob.GET_UPLOAD_MODE: + try: + _get_upload_mode(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to get upload mode: {e}") + traceback.format_exc() + + # Items are either: + # - dropped (if should_ignore) + # - put in LFS queue (if LFS) + # - put in commit queue (if regular) + # - or put back (if error occurred). + for item in items: + _, metadata = item + if metadata.should_ignore: + continue + if metadata.upload_mode == "lfs": + status.queue_preupload_lfs.put(item) + elif metadata.upload_mode == "regular": + status.queue_commit.put(item) + else: + status.queue_get_upload_mode.put(item) + + with status.lock: + status.nb_workers_get_upload_mode -= 1 + + elif job == WorkerJob.PREUPLOAD_LFS: + try: + _preupload_lfs(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) + for item in items: + status.queue_commit.put(item) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to preupload LFS: {e}") + traceback.format_exc() + for item in items: + status.queue_preupload_lfs.put(item) + + with status.lock: + status.nb_workers_preupload_lfs -= 1 + + elif job == WorkerJob.COMMIT: + start_ts = time.time() + success = True + try: + _commit(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to commit: {e}") + traceback.format_exc() + for item in items: + status.queue_commit.put(item) + success = False + duration = time.time() - start_ts + status.update_chunk(success, len(items), duration) + with status.lock: + status.last_commit_attempt = time.time() + status.nb_workers_commit -= 1 + + elif job == WorkerJob.WAIT: + time.sleep(WAITING_TIME_IF_NO_TASKS) + with status.lock: + status.nb_workers_waiting -= 1 + + +def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]]: + with status.lock: + # 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file) + if ( + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.last_commit_attempt is not None + and time.time() - status.last_commit_attempt > 5 * 60 + ): + status.nb_workers_commit += 1 + logger.debug("Job: commit (more than 5 minutes since last commit attempt)") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 2. Commit if at least 100 files are ready to commit + elif status.nb_workers_commit == 0 and status.queue_commit.qsize() >= 150: + status.nb_workers_commit += 1 + logger.debug("Job: commit (>100 files ready)") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 3. Get upload mode if at least 100 files + elif status.queue_get_upload_mode.qsize() >= MAX_NB_FILES_FETCH_UPLOAD_MODE: + status.nb_workers_get_upload_mode += 1 + logger.debug(f"Job: get upload mode (>{MAX_NB_FILES_FETCH_UPLOAD_MODE} files ready)") + return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE)) + + # 4. Preupload LFS file if at least `status.upload_batch_size` files and no worker is preuploading LFS + elif status.queue_preupload_lfs.qsize() >= status.upload_batch_size and status.nb_workers_preupload_lfs == 0: + status.nb_workers_preupload_lfs += 1 + logger.debug("Job: preupload LFS (no other worker preuploading LFS)") + return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size)) + + # 5. Compute sha256 if at least 1 file and no worker is computing sha256 + elif status.queue_sha256.qsize() > 0 and status.nb_workers_sha256 == 0: + status.nb_workers_sha256 += 1 + logger.debug("Job: sha256 (no other worker computing sha256)") + return (WorkerJob.SHA256, _get_one(status.queue_sha256)) + + # 6. Get upload mode if at least 1 file and no worker is getting upload mode + elif status.queue_get_upload_mode.qsize() > 0 and status.nb_workers_get_upload_mode == 0: + status.nb_workers_get_upload_mode += 1 + logger.debug("Job: get upload mode (no other worker getting upload mode)") + return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE)) + + # 7. Preupload LFS file if at least `status.upload_batch_size` files + # Skip if hf_transfer is enabled and there is already a worker preuploading LFS + elif status.queue_preupload_lfs.qsize() >= status.upload_batch_size and ( + status.nb_workers_preupload_lfs == 0 or not constants.HF_HUB_ENABLE_HF_TRANSFER + ): + status.nb_workers_preupload_lfs += 1 + logger.debug("Job: preupload LFS") + return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size)) + + # 8. Compute sha256 if at least 1 file + elif status.queue_sha256.qsize() > 0: + status.nb_workers_sha256 += 1 + logger.debug("Job: sha256") + return (WorkerJob.SHA256, _get_one(status.queue_sha256)) + + # 9. Get upload mode if at least 1 file + elif status.queue_get_upload_mode.qsize() > 0: + status.nb_workers_get_upload_mode += 1 + logger.debug("Job: get upload mode") + return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE)) + + # 10. Preupload LFS file if at least 1 file + elif status.queue_preupload_lfs.qsize() > 0: + status.nb_workers_preupload_lfs += 1 + logger.debug("Job: preupload LFS") + return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size)) + + # 11. Commit if at least 1 file and 1 min since last commit attempt + elif ( + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.last_commit_attempt is not None + and time.time() - status.last_commit_attempt > 1 * 60 + ): + status.nb_workers_commit += 1 + logger.debug("Job: commit (1 min since last commit attempt)") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 12. Commit if at least 1 file all other queues are empty and all workers are waiting + # e.g. when it's the last commit + elif ( + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.queue_sha256.qsize() == 0 + and status.queue_get_upload_mode.qsize() == 0 + and status.queue_preupload_lfs.qsize() == 0 + and status.nb_workers_sha256 == 0 + and status.nb_workers_get_upload_mode == 0 + and status.nb_workers_preupload_lfs == 0 + ): + status.nb_workers_commit += 1 + logger.debug("Job: commit") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 13. If all queues are empty, exit + elif all(metadata.is_committed or metadata.should_ignore for _, metadata in status.items): + logger.info("All files have been processed! Exiting worker.") + return None + + # 14. If no task is available, wait + else: + status.nb_workers_waiting += 1 + logger.debug(f"No task available, waiting... ({WAITING_TIME_IF_NO_TASKS}s)") + return (WorkerJob.WAIT, []) + + +#################### +# Atomic jobs (sha256, get_upload_mode, preupload_lfs, commit) +#################### + + +def _compute_sha256(item: JOB_ITEM_T) -> None: + """Compute sha256 of a file and save it in metadata.""" + paths, metadata = item + if metadata.sha256 is None: + with paths.file_path.open("rb") as f: + metadata.sha256 = sha_fileobj(f).hex() + metadata.save(paths) + + +def _get_upload_mode(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: + """Get upload mode for each file and update metadata. + + Also receive info if the file should be ignored. + """ + additions = [_build_hacky_operation(item) for item in items] + _fetch_upload_modes( + additions=additions, + repo_type=repo_type, + repo_id=repo_id, + headers=api._build_hf_headers(), + revision=quote(revision, safe=""), + endpoint=api.endpoint, + ) + for item, addition in zip(items, additions): + paths, metadata = item + metadata.upload_mode = addition._upload_mode + metadata.should_ignore = addition._should_ignore + metadata.remote_oid = addition._remote_oid + metadata.save(paths) + + +def _preupload_lfs(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: + """Preupload LFS files and update metadata.""" + additions = [_build_hacky_operation(item) for item in items] + api.preupload_lfs_files( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + additions=additions, + ) + + for paths, metadata in items: + metadata.is_uploaded = True + metadata.save(paths) + + +def _commit(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: + """Commit files to the repo.""" + additions = [_build_hacky_operation(item) for item in items] + api.create_commit( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + operations=additions, + commit_message="Add files using upload-large-folder tool", + ) + for paths, metadata in items: + metadata.is_committed = True + metadata.save(paths) + + +#################### +# Hacks with CommitOperationAdd to bypass checks/sha256 calculation +#################### + + +class HackyCommitOperationAdd(CommitOperationAdd): + def __post_init__(self) -> None: + if isinstance(self.path_or_fileobj, Path): + self.path_or_fileobj = str(self.path_or_fileobj) + + +def _build_hacky_operation(item: JOB_ITEM_T) -> HackyCommitOperationAdd: + paths, metadata = item + operation = HackyCommitOperationAdd(path_in_repo=paths.path_in_repo, path_or_fileobj=paths.file_path) + with paths.file_path.open("rb") as file: + sample = file.peek(512)[:512] + if metadata.sha256 is None: + raise ValueError("sha256 must have been computed by now!") + operation.upload_info = UploadInfo(sha256=bytes.fromhex(metadata.sha256), size=metadata.size, sample=sample) + operation._upload_mode = metadata.upload_mode # type: ignore[assignment] + operation._should_ignore = metadata.should_ignore + operation._remote_oid = metadata.remote_oid + return operation + + +#################### +# Misc helpers +#################### + + +def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]: + return [queue.get()] + + +def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> List[JOB_ITEM_T]: + return [queue.get() for _ in range(min(queue.qsize(), n))] + + +def _print_overwrite(report: str) -> None: + """Print a report, overwriting the previous lines. + + Since tqdm in using `sys.stderr` to (re-)write progress bars, we need to use `sys.stdout` + to print the report. + + Note: works well only if no other process is writing to `sys.stdout`! + """ + report += "\n" + # Get terminal width + terminal_width = shutil.get_terminal_size().columns + + # Count number of lines that should be cleared + nb_lines = sum(len(line) // terminal_width + 1 for line in report.splitlines()) + + # Clear previous lines based on the number of lines in the report + for _ in range(nb_lines): + sys.stdout.write("\r\033[K") # Clear line + sys.stdout.write("\033[F") # Move cursor up one line + + # Print the new report, filling remaining space with whitespace + sys.stdout.write(report) + sys.stdout.write(" " * (terminal_width - len(report.splitlines()[-1]))) + sys.stdout.flush() diff --git a/phivenv/Lib/site-packages/huggingface_hub/_webhooks_payload.py b/phivenv/Lib/site-packages/huggingface_hub/_webhooks_payload.py new file mode 100644 index 0000000000000000000000000000000000000000..288f4b08b9428980e99ca06703442eab62fad277 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_webhooks_payload.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains data structures to parse the webhooks payload.""" + +from typing import List, Literal, Optional + +from .utils import is_pydantic_available + + +if is_pydantic_available(): + from pydantic import BaseModel +else: + # Define a dummy BaseModel to avoid import errors when pydantic is not installed + # Import error will be raised when trying to use the class + + class BaseModel: # type: ignore [no-redef] + def __init__(self, *args, **kwargs) -> None: + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + + +# This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they +# are not in used anymore. To keep in sync when format is updated in +# https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link). + + +WebhookEvent_T = Literal[ + "create", + "delete", + "move", + "update", +] +RepoChangeEvent_T = Literal[ + "add", + "move", + "remove", + "update", +] +RepoType_T = Literal[ + "dataset", + "model", + "space", +] +DiscussionStatus_T = Literal[ + "closed", + "draft", + "open", + "merged", +] +SupportedWebhookVersion = Literal[3] + + +class ObjectId(BaseModel): + id: str + + +class WebhookPayloadUrl(BaseModel): + web: str + api: Optional[str] = None + + +class WebhookPayloadMovedTo(BaseModel): + name: str + owner: ObjectId + + +class WebhookPayloadWebhook(ObjectId): + version: SupportedWebhookVersion + + +class WebhookPayloadEvent(BaseModel): + action: WebhookEvent_T + scope: str + + +class WebhookPayloadDiscussionChanges(BaseModel): + base: str + mergeCommitId: Optional[str] = None + + +class WebhookPayloadComment(ObjectId): + author: ObjectId + hidden: bool + content: Optional[str] = None + url: WebhookPayloadUrl + + +class WebhookPayloadDiscussion(ObjectId): + num: int + author: ObjectId + url: WebhookPayloadUrl + title: str + isPullRequest: bool + status: DiscussionStatus_T + changes: Optional[WebhookPayloadDiscussionChanges] = None + pinned: Optional[bool] = None + + +class WebhookPayloadRepo(ObjectId): + owner: ObjectId + head_sha: Optional[str] = None + name: str + private: bool + subdomain: Optional[str] = None + tags: Optional[List[str]] = None + type: Literal["dataset", "model", "space"] + url: WebhookPayloadUrl + + +class WebhookPayloadUpdatedRef(BaseModel): + ref: str + oldSha: Optional[str] = None + newSha: Optional[str] = None + + +class WebhookPayload(BaseModel): + event: WebhookPayloadEvent + repo: WebhookPayloadRepo + discussion: Optional[WebhookPayloadDiscussion] = None + comment: Optional[WebhookPayloadComment] = None + webhook: WebhookPayloadWebhook + movedTo: Optional[WebhookPayloadMovedTo] = None + updatedRefs: Optional[List[WebhookPayloadUpdatedRef]] = None diff --git a/phivenv/Lib/site-packages/huggingface_hub/_webhooks_server.py b/phivenv/Lib/site-packages/huggingface_hub/_webhooks_server.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bd6c86261b1cc26dfcfe3a65f5aec9851a1162 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/_webhooks_server.py @@ -0,0 +1,388 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily.""" + +import atexit +import inspect +import os +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +from .utils import experimental, is_fastapi_available, is_gradio_available + + +if TYPE_CHECKING: + import gradio as gr + from fastapi import Request + +if is_fastapi_available(): + from fastapi import FastAPI, Request + from fastapi.responses import JSONResponse +else: + # Will fail at runtime if FastAPI is not available + FastAPI = Request = JSONResponse = None # type: ignore [misc, assignment] + + +_global_app: Optional["WebhooksServer"] = None +_is_local = os.environ.get("SPACE_ID") is None + + +@experimental +class WebhooksServer: + """ + The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks. + These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to + the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `launch` method has to be + called to start the app. + + It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic + model that contains all the information about the webhook event. The data will be parsed automatically for you. + + Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your + WebhooksServer and deploy it on a Space. + + + + `WebhooksServer` is experimental. Its API is subject to change in the future. + + + + + + You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`). + + + + Args: + ui (`gradio.Blocks`, optional): + A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions + about the configured webhooks is created. + webhook_secret (`str`, optional): + A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as + you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You + can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the + webhook endpoints are opened without any security. + + Example: + + ```python + import gradio as gr + from huggingface_hub import WebhooksServer, WebhookPayload + + with gr.Blocks() as ui: + ... + + app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") + + @app.add_webhook("/say_hello") + async def hello(payload: WebhookPayload): + return {"message": "hello"} + + app.launch() + ``` + """ + + def __new__(cls, *args, **kwargs) -> "WebhooksServer": + if not is_gradio_available(): + raise ImportError( + "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`" + " first." + ) + if not is_fastapi_available(): + raise ImportError( + "You must have `fastapi` installed to use `WebhooksServer`. Please run `pip install --upgrade fastapi`" + " first." + ) + return super().__new__(cls) + + def __init__( + self, + ui: Optional["gr.Blocks"] = None, + webhook_secret: Optional[str] = None, + ) -> None: + self._ui = ui + + self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET") + self.registered_webhooks: Dict[str, Callable] = {} + _warn_on_empty_secret(self.webhook_secret) + + def add_webhook(self, path: Optional[str] = None) -> Callable: + """ + Decorator to add a webhook to the [`WebhooksServer`] server. + + Args: + path (`str`, optional): + The URL path to register the webhook function. If not provided, the function name will be used as the + path. In any case, all webhooks are registered under `/webhooks`. + + Raises: + ValueError: If the provided path is already registered as a webhook. + + Example: + ```python + from huggingface_hub import WebhooksServer, WebhookPayload + + app = WebhooksServer() + + @app.add_webhook + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + app.launch() + ``` + """ + # Usage: directly as decorator. Example: `@app.add_webhook` + if callable(path): + # If path is a function, it means it was used as a decorator without arguments + return self.add_webhook()(path) + + # Usage: provide a path. Example: `@app.add_webhook(...)` + @wraps(FastAPI.post) + def _inner_post(*args, **kwargs): + func = args[0] + abs_path = f"/webhooks/{(path or func.__name__).strip('/')}" + if abs_path in self.registered_webhooks: + raise ValueError(f"Webhook {abs_path} already exists.") + self.registered_webhooks[abs_path] = func + + return _inner_post + + def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None: + """Launch the Gradio app and register webhooks to the underlying FastAPI server. + + Input parameters are forwarded to Gradio when launching the app. + """ + ui = self._ui or self._get_default_ui() + + # Start Gradio App + # - as non-blocking so that webhooks can be added afterwards + # - as shared if launch locally (to debug webhooks) + launch_kwargs.setdefault("share", _is_local) + self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs) + + # Register webhooks to FastAPI app + for path, func in self.registered_webhooks.items(): + # Add secret check if required + if self.webhook_secret is not None: + func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret) + + # Add route to FastAPI app + self.fastapi_app.post(path)(func) + + # Print instructions and block main thread + space_host = os.environ.get("SPACE_HOST") + url = "https://" + space_host if space_host is not None else (ui.share_url or ui.local_url) + if url is None: + raise ValueError("Cannot find the URL of the app. Please provide a valid `ui` or update `gradio` version.") + url = url.strip("/") + message = "\nWebhooks are correctly setup and ready to use:" + message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks) + message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks." + print(message) + + if not prevent_thread_lock: + ui.block_thread() + + def _get_default_ui(self) -> "gr.Blocks": + """Default UI if not provided (lists webhooks and provides basic instructions).""" + import gradio as gr + + with gr.Blocks() as ui: + gr.Markdown("# This is an app to process 🤗 Webhooks") + gr.Markdown( + "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on" + " specific repos or to all repos belonging to particular set of users/organizations (not just your" + " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to" + " know more about webhooks on the Huggingface Hub." + ) + gr.Markdown( + f"{len(self.registered_webhooks)} webhook(s) are registered:" + + "\n\n" + + "\n ".join( + f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})" + for webhook_path, webhook in self.registered_webhooks.items() + ) + ) + gr.Markdown( + "Go to https://huggingface.co/settings/webhooks to setup your webhooks." + + "\nYou app is running locally. Please look at the logs to check the full URL you need to set." + if _is_local + else ( + "\nThis app is running on a Space. You can find the corresponding URL in the options menu" + " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'." + ) + ) + return ui + + +@experimental +def webhook_endpoint(path: Optional[str] = None) -> Callable: + """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint. + + This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret), + you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using + this decorator multiple times. + + Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your + server and deploy it on a Space. + + + + `webhook_endpoint` is experimental. Its API is subject to change in the future. + + + + + + You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`). + + + + Args: + path (`str`, optional): + The URL path to register the webhook function. If not provided, the function name will be used as the path. + In any case, all webhooks are registered under `/webhooks`. + + Examples: + The default usage is to register a function as a webhook endpoint. The function name will be used as the path. + The server will be started automatically at exit (i.e. at the end of the script). + + ```python + from huggingface_hub import webhook_endpoint, WebhookPayload + + @webhook_endpoint + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + # Server is automatically started at the end of the script. + ``` + + Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you + are running it in a notebook. + + ```python + from huggingface_hub import webhook_endpoint, WebhookPayload + + @webhook_endpoint + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + # Start the server manually + trigger_training.launch() + ``` + """ + if callable(path): + # If path is a function, it means it was used as a decorator without arguments + return webhook_endpoint()(path) + + @wraps(WebhooksServer.add_webhook) + def _inner(func: Callable) -> Callable: + app = _get_global_app() + app.add_webhook(path)(func) + if len(app.registered_webhooks) == 1: + # Register `app.launch` to run at exit (only once) + atexit.register(app.launch) + + @wraps(app.launch) + def _launch_now(): + # Run the app directly (without waiting atexit) + atexit.unregister(app.launch) + app.launch() + + func.launch = _launch_now # type: ignore + return func + + return _inner + + +def _get_global_app() -> WebhooksServer: + global _global_app + if _global_app is None: + _global_app = WebhooksServer() + return _global_app + + +def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None: + if webhook_secret is None: + print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.") + print( + "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: " + "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`" + ) + print( + "For more details about webhook secrets, please refer to" + " https://huggingface.co/docs/hub/webhooks#webhook-secret." + ) + else: + print("Webhook secret is correctly defined.") + + +def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str: + """Returns the anchor to a given webhook in the docs (experimental)""" + return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post" + + +def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable: + """Wraps a webhook function to check the webhook secret before calling the function. + + This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route + parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request` + object (and hence the headers). A far cleaner solution would be to use a middleware. However, since + `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by + Gradio internals (and not by us), we cannot add a middleware. + + This method is called only when a secret has been defined by the user. If a request is sent without the + "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect, + the function will return a 403 error (forbidden). + + Inspired by https://stackoverflow.com/a/33112180. + """ + initial_sig = inspect.signature(func) + + @wraps(func) + async def _protected_func(request: Request, **kwargs): + request_secret = request.headers.get("x-webhook-secret") + if request_secret is None: + return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401) + if request_secret != webhook_secret: + return JSONResponse({"error": "Invalid webhook secret."}, status_code=403) + + # Inject `request` in kwargs if required + if "request" in initial_sig.parameters: + kwargs["request"] = request + + # Handle both sync and async routes + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) + + # Update signature to include request + if "request" not in initial_sig.parameters: + _protected_func.__signature__ = initial_sig.replace( # type: ignore + parameters=( + inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request), + ) + + tuple(initial_sig.parameters.values()) + ) + + # Return protected route + return _protected_func diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1a8d793b89e16e5fa46ec5d420ec96fe1d72fe --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import _SubParsersAction + + +class BaseHuggingfaceCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: _SubParsersAction): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/auth.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/auth.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1072e718e427a5fb860a0e6d990ac6c0b0be2f5 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/auth.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/cache.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29cc21ecfeb02873f13534c7e98e9109ea229132 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/download.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/download.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3e5df70bdfba93a85ab9e5f543a25ce749a0ad0 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/download.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/hf.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/hf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f7bb94e420910edf611a265d31e50b670ebf08 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/hf.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/jobs.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/jobs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26f631a7d2f5fba477ec283fb0d041e2fbd93615 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/jobs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/lfs.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/lfs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d90acb218af2208d7bfa0b0075d412072a9473 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/lfs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/repo.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/repo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7864584a05ada14d88161ae3548abe94adef946 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/repo.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/repo_files.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/repo_files.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da84f5b87f4c003d7947d52813b52fc8e3b0eb14 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/repo_files.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/system.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/system.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60dad85d2b5004bf2098a4d7d04a48f800a7b7c2 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/system.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/_cli_utils.py b/phivenv/Lib/site-packages/huggingface_hub/cli/_cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bd56ad6896db2a257323e022896940c0ba0d68d3 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/_cli_utils.py @@ -0,0 +1,69 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a utility for good-looking prints.""" + +import os +from typing import List, Union + + +class ANSI: + """ + Helper for en.wikipedia.org/wiki/ANSI_escape_code + """ + + _bold = "\u001b[1m" + _gray = "\u001b[90m" + _red = "\u001b[31m" + _reset = "\u001b[0m" + _yellow = "\u001b[33m" + + @classmethod + def bold(cls, s: str) -> str: + return cls._format(s, cls._bold) + + @classmethod + def gray(cls, s: str) -> str: + return cls._format(s, cls._gray) + + @classmethod + def red(cls, s: str) -> str: + return cls._format(s, cls._bold + cls._red) + + @classmethod + def yellow(cls, s: str) -> str: + return cls._format(s, cls._yellow) + + @classmethod + def _format(cls, s: str, code: str) -> str: + if os.environ.get("NO_COLOR"): + # See https://no-color.org/ + return s + return f"{code}{s}{cls._reset}" + + +def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: + """ + Inspired by: + + - stackoverflow.com/a/8356620/593036 + - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data + """ + col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] + row_format = ("{{:{}}} " * len(headers)).format(*col_widths) + lines = [] + lines.append(row_format.format(*headers)) + lines.append(row_format.format(*["-" * w for w in col_widths])) + for row in rows: + lines.append(row_format.format(*row)) + return "\n".join(lines) diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/auth.py b/phivenv/Lib/site-packages/huggingface_hub/cli/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c5952ea8aca76bf1c2e75a39e23ac15a913060 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/auth.py @@ -0,0 +1,213 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to authenticate to the Hugging Face Hub and interact with your repositories. + +Usage: + # login and save token locally. + hf auth login --token=hf_*** --add-to-git-credential + + # switch between tokens + hf auth switch + + # list all tokens + hf auth list + + # logout from all tokens + hf auth logout + + # check which account you are logged in as + hf auth whoami +""" + +from argparse import _SubParsersAction +from typing import List, Optional + +from requests.exceptions import HTTPError + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.constants import ENDPOINT +from huggingface_hub.hf_api import HfApi + +from .._login import auth_list, auth_switch, login, logout +from ..utils import get_stored_tokens, get_token, logging +from ._cli_utils import ANSI + + +logger = logging.get_logger(__name__) + +try: + from InquirerPy import inquirer + from InquirerPy.base.control import Choice + + _inquirer_py_available = True +except ImportError: + _inquirer_py_available = False + + +class AuthCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + # Create the main 'auth' command + auth_parser = parser.add_parser("auth", help="Manage authentication (login, logout, etc.).") + auth_subparsers = auth_parser.add_subparsers(help="Authentication subcommands") + + # Show help if no subcommand is provided + auth_parser.set_defaults(func=lambda args: auth_parser.print_help()) + + # Add 'login' as a subcommand of 'auth' + login_parser = auth_subparsers.add_parser( + "login", help="Log in using a token from huggingface.co/settings/tokens" + ) + login_parser.add_argument( + "--token", + type=str, + help="Token generated from https://huggingface.co/settings/tokens", + ) + login_parser.add_argument( + "--add-to-git-credential", + action="store_true", + help="Optional: Save token to git credential helper.", + ) + login_parser.set_defaults(func=lambda args: AuthLogin(args)) + + # Add 'logout' as a subcommand of 'auth' + logout_parser = auth_subparsers.add_parser("logout", help="Log out") + logout_parser.add_argument( + "--token-name", + type=str, + help="Optional: Name of the access token to log out from.", + ) + logout_parser.set_defaults(func=lambda args: AuthLogout(args)) + + # Add 'whoami' as a subcommand of 'auth' + whoami_parser = auth_subparsers.add_parser( + "whoami", help="Find out which huggingface.co account you are logged in as." + ) + whoami_parser.set_defaults(func=lambda args: AuthWhoami(args)) + + # Existing subcommands + auth_switch_parser = auth_subparsers.add_parser("switch", help="Switch between access tokens") + auth_switch_parser.add_argument( + "--token-name", + type=str, + help="Optional: Name of the access token to switch to.", + ) + auth_switch_parser.add_argument( + "--add-to-git-credential", + action="store_true", + help="Optional: Save token to git credential helper.", + ) + auth_switch_parser.set_defaults(func=lambda args: AuthSwitch(args)) + + auth_list_parser = auth_subparsers.add_parser("list", help="List all stored access tokens") + auth_list_parser.set_defaults(func=lambda args: AuthList(args)) + + +class BaseAuthCommand: + def __init__(self, args): + self.args = args + self._api = HfApi() + + +class AuthLogin(BaseAuthCommand): + def run(self): + logging.set_verbosity_info() + login( + token=self.args.token, + add_to_git_credential=self.args.add_to_git_credential, + ) + + +class AuthLogout(BaseAuthCommand): + def run(self): + logging.set_verbosity_info() + logout(token_name=self.args.token_name) + + +class AuthSwitch(BaseAuthCommand): + def run(self): + logging.set_verbosity_info() + token_name = self.args.token_name + if token_name is None: + token_name = self._select_token_name() + + if token_name is None: + print("No token name provided. Aborting.") + exit() + auth_switch(token_name, add_to_git_credential=self.args.add_to_git_credential) + + def _select_token_name(self) -> Optional[str]: + token_names = list(get_stored_tokens().keys()) + + if not token_names: + logger.error("No stored tokens found. Please login first.") + return None + + if _inquirer_py_available: + return self._select_token_name_tui(token_names) + # if inquirer is not available, use a simpler terminal UI + print("Available stored tokens:") + for i, token_name in enumerate(token_names, 1): + print(f"{i}. {token_name}") + while True: + try: + choice = input("Enter the number of the token to switch to (or 'q' to quit): ") + if choice.lower() == "q": + return None + index = int(choice) - 1 + if 0 <= index < len(token_names): + return token_names[index] + else: + print("Invalid selection. Please try again.") + except ValueError: + print("Invalid input. Please enter a number or 'q' to quit.") + + def _select_token_name_tui(self, token_names: List[str]) -> Optional[str]: + choices = [Choice(token_name, name=token_name) for token_name in token_names] + try: + return inquirer.select( + message="Select a token to switch to:", + choices=choices, + default=None, + ).execute() + except KeyboardInterrupt: + logger.info("Token selection cancelled.") + return None + + +class AuthList(BaseAuthCommand): + def run(self): + logging.set_verbosity_info() + auth_list() + + +class AuthWhoami(BaseAuthCommand): + def run(self): + token = get_token() + if token is None: + print("Not logged in") + exit() + try: + info = self._api.whoami(token) + print(info["name"]) + orgs = [org["name"] for org in info["orgs"]] + if orgs: + print(ANSI.bold("orgs: "), ",".join(orgs)) + + if ENDPOINT != "https://huggingface.co": + print(f"Authenticated through private endpoint: {ENDPOINT}") + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/cache.py b/phivenv/Lib/site-packages/huggingface_hub/cli/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b93ceda874491ea5b377e7d74c4e47b2bc4426 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/cache.py @@ -0,0 +1,409 @@ +# coding=utf-8 +# Copyright 2025-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains the 'hf cache' command group with 'scan' and 'delete' subcommands.""" + +import os +import time +from argparse import Namespace, _SubParsersAction +from functools import wraps +from tempfile import mkstemp +from typing import Any, Callable, Iterable, List, Literal, Optional, Union + +from ..utils import ( + CachedRepoInfo, + CachedRevisionInfo, + CacheNotFound, + HFCacheInfo, + scan_cache_dir, +) +from . import BaseHuggingfaceCLICommand +from ._cli_utils import ANSI, tabulate + + +# --- DELETE helpers (from delete_cache.py) --- +try: + from InquirerPy import inquirer + from InquirerPy.base.control import Choice + from InquirerPy.separator import Separator + + _inquirer_py_available = True +except ImportError: + _inquirer_py_available = False + +SortingOption_T = Literal["alphabetical", "lastUpdated", "lastUsed", "size"] +_CANCEL_DELETION_STR = "CANCEL_DELETION" + + +def require_inquirer_py(fn: Callable) -> Callable: + @wraps(fn) + def _inner(*args, **kwargs): + if not _inquirer_py_available: + raise ImportError( + "The 'cache delete' command requires extra dependencies for the TUI.\n" + "Please run 'pip install huggingface_hub[cli]' to install them.\n" + "Otherwise, disable TUI using the '--disable-tui' flag." + ) + return fn(*args, **kwargs) + + return _inner + + +class CacheCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + cache_parser = parser.add_parser("cache", help="Manage local cache directory.") + cache_subparsers = cache_parser.add_subparsers(dest="cache_command", help="Cache subcommands") + + # Show help if no subcommand is provided + cache_parser.set_defaults(func=lambda args: cache_parser.print_help()) + + # Scan subcommand + scan_parser = cache_subparsers.add_parser("scan", help="Scan cache directory.") + scan_parser.add_argument( + "--dir", + type=str, + default=None, + help="cache directory to scan (optional). Default to the default HuggingFace cache.", + ) + scan_parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="show a more verbose output", + ) + scan_parser.set_defaults(func=CacheCommand, cache_command="scan") + # Delete subcommand + delete_parser = cache_subparsers.add_parser("delete", help="Delete revisions from the cache directory.") + delete_parser.add_argument( + "--dir", + type=str, + default=None, + help="cache directory (optional). Default to the default HuggingFace cache.", + ) + delete_parser.add_argument( + "--disable-tui", + action="store_true", + help=( + "Disable Terminal User Interface (TUI) mode. Useful if your platform/terminal doesn't support the multiselect menu." + ), + ) + delete_parser.add_argument( + "--sort", + nargs="?", + choices=["alphabetical", "lastUpdated", "lastUsed", "size"], + help=( + "Sort repositories by the specified criteria. Options: " + "'alphabetical' (A-Z), " + "'lastUpdated' (newest first), " + "'lastUsed' (most recent first), " + "'size' (largest first)." + ), + ) + delete_parser.set_defaults(func=CacheCommand, cache_command="delete") + + def __init__(self, args: Namespace) -> None: + self.args = args + self.verbosity: int = getattr(args, "verbose", 0) + self.cache_dir: Optional[str] = getattr(args, "dir", None) + self.disable_tui: bool = getattr(args, "disable_tui", False) + self.sort_by: Optional[SortingOption_T] = getattr(args, "sort", None) + self.cache_command: Optional[str] = getattr(args, "cache_command", None) + + def run(self): + if self.cache_command == "scan": + self._run_scan() + elif self.cache_command == "delete": + self._run_delete() + else: + print("Please specify a cache subcommand (scan or delete). Use -h for help.") + + def _run_scan(self): + try: + t0 = time.time() + hf_cache_info = scan_cache_dir(self.cache_dir) + t1 = time.time() + except CacheNotFound as exc: + cache_dir = exc.cache_dir + print(f"Cache directory not found: {cache_dir}") + return + print(get_table(hf_cache_info, verbosity=self.verbosity)) + print( + f"\nDone in {round(t1 - t0, 1)}s. Scanned {len(hf_cache_info.repos)} repo(s)" + f" for a total of {ANSI.red(hf_cache_info.size_on_disk_str)}." + ) + if len(hf_cache_info.warnings) > 0: + message = f"Got {len(hf_cache_info.warnings)} warning(s) while scanning." + if self.verbosity >= 3: + print(ANSI.gray(message)) + for warning in hf_cache_info.warnings: + print(ANSI.gray(warning)) + else: + print(ANSI.gray(message + " Use -vvv to print details.")) + + def _run_delete(self): + hf_cache_info = scan_cache_dir(self.cache_dir) + if self.disable_tui: + selected_hashes = _manual_review_no_tui(hf_cache_info, preselected=[], sort_by=self.sort_by) + else: + selected_hashes = _manual_review_tui(hf_cache_info, preselected=[], sort_by=self.sort_by) + if len(selected_hashes) > 0 and _CANCEL_DELETION_STR not in selected_hashes: + confirm_message = _get_expectations_str(hf_cache_info, selected_hashes) + " Confirm deletion ?" + if self.disable_tui: + confirmed = _ask_for_confirmation_no_tui(confirm_message) + else: + confirmed = _ask_for_confirmation_tui(confirm_message) + if confirmed: + strategy = hf_cache_info.delete_revisions(*selected_hashes) + print("Start deletion.") + strategy.execute() + print( + f"Done. Deleted {len(strategy.repos)} repo(s) and" + f" {len(strategy.snapshots)} revision(s) for a total of" + f" {strategy.expected_freed_size_str}." + ) + return + print("Deletion is cancelled. Do nothing.") + + +def get_table(hf_cache_info: HFCacheInfo, *, verbosity: int = 0) -> str: + if verbosity == 0: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + "{:>12}".format(repo.size_on_disk_str), + repo.nb_files, + repo.last_accessed_str, + repo.last_modified_str, + ", ".join(sorted(repo.refs)), + str(repo.repo_path), + ] + for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "SIZE ON DISK", + "NB FILES", + "LAST_ACCESSED", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) + else: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + revision.commit_hash, + "{:>12}".format(revision.size_on_disk_str), + revision.nb_files, + revision.last_modified_str, + ", ".join(sorted(revision.refs)), + str(revision.snapshot_path), + ] + for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) + for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "REVISION", + "SIZE ON DISK", + "NB FILES", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) + + +def _get_repo_sorting_key(repo: CachedRepoInfo, sort_by: Optional[SortingOption_T] = None): + if sort_by == "alphabetical": + return (repo.repo_type, repo.repo_id.lower()) + elif sort_by == "lastUpdated": + return -max(rev.last_modified for rev in repo.revisions) + elif sort_by == "lastUsed": + return -repo.last_accessed + elif sort_by == "size": + return -repo.size_on_disk + else: + return (repo.repo_type, repo.repo_id) + + +@require_inquirer_py +def _manual_review_tui( + hf_cache_info: HFCacheInfo, preselected: List[str], sort_by: Optional[SortingOption_T] = None +) -> List[str]: + choices = _get_tui_choices_from_scan(repos=hf_cache_info.repos, preselected=preselected, sort_by=sort_by) + checkbox = inquirer.checkbox( + message="Select revisions to delete:", + choices=choices, + cycle=False, + height=100, + instruction=_get_expectations_str( + hf_cache_info, selected_hashes=[c.value for c in choices if isinstance(c, Choice) and c.enabled] + ), + long_instruction="Press to select, to validate and to quit without modification.", + transformer=lambda result: f"{len(result)} revision(s) selected.", + ) + + def _update_expectations(_): + checkbox._instruction = _get_expectations_str( + hf_cache_info, + selected_hashes=[choice["value"] for choice in checkbox.content_control.choices if choice["enabled"]], + ) + + checkbox.kb_func_lookup["toggle"].append({"func": _update_expectations}) + try: + return checkbox.execute() + except KeyboardInterrupt: + return [] + + +@require_inquirer_py +def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool: + return inquirer.confirm(message, default=default).execute() + + +def _get_tui_choices_from_scan( + repos: Iterable[CachedRepoInfo], preselected: List[str], sort_by: Optional[SortingOption_T] = None +) -> List: + choices: List[Union["Choice", "Separator"]] = [] + choices.append( + Choice( + _CANCEL_DELETION_STR, name="None of the following (if selected, nothing will be deleted).", enabled=False + ) + ) + sorted_repos = sorted(repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) + for repo in sorted_repos: + choices.append( + Separator( + f"\n{repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}, used {repo.last_accessed_str})" + ) + ) + for revision in sorted(repo.revisions, key=_revision_sorting_order): + choices.append( + Choice( + revision.commit_hash, + name=( + f"{revision.commit_hash[:8]}: {', '.join(sorted(revision.refs)) or '(detached)'} # modified {revision.last_modified_str}" + ), + enabled=revision.commit_hash in preselected, + ) + ) + return choices + + +def _manual_review_no_tui( + hf_cache_info: HFCacheInfo, preselected: List[str], sort_by: Optional[SortingOption_T] = None +) -> List[str]: + fd, tmp_path = mkstemp(suffix=".txt") + os.close(fd) + lines = [] + sorted_repos = sorted(hf_cache_info.repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) + for repo in sorted_repos: + lines.append( + f"\n# {repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}, used {repo.last_accessed_str})" + ) + for revision in sorted(repo.revisions, key=_revision_sorting_order): + lines.append( + f"{'' if revision.commit_hash in preselected else '#'} {revision.commit_hash} # Refs: {', '.join(sorted(revision.refs)) or '(detached)'} # modified {revision.last_modified_str}" + ) + with open(tmp_path, "w") as f: + f.write(_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS) + f.write("\n".join(lines)) + instructions = f""" + TUI is disabled. In order to select which revisions you want to delete, please edit + the following file using the text editor of your choice. Instructions for manual + editing are located at the beginning of the file. Edit the file, save it and confirm + to continue. + File to edit: {ANSI.bold(tmp_path)} + """ + print("\n".join(line.strip() for line in instructions.strip().split("\n"))) + while True: + selected_hashes = _read_manual_review_tmp_file(tmp_path) + if _ask_for_confirmation_no_tui( + _get_expectations_str(hf_cache_info, selected_hashes) + " Continue ?", default=False + ): + break + os.remove(tmp_path) + return sorted(selected_hashes) + + +def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool: + YES = ("y", "yes", "1") + NO = ("n", "no", "0") + DEFAULT = "" + ALL = YES + NO + (DEFAULT,) + full_message = message + (" (Y/n) " if default else " (y/N) ") + while True: + answer = input(full_message).lower() + if answer == DEFAULT: + return default + if answer in YES: + return True + if answer in NO: + return False + print(f"Invalid input. Must be one of {ALL}") + + +def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str: + if _CANCEL_DELETION_STR in selected_hashes: + return "Nothing will be deleted." + strategy = hf_cache_info.delete_revisions(*selected_hashes) + return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}." + + +def _read_manual_review_tmp_file(tmp_path: str) -> List[str]: + with open(tmp_path) as f: + content = f.read() + lines = [line.strip() for line in content.split("\n")] + selected_lines = [line for line in lines if not line.startswith("#")] + selected_hashes = [line.split("#")[0].strip() for line in selected_lines] + return [hash for hash in selected_hashes if len(hash) > 0] + + +_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS = f""" +# INSTRUCTIONS +# ------------ +# This is a temporary file created by running `hf cache delete --disable-tui`. It contains a set of revisions that can be deleted from your local cache directory. +# +# Please manually review the revisions you want to delete: +# - Revision hashes can be commented out with '#'. +# - Only non-commented revisions in this file will be deleted. +# - Revision hashes that are removed from this file are ignored as well. +# - If `{_CANCEL_DELETION_STR}` line is uncommented, the all cache deletion is cancelled and no changes will be applied. +# +# Once you've manually reviewed this file, please confirm deletion in the terminal. This file will be automatically removed once done. +# ------------ + +# KILL SWITCH +# ------------ +# Un-comment following line to completely cancel the deletion process +# {_CANCEL_DELETION_STR} +# ------------ + +# REVISIONS +# ------------ +""".strip() + + +def _revision_sorting_order(revision: CachedRevisionInfo) -> Any: + return revision.last_modified diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/download.py b/phivenv/Lib/site-packages/huggingface_hub/cli/download.py new file mode 100644 index 0000000000000000000000000000000000000000..3e59233da1d4386c8f0018a65c8d7f8d91b34612 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/download.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 202-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to download files from the Hub with the CLI. + +Usage: + hf download --help + + # Download file + hf download gpt2 config.json + + # Download entire repo + hf download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78 + + # Download repo with filters + hf download gpt2 --include="*.safetensors" + + # Download with token + hf download Wauplin/private-model --token=hf_*** + + # Download quietly (no progress bar, no warnings, only the returned path) + hf download gpt2 config.json --quiet + + # Download to local dir + hf download gpt2 --local-dir=./models/gpt2 +""" + +import warnings +from argparse import Namespace, _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub._snapshot_download import snapshot_download +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.utils import disable_progress_bars, enable_progress_bars + + +logger = logging.get_logger(__name__) + + +class DownloadCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + download_parser = parser.add_parser("download", help="Download files from the Hub") + download_parser.add_argument( + "repo_id", type=str, help="ID of the repo to download from (e.g. `username/repo-name`)." + ) + download_parser.add_argument( + "filenames", type=str, nargs="*", help="Files to download (e.g. `config.json`, `data/metadata.jsonl`)." + ) + download_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Type of repo to download from (defaults to 'model').", + ) + download_parser.add_argument( + "--revision", + type=str, + help="An optional Git revision id which can be a branch name, a tag, or a commit hash.", + ) + download_parser.add_argument( + "--include", nargs="*", type=str, help="Glob patterns to match files to download." + ) + download_parser.add_argument( + "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to download." + ) + download_parser.add_argument( + "--cache-dir", type=str, help="Path to the directory where to save the downloaded files." + ) + download_parser.add_argument( + "--local-dir", + type=str, + help=( + "If set, the downloaded file will be placed under this directory. Check out" + " https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder for more" + " details." + ), + ) + download_parser.add_argument( + "--force-download", + action="store_true", + help="If True, the files will be downloaded even if they are already cached.", + ) + download_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + download_parser.add_argument( + "--quiet", + action="store_true", + help="If True, progress bars are disabled and only the path to the download files is printed.", + ) + download_parser.add_argument( + "--max-workers", + type=int, + default=8, + help="Maximum number of workers to use for downloading files. Default is 8.", + ) + download_parser.set_defaults(func=DownloadCommand) + + def __init__(self, args: Namespace) -> None: + self.token = args.token + self.repo_id: str = args.repo_id + self.filenames: List[str] = args.filenames + self.repo_type: str = args.repo_type + self.revision: Optional[str] = args.revision + self.include: Optional[List[str]] = args.include + self.exclude: Optional[List[str]] = args.exclude + self.cache_dir: Optional[str] = args.cache_dir + self.local_dir: Optional[str] = args.local_dir + self.force_download: bool = args.force_download + self.quiet: bool = args.quiet + self.max_workers: int = args.max_workers + + def run(self) -> None: + if self.quiet: + disable_progress_bars() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + print(self._download()) # Print path to downloaded files + enable_progress_bars() + else: + logging.set_verbosity_info() + print(self._download()) # Print path to downloaded files + logging.set_verbosity_warning() + + def _download(self) -> str: + # Warn user if patterns are ignored + if len(self.filenames) > 0: + if self.include is not None and len(self.include) > 0: + warnings.warn("Ignoring `--include` since filenames have being explicitly set.") + if self.exclude is not None and len(self.exclude) > 0: + warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.") + + # Single file to download: use `hf_hub_download` + if len(self.filenames) == 1: + return hf_hub_download( + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + filename=self.filenames[0], + cache_dir=self.cache_dir, + force_download=self.force_download, + token=self.token, + local_dir=self.local_dir, + library_name="hf", + ) + + # Otherwise: use `snapshot_download` to ensure all files comes from same revision + elif len(self.filenames) == 0: + allow_patterns = self.include + ignore_patterns = self.exclude + else: + allow_patterns = self.filenames + ignore_patterns = None + + return snapshot_download( + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + force_download=self.force_download, + cache_dir=self.cache_dir, + token=self.token, + local_dir=self.local_dir, + library_name="hf", + max_workers=self.max_workers, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/hf.py b/phivenv/Lib/site-packages/huggingface_hub/cli/hf.py new file mode 100644 index 0000000000000000000000000000000000000000..2587918b294b427fb8f3e0f990884826b66514a8 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/hf.py @@ -0,0 +1,63 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from huggingface_hub.cli.auth import AuthCommands +from huggingface_hub.cli.cache import CacheCommand +from huggingface_hub.cli.download import DownloadCommand +from huggingface_hub.cli.jobs import JobsCommands +from huggingface_hub.cli.lfs import LfsCommands +from huggingface_hub.cli.repo import RepoCommands +from huggingface_hub.cli.repo_files import RepoFilesCommand +from huggingface_hub.cli.system import EnvironmentCommand, VersionCommand +from huggingface_hub.cli.upload import UploadCommand +from huggingface_hub.cli.upload_large_folder import UploadLargeFolderCommand + + +def main(): + parser = ArgumentParser("hf", usage="hf []") + commands_parser = parser.add_subparsers(help="hf command helpers") + + # Register commands + AuthCommands.register_subcommand(commands_parser) + CacheCommand.register_subcommand(commands_parser) + DownloadCommand.register_subcommand(commands_parser) + JobsCommands.register_subcommand(commands_parser) + RepoCommands.register_subcommand(commands_parser) + RepoFilesCommand.register_subcommand(commands_parser) + UploadCommand.register_subcommand(commands_parser) + UploadLargeFolderCommand.register_subcommand(commands_parser) + + # System commands + EnvironmentCommand.register_subcommand(commands_parser) + VersionCommand.register_subcommand(commands_parser) + + # LFS commands (hidden in --help) + LfsCommands.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + service = args.func(args) + if service is not None: + service.run() + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/jobs.py b/phivenv/Lib/site-packages/huggingface_hub/cli/jobs.py new file mode 100644 index 0000000000000000000000000000000000000000..bd88648c7901b08ac653a669eef12ccb7c5b634f --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/jobs.py @@ -0,0 +1,550 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with jobs on the Hugging Face Hub. + +Usage: + # run a job + hf jobs run + + # List running or completed jobs + hf jobs ps [-a] [-f key=value] [--format TEMPLATE] + + # Stream logs from a job + hf jobs logs + + # Inspect detailed information about a job + hf jobs inspect + + # Cancel a running job + hf jobs cancel +""" + +import json +import os +import re +from argparse import Namespace, _SubParsersAction +from dataclasses import asdict +from pathlib import Path +from typing import Dict, List, Optional, Union + +import requests + +from huggingface_hub import HfApi, SpaceHardware, get_token +from huggingface_hub.utils import logging +from huggingface_hub.utils._dotenv import load_dotenv + +from . import BaseHuggingfaceCLICommand + + +logger = logging.get_logger(__name__) + +SUGGESTED_FLAVORS = [item.value for item in SpaceHardware if item.value != "zero-a10g"] + + +class JobsCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + jobs_parser = parser.add_parser("jobs", help="Run and manage Jobs on the Hub.") + jobs_subparsers = jobs_parser.add_subparsers(help="huggingface.co jobs related commands") + + # Show help if no subcommand is provided + jobs_parser.set_defaults(func=lambda args: jobs_parser.print_help()) + + # Register commands + InspectCommand.register_subcommand(jobs_subparsers) + LogsCommand.register_subcommand(jobs_subparsers) + PsCommand.register_subcommand(jobs_subparsers) + RunCommand.register_subcommand(jobs_subparsers) + CancelCommand.register_subcommand(jobs_subparsers) + UvCommand.register_subcommand(jobs_subparsers) + + +class RunCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("run", help="Run a Job") + run_parser.add_argument("image", type=str, help="The Docker image to use.") + run_parser.add_argument("-e", "--env", action="append", help="Set environment variables. E.g. --env ENV=value") + run_parser.add_argument( + "-s", + "--secrets", + action="append", + help=( + "Set secret environment variables. E.g. --secrets SECRET=value " + "or `--secrets HF_TOKEN` to pass your Hugging Face token." + ), + ) + run_parser.add_argument("--env-file", type=str, help="Read in a file of environment variables.") + run_parser.add_argument("--secrets-file", type=str, help="Read in a file of secret environment variables.") + run_parser.add_argument( + "--flavor", + type=str, + help=f"Flavor for the hardware, as in HF Spaces. Defaults to `cpu-basic`. Possible values: {', '.join(SUGGESTED_FLAVORS)}.", + ) + run_parser.add_argument( + "--timeout", + type=str, + help="Max duration: int/float with s (seconds, default), m (minutes), h (hours) or d (days).", + ) + run_parser.add_argument( + "-d", + "--detach", + action="store_true", + help="Run the Job in the background and print the Job ID.", + ) + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the Job will be created. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", + type=str, + help="A User Access Token generated from https://huggingface.co/settings/tokens", + ) + run_parser.add_argument("command", nargs="...", help="The command to run.") + run_parser.set_defaults(func=RunCommand) + + def __init__(self, args: Namespace) -> None: + self.image: str = args.image + self.command: List[str] = args.command + self.env: dict[str, Optional[str]] = {} + if args.env_file: + self.env.update(load_dotenv(Path(args.env_file).read_text(), environ=os.environ.copy())) + for env_value in args.env or []: + self.env.update(load_dotenv(env_value, environ=os.environ.copy())) + self.secrets: dict[str, Optional[str]] = {} + extended_environ = _get_extended_environ() + if args.secrets_file: + self.secrets.update(load_dotenv(Path(args.secrets_file).read_text(), environ=extended_environ)) + for secret in args.secrets or []: + self.secrets.update(load_dotenv(secret, environ=extended_environ)) + self.flavor: Optional[SpaceHardware] = args.flavor + self.timeout: Optional[str] = args.timeout + self.detach: bool = args.detach + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + + def run(self) -> None: + api = HfApi(token=self.token) + job = api.run_job( + image=self.image, + command=self.command, + env=self.env, + secrets=self.secrets, + flavor=self.flavor, + timeout=self.timeout, + namespace=self.namespace, + ) + # Always print the job ID to the user + print(f"Job started with ID: {job.id}") + print(f"View at: {job.url}") + + if self.detach: + return + + # Now let's stream the logs + for log in api.fetch_job_logs(job_id=job.id): + print(log) + + +class LogsCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("logs", help="Fetch the logs of a Job") + run_parser.add_argument("job_id", type=str, help="Job ID") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the job is running. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + run_parser.set_defaults(func=LogsCommand) + + def __init__(self, args: Namespace) -> None: + self.job_id: str = args.job_id + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + + def run(self) -> None: + api = HfApi(token=self.token) + for log in api.fetch_job_logs(job_id=self.job_id, namespace=self.namespace): + print(log) + + +def _tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: + """ + Inspired by: + + - stackoverflow.com/a/8356620/593036 + - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data + """ + col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] + terminal_width = max(os.get_terminal_size().columns, len(headers) * 12) + while len(headers) + sum(col_widths) > terminal_width: + col_to_minimize = col_widths.index(max(col_widths)) + col_widths[col_to_minimize] //= 2 + if len(headers) + sum(col_widths) <= terminal_width: + col_widths[col_to_minimize] = terminal_width - sum(col_widths) - len(headers) + col_widths[col_to_minimize] + row_format = ("{{:{}}} " * len(headers)).format(*col_widths) + lines = [] + lines.append(row_format.format(*headers)) + lines.append(row_format.format(*["-" * w for w in col_widths])) + for row in rows: + row_format_args = [ + str(x)[: col_width - 3] + "..." if len(str(x)) > col_width else str(x) + for x, col_width in zip(row, col_widths) + ] + lines.append(row_format.format(*row_format_args)) + return "\n".join(lines) + + +class PsCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("ps", help="List Jobs") + run_parser.add_argument( + "-a", + "--all", + action="store_true", + help="Show all Jobs (default shows just running)", + ) + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace from where it lists the jobs. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", + type=str, + help="A User Access Token generated from https://huggingface.co/settings/tokens", + ) + # Add Docker-style filtering argument + run_parser.add_argument( + "-f", + "--filter", + action="append", + default=[], + help="Filter output based on conditions provided (format: key=value)", + ) + # Add option to format output + run_parser.add_argument( + "--format", + type=str, + help="Format output using a custom template", + ) + run_parser.set_defaults(func=PsCommand) + + def __init__(self, args: Namespace) -> None: + self.all: bool = args.all + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + self.format: Optional[str] = args.format + self.filters: Dict[str, str] = {} + + # Parse filter arguments (key=value pairs) + for f in args.filter: + if "=" in f: + key, value = f.split("=", 1) + self.filters[key.lower()] = value + else: + print(f"Warning: Ignoring invalid filter format '{f}'. Use key=value format.") + + def run(self) -> None: + """ + Fetch and display job information for the current user. + Uses Docker-style filtering with -f/--filter flag and key=value pairs. + """ + try: + api = HfApi(token=self.token) + + # Fetch jobs data + jobs = api.list_jobs(namespace=self.namespace) + + # Define table headers + table_headers = ["JOB ID", "IMAGE/SPACE", "COMMAND", "CREATED", "STATUS"] + + # Process jobs data + rows = [] + + for job in jobs: + # Extract job data for filtering + status = job.status.stage if job.status else "UNKNOWN" + + # Skip job if not all jobs should be shown and status doesn't match criteria + if not self.all and status not in ("RUNNING", "UPDATING"): + continue + + # Extract job ID + job_id = job.id + + # Extract image or space information + image_or_space = job.docker_image or "N/A" + + # Extract and format command + command = job.command or [] + command_str = " ".join(command) if command else "N/A" + + # Extract creation time + created_at = job.created_at or "N/A" + + # Create a dict with all job properties for filtering + job_properties = { + "id": job_id, + "image": image_or_space, + "status": status.lower(), + "command": command_str, + } + + # Check if job matches all filters + if not self._matches_filters(job_properties): + continue + + # Create row + rows.append([job_id, image_or_space, command_str, created_at, status]) + + # Handle empty results + if not rows: + filters_msg = "" + if self.filters: + filters_msg = f" matching filters: {', '.join([f'{k}={v}' for k, v in self.filters.items()])}" + + print(f"No jobs found{filters_msg}") + return + + # Apply custom format if provided or use default tabular format + self._print_output(rows, table_headers) + + except requests.RequestException as e: + print(f"Error fetching jobs data: {e}") + except (KeyError, ValueError, TypeError) as e: + print(f"Error processing jobs data: {e}") + except Exception as e: + print(f"Unexpected error - {type(e).__name__}: {e}") + + def _matches_filters(self, job_properties: Dict[str, str]) -> bool: + """Check if job matches all specified filters.""" + for key, pattern in self.filters.items(): + # Check if property exists + if key not in job_properties: + return False + + # Support pattern matching with wildcards + if "*" in pattern or "?" in pattern: + # Convert glob pattern to regex + regex_pattern = pattern.replace("*", ".*").replace("?", ".") + if not re.search(f"^{regex_pattern}$", job_properties[key], re.IGNORECASE): + return False + # Simple substring matching + elif pattern.lower() not in job_properties[key].lower(): + return False + + return True + + def _print_output(self, rows, headers): + """Print output according to the chosen format.""" + if self.format: + # Custom template formatting (simplified) + template = self.format + for row in rows: + line = template + for i, field in enumerate(["id", "image", "command", "created", "status"]): + placeholder = f"{{{{.{field}}}}}" + if placeholder in line: + line = line.replace(placeholder, str(row[i])) + print(line) + else: + # Default tabular format + print( + _tabulate( + rows, + headers=headers, + ) + ) + + +class InspectCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("inspect", help="Display detailed information on one or more Jobs") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the job is running. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + run_parser.add_argument("job_ids", nargs="...", help="The jobs to inspect") + run_parser.set_defaults(func=InspectCommand) + + def __init__(self, args: Namespace) -> None: + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + self.job_ids: List[str] = args.job_ids + + def run(self) -> None: + api = HfApi(token=self.token) + jobs = [api.inspect_job(job_id=job_id, namespace=self.namespace) for job_id in self.job_ids] + print(json.dumps([asdict(job) for job in jobs], indent=4, default=str)) + + +class CancelCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("cancel", help="Cancel a Job") + run_parser.add_argument("job_id", type=str, help="Job ID") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the job is running. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + run_parser.set_defaults(func=CancelCommand) + + def __init__(self, args: Namespace) -> None: + self.job_id: str = args.job_id + self.namespace = args.namespace + self.token: Optional[str] = args.token + + def run(self) -> None: + api = HfApi(token=self.token) + api.cancel_job(job_id=self.job_id, namespace=self.namespace) + + +class UvCommand(BaseHuggingfaceCLICommand): + """Run UV scripts on Hugging Face infrastructure.""" + + @staticmethod + def register_subcommand(parser): + """Register UV run subcommand.""" + uv_parser = parser.add_parser( + "uv", + help="Run UV scripts (Python with inline dependencies) on HF infrastructure", + ) + + subparsers = uv_parser.add_subparsers(dest="uv_command", help="UV commands", required=True) + + # Run command only + run_parser = subparsers.add_parser( + "run", + help="Run a UV script (local file or URL) on HF infrastructure", + ) + run_parser.add_argument("script", help="UV script to run (local file or URL)") + run_parser.add_argument("script_args", nargs="...", help="Arguments for the script", default=[]) + run_parser.add_argument("--image", type=str, help="Use a custom Docker image with `uv` installed.") + run_parser.add_argument( + "--repo", + help="Repository name for the script (creates ephemeral if not specified)", + ) + run_parser.add_argument( + "--flavor", + type=str, + help=f"Flavor for the hardware, as in HF Spaces. Defaults to `cpu-basic`. Possible values: {', '.join(SUGGESTED_FLAVORS)}.", + ) + run_parser.add_argument("-e", "--env", action="append", help="Environment variables") + run_parser.add_argument( + "-s", + "--secrets", + action="append", + help=( + "Set secret environment variables. E.g. --secrets SECRET=value " + "or `--secrets HF_TOKEN` to pass your Hugging Face token." + ), + ) + run_parser.add_argument("--env-file", type=str, help="Read in a file of environment variables.") + run_parser.add_argument( + "--secrets-file", + type=str, + help="Read in a file of secret environment variables.", + ) + run_parser.add_argument("--timeout", type=str, help="Max duration (e.g., 30s, 5m, 1h)") + run_parser.add_argument("-d", "--detach", action="store_true", help="Run in background") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the Job will be created. Defaults to the current user's namespace.", + ) + run_parser.add_argument("--token", type=str, help="HF token") + # UV options + run_parser.add_argument("--with", action="append", help="Run with the given packages installed", dest="with_") + run_parser.add_argument( + "-p", "--python", type=str, help="The Python interpreter to use for the run environment" + ) + run_parser.set_defaults(func=UvCommand) + + def __init__(self, args: Namespace) -> None: + """Initialize the command with parsed arguments.""" + self.script = args.script + self.script_args = args.script_args + self.dependencies = args.with_ + self.python = args.python + self.image = args.image + self.env: dict[str, Optional[str]] = {} + if args.env_file: + self.env.update(load_dotenv(Path(args.env_file).read_text(), environ=os.environ.copy())) + for env_value in args.env or []: + self.env.update(load_dotenv(env_value, environ=os.environ.copy())) + self.secrets: dict[str, Optional[str]] = {} + extended_environ = _get_extended_environ() + if args.secrets_file: + self.secrets.update(load_dotenv(Path(args.secrets_file).read_text(), environ=extended_environ)) + for secret in args.secrets or []: + self.secrets.update(load_dotenv(secret, environ=extended_environ)) + self.flavor: Optional[SpaceHardware] = args.flavor + self.timeout: Optional[str] = args.timeout + self.detach: bool = args.detach + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + self._repo = args.repo + + def run(self) -> None: + """Execute UV command.""" + logging.set_verbosity(logging.INFO) + api = HfApi(token=self.token) + job = api.run_uv_job( + script=self.script, + script_args=self.script_args, + dependencies=self.dependencies, + python=self.python, + image=self.image, + env=self.env, + secrets=self.secrets, + flavor=self.flavor, + timeout=self.timeout, + namespace=self.namespace, + _repo=self._repo, + ) + + # Always print the job ID to the user + print(f"Job started with ID: {job.id}") + print(f"View at: {job.url}") + + if self.detach: + return + + # Now let's stream the logs + for log in api.fetch_job_logs(job_id=job.id): + print(log) + + +def _get_extended_environ() -> Dict[str, str]: + extended_environ = os.environ.copy() + if (token := get_token()) is not None: + extended_environ["HF_TOKEN"] = token + return extended_environ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/lfs.py b/phivenv/Lib/site-packages/huggingface_hub/cli/lfs.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c5b900c816494c260f6c440843a2d83703fab5 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/lfs.py @@ -0,0 +1,198 @@ +""" +Implementation of a custom transfer agent for the transfer type "multipart" for +git-lfs. + +Inspired by: +github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py + +Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md + + +To launch debugger while developing: + +``` [lfs "customtransfer.multipart"] +path = /path/to/huggingface_hub/.env/bin/python args = -m debugpy --listen 5678 +--wait-for-client +/path/to/huggingface_hub/src/huggingface_hub/commands/huggingface_cli.py +lfs-multipart-upload ```""" + +import json +import os +import subprocess +import sys +from argparse import _SubParsersAction +from typing import Dict, List, Optional + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND + +from ..utils import get_session, hf_raise_for_status, logging +from ..utils._lfs import SliceFileObj + + +logger = logging.get_logger(__name__) + + +class LfsCommands(BaseHuggingfaceCLICommand): + """ + Implementation of a custom transfer agent for the transfer type "multipart" + for git-lfs. This lets users upload large files >5GB 🔥. Spec for LFS custom + transfer agent is: + https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md + + This introduces two commands to the CLI: + + 1. $ hf lfs-enable-largefiles + + This should be executed once for each model repo that contains a model file + >5GB. It's documented in the error message you get if you just try to git + push a 5GB file without having enabled it before. + + 2. $ hf lfs-multipart-upload + + This command is called by lfs directly and is not meant to be called by the + user. + """ + + @staticmethod + def register_subcommand(parser: _SubParsersAction): + enable_parser = parser.add_parser("lfs-enable-largefiles", add_help=False) + enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.") + enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args)) + + # Command will get called by git-lfs, do not call it directly. + upload_parser = parser.add_parser(LFS_MULTIPART_UPLOAD_COMMAND, add_help=False) + upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args)) + + +class LfsEnableCommand: + def __init__(self, args): + self.args = args + + def run(self): + local_path = os.path.abspath(self.args.path) + if not os.path.isdir(local_path): + print("This does not look like a valid git repo.") + exit(1) + subprocess.run( + "git config lfs.customtransfer.multipart.path hf".split(), + check=True, + cwd=local_path, + ) + subprocess.run( + f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(), + check=True, + cwd=local_path, + ) + print("Local repo set up for largefiles") + + +def write_msg(msg: Dict): + """Write out the message in Line delimited JSON.""" + msg_str = json.dumps(msg) + "\n" + sys.stdout.write(msg_str) + sys.stdout.flush() + + +def read_msg() -> Optional[Dict]: + """Read Line delimited JSON from stdin.""" + msg = json.loads(sys.stdin.readline().strip()) + + if "terminate" in (msg.get("type"), msg.get("event")): + # terminate message received + return None + + if msg.get("event") not in ("download", "upload"): + logger.critical("Received unexpected message") + sys.exit(1) + + return msg + + +class LfsUploadCommand: + def __init__(self, args) -> None: + self.args = args + + def run(self) -> None: + # Immediately after invoking a custom transfer process, git-lfs + # sends initiation data to the process over stdin. + # This tells the process useful information about the configuration. + init_msg = json.loads(sys.stdin.readline().strip()) + if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"): + write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}}) + sys.exit(1) + + # The transfer process should use the information it needs from the + # initiation structure, and also perform any one-off setup tasks it + # needs to do. It should then respond on stdout with a simple empty + # confirmation structure, as follows: + write_msg({}) + + # After the initiation exchange, git-lfs will send any number of + # transfer requests to the stdin of the transfer process, in a serial sequence. + while True: + msg = read_msg() + if msg is None: + # When all transfers have been processed, git-lfs will send + # a terminate event to the stdin of the transfer process. + # On receiving this message the transfer process should + # clean up and terminate. No response is expected. + sys.exit(0) + + oid = msg["oid"] + filepath = msg["path"] + completion_url = msg["action"]["href"] + header = msg["action"]["header"] + chunk_size = int(header.pop("chunk_size")) + presigned_urls: List[str] = list(header.values()) + + # Send a "started" progress event to allow other workers to start. + # Otherwise they're delayed until first "progress" event is reported, + # i.e. after the first 5GB by default (!) + write_msg( + { + "event": "progress", + "oid": oid, + "bytesSoFar": 1, + "bytesSinceLast": 0, + } + ) + + parts = [] + with open(filepath, "rb") as file: + for i, presigned_url in enumerate(presigned_urls): + with SliceFileObj( + file, + seek_from=i * chunk_size, + read_limit=chunk_size, + ) as data: + r = get_session().put(presigned_url, data=data) + hf_raise_for_status(r) + parts.append( + { + "etag": r.headers.get("etag"), + "partNumber": i + 1, + } + ) + # In order to support progress reporting while data is uploading / downloading, + # the transfer process should post messages to stdout + write_msg( + { + "event": "progress", + "oid": oid, + "bytesSoFar": (i + 1) * chunk_size, + "bytesSinceLast": chunk_size, + } + ) + # Not precise but that's ok. + + r = get_session().post( + completion_url, + json={ + "oid": oid, + "parts": parts, + }, + ) + hf_raise_for_status(r) + + write_msg({"event": "complete", "oid": oid}) diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/repo.py b/phivenv/Lib/site-packages/huggingface_hub/cli/repo.py new file mode 100644 index 0000000000000000000000000000000000000000..e11f316610830527c360353334b1d36b95e9171d --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/repo.py @@ -0,0 +1,247 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with repositories on the Hugging Face Hub. + +Usage: + # create a new dataset repo on the Hub + hf repo create my-cool-dataset --repo-type=dataset + + # create a private model repo on the Hub + hf repo create my-cool-model --private +""" + +import argparse +from argparse import _SubParsersAction +from typing import Optional + +from requests.exceptions import HTTPError + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.commands._cli_utils import ANSI +from huggingface_hub.constants import REPO_TYPES, SPACES_SDK_TYPES +from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import logging + + +logger = logging.get_logger(__name__) + + +class RepoCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + repo_parser = parser.add_parser("repo", help="Manage repos on the Hub.") + repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands") + + # Show help if no subcommand is provided + repo_parser.set_defaults(func=lambda args: repo_parser.print_help()) + + # CREATE + repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co") + repo_create_parser.add_argument( + "repo_id", + type=str, + help="The ID of the repo to create to (e.g. `username/repo-name`). The username is optional and will be set to your username if not provided.", + ) + repo_create_parser.add_argument( + "--repo-type", + type=str, + help='Optional: set to "dataset" or "space" if creating a dataset or space, default is model.', + ) + repo_create_parser.add_argument( + "--space_sdk", + type=str, + help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".', + choices=SPACES_SDK_TYPES, + ) + repo_create_parser.add_argument( + "--private", + action="store_true", + help="Whether to create a private repository. Defaults to public unless the organization's default is private.", + ) + repo_create_parser.add_argument( + "--token", + type=str, + help="Hugging Face token. Will default to the locally saved token if not provided.", + ) + repo_create_parser.add_argument( + "--exist-ok", + action="store_true", + help="Do not raise an error if repo already exists.", + ) + repo_create_parser.add_argument( + "--resource-group-id", + type=str, + help="Resource group in which to create the repo. Resource groups is only available for Enterprise Hub organizations.", + ) + repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args)) + + # TAG SUBCOMMANDS + repo_tag_parser = repo_subparsers.add_parser("tag", help="Manage tags for a repo on the Hub.") + tag_subparsers = repo_tag_parser.add_subparsers(help="Tag actions", dest="tag_action", required=True) + + # tag create + tag_create_parser = tag_subparsers.add_parser("create", help="Create a tag for a repo.") + tag_create_parser.add_argument( + "repo_id", type=str, help="The ID of the repo to tag (e.g. `username/repo-name`)." + ) + tag_create_parser.add_argument("tag", type=str, help="The name of the tag to create.") + tag_create_parser.add_argument("-m", "--message", type=str, help="The description of the tag to create.") + tag_create_parser.add_argument("--revision", type=str, help="The git revision to tag.") + tag_create_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens." + ) + tag_create_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Set the type of repository (model, dataset, or space).", + ) + tag_create_parser.set_defaults(func=lambda args: RepoTagCreateCommand(args)) + + # tag list + tag_list_parser = tag_subparsers.add_parser("list", help="List tags for a repo.") + tag_list_parser.add_argument( + "repo_id", type=str, help="The ID of the repo to list tags for (e.g. `username/repo-name`)." + ) + tag_list_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens." + ) + tag_list_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Set the type of repository (model, dataset, or space).", + ) + tag_list_parser.set_defaults(func=lambda args: RepoTagListCommand(args)) + + # tag delete + tag_delete_parser = tag_subparsers.add_parser("delete", help="Delete a tag from a repo.") + tag_delete_parser.add_argument( + "repo_id", type=str, help="The ID of the repo to delete the tag from (e.g. `username/repo-name`)." + ) + tag_delete_parser.add_argument("tag", type=str, help="The name of the tag to delete.") + tag_delete_parser.add_argument("-y", "--yes", action="store_true", help="Answer Yes to prompts automatically.") + tag_delete_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens." + ) + tag_delete_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Set the type of repository (model, dataset, or space).", + ) + tag_delete_parser.set_defaults(func=lambda args: RepoTagDeleteCommand(args)) + + +class RepoCreateCommand: + def __init__(self, args: argparse.Namespace): + self.repo_id: str = args.repo_id + self.repo_type: Optional[str] = args.repo_type + self.space_sdk: Optional[str] = args.space_sdk + self.private: bool = args.private + self.token: Optional[str] = args.token + self.exist_ok: bool = args.exist_ok + self.resource_group_id: Optional[str] = args.resource_group_id + self._api = HfApi() + + def run(self): + repo_url = self._api.create_repo( + repo_id=self.repo_id, + repo_type=self.repo_type, + private=self.private, + token=self.token, + exist_ok=self.exist_ok, + resource_group_id=self.resource_group_id, + space_sdk=self.space_sdk, + ) + print(f"Successfully created {ANSI.bold(repo_url.repo_id)} on the Hub.") + print(f"Your repo is now available at {ANSI.bold(repo_url)}") + + +class RepoTagCommand: + def __init__(self, args): + self.args = args + self.api = HfApi(token=getattr(args, "token", None)) + self.repo_id = args.repo_id + self.repo_type = getattr(args, "repo_type", "model") + if self.repo_type not in REPO_TYPES: + print("Invalid repo --repo-type") + exit(1) + + +class RepoTagCreateCommand(RepoTagCommand): + def run(self): + print(f"You are about to create tag {ANSI.bold(self.args.tag)} on {self.repo_type} {ANSI.bold(self.repo_id)}") + try: + self.api.create_tag( + repo_id=self.repo_id, + tag=self.args.tag, + tag_message=getattr(self.args, "message", None), + revision=getattr(self.args, "revision", None), + repo_type=self.repo_type, + ) + except RepositoryNotFoundError: + print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") + exit(1) + except RevisionNotFoundError: + print(f"Revision {ANSI.bold(getattr(self.args, 'revision', None))} not found.") + exit(1) + except HfHubHTTPError as e: + if e.response.status_code == 409: + print(f"Tag {ANSI.bold(self.args.tag)} already exists on {ANSI.bold(self.repo_id)}") + exit(1) + raise e + print(f"Tag {ANSI.bold(self.args.tag)} created on {ANSI.bold(self.repo_id)}") + + +class RepoTagListCommand(RepoTagCommand): + def run(self): + try: + refs = self.api.list_repo_refs( + repo_id=self.repo_id, + repo_type=self.repo_type, + ) + except RepositoryNotFoundError: + print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") + exit(1) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + if len(refs.tags) == 0: + print("No tags found") + exit(0) + print(f"Tags for {self.repo_type} {ANSI.bold(self.repo_id)}:") + for tag in refs.tags: + print(tag.name) + + +class RepoTagDeleteCommand(RepoTagCommand): + def run(self): + print(f"You are about to delete tag {ANSI.bold(self.args.tag)} on {self.repo_type} {ANSI.bold(self.repo_id)}") + if not getattr(self.args, "yes", False): + choice = input("Proceed? [Y/n] ").lower() + if choice not in ("", "y", "yes"): + print("Abort") + exit() + try: + self.api.delete_tag(repo_id=self.repo_id, tag=self.args.tag, repo_type=self.repo_type) + except RepositoryNotFoundError: + print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") + exit(1) + except RevisionNotFoundError: + print(f"Tag {ANSI.bold(self.args.tag)} not found on {ANSI.bold(self.repo_id)}") + exit(1) + print(f"Tag {ANSI.bold(self.args.tag)} deleted on {ANSI.bold(self.repo_id)}") diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/repo_files.py b/phivenv/Lib/site-packages/huggingface_hub/cli/repo_files.py new file mode 100644 index 0000000000000000000000000000000000000000..34fbeb09c23ef7b7a4619fed7352e3553a134a0d --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/repo_files.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to update or delete files in a repository using the CLI. + +Usage: + # delete all + hf repo-files delete "*" + + # delete single file + hf repo-files delete file.txt + + # delete single folder + hf repo-files delete folder/ + + # delete multiple + hf repo-files delete file.txt folder/ file2.txt + + # delete multiple patterns + hf repo-files delete file.txt "*.json" "folder/*.parquet" + + # delete from different revision / repo-type + hf repo-files delete file.txt --revision=refs/pr/1 --repo-type=dataset +""" + +from argparse import _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.hf_api import HfApi + + +logger = logging.get_logger(__name__) + + +class DeleteFilesSubCommand: + def __init__(self, args) -> None: + self.args = args + self.repo_id: str = args.repo_id + self.repo_type: Optional[str] = args.repo_type + self.revision: Optional[str] = args.revision + self.api: HfApi = HfApi(token=args.token, library_name="hf") + self.patterns: List[str] = args.patterns + self.commit_message: Optional[str] = args.commit_message + self.commit_description: Optional[str] = args.commit_description + self.create_pr: bool = args.create_pr + self.token: Optional[str] = args.token + + def run(self) -> None: + logging.set_verbosity_info() + url = self.api.delete_files( + delete_patterns=self.patterns, + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + commit_message=self.commit_message, + commit_description=self.commit_description, + create_pr=self.create_pr, + ) + print(f"Files correctly deleted from repo. Commit: {url}.") + logging.set_verbosity_warning() + + +class RepoFilesCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + repo_files_parser = parser.add_parser("repo-files", help="Manage files in a repo on the Hub.") + repo_files_subparsers = repo_files_parser.add_subparsers( + help="Action to execute against the files.", + required=True, + ) + delete_subparser = repo_files_subparsers.add_parser( + "delete", + help="Delete files from a repo on the Hub", + ) + delete_subparser.set_defaults(func=lambda args: DeleteFilesSubCommand(args)) + delete_subparser.add_argument( + "repo_id", type=str, help="The ID of the repo to manage (e.g. `username/repo-name`)." + ) + delete_subparser.add_argument( + "patterns", + nargs="+", + type=str, + help="Glob patterns to match files to delete.", + ) + delete_subparser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Type of the repo to upload to (e.g. `dataset`).", + ) + delete_subparser.add_argument( + "--revision", + type=str, + help=( + "An optional Git revision to push to. It can be a branch name " + "or a PR reference. If revision does not" + " exist and `--create-pr` is not set, a branch will be automatically created." + ), + ) + delete_subparser.add_argument( + "--commit-message", type=str, help="The summary / title / first line of the generated commit." + ) + delete_subparser.add_argument( + "--commit-description", type=str, help="The description of the generated commit." + ) + delete_subparser.add_argument( + "--create-pr", action="store_true", help="Whether to create a new Pull Request for these changes." + ) + delete_subparser.add_argument( + "--token", + type=str, + help="A User Access Token generated from https://huggingface.co/settings/tokens", + ) + + repo_files_parser.set_defaults(func=RepoFilesCommand) diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/system.py b/phivenv/Lib/site-packages/huggingface_hub/cli/system.py new file mode 100644 index 0000000000000000000000000000000000000000..03650175e9b71e329755de5c86e5bbf50569d4b7 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/system.py @@ -0,0 +1,52 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to print information about the environment and version. + +Usage: + hf env + hf version +""" + +from argparse import _SubParsersAction + +from huggingface_hub import __version__ + +from ..utils import dump_environment_info +from . import BaseHuggingfaceCLICommand + + +class EnvironmentCommand(BaseHuggingfaceCLICommand): + def __init__(self, args): + self.args = args + + @staticmethod + def register_subcommand(parser: _SubParsersAction): + env_parser = parser.add_parser("env", help="Print information about the environment.") + env_parser.set_defaults(func=EnvironmentCommand) + + def run(self) -> None: + dump_environment_info() + + +class VersionCommand(BaseHuggingfaceCLICommand): + def __init__(self, args): + self.args = args + + @staticmethod + def register_subcommand(parser: _SubParsersAction): + version_parser = parser.add_parser("version", help="Print information about the hf version.") + version_parser.set_defaults(func=VersionCommand) + + def run(self) -> None: + print(f"huggingface_hub version: {__version__}") diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/upload.py b/phivenv/Lib/site-packages/huggingface_hub/cli/upload.py new file mode 100644 index 0000000000000000000000000000000000000000..07ab79bf24faadaf08224a64ccfbbbfb4018d658 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/upload.py @@ -0,0 +1,316 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to upload a repo or file with the CLI. + +Usage: + # Upload file (implicit) + hf upload my-cool-model ./my-cool-model.safetensors + + # Upload file (explicit) + hf upload my-cool-model ./my-cool-model.safetensors model.safetensors + + # Upload directory (implicit). If `my-cool-model/` is a directory it will be uploaded, otherwise an exception is raised. + hf upload my-cool-model + + # Upload directory (explicit) + hf upload my-cool-model ./models/my-cool-model . + + # Upload filtered directory (example: tensorboard logs except for the last run) + hf upload my-cool-model ./model/training /logs --include "*.tfevents.*" --exclude "*20230905*" + + # Upload with wildcard + hf upload my-cool-model "./model/training/*.safetensors" + + # Upload private dataset + hf upload Wauplin/my-cool-dataset ./data . --repo-type=dataset --private + + # Upload with token + hf upload Wauplin/my-cool-model --token=hf_**** + + # Sync local Space with Hub (upload new files, delete removed files) + hf upload Wauplin/space-example --repo-type=space --exclude="/logs/*" --delete="*" --commit-message="Sync local Space with Hub" + + # Schedule commits every 30 minutes + hf upload Wauplin/my-cool-model --every=30 +""" + +import os +import time +import warnings +from argparse import Namespace, _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub._commit_scheduler import CommitScheduler +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER +from huggingface_hub.errors import RevisionNotFoundError +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import disable_progress_bars, enable_progress_bars +from huggingface_hub.utils._runtime import is_xet_available + + +logger = logging.get_logger(__name__) + + +class UploadCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + upload_parser = parser.add_parser( + "upload", help="Upload a file or a folder to the Hub. Recommended for single-commit uploads." + ) + upload_parser.add_argument( + "repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)." + ) + upload_parser.add_argument( + "local_path", + nargs="?", + help="Local path to the file or folder to upload. Wildcard patterns are supported. Defaults to current directory.", + ) + upload_parser.add_argument( + "path_in_repo", + nargs="?", + help="Path of the file or folder in the repo. Defaults to the relative path of the file or folder.", + ) + upload_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Type of the repo to upload to (e.g. `dataset`).", + ) + upload_parser.add_argument( + "--revision", + type=str, + help=( + "An optional Git revision to push to. It can be a branch name or a PR reference. If revision does not" + " exist and `--create-pr` is not set, a branch will be automatically created." + ), + ) + upload_parser.add_argument( + "--private", + action="store_true", + help=( + "Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already" + " exists." + ), + ) + upload_parser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.") + upload_parser.add_argument( + "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload." + ) + upload_parser.add_argument( + "--delete", + nargs="*", + type=str, + help="Glob patterns for file to be deleted from the repo while committing.", + ) + upload_parser.add_argument( + "--commit-message", type=str, help="The summary / title / first line of the generated commit." + ) + upload_parser.add_argument("--commit-description", type=str, help="The description of the generated commit.") + upload_parser.add_argument( + "--create-pr", action="store_true", help="Whether to upload content as a new Pull Request." + ) + upload_parser.add_argument( + "--every", + type=float, + help="If set, a background job is scheduled to create commits every `every` minutes.", + ) + upload_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + upload_parser.add_argument( + "--quiet", + action="store_true", + help="If True, progress bars are disabled and only the path to the uploaded files is printed.", + ) + upload_parser.set_defaults(func=UploadCommand) + + def __init__(self, args: Namespace) -> None: + self.repo_id: str = args.repo_id + self.repo_type: Optional[str] = args.repo_type + self.revision: Optional[str] = args.revision + self.private: bool = args.private + + self.include: Optional[List[str]] = args.include + self.exclude: Optional[List[str]] = args.exclude + self.delete: Optional[List[str]] = args.delete + + self.commit_message: Optional[str] = args.commit_message + self.commit_description: Optional[str] = args.commit_description + self.create_pr: bool = args.create_pr + self.api: HfApi = HfApi(token=args.token, library_name="hf") + self.quiet: bool = args.quiet # disable warnings and progress bars + + # Check `--every` is valid + if args.every is not None and args.every <= 0: + raise ValueError(f"`every` must be a positive value (got '{args.every}')") + self.every: Optional[float] = args.every + + # Resolve `local_path` and `path_in_repo` + repo_name: str = args.repo_id.split("/")[-1] # e.g. "Wauplin/my-cool-model" => "my-cool-model" + self.local_path: str + self.path_in_repo: str + + if args.local_path is not None and any(c in args.local_path for c in ["*", "?", "["]): + if args.include is not None: + raise ValueError("Cannot set `--include` when passing a `local_path` containing a wildcard.") + if args.path_in_repo is not None and args.path_in_repo != ".": + raise ValueError("Cannot set `path_in_repo` when passing a `local_path` containing a wildcard.") + self.local_path = "." + self.include = args.local_path + self.path_in_repo = "." + elif args.local_path is None and os.path.isfile(repo_name): + # Implicit case 1: user provided only a repo_id which happen to be a local file as well => upload it with same name + self.local_path = repo_name + self.path_in_repo = repo_name + elif args.local_path is None and os.path.isdir(repo_name): + # Implicit case 2: user provided only a repo_id which happen to be a local folder as well => upload it at root + self.local_path = repo_name + self.path_in_repo = "." + elif args.local_path is None: + # Implicit case 3: user provided only a repo_id that does not match a local file or folder + # => the user must explicitly provide a local_path => raise exception + raise ValueError(f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly.") + elif args.path_in_repo is None and os.path.isfile(args.local_path): + # Explicit local path to file, no path in repo => upload it at root with same name + self.local_path = args.local_path + self.path_in_repo = os.path.basename(args.local_path) + elif args.path_in_repo is None: + # Explicit local path to folder, no path in repo => upload at root + self.local_path = args.local_path + self.path_in_repo = "." + else: + # Finally, if both paths are explicit + self.local_path = args.local_path + self.path_in_repo = args.path_in_repo + + def run(self) -> None: + if self.quiet: + disable_progress_bars() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + print(self._upload()) + enable_progress_bars() + else: + logging.set_verbosity_info() + print(self._upload()) + logging.set_verbosity_warning() + + def _upload(self) -> str: + if os.path.isfile(self.local_path): + if self.include is not None and len(self.include) > 0: + warnings.warn("Ignoring `--include` since a single file is uploaded.") + if self.exclude is not None and len(self.exclude) > 0: + warnings.warn("Ignoring `--exclude` since a single file is uploaded.") + if self.delete is not None and len(self.delete) > 0: + warnings.warn("Ignoring `--delete` since a single file is uploaded.") + + if not is_xet_available() and not HF_HUB_ENABLE_HF_TRANSFER: + logger.info( + "Consider using `hf_transfer` for faster uploads. This solution comes with some limitations. See" + " https://huggingface.co/docs/huggingface_hub/hf_transfer for more details." + ) + + # Schedule commits if `every` is set + if self.every is not None: + if os.path.isfile(self.local_path): + # If file => watch entire folder + use allow_patterns + folder_path = os.path.dirname(self.local_path) + path_in_repo = ( + self.path_in_repo[: -len(self.local_path)] # remove filename from path_in_repo + if self.path_in_repo.endswith(self.local_path) + else self.path_in_repo + ) + allow_patterns = [self.local_path] + ignore_patterns = [] + else: + folder_path = self.local_path + path_in_repo = self.path_in_repo + allow_patterns = self.include or [] + ignore_patterns = self.exclude or [] + if self.delete is not None and len(self.delete) > 0: + warnings.warn("Ignoring `--delete` when uploading with scheduled commits.") + + scheduler = CommitScheduler( + folder_path=folder_path, + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + path_in_repo=path_in_repo, + private=self.private, + every=self.every, + hf_api=self.api, + ) + print(f"Scheduling commits every {self.every} minutes to {scheduler.repo_id}.") + try: # Block main thread until KeyboardInterrupt + while True: + time.sleep(100) + except KeyboardInterrupt: + scheduler.stop() + return "Stopped scheduled commits." + + # Otherwise, create repo and proceed with the upload + if not os.path.isfile(self.local_path) and not os.path.isdir(self.local_path): + raise FileNotFoundError(f"No such file or directory: '{self.local_path}'.") + repo_id = self.api.create_repo( + repo_id=self.repo_id, + repo_type=self.repo_type, + exist_ok=True, + private=self.private, + space_sdk="gradio" if self.repo_type == "space" else None, + # ^ We don't want it to fail when uploading to a Space => let's set Gradio by default. + # ^ I'd rather not add CLI args to set it explicitly as we already have `hf repo create` for that. + ).repo_id + + # Check if branch already exists and if not, create it + if self.revision is not None and not self.create_pr: + try: + self.api.repo_info(repo_id=repo_id, repo_type=self.repo_type, revision=self.revision) + except RevisionNotFoundError: + logger.info(f"Branch '{self.revision}' not found. Creating it...") + self.api.create_branch(repo_id=repo_id, repo_type=self.repo_type, branch=self.revision, exist_ok=True) + # ^ `exist_ok=True` to avoid race concurrency issues + + # File-based upload + if os.path.isfile(self.local_path): + return self.api.upload_file( + path_or_fileobj=self.local_path, + path_in_repo=self.path_in_repo, + repo_id=repo_id, + repo_type=self.repo_type, + revision=self.revision, + commit_message=self.commit_message, + commit_description=self.commit_description, + create_pr=self.create_pr, + ) + + # Folder-based upload + else: + return self.api.upload_folder( + folder_path=self.local_path, + path_in_repo=self.path_in_repo, + repo_id=repo_id, + repo_type=self.repo_type, + revision=self.revision, + commit_message=self.commit_message, + commit_description=self.commit_description, + create_pr=self.create_pr, + allow_patterns=self.include, + ignore_patterns=self.exclude, + delete_patterns=self.delete, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/upload_large_folder.py b/phivenv/Lib/site-packages/huggingface_hub/cli/upload_large_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..618cd21b52748d840710d632115df7d132de6115 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/cli/upload_large_folder.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to upload a large folder with the CLI.""" + +import os +from argparse import Namespace, _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import disable_progress_bars + +from ._cli_utils import ANSI + + +logger = logging.get_logger(__name__) + + +class UploadLargeFolderCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + subparser = parser.add_parser( + "upload-large-folder", + help="Upload a large folder to the Hub. Recommended for resumable uploads.", + ) + subparser.add_argument( + "repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)." + ) + subparser.add_argument("local_path", type=str, help="Local path to the file or folder to upload.") + subparser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + help="Type of the repo to upload to (e.g. `dataset`).", + ) + subparser.add_argument( + "--revision", + type=str, + help=("An optional Git revision to push to. It can be a branch name or a PR reference."), + ) + subparser.add_argument( + "--private", + action="store_true", + help=( + "Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already exists." + ), + ) + subparser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.") + subparser.add_argument("--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload.") + subparser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + subparser.add_argument( + "--num-workers", type=int, help="Number of workers to use to hash, upload and commit files." + ) + subparser.add_argument("--no-report", action="store_true", help="Whether to disable regular status report.") + subparser.add_argument("--no-bars", action="store_true", help="Whether to disable progress bars.") + subparser.set_defaults(func=UploadLargeFolderCommand) + + def __init__(self, args: Namespace) -> None: + self.repo_id: str = args.repo_id + self.local_path: str = args.local_path + self.repo_type: str = args.repo_type + self.revision: Optional[str] = args.revision + self.private: bool = args.private + + self.include: Optional[List[str]] = args.include + self.exclude: Optional[List[str]] = args.exclude + + self.api: HfApi = HfApi(token=args.token, library_name="hf") + + self.num_workers: Optional[int] = args.num_workers + self.no_report: bool = args.no_report + self.no_bars: bool = args.no_bars + + if not os.path.isdir(self.local_path): + raise ValueError("Large upload is only supported for folders.") + + def run(self) -> None: + logging.set_verbosity_info() + + print( + ANSI.yellow( + "You are about to upload a large folder to the Hub using `hf upload-large-folder`. " + "This is a new feature so feedback is very welcome!\n" + "\n" + "A few things to keep in mind:\n" + " - Repository limits still apply: https://huggingface.co/docs/hub/repositories-recommendations\n" + " - Do not start several processes in parallel.\n" + " - You can interrupt and resume the process at any time. " + "The script will pick up where it left off except for partially uploaded files that would have to be entirely reuploaded.\n" + " - Do not upload the same folder to several repositories. If you need to do so, you must delete the `./.cache/huggingface/` folder first.\n" + "\n" + f"Some temporary metadata will be stored under `{self.local_path}/.cache/huggingface`.\n" + " - You must not modify those files manually.\n" + " - You must not delete the `./.cache/huggingface/` folder while a process is running.\n" + " - You can delete the `./.cache/huggingface/` folder to reinitialize the upload state when process is not running. Files will have to be hashed and preuploaded again, except for already committed files.\n" + "\n" + "If the process output is too verbose, you can disable the progress bars with `--no-bars`. " + "You can also entirely disable the status report with `--no-report`.\n" + "\n" + "For more details, run `hf upload-large-folder --help` or check the documentation at " + "https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-large-folder." + ) + ) + + if self.no_bars: + disable_progress_bars() + + self.api.upload_large_folder( + repo_id=self.repo_id, + folder_path=self.local_path, + repo_type=self.repo_type, + revision=self.revision, + private=self.private, + allow_patterns=self.include, + ignore_patterns=self.exclude, + num_workers=self.num_workers, + print_report=not self.no_report, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/community.py b/phivenv/Lib/site-packages/huggingface_hub/community.py new file mode 100644 index 0000000000000000000000000000000000000000..16f2f02428dd5c2ce6437534af0397801bda45c5 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/community.py @@ -0,0 +1,355 @@ +""" +Data structures to interact with Discussions and Pull Requests on the Hub. + +See [the Discussions and Pull Requests guide](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) +for more information on Pull Requests, Discussions, and the community tab. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import List, Literal, Optional, Union + +from . import constants +from .utils import parse_datetime + + +DiscussionStatus = Literal["open", "closed", "merged", "draft"] + + +@dataclass +class Discussion: + """ + A Discussion or Pull Request on the Hub. + + This dataclass is not intended to be instantiated directly. + + Attributes: + title (`str`): + The title of the Discussion / Pull Request + status (`str`): + The status of the Discussion / Pull Request. + It must be one of: + * `"open"` + * `"closed"` + * `"merged"` (only for Pull Requests ) + * `"draft"` (only for Pull Requests ) + num (`int`): + The number of the Discussion / Pull Request. + repo_id (`str`): + The id (`"{namespace}/{repo_name}"`) of the repo on which + the Discussion / Pull Request was open. + repo_type (`str`): + The type of the repo on which the Discussion / Pull Request was open. + Possible values are: `"model"`, `"dataset"`, `"space"`. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + is_pull_request (`bool`): + Whether or not this is a Pull Request. + created_at (`datetime`): + The `datetime` of creation of the Discussion / Pull Request. + endpoint (`str`): + Endpoint of the Hub. Default is https://huggingface.co. + git_reference (`str`, *optional*): + (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise. + url (`str`): + (property) URL of the discussion on the Hub. + """ + + title: str + status: DiscussionStatus + num: int + repo_id: str + repo_type: str + author: str + is_pull_request: bool + created_at: datetime + endpoint: str + + @property + def git_reference(self) -> Optional[str]: + """ + If this is a Pull Request , returns the git reference to which changes can be pushed. + Returns `None` otherwise. + """ + if self.is_pull_request: + return f"refs/pr/{self.num}" + return None + + @property + def url(self) -> str: + """Returns the URL of the discussion on the Hub.""" + if self.repo_type is None or self.repo_type == constants.REPO_TYPE_MODEL: + return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}" + return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}" + + +@dataclass +class DiscussionWithDetails(Discussion): + """ + Subclass of [`Discussion`]. + + Attributes: + title (`str`): + The title of the Discussion / Pull Request + status (`str`): + The status of the Discussion / Pull Request. + It can be one of: + * `"open"` + * `"closed"` + * `"merged"` (only for Pull Requests ) + * `"draft"` (only for Pull Requests ) + num (`int`): + The number of the Discussion / Pull Request. + repo_id (`str`): + The id (`"{namespace}/{repo_name}"`) of the repo on which + the Discussion / Pull Request was open. + repo_type (`str`): + The type of the repo on which the Discussion / Pull Request was open. + Possible values are: `"model"`, `"dataset"`, `"space"`. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + is_pull_request (`bool`): + Whether or not this is a Pull Request. + created_at (`datetime`): + The `datetime` of creation of the Discussion / Pull Request. + events (`list` of [`DiscussionEvent`]) + The list of [`DiscussionEvents`] in this Discussion or Pull Request. + conflicting_files (`Union[List[str], bool, None]`, *optional*): + A list of conflicting files if this is a Pull Request. + `None` if `self.is_pull_request` is `False`. + `True` if there are conflicting files but the list can't be retrieved. + target_branch (`str`, *optional*): + The branch into which changes are to be merged if this is a + Pull Request . `None` if `self.is_pull_request` is `False`. + merge_commit_oid (`str`, *optional*): + If this is a merged Pull Request , this is set to the OID / SHA of + the merge commit, `None` otherwise. + diff (`str`, *optional*): + The git diff if this is a Pull Request , `None` otherwise. + endpoint (`str`): + Endpoint of the Hub. Default is https://huggingface.co. + git_reference (`str`, *optional*): + (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise. + url (`str`): + (property) URL of the discussion on the Hub. + """ + + events: List["DiscussionEvent"] + conflicting_files: Union[List[str], bool, None] + target_branch: Optional[str] + merge_commit_oid: Optional[str] + diff: Optional[str] + + +@dataclass +class DiscussionEvent: + """ + An event in a Discussion or Pull Request. + + Use concrete classes: + * [`DiscussionComment`] + * [`DiscussionStatusChange`] + * [`DiscussionCommit`] + * [`DiscussionTitleChange`] + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + """ + + id: str + type: str + created_at: datetime + author: str + + _event: dict + """Stores the original event data, in case we need to access it later.""" + + +@dataclass +class DiscussionComment(DiscussionEvent): + """A comment in a Discussion / Pull Request. + + Subclass of [`DiscussionEvent`]. + + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + content (`str`): + The raw markdown content of the comment. Mentions, links and images are not rendered. + edited (`bool`): + Whether or not this comment has been edited. + hidden (`bool`): + Whether or not this comment has been hidden. + """ + + content: str + edited: bool + hidden: bool + + @property + def rendered(self) -> str: + """The rendered comment, as a HTML string""" + return self._event["data"]["latest"]["html"] + + @property + def last_edited_at(self) -> datetime: + """The last edit time, as a `datetime` object.""" + return parse_datetime(self._event["data"]["latest"]["updatedAt"]) + + @property + def last_edited_by(self) -> str: + """The last edit time, as a `datetime` object.""" + return self._event["data"]["latest"].get("author", {}).get("name", "deleted") + + @property + def edit_history(self) -> List[dict]: + """The edit history of the comment""" + return self._event["data"]["history"] + + @property + def number_of_edits(self) -> int: + return len(self.edit_history) + + +@dataclass +class DiscussionStatusChange(DiscussionEvent): + """A change of status in a Discussion / Pull Request. + + Subclass of [`DiscussionEvent`]. + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + new_status (`str`): + The status of the Discussion / Pull Request after the change. + It can be one of: + * `"open"` + * `"closed"` + * `"merged"` (only for Pull Requests ) + """ + + new_status: str + + +@dataclass +class DiscussionCommit(DiscussionEvent): + """A commit in a Pull Request. + + Subclass of [`DiscussionEvent`]. + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + summary (`str`): + The summary of the commit. + oid (`str`): + The OID / SHA of the commit, as a hexadecimal string. + """ + + summary: str + oid: str + + +@dataclass +class DiscussionTitleChange(DiscussionEvent): + """A rename event in a Discussion / Pull Request. + + Subclass of [`DiscussionEvent`]. + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + old_title (`str`): + The previous title for the Discussion / Pull Request. + new_title (`str`): + The new title. + """ + + old_title: str + new_title: str + + +def deserialize_event(event: dict) -> DiscussionEvent: + """Instantiates a [`DiscussionEvent`] from a dict""" + event_id: str = event["id"] + event_type: str = event["type"] + created_at = parse_datetime(event["createdAt"]) + + common_args = dict( + id=event_id, + type=event_type, + created_at=created_at, + author=event.get("author", {}).get("name", "deleted"), + _event=event, + ) + + if event_type == "comment": + return DiscussionComment( + **common_args, + edited=event["data"]["edited"], + hidden=event["data"]["hidden"], + content=event["data"]["latest"]["raw"], + ) + if event_type == "status-change": + return DiscussionStatusChange( + **common_args, + new_status=event["data"]["status"], + ) + if event_type == "commit": + return DiscussionCommit( + **common_args, + summary=event["data"]["subject"], + oid=event["data"]["oid"], + ) + if event_type == "title-change": + return DiscussionTitleChange( + **common_args, + old_title=event["data"]["from"], + new_title=event["data"]["to"], + ) + + return DiscussionEvent(**common_args) diff --git a/phivenv/Lib/site-packages/huggingface_hub/constants.py b/phivenv/Lib/site-packages/huggingface_hub/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b30b2c01d99c5ee5428875f3711227024f5d0829 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/constants.py @@ -0,0 +1,294 @@ +import os +import re +import typing +from typing import Literal, Optional, Tuple + + +# Possible values for env variables + + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + + +def _is_true(value: Optional[str]) -> bool: + if value is None: + return False + return value.upper() in ENV_VARS_TRUE_VALUES + + +def _as_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + return int(value) + + +# Constants for file downloads + +PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" +TF2_WEIGHTS_NAME = "tf_model.h5" +TF_WEIGHTS_NAME = "model.ckpt" +FLAX_WEIGHTS_NAME = "flax_model.msgpack" +CONFIG_NAME = "config.json" +REPOCARD_NAME = "README.md" +DEFAULT_ETAG_TIMEOUT = 10 +DEFAULT_DOWNLOAD_TIMEOUT = 10 +DEFAULT_REQUEST_TIMEOUT = 10 +DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024 +HF_TRANSFER_CONCURRENCY = 100 +MAX_HTTP_DOWNLOAD_SIZE = 50 * 1000 * 1000 * 1000 # 50 GB + +# Constants for serialization + +PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead +SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors" +TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5" + +# Constants for safetensors repos + +SAFETENSORS_SINGLE_FILE = "model.safetensors" +SAFETENSORS_INDEX_FILE = "model.safetensors.index.json" +SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000 + +# Timeout of aquiring file lock and logging the attempt +FILELOCK_LOG_EVERY_SECONDS = 10 + +# Git-related constants + +DEFAULT_REVISION = "main" +REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}") + +HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/" + +_staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING")) + +_HF_DEFAULT_ENDPOINT = "https://huggingface.co" +_HF_DEFAULT_STAGING_ENDPOINT = "https://hub-ci.huggingface.co" +ENDPOINT = os.getenv("HF_ENDPOINT", _HF_DEFAULT_ENDPOINT).rstrip("/") +HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" + +if _staging_mode: + ENDPOINT = _HF_DEFAULT_STAGING_ENDPOINT + HUGGINGFACE_CO_URL_TEMPLATE = _HF_DEFAULT_STAGING_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" + +HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit" +HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag" +HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size" +HUGGINGFACE_HEADER_X_BILL_TO = "X-HF-Bill-To" + +INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co") + +# See https://huggingface.co/docs/inference-endpoints/index +INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2" +INFERENCE_CATALOG_ENDPOINT = "https://endpoints.huggingface.co/api/catalog" + +# See https://api.endpoints.huggingface.cloud/#post-/v2/endpoint/-namespace- +INFERENCE_ENDPOINT_IMAGE_KEYS = [ + "custom", + "huggingface", + "huggingfaceNeuron", + "llamacpp", + "tei", + "tgi", + "tgiNeuron", +] + +# Proxy for third-party providers +INFERENCE_PROXY_TEMPLATE = "https://router.huggingface.co/{provider}" + +REPO_ID_SEPARATOR = "--" +# ^ this substring is not allowed in repo_ids on hf.co +# and is the canonical one we use for serialization of repo ids elsewhere. + + +REPO_TYPE_DATASET = "dataset" +REPO_TYPE_SPACE = "space" +REPO_TYPE_MODEL = "model" +REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE] +SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"] + +REPO_TYPES_URL_PREFIXES = { + REPO_TYPE_DATASET: "datasets/", + REPO_TYPE_SPACE: "spaces/", +} +REPO_TYPES_MAPPING = { + "datasets": REPO_TYPE_DATASET, + "spaces": REPO_TYPE_SPACE, + "models": REPO_TYPE_MODEL, +} + +DiscussionTypeFilter = Literal["all", "discussion", "pull_request"] +DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter) +DiscussionStatusFilter = Literal["all", "open", "closed"] +DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter) + +# Webhook subscription types +WEBHOOK_DOMAIN_T = Literal["repo", "discussions"] + +# default cache +default_home = os.path.join(os.path.expanduser("~"), ".cache") +HF_HOME = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_HOME", + os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"), + ) + ) +) +hf_cache_home = HF_HOME # for backward compatibility. TODO: remove this in 1.0.0 + +default_cache_path = os.path.join(HF_HOME, "hub") +default_assets_cache_path = os.path.join(HF_HOME, "assets") + +# Legacy env variables +HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path) +HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path) + +# New env variables +HF_HUB_CACHE = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_HUB_CACHE", + HUGGINGFACE_HUB_CACHE, + ) + ) +) +HF_ASSETS_CACHE = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_ASSETS_CACHE", + HUGGINGFACE_ASSETS_CACHE, + ) + ) +) + +HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE")) + +# If set, log level will be set to DEBUG and all requests made to the Hub will be logged +# as curl commands for reproducibility. +HF_DEBUG = _is_true(os.environ.get("HF_DEBUG")) + +# Opt-out from telemetry requests +HF_HUB_DISABLE_TELEMETRY = ( + _is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable + or _is_true(os.environ.get("DISABLE_TELEMETRY")) + or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/ +) + +HF_TOKEN_PATH = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_TOKEN_PATH", + os.path.join(HF_HOME, "token"), + ) + ) +) +HF_STORED_TOKENS_PATH = os.path.join(os.path.dirname(HF_TOKEN_PATH), "stored_tokens") + +if _staging_mode: + # In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens + # In practice in `huggingface_hub` tests, we monkeypatch these values with temporary directories. The following + # lines are only used in third-party libraries tests (e.g. `transformers`, `diffusers`, etc.). + _staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging") + HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub") + HF_TOKEN_PATH = os.path.join(_staging_home, "token") + +# Here, `True` will disable progress bars globally without possibility of enabling it +# programmatically. `False` will enable them without possibility of disabling them. +# If environment variable is not set (None), then the user is free to enable/disable +# them programmatically. +# TL;DR: env variable has priority over code +__HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS") +HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = ( + _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None +) + +# Disable warning on machines that do not support symlinks (e.g. Windows non-developer) +HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING")) + +# Disable warning when using experimental features +HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING")) + +# Disable sending the cached token by default is all HTTP requests to the Hub +HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN")) + +# Enable fast-download using external dependency "hf_transfer" +# See: +# - https://pypi.org/project/hf-transfer/ +# - https://github.com/huggingface/hf_transfer (private) +HF_HUB_ENABLE_HF_TRANSFER: bool = _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")) + + +# UNUSED +# We don't use symlinks in local dir anymore. +HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD: int = ( + _as_int(os.environ.get("HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD")) or 5 * 1024 * 1024 +) + +# Used to override the etag timeout on a system level +HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT + +# Used to override the get request timeout on a system level +HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT + +# Allows to add information about the requester in the user-agent (eg. partner name) +HF_HUB_USER_AGENT_ORIGIN: Optional[str] = os.environ.get("HF_HUB_USER_AGENT_ORIGIN") + +# List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are +# deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by +# default. We still keep the full list of supported frameworks in case we want to scan all of them. +MAIN_INFERENCE_API_FRAMEWORKS = [ + "diffusers", + "sentence-transformers", + "text-generation-inference", + "transformers", +] + +ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [ + "adapter-transformers", + "allennlp", + "asteroid", + "bertopic", + "doctr", + "espnet", + "fairseq", + "fastai", + "fasttext", + "flair", + "k2", + "keras", + "mindspore", + "nemo", + "open_clip", + "paddlenlp", + "peft", + "pyannote-audio", + "sklearn", + "spacy", + "span-marker", + "speechbrain", + "stanza", + "timm", +] + +# If OAuth didn't work after 2 redirects, there's likely a third-party cookie issue in the Space iframe view. +# In this case, we redirect the user to the non-iframe view. +OAUTH_MAX_REDIRECTS = 2 + +# OAuth-related environment variables injected by the Space +OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID") +OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET") +OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES") +OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL") + +# Xet constants +HUGGINGFACE_HEADER_X_XET_ENDPOINT = "X-Xet-Cas-Url" +HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN = "X-Xet-Access-Token" +HUGGINGFACE_HEADER_X_XET_EXPIRATION = "X-Xet-Token-Expiration" +HUGGINGFACE_HEADER_X_XET_HASH = "X-Xet-Hash" +HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE = "X-Xet-Refresh-Route" +HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY = "xet-auth" + +default_xet_cache_path = os.path.join(HF_HOME, "xet") +HF_XET_CACHE = os.getenv("HF_XET_CACHE", default_xet_cache_path) +HF_HUB_DISABLE_XET: bool = _is_true(os.environ.get("HF_HUB_DISABLE_XET")) diff --git a/phivenv/Lib/site-packages/huggingface_hub/dataclasses.py b/phivenv/Lib/site-packages/huggingface_hub/dataclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..bf81c522d527c9078c611a4ac73f53217bed0879 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/dataclasses.py @@ -0,0 +1,481 @@ +import inspect +from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields +from functools import wraps +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, + overload, +) + +from .errors import ( + StrictDataclassClassValidationError, + StrictDataclassDefinitionError, + StrictDataclassFieldValidationError, +) + + +Validator_T = Callable[[Any], None] +T = TypeVar("T") + + +# The overload decorator helps type checkers understand the different return types +@overload +def strict(cls: Type[T]) -> Type[T]: ... + + +@overload +def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ... + + +def strict( + cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False +) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + """ + Decorator to add strict validation to a dataclass. + + This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools + recognize the class as a dataclass. + + Can be used with or without arguments: + - `@strict` + - `@strict(accept_kwargs=True)` + + Args: + cls: + The class to convert to a strict dataclass. + accept_kwargs (`bool`, *optional*): + If True, allows arbitrary keyword arguments in `__init__`. Defaults to False. + + Returns: + The enhanced dataclass with strict validation on field assignment. + + Example: + ```py + >>> from dataclasses import dataclass + >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field + + >>> @as_validated_field + >>> def positive_int(value: int): + ... if not value >= 0: + ... raise ValueError(f"Value must be positive, got {value}") + + >>> @strict(accept_kwargs=True) + ... @dataclass + ... class User: + ... name: str + ... age: int = positive_int(default=10) + + # Initialize + >>> User(name="John") + User(name='John', age=10) + + # Extra kwargs are accepted + >>> User(name="John", age=30, lastname="Doe") + User(name='John', age=30, *lastname='Doe') + + # Invalid type => raises + >>> User(name="John", age="30") + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + TypeError: Field 'age' expected int, got str (value: '30') + + # Invalid value => raises + >>> User(name="John", age=-1) + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + ValueError: Value must be positive, got -1 + ``` + """ + + def wrap(cls: Type[T]) -> Type[T]: + if not hasattr(cls, "__dataclass_fields__"): + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' must be a dataclass before applying @strict." + ) + + # List and store validators + field_validators: Dict[str, List[Validator_T]] = {} + for f in fields(cls): # type: ignore [arg-type] + validators = [] + validators.append(_create_type_validator(f)) + custom_validator = f.metadata.get("validator") + if custom_validator is not None: + if not isinstance(custom_validator, list): + custom_validator = [custom_validator] + for validator in custom_validator: + if not _is_validator(validator): + raise StrictDataclassDefinitionError( + f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument." + ) + validators.extend(custom_validator) + field_validators[f.name] = validators + cls.__validators__ = field_validators # type: ignore + + # Override __setattr__ to validate fields on assignment + original_setattr = cls.__setattr__ + + def __strict_setattr__(self: Any, name: str, value: Any) -> None: + """Custom __setattr__ method for strict dataclasses.""" + # Run all validators + for validator in self.__validators__.get(name, []): + try: + validator(value) + except (ValueError, TypeError) as e: + raise StrictDataclassFieldValidationError(field=name, cause=e) from e + + # If validation passed, set the attribute + original_setattr(self, name, value) + + cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign] + + if accept_kwargs: + # (optional) Override __init__ to accept arbitrary keyword arguments + original_init = cls.__init__ + + @wraps(original_init) + def __init__(self, **kwargs: Any) -> None: + # Extract only the fields that are part of the dataclass + dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type] + standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields} + + # Call the original __init__ with standard fields + original_init(self, **standard_kwargs) + + # Add any additional kwargs as attributes + for name, value in kwargs.items(): + if name not in dataclass_fields: + self.__setattr__(name, value) + + cls.__init__ = __init__ # type: ignore[method-assign] + + # (optional) Override __repr__ to include additional kwargs + original_repr = cls.__repr__ + + @wraps(original_repr) + def __repr__(self) -> str: + # Call the original __repr__ to get the standard fields + standard_repr = original_repr(self) + + # Get additional kwargs + additional_kwargs = [ + # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass + f"*{k}={v!r}" + for k, v in self.__dict__.items() + if k not in cls.__dataclass_fields__ # type: ignore [attr-defined] + ] + additional_repr = ", ".join(additional_kwargs) + + # Combine both representations + return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr + + cls.__repr__ = __repr__ # type: ignore [method-assign] + + # List all public methods starting with `validate_` => class validators. + class_validators = [] + + for name in dir(cls): + if not name.startswith("validate_"): + continue + method = getattr(cls, name) + if not callable(method): + continue + if len(inspect.signature(method).parameters) != 1: + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument." + " Class validators must take only 'self' as an argument. Methods starting with 'validate_'" + " are considered to be class validators." + ) + class_validators.append(method) + + cls.__class_validators__ = class_validators # type: ignore [attr-defined] + + # Add `validate` method to the class, but first check if it already exists + def validate(self: T) -> None: + """Run class validators on the instance.""" + for validator in cls.__class_validators__: # type: ignore [attr-defined] + try: + validator(self) + except (ValueError, TypeError) as e: + raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e + + # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class + # (in which case we just override it) + validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined] + + if hasattr(cls, "validate"): + if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined] + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' already implements a method called 'validate'." + " This method name is reserved when using the @strict decorator on a dataclass." + " If you want to keep your own method, please rename it." + ) + + cls.validate = validate # type: ignore + + # Run class validators after initialization + initial_init = cls.__init__ + + @wraps(initial_init) + def init_with_validate(self, *args, **kwargs) -> None: + """Run class validators after initialization.""" + initial_init(self, *args, **kwargs) # type: ignore [call-arg] + cls.validate(self) # type: ignore [attr-defined] + + setattr(cls, "__init__", init_with_validate) + + return cls + + # Return wrapped class or the decorator itself + return wrap(cls) if cls is not None else wrap + + +def validated_field( + validator: Union[List[Validator_T], Validator_T], + default: Union[Any, _MISSING_TYPE] = MISSING, + default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Dict] = None, + **kwargs: Any, +) -> Any: + """ + Create a dataclass field with a custom validator. + + Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator. + + Args: + validator (`Callable` or `List[Callable]`): + A method that takes a value as input and raises ValueError/TypeError if the value is invalid. + Can be a list of validators to apply multiple checks. + **kwargs: + Additional arguments to pass to `dataclasses.field()`. + + Returns: + A field with the validator attached in metadata + """ + if not isinstance(validator, list): + validator = [validator] + if metadata is None: + metadata = {} + metadata["validator"] = validator + return field( # type: ignore + default=default, # type: ignore [arg-type] + default_factory=default_factory, # type: ignore [arg-type] + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + **kwargs, + ) + + +def as_validated_field(validator: Validator_T): + """ + Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator). + + Args: + validator (`Callable`): + A method that takes a value as input and raises ValueError/TypeError if the value is invalid. + """ + + def _inner( + default: Union[Any, _MISSING_TYPE] = MISSING, + default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Dict] = None, + **kwargs: Any, + ): + return validated_field( + validator, + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + **kwargs, + ) + + return _inner + + +def type_validator(name: str, value: Any, expected_type: Any) -> None: + """Validate that 'value' matches 'expected_type'.""" + origin = get_origin(expected_type) + args = get_args(expected_type) + + if expected_type is Any: + return + elif validator := _BASIC_TYPE_VALIDATORS.get(origin): + validator(name, value, args) + elif isinstance(expected_type, type): # simple types + _validate_simple_type(name, value, expected_type) + else: + raise TypeError(f"Unsupported type for field '{name}': {expected_type}") + + +def _validate_union(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate that value matches one of the types in a Union.""" + errors = [] + for t in args: + try: + type_validator(name, value, t) + return # Valid if any type matches + except TypeError as e: + errors.append(str(e)) + + raise TypeError( + f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}" + ) + + +def _validate_literal(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Literal type.""" + if value not in args: + raise TypeError(f"Field '{name}' expected one of {args}, got {value}") + + +def _validate_list(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate List[T] type.""" + if not isinstance(value, list): + raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}") + + # Validate each item in the list + item_type = args[0] + for i, item in enumerate(value): + try: + type_validator(f"{name}[{i}]", item, item_type) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in list '{name}'") from e + + +def _validate_dict(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Dict[K, V] type.""" + if not isinstance(value, dict): + raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}") + + # Validate keys and values + key_type, value_type = args + for k, v in value.items(): + try: + type_validator(f"{name}.key", k, key_type) + type_validator(f"{name}[{k!r}]", v, value_type) + except TypeError as e: + raise TypeError(f"Invalid key or value in dict '{name}'") from e + + +def _validate_tuple(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Tuple type.""" + if not isinstance(value, tuple): + raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}") + + # Handle variable-length tuples: Tuple[T, ...] + if len(args) == 2 and args[1] is Ellipsis: + for i, item in enumerate(value): + try: + type_validator(f"{name}[{i}]", item, args[0]) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e + # Handle fixed-length tuples: Tuple[T1, T2, ...] + elif len(args) != len(value): + raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}") + else: + for i, (item, expected) in enumerate(zip(value, args)): + try: + type_validator(f"{name}[{i}]", item, expected) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e + + +def _validate_set(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Set[T] type.""" + if not isinstance(value, set): + raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}") + + # Validate each item in the set + item_type = args[0] + for i, item in enumerate(value): + try: + type_validator(f"{name} item", item, item_type) + except TypeError as e: + raise TypeError(f"Invalid item in set '{name}'") from e + + +def _validate_simple_type(name: str, value: Any, expected_type: type) -> None: + """Validate simple type (int, str, etc.).""" + if not isinstance(value, expected_type): + raise TypeError( + f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})" + ) + + +def _create_type_validator(field: Field) -> Validator_T: + """Create a type validator function for a field.""" + # Hacky: we cannot use a lambda here because of reference issues + + def validator(value: Any) -> None: + type_validator(field.name, value, field.type) + + return validator + + +def _is_validator(validator: Any) -> bool: + """Check if a function is a validator. + + A validator is a Callable that can be called with a single positional argument. + The validator can have more arguments with default values. + + Basically, returns True if `validator(value)` is possible. + """ + if not callable(validator): + return False + + signature = inspect.signature(validator) + parameters = list(signature.parameters.values()) + if len(parameters) == 0: + return False + if parameters[0].kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.VAR_POSITIONAL, + ): + return False + for parameter in parameters[1:]: + if parameter.default == inspect.Parameter.empty: + return False + return True + + +_BASIC_TYPE_VALIDATORS = { + Union: _validate_union, + Literal: _validate_literal, + list: _validate_list, + dict: _validate_dict, + tuple: _validate_tuple, + set: _validate_set, +} + + +__all__ = [ + "strict", + "validated_field", + "Validator_T", + "StrictDataclassClassValidationError", + "StrictDataclassDefinitionError", + "StrictDataclassFieldValidationError", +] diff --git a/phivenv/Lib/site-packages/huggingface_hub/errors.py b/phivenv/Lib/site-packages/huggingface_hub/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f7ed80e35a7cbe1dcc0f21dfa0354e467676f3 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/errors.py @@ -0,0 +1,377 @@ +"""Contains all custom errors.""" + +from pathlib import Path +from typing import Optional, Union + +from requests import HTTPError, Response + + +# CACHE ERRORS + + +class CacheNotFound(Exception): + """Exception thrown when the Huggingface cache is not found.""" + + cache_dir: Union[str, Path] + + def __init__(self, msg: str, cache_dir: Union[str, Path], *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.cache_dir = cache_dir + + +class CorruptedCacheException(Exception): + """Exception for any unexpected structure in the Huggingface cache-system.""" + + +# HEADERS ERRORS + + +class LocalTokenNotFoundError(EnvironmentError): + """Raised if local token is required but not found.""" + + +# HTTP ERRORS + + +class OfflineModeIsEnabled(ConnectionError): + """Raised when a request is made but `HF_HUB_OFFLINE=1` is set as environment variable.""" + + +class HfHubHTTPError(HTTPError): + """ + HTTPError to inherit from for any custom HTTP Error raised in HF Hub. + + Any HTTPError is converted at least into a `HfHubHTTPError`. If some information is + sent back by the server, it will be added to the error message. + + Added details: + - Request id from "X-Request-Id" header if exists. If not, fallback to "X-Amzn-Trace-Id" header if exists. + - Server error message from the header "X-Error-Message". + - Server error message if we can found one in the response body. + + Example: + ```py + import requests + from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError + + response = get_session().post(...) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + print(str(e)) # formatted message + e.request_id, e.server_message # details returned by server + + # Complete the error message with additional information once it's raised + e.append_to_message("\n`create_commit` expects the repository to exist.") + raise + ``` + """ + + def __init__(self, message: str, response: Optional[Response] = None, *, server_message: Optional[str] = None): + self.request_id = ( + response.headers.get("x-request-id") or response.headers.get("X-Amzn-Trace-Id") + if response is not None + else None + ) + self.server_message = server_message + + super().__init__( + message, + response=response, # type: ignore [arg-type] + request=response.request if response is not None else None, # type: ignore [arg-type] + ) + + def append_to_message(self, additional_message: str) -> None: + """Append additional information to the `HfHubHTTPError` initial message.""" + self.args = (self.args[0] + additional_message,) + self.args[1:] + + +# INFERENCE CLIENT ERRORS + + +class InferenceTimeoutError(HTTPError, TimeoutError): + """Error raised when a model is unavailable or the request times out.""" + + +# INFERENCE ENDPOINT ERRORS + + +class InferenceEndpointError(Exception): + """Generic exception when dealing with Inference Endpoints.""" + + +class InferenceEndpointTimeoutError(InferenceEndpointError, TimeoutError): + """Exception for timeouts while waiting for Inference Endpoint.""" + + +# SAFETENSORS ERRORS + + +class SafetensorsParsingError(Exception): + """Raised when failing to parse a safetensors file metadata. + + This can be the case if the file is not a safetensors file or does not respect the specification. + """ + + +class NotASafetensorsRepoError(Exception): + """Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a + `model.safetensors.index.json` file. + """ + + +# TEXT GENERATION ERRORS + + +class TextGenerationError(HTTPError): + """Generic error raised if text-generation went wrong.""" + + +# Text Generation Inference Errors +class ValidationError(TextGenerationError): + """Server-side validation error.""" + + +class GenerationError(TextGenerationError): + pass + + +class OverloadedError(TextGenerationError): + pass + + +class IncompleteGenerationError(TextGenerationError): + pass + + +class UnknownError(TextGenerationError): + pass + + +# VALIDATION ERRORS + + +class HFValidationError(ValueError): + """Generic exception thrown by `huggingface_hub` validators. + + Inherits from [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError). + """ + + +# FILE METADATA ERRORS + + +class FileMetadataError(OSError): + """Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash). + + Inherits from `OSError` for backward compatibility. + """ + + +# REPOSITORY ERRORS + + +class RepositoryNotFoundError(HfHubHTTPError): + """ + Raised when trying to access a hf.co URL with an invalid repository name, or + with a private repo name the user does not have access to. + + Example: + + ```py + >>> from huggingface_hub import model_info + >>> model_info("") + (...) + huggingface_hub.utils._errors.RepositoryNotFoundError: 401 Client Error. (Request ID: PvMw_VjBMjVdMz53WKIzP) + + Repository Not Found for url: https://huggingface.co/api/models/%3Cnon_existent_repository%3E. + Please make sure you specified the correct `repo_id` and `repo_type`. + If the repo is private, make sure you are authenticated. + Invalid username or password. + ``` + """ + + +class GatedRepoError(RepositoryNotFoundError): + """ + Raised when trying to access a gated repository for which the user is not on the + authorized list. + + Note: derives from `RepositoryNotFoundError` to ensure backward compatibility. + + Example: + + ```py + >>> from huggingface_hub import model_info + >>> model_info("") + (...) + huggingface_hub.utils._errors.GatedRepoError: 403 Client Error. (Request ID: ViT1Bf7O_026LGSQuVqfa) + + Cannot access gated repo for url https://huggingface.co/api/models/ardent-figment/gated-model. + Access to model ardent-figment/gated-model is restricted and you are not in the authorized list. + Visit https://huggingface.co/ardent-figment/gated-model to ask for access. + ``` + """ + + +class DisabledRepoError(HfHubHTTPError): + """ + Raised when trying to access a repository that has been disabled by its author. + + Example: + + ```py + >>> from huggingface_hub import dataset_info + >>> dataset_info("laion/laion-art") + (...) + huggingface_hub.utils._errors.DisabledRepoError: 403 Client Error. (Request ID: Root=1-659fc3fa-3031673e0f92c71a2260dbe2;bc6f4dfb-b30a-4862-af0a-5cfe827610d8) + + Cannot access repository for url https://huggingface.co/api/datasets/laion/laion-art. + Access to this resource is disabled. + ``` + """ + + +# REVISION ERROR + + +class RevisionNotFoundError(HfHubHTTPError): + """ + Raised when trying to access a hf.co URL with a valid repository but an invalid + revision. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', 'config.json', revision='') + (...) + huggingface_hub.utils._errors.RevisionNotFoundError: 404 Client Error. (Request ID: Mwhe_c3Kt650GcdKEFomX) + + Revision Not Found for url: https://huggingface.co/bert-base-cased/resolve/%3Cnon-existent-revision%3E/config.json. + ``` + """ + + +# ENTRY ERRORS +class EntryNotFoundError(HfHubHTTPError): + """ + Raised when trying to access a hf.co URL with a valid repository and revision + but an invalid filename. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', '') + (...) + huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: 53pNl6M0MxsnG5Sw8JA6x) + + Entry Not Found for url: https://huggingface.co/bert-base-cased/resolve/main/%3Cnon-existent-file%3E. + ``` + """ + + +class LocalEntryNotFoundError(EntryNotFoundError, FileNotFoundError, ValueError): + """ + Raised when trying to access a file or snapshot that is not on the disk when network is + disabled or unavailable (connection issue). The entry may exist on the Hub. + + Note: `ValueError` type is to ensure backward compatibility. + Note: `LocalEntryNotFoundError` derives from `HTTPError` because of `EntryNotFoundError` + even when it is not a network issue. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', '', local_files_only=True) + (...) + huggingface_hub.utils._errors.LocalEntryNotFoundError: Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co look-ups and downloads online, set 'local_files_only' to False. + ``` + """ + + def __init__(self, message: str): + super().__init__(message, response=None) + + +# REQUEST ERROR +class BadRequestError(HfHubHTTPError, ValueError): + """ + Raised by `hf_raise_for_status` when the server returns a HTTP 400 error. + + Example: + + ```py + >>> resp = requests.post("hf.co/api/check", ...) + >>> hf_raise_for_status(resp, endpoint_name="check") + huggingface_hub.utils._errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX) + ``` + """ + + +# DDUF file format ERROR + + +class DDUFError(Exception): + """Base exception for errors related to the DDUF format.""" + + +class DDUFCorruptedFileError(DDUFError): + """Exception thrown when the DDUF file is corrupted.""" + + +class DDUFExportError(DDUFError): + """Base exception for errors during DDUF export.""" + + +class DDUFInvalidEntryNameError(DDUFExportError): + """Exception thrown when the entry name is invalid.""" + + +# STRICT DATACLASSES ERRORS + + +class StrictDataclassError(Exception): + """Base exception for strict dataclasses.""" + + +class StrictDataclassDefinitionError(StrictDataclassError): + """Exception thrown when a strict dataclass is defined incorrectly.""" + + +class StrictDataclassFieldValidationError(StrictDataclassError): + """Exception thrown when a strict dataclass fails validation for a given field.""" + + def __init__(self, field: str, cause: Exception): + error_message = f"Validation error for field '{field}':" + error_message += f"\n {cause.__class__.__name__}: {cause}" + super().__init__(error_message) + + +class StrictDataclassClassValidationError(StrictDataclassError): + """Exception thrown when a strict dataclass fails validation on a class validator.""" + + def __init__(self, validator: str, cause: Exception): + error_message = f"Class validation error for validator '{validator}':" + error_message += f"\n {cause.__class__.__name__}: {cause}" + super().__init__(error_message) + + +# XET ERRORS + + +class XetError(Exception): + """Base exception for errors related to Xet Storage.""" + + +class XetAuthorizationError(XetError): + """Exception thrown when the user does not have the right authorization to use Xet Storage.""" + + +class XetRefreshTokenError(XetError): + """Exception thrown when the refresh token is invalid.""" + + +class XetDownloadError(Exception): + """Exception thrown when the download from Xet Storage fails.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/fastai_utils.py b/phivenv/Lib/site-packages/huggingface_hub/fastai_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e75eba2a8baee7bdeb8d36a1c06bd950cf857c44 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/fastai_utils.py @@ -0,0 +1,425 @@ +import json +import os +from pathlib import Path +from pickle import DEFAULT_PROTOCOL, PicklingError +from typing import Any, Dict, List, Optional, Union + +from packaging import version + +from huggingface_hub import constants, snapshot_download +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import ( + SoftTemporaryDirectory, + get_fastai_version, + get_fastcore_version, + get_python_version, +) + +from .utils import logging, validate_hf_hub_args +from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility... + + +logger = logging.get_logger(__name__) + + +def _check_fastai_fastcore_versions( + fastai_min_version: str = "2.4", + fastcore_min_version: str = "1.3.27", +): + """ + Checks that the installed fastai and fastcore versions are compatible for pickle serialization. + + Args: + fastai_min_version (`str`, *optional*): + The minimum fastai version supported. + fastcore_min_version (`str`, *optional*): + The minimum fastcore version supported. + + + Raises the following error: + + - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + if the fastai or fastcore libraries are not available or are of an invalid version. + + + """ + + if (get_fastcore_version() or get_fastai_version()) == "N/A": + raise ImportError( + f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are" + f" required. Currently using fastai=={get_fastai_version()} and" + f" fastcore=={get_fastcore_version()}." + ) + + current_fastai_version = version.Version(get_fastai_version()) + current_fastcore_version = version.Version(get_fastcore_version()) + + if current_fastai_version < version.Version(fastai_min_version): + raise ImportError( + "`push_to_hub_fastai` and `from_pretrained_fastai` require a" + f" fastai>={fastai_min_version} version, but you are using fastai version" + f" {get_fastai_version()} which is incompatible. Upgrade with `pip install" + " fastai==2.5.6`." + ) + + if current_fastcore_version < version.Version(fastcore_min_version): + raise ImportError( + "`push_to_hub_fastai` and `from_pretrained_fastai` require a" + f" fastcore>={fastcore_min_version} version, but you are using fastcore" + f" version {get_fastcore_version()} which is incompatible. Upgrade with" + " `pip install fastcore==1.3.27`." + ) + + +def _check_fastai_fastcore_pyproject_versions( + storage_folder: str, + fastai_min_version: str = "2.4", + fastcore_min_version: str = "1.3.27", +): + """ + Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions + that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist + or does not contain versions for fastai and fastcore, then it logs a warning. + + Args: + storage_folder (`str`): + Folder to look for the `pyproject.toml` file. + fastai_min_version (`str`, *optional*): + The minimum fastai version supported. + fastcore_min_version (`str`, *optional*): + The minimum fastcore version supported. + + + Raises the following errors: + + - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + if the `toml` module is not installed. + - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore. + + + """ + + try: + import toml + except ModuleNotFoundError: + raise ImportError( + "`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module." + " Install it with `pip install toml`." + ) + + # Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages. + if not os.path.isfile(f"{storage_folder}/pyproject.toml"): + logger.warning( + "There is no `pyproject.toml` in the repository that contains the fastai" + " `Learner`. The `pyproject.toml` would allow us to verify that your fastai" + " and fastcore versions are compatible with those of the model you want to" + " load." + ) + return + pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml") + + if "build-system" not in pyproject_toml.keys(): + logger.warning( + "There is no `build-system` section in the pyproject.toml of the repository" + " that contains the fastai `Learner`. The `build-system` would allow us to" + " verify that your fastai and fastcore versions are compatible with those" + " of the model you want to load." + ) + return + build_system_toml = pyproject_toml["build-system"] + + if "requires" not in build_system_toml.keys(): + logger.warning( + "There is no `requires` section in the pyproject.toml of the repository" + " that contains the fastai `Learner`. The `requires` would allow us to" + " verify that your fastai and fastcore versions are compatible with those" + " of the model you want to load." + ) + return + package_versions = build_system_toml["requires"] + + # Extracts contains fastai and fastcore versions from `pyproject.toml` if available. + # If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest. + fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")] + if len(fastai_packages) == 0: + logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.") + # fastai_version is an empty string if not specified + else: + fastai_version = str(fastai_packages[0]).partition("=")[2] + if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version): + raise ImportError( + "`from_pretrained_fastai` requires" + f" fastai>={fastai_min_version} version but the model to load uses" + f" {fastai_version} which is incompatible." + ) + + fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")] + if len(fastcore_packages) == 0: + logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.") + # fastcore_version is an empty string if not specified + else: + fastcore_version = str(fastcore_packages[0]).partition("=")[2] + if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version): + raise ImportError( + "`from_pretrained_fastai` requires" + f" fastcore>={fastcore_min_version} version, but you are using fastcore" + f" version {fastcore_version} which is incompatible." + ) + + +README_TEMPLATE = """--- +tags: +- fastai +--- + +# Amazing! + +🥳 Congratulations on hosting your fastai model on the Hugging Face Hub! + +# Some next steps +1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))! + +2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)). + +3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)! + +Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card. + + +--- + + +# Model card + +## Model description +More information needed + +## Intended uses & limitations +More information needed + +## Training and evaluation data +More information needed +""" + +PYPROJECT_TEMPLATE = f"""[build-system] +requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"] +build-backend = "setuptools.build_meta:__legacy__" +""" + + +def _create_model_card(repo_dir: Path): + """ + Creates a model card for the repository. + + Args: + repo_dir (`Path`): + Directory where model card is created. + """ + readme_path = repo_dir / "README.md" + + if not readme_path.exists(): + with readme_path.open("w", encoding="utf-8") as f: + f.write(README_TEMPLATE) + + +def _create_model_pyproject(repo_dir: Path): + """ + Creates a `pyproject.toml` for the repository. + + Args: + repo_dir (`Path`): + Directory where `pyproject.toml` is created. + """ + pyproject_path = repo_dir / "pyproject.toml" + + if not pyproject_path.exists(): + with pyproject_path.open("w", encoding="utf-8") as f: + f.write(PYPROJECT_TEMPLATE) + + +def _save_pretrained_fastai( + learner, + save_directory: Union[str, Path], + config: Optional[Dict[str, Any]] = None, +): + """ + Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used. + + Args: + learner (`Learner`): + The `fastai.Learner` you'd like to save. + save_directory (`str` or `Path`): + Specific directory in which you want to save the fastai learner. + config (`dict`, *optional*): + Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'. + + + + Raises the following error: + + - [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError) + if the config file provided is not a dictionary. + + + """ + _check_fastai_fastcore_versions() + + os.makedirs(save_directory, exist_ok=True) + + # if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE. + if config is not None: + if not isinstance(config, dict): + raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'") + path = os.path.join(save_directory, constants.CONFIG_NAME) + with open(path, "w") as f: + json.dump(config, f) + + _create_model_card(Path(save_directory)) + _create_model_pyproject(Path(save_directory)) + + # learner.export saves the model in `self.path`. + learner.path = Path(save_directory) + os.makedirs(save_directory, exist_ok=True) + try: + learner.export( + fname="model.pkl", + pickle_protocol=DEFAULT_PROTOCOL, + ) + except PicklingError: + raise PicklingError( + "You are using a lambda function, i.e., an anonymous function. `pickle`" + " cannot pickle function objects and requires that all functions have" + " names. One possible solution is to name the function." + ) + + +@validate_hf_hub_args +def from_pretrained_fastai( + repo_id: str, + revision: Optional[str] = None, +): + """ + Load pretrained fastai model from the Hub or from a local directory. + + Args: + repo_id (`str`): + The location where the pickled fastai.Learner is. It can be either of the two: + - Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'. + You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`. + Revision is the specific model version to use. Since we use a git-based system for storing models and other + artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id. + - Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml + indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`. + revision (`str`, *optional*): + Revision at which the repo's files are downloaded. See documentation of `snapshot_download`. + + Returns: + The `fastai.Learner` model in the `repo_id` repo. + """ + _check_fastai_fastcore_versions() + + # Load the `repo_id` repo. + # `snapshot_download` returns the folder where the model was stored. + # `cache_dir` will be the default '/root/.cache/huggingface/hub' + if not os.path.isdir(repo_id): + storage_folder = snapshot_download( + repo_id=repo_id, + revision=revision, + library_name="fastai", + library_version=get_fastai_version(), + ) + else: + storage_folder = repo_id + + _check_fastai_fastcore_pyproject_versions(storage_folder) + + from fastai.learner import load_learner # type: ignore + + return load_learner(os.path.join(storage_folder, "model.pkl")) + + +@validate_hf_hub_args +def push_to_hub_fastai( + learner, + *, + repo_id: str, + commit_message: str = "Push FastAI model using huggingface_hub.", + private: Optional[bool] = None, + token: Optional[str] = None, + config: Optional[dict] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + api_endpoint: Optional[str] = None, +): + """ + Upload learner checkpoint files to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + learner (`Learner`): + The `fastai.Learner' you'd like to push to the Hub. + repo_id (`str`): + The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de'). + commit_message (`str`, *optional*): + Message to commit while pushing. Will default to :obj:`"add model"`. + private (`bool`, *optional*): + Whether or not the repository created should be private. + If `None` (default), will default to been public except if the organization's default is private. + token (`str`, *optional*): + The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to + the default branch as specified in your repository, which + defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. + Defaults to `False`. + api_endpoint (`str`, *optional*): + The API endpoint to use when pushing the model to the hub. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + + Returns: + The url of the commit of your model in the given repository. + + + + Raises the following error: + + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if the user is not log on to the Hugging Face Hub. + + + """ + _check_fastai_fastcore_versions() + api = HfApi(endpoint=api_endpoint) + repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + _save_pretrained_fastai(learner, saved_path, config=config) + return api.upload_folder( + repo_id=repo_id, + token=token, + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/file_download.py b/phivenv/Lib/site-packages/huggingface_hub/file_download.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc063796a50d10546f1d273e4ded9e84c3e59b2 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/file_download.py @@ -0,0 +1,1816 @@ +import copy +import errno +import inspect +import os +import re +import shutil +import stat +import time +import uuid +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union +from urllib.parse import quote, urlparse + +import requests + +from . import ( + __version__, # noqa: F401 # for backward compatibility + constants, +) +from ._local_folder import get_local_download_paths, read_download_metadata, write_download_metadata +from .constants import ( + HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 # for backward compatibility + HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility +) +from .errors import ( + EntryNotFoundError, + FileMetadataError, + GatedRepoError, + HfHubHTTPError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from .utils import ( + OfflineModeIsEnabled, + SoftTemporaryDirectory, + WeakFileLock, + XetFileData, + build_hf_headers, + get_fastai_version, # noqa: F401 # for backward compatibility + get_fastcore_version, # noqa: F401 # for backward compatibility + get_graphviz_version, # noqa: F401 # for backward compatibility + get_jinja_version, # noqa: F401 # for backward compatibility + get_pydot_version, # noqa: F401 # for backward compatibility + get_tf_version, # noqa: F401 # for backward compatibility + get_torch_version, # noqa: F401 # for backward compatibility + hf_raise_for_status, + is_fastai_available, # noqa: F401 # for backward compatibility + is_fastcore_available, # noqa: F401 # for backward compatibility + is_graphviz_available, # noqa: F401 # for backward compatibility + is_jinja_available, # noqa: F401 # for backward compatibility + is_pydot_available, # noqa: F401 # for backward compatibility + is_tf_available, # noqa: F401 # for backward compatibility + is_torch_available, # noqa: F401 # for backward compatibility + logging, + parse_xet_file_data_from_response, + refresh_xet_connection_info, + reset_sessions, + tqdm, + validate_hf_hub_args, +) +from .utils._http import _adjust_range_header, http_backoff +from .utils._runtime import _PY_VERSION, is_xet_available # noqa: F401 # for backward compatibility +from .utils._typing import HTTP_METHOD_T +from .utils.sha import sha_fileobj +from .utils.tqdm import _get_progress_bar_context + + +logger = logging.get_logger(__name__) + +# Return value when trying to load a file from cache but the file does not exist in the distant repo. +_CACHED_NO_EXIST = object() +_CACHED_NO_EXIST_T = Any + +# Regex to get filename from a "Content-Disposition" header for CDN-served files +HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P.*?)";') + +# Regex to check if the revision IS directly a commit_hash +REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$") + +# Regex to check if the file etag IS a valid sha256 +REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$") + +_are_symlinks_supported_in_dir: Dict[str, bool] = {} + + +def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool: + """Return whether the symlinks are supported on the machine. + + Since symlinks support can change depending on the mounted disk, we need to check + on the precise cache folder. By default, the default HF cache directory is checked. + + Args: + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + + Returns: [bool] Whether symlinks are supported in the directory. + """ + # Defaults to HF cache + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique + + # Check symlink compatibility only once (per cache directory) at first time use + if cache_dir not in _are_symlinks_supported_in_dir: + _are_symlinks_supported_in_dir[cache_dir] = True + + os.makedirs(cache_dir, exist_ok=True) + with SoftTemporaryDirectory(dir=cache_dir) as tmpdir: + src_path = Path(tmpdir) / "dummy_file_src" + src_path.touch() + dst_path = Path(tmpdir) / "dummy_file_dst" + + # Relative source path as in `_create_symlink`` + relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path)) + try: + os.symlink(relative_src, dst_path) + except OSError: + # Likely running on Windows + _are_symlinks_supported_in_dir[cache_dir] = False + + if not constants.HF_HUB_DISABLE_SYMLINKS_WARNING: + message = ( + "`huggingface_hub` cache-system uses symlinks by default to" + " efficiently store duplicated files but your machine does not" + f" support them in {cache_dir}. Caching files will still work" + " but in a degraded version that might require more space on" + " your disk. This warning can be disabled by setting the" + " `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For" + " more details, see" + " https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations." + ) + if os.name == "nt": + message += ( + "\nTo support symlinks on Windows, you either need to" + " activate Developer Mode or to run Python as an" + " administrator. In order to activate developer mode," + " see this article:" + " https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development" + ) + warnings.warn(message) + + return _are_symlinks_supported_in_dir[cache_dir] + + +@dataclass(frozen=True) +class HfFileMetadata: + """Data structure containing information about a file versioned on the Hub. + + Returned by [`get_hf_file_metadata`] based on a URL. + + Args: + commit_hash (`str`, *optional*): + The commit_hash related to the file. + etag (`str`, *optional*): + Etag of the file on the server. + location (`str`): + Location where to download the file. Can be a Hub url or not (CDN). + size (`size`): + Size of the file. In case of an LFS file, contains the size of the actual + LFS file, not the pointer. + xet_file_data (`XetFileData`, *optional*): + Xet information for the file. This is only set if the file is stored using Xet storage. + """ + + commit_hash: Optional[str] + etag: Optional[str] + location: str + size: Optional[int] + xet_file_data: Optional[XetFileData] + + +@validate_hf_hub_args +def hf_hub_url( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + endpoint: Optional[str] = None, +) -> str: + """Construct the URL of a file from the given information. + + The resolved address can either be a huggingface.co-hosted url, or a link to + Cloudfront (a Content Delivery Network, or CDN) for large files which are + more than a few MBs. + + Args: + repo_id (`str`): + A namespace (user or an organization) name and a repo name separated + by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + + Example: + + ```python + >>> from huggingface_hub import hf_hub_url + + >>> hf_hub_url( + ... repo_id="julien-c/EsperBERTo-small", filename="pytorch_model.bin" + ... ) + 'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin' + ``` + + + + Notes: + + Cloudfront is replicated over the globe so downloads are way faster for + the end user (and it also lowers our bandwidth costs). + + Cloudfront aggressively caches files by default (default TTL is 24 + hours), however this is not an issue here because we implement a + git-based versioning system on huggingface.co, which means that we store + the files on S3/Cloudfront in a content-addressable way (i.e., the file + name is its hash). Using content-addressable filenames means cache can't + ever be stale. + + In terms of client-side caching from this library, we base our caching + on the objects' entity tag (`ETag`), which is an identifier of a + specific version of a resource [1]_. An object's ETag is: its git-sha1 + if stored in git, or its sha256 if stored in git-lfs. + + + + References: + + - [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag + """ + if subfolder == "": + subfolder = None + if subfolder is not None: + filename = f"{subfolder}/{filename}" + + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + + if revision is None: + revision = constants.DEFAULT_REVISION + url = HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename) + ) + # Update endpoint if provided + if endpoint is not None and url.startswith(constants.ENDPOINT): + url = endpoint + url[len(constants.ENDPOINT) :] + return url + + +def _request_wrapper( + method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params +) -> requests.Response: + """Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when + `allow_redirection=False`. + + A backoff mechanism retries the HTTP call on 429, 503 and 504 errors. + + Args: + method (`str`): + HTTP method, such as 'GET' or 'HEAD'. + url (`str`): + The URL of the resource to fetch. + follow_relative_redirects (`bool`, *optional*, defaults to `False`) + If True, relative redirection (redirection to the same site) will be resolved even when `allow_redirection` + kwarg is set to False. Useful when we want to follow a redirection to a renamed repository without + following redirection to a CDN. + **params (`dict`, *optional*): + Params to pass to `requests.request`. + """ + # Recursively follow relative redirects + if follow_relative_redirects: + response = _request_wrapper( + method=method, + url=url, + follow_relative_redirects=False, + **params, + ) + + # If redirection, we redirect only relative paths. + # This is useful in case of a renamed repository. + if 300 <= response.status_code <= 399: + parsed_target = urlparse(response.headers["Location"]) + if parsed_target.netloc == "": + # This means it is a relative 'location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + # We want to follow this relative redirect ! + # + # Highly inspired by `resolve_redirects` from requests library. + # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159 + next_url = urlparse(url)._replace(path=parsed_target.path).geturl() + return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params) + return response + + # Perform request and return if status_code is not in the retry list. + response = http_backoff(method=method, url=url, **params, retry_on_exceptions=(), retry_on_status_codes=(429,)) + hf_raise_for_status(response) + return response + + +def _get_file_length_from_http_response(response: requests.Response) -> Optional[int]: + """ + Get the length of the file from the HTTP response headers. + + This function extracts the file size from the HTTP response headers, either from the + `Content-Range` or `Content-Length` header, if available (in that order). + + Args: + response (`requests.Response`): + The HTTP response object. + + Returns: + `int` or `None`: The length of the file in bytes, or None if not available. + """ + + # If HTTP response contains compressed body (e.g. gzip), the `Content-Length` header will + # contain the length of the compressed body, not the uncompressed file size. + # And at the start of transmission there's no way to know the uncompressed file size for gzip, + # thus we return None in that case. + content_encoding = response.headers.get("Content-Encoding", "identity").lower() + if content_encoding != "identity": + # gzip/br/deflate/zstd etc + return None + + content_range = response.headers.get("Content-Range") + if content_range is not None: + return int(content_range.rsplit("/")[-1]) + + content_length = response.headers.get("Content-Length") + if content_length is not None: + return int(content_length) + + return None + + +def http_get( + url: str, + temp_file: BinaryIO, + *, + proxies: Optional[Dict] = None, + resume_size: int = 0, + headers: Optional[Dict[str, Any]] = None, + expected_size: Optional[int] = None, + displayed_filename: Optional[str] = None, + _nb_retries: int = 5, + _tqdm_bar: Optional[tqdm] = None, +) -> None: + """ + Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub. + + If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely a + transient error (network outage?). We log a warning message and try to resume the download a few times before + giving up. The method gives up after 5 attempts if no new data has being received from the server. + + Args: + url (`str`): + The URL of the file to download. + temp_file (`BinaryIO`): + The file-like object where to save the file. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. + resume_size (`int`, *optional*): + The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a + positive number, the download will resume at the given position. + headers (`dict`, *optional*): + Dictionary of HTTP Headers to send with the request. + expected_size (`int`, *optional*): + The expected size of the file to download. If set, the download will raise an error if the size of the + received content is different from the expected one. + displayed_filename (`str`, *optional*): + The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If + not set, the filename is guessed from the URL or the `Content-Disposition` header. + """ + if expected_size is not None and resume_size == expected_size: + # If the file is already fully downloaded, we don't need to download it again. + return + + has_custom_range_header = headers is not None and any(h.lower() == "range" for h in headers) + hf_transfer = None + if constants.HF_HUB_ENABLE_HF_TRANSFER: + if resume_size != 0: + warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method") + elif proxies is not None: + warnings.warn("'hf_transfer' does not support `proxies`: falling back to regular download method") + elif has_custom_range_header: + warnings.warn("'hf_transfer' ignores custom 'Range' headers; falling back to regular download method") + else: + try: + import hf_transfer # type: ignore[no-redef] + except ImportError: + raise ValueError( + "Fast download using 'hf_transfer' is enabled" + " (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is not" + " available in your environment. Try `pip install hf_transfer`." + ) + + initial_headers = headers + headers = copy.deepcopy(headers) or {} + if resume_size > 0: + headers["Range"] = _adjust_range_header(headers.get("Range"), resume_size) + elif expected_size and expected_size > constants.MAX_HTTP_DOWNLOAD_SIZE: + # Any files over 50GB will not be available through basic http request. + # Setting the range header to 0-0 will force the server to return the file size in the Content-Range header. + # Since hf_transfer splits the download into chunks, the process will succeed afterwards. + if hf_transfer: + headers["Range"] = "bytes=0-0" + else: + raise ValueError( + "The file is too large to be downloaded using the regular download method. Use `hf_transfer` or `hf_xet` instead." + " Try `pip install hf_transfer` or `pip install hf_xet`." + ) + + r = _request_wrapper( + method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT + ) + + hf_raise_for_status(r) + total: Optional[int] = _get_file_length_from_http_response(r) + + if displayed_filename is None: + displayed_filename = url + content_disposition = r.headers.get("Content-Disposition") + if content_disposition is not None: + match = HEADER_FILENAME_PATTERN.search(content_disposition) + if match is not None: + # Means file is on CDN + displayed_filename = match.groupdict()["filename"] + + # Truncate filename if too long to display + if len(displayed_filename) > 40: + displayed_filename = f"(…){displayed_filename[-40:]}" + + consistency_error_message = ( + f"Consistency check failed: file should be of size {expected_size} but has size" + f" {{actual_size}} ({displayed_filename}).\nThis is usually due to network issues while downloading the file." + " Please retry with `force_download=True`." + ) + progress_cm = _get_progress_bar_context( + desc=displayed_filename, + log_level=logger.getEffectiveLevel(), + total=total, + initial=resume_size, + name="huggingface_hub.http_get", + _tqdm_bar=_tqdm_bar, + ) + + with progress_cm as progress: + if hf_transfer and total is not None and total > 5 * constants.DOWNLOAD_CHUNK_SIZE: + supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters + if not supports_callback: + warnings.warn( + "You are using an outdated version of `hf_transfer`. " + "Consider upgrading to latest version to enable progress bars " + "using `pip install -U hf_transfer`." + ) + try: + hf_transfer.download( + url=url, + filename=temp_file.name, + max_files=constants.HF_TRANSFER_CONCURRENCY, + chunk_size=constants.DOWNLOAD_CHUNK_SIZE, + headers=initial_headers, + parallel_failures=3, + max_retries=5, + **({"callback": progress.update} if supports_callback else {}), + ) + except Exception as e: + raise RuntimeError( + "An error occurred while downloading using `hf_transfer`. Consider" + " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling." + ) from e + if not supports_callback: + progress.update(total) + if expected_size is not None and expected_size != os.path.getsize(temp_file.name): + raise EnvironmentError( + consistency_error_message.format( + actual_size=os.path.getsize(temp_file.name), + ) + ) + return + new_resume_size = resume_size + try: + for chunk in r.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + new_resume_size += len(chunk) + # Some data has been downloaded from the server so we reset the number of retries. + _nb_retries = 5 + except (requests.ConnectionError, requests.ReadTimeout) as e: + # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely + # a transient error (network outage?). We log a warning message and try to resume the download a few times + # before giving up. Tre retry mechanism is basic but should be enough in most cases. + if _nb_retries <= 0: + logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e)) + raise + logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e)) + time.sleep(1) + reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects + return http_get( + url=url, + temp_file=temp_file, + proxies=proxies, + resume_size=new_resume_size, + headers=initial_headers, + expected_size=expected_size, + _nb_retries=_nb_retries - 1, + _tqdm_bar=_tqdm_bar, + ) + + if expected_size is not None and expected_size != temp_file.tell(): + raise EnvironmentError( + consistency_error_message.format( + actual_size=temp_file.tell(), + ) + ) + + +def xet_get( + *, + incomplete_path: Path, + xet_file_data: XetFileData, + headers: Dict[str, str], + expected_size: Optional[int] = None, + displayed_filename: Optional[str] = None, + _tqdm_bar: Optional[tqdm] = None, +) -> None: + """ + Download a file using Xet storage service. + + Args: + incomplete_path (`Path`): + The path to the file to download. + xet_file_data (`XetFileData`): + The file metadata needed to make the request to the xet storage service. + headers (`Dict[str, str]`): + The headers to send to the xet storage service. + expected_size (`int`, *optional*): + The expected size of the file to download. If set, the download will raise an error if the size of the + received content is different from the expected one. + displayed_filename (`str`, *optional*): + The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If + not set, the filename is guessed from the URL or the `Content-Disposition` header. + + **How it works:** + The file download system uses Xet storage, which is a content-addressable storage system that breaks files into chunks + for efficient storage and transfer. + + `hf_xet.download_files` manages downloading files by: + - Taking a list of files to download (each with its unique content hash) + - Connecting to a storage server (CAS server) that knows how files are chunked + - Using authentication to ensure secure access + - Providing progress updates during download + + Authentication works by regularly refreshing access tokens through `refresh_xet_connection_info` to maintain a valid + connection to the storage server. + + The download process works like this: + 1. Create a local cache folder at `~/.cache/huggingface/xet/chunk-cache` to store reusable file chunks + 2. Download files in parallel: + 2.1. Prepare to write the file to disk + 2.2. Ask the server "how is this file split into chunks?" using the file's unique hash + The server responds with: + - Which chunks make up the complete file + - Where each chunk can be downloaded from + 2.3. For each needed chunk: + - Checks if we already have it in our local cache + - If not, download it from cloud storage (S3) + - Save it to cache for future use + - Assemble the chunks in order to recreate the original file + + """ + try: + from hf_xet import PyXetDownloadInfo, download_files # type: ignore[no-redef] + except ImportError: + raise ValueError( + "To use optimized download using Xet storage, you need to install the hf_xet package. " + 'Try `pip install "huggingface_hub[hf_xet]"` or `pip install hf_xet`.' + ) + + connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) + + def token_refresher() -> Tuple[str, int]: + connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) + if connection_info is None: + raise ValueError("Failed to refresh token using xet metadata.") + return connection_info.access_token, connection_info.expiration_unix_epoch + + xet_download_info = [ + PyXetDownloadInfo( + destination_path=str(incomplete_path.absolute()), hash=xet_file_data.file_hash, file_size=expected_size + ) + ] + + if not displayed_filename: + displayed_filename = incomplete_path.name + + # Truncate filename if too long to display + if len(displayed_filename) > 40: + displayed_filename = f"{displayed_filename[:40]}(…)" + + progress_cm = _get_progress_bar_context( + desc=displayed_filename, + log_level=logger.getEffectiveLevel(), + total=expected_size, + initial=0, + name="huggingface_hub.xet_get", + _tqdm_bar=_tqdm_bar, + ) + + with progress_cm as progress: + + def progress_updater(progress_bytes: float): + progress.update(progress_bytes) + + download_files( + xet_download_info, + endpoint=connection_info.endpoint, + token_info=(connection_info.access_token, connection_info.expiration_unix_epoch), + token_refresher=token_refresher, + progress_updater=[progress_updater], + ) + + +def _normalize_etag(etag: Optional[str]) -> Optional[str]: + """Normalize ETag HTTP header, so it can be used to create nice filepaths. + + The HTTP spec allows two forms of ETag: + ETag: W/"" + ETag: "" + + For now, we only expect the second form from the server, but we want to be future-proof so we support both. For + more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428. + + Args: + etag (`str`, *optional*): HTTP header + + Returns: + `str` or `None`: string that can be used as a nice directory name. + Returns `None` if input is None. + """ + if etag is None: + return None + return etag.lstrip("W/").strip('"') + + +def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None: + """Alias method used in `transformers` conversion script.""" + return _create_symlink(src=src, dst=dst, new_blob=new_blob) + + +def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None: + """Create a symbolic link named dst pointing to src. + + By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages: + - If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will + not break. + - Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when + changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398, + https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228. + NOTE: The issue with absolute paths doesn't happen on admin mode. + When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created. + This happens when paths are not on the same volume. In that case, we use absolute paths. + + + The result layout looks something like + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + + If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by + having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file + (`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing + cache, the file is duplicated on the disk. + + In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`. + The warning message can be disabled with the `DISABLE_SYMLINKS_WARNING` environment variable. + """ + try: + os.remove(dst) + except OSError: + pass + + abs_src = os.path.abspath(os.path.expanduser(src)) + abs_dst = os.path.abspath(os.path.expanduser(dst)) + abs_dst_folder = os.path.dirname(abs_dst) + + # Use relative_dst in priority + try: + relative_src = os.path.relpath(abs_src, abs_dst_folder) + except ValueError: + # Raised on Windows if src and dst are not on the same volume. This is the case when creating a symlink to a + # local_dir instead of within the cache directory. + # See https://docs.python.org/3/library/os.path.html#os.path.relpath + relative_src = None + + try: + commonpath = os.path.commonpath([abs_src, abs_dst]) + _support_symlinks = are_symlinks_supported(commonpath) + except ValueError: + # Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos. + # See https://docs.python.org/3/library/os.path.html#os.path.commonpath + _support_symlinks = os.name != "nt" + except PermissionError: + # Permission error means src and dst are not in the same volume (e.g. destination path has been provided + # by the user via `local_dir`. Let's test symlink support there) + _support_symlinks = are_symlinks_supported(abs_dst_folder) + except OSError as e: + # OS error (errno=30) means that the commonpath is readonly on Linux/MacOS. + if e.errno == errno.EROFS: + _support_symlinks = are_symlinks_supported(abs_dst_folder) + else: + raise + + # Symlinks are supported => let's create a symlink. + if _support_symlinks: + src_rel_or_abs = relative_src or abs_src + logger.debug(f"Creating pointer from {src_rel_or_abs} to {abs_dst}") + try: + os.symlink(src_rel_or_abs, abs_dst) + return + except FileExistsError: + if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src): + # `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has + # been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing. + return + else: + # Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and + # `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception. + raise + except PermissionError: + # Permission error means src and dst are not in the same volume (e.g. download to local dir) and symlink + # is supported on both volumes but not between them. Let's just make a hard copy in that case. + pass + + # Symlinks are not supported => let's move or copy the file. + if new_blob: + logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}") + shutil.move(abs_src, abs_dst, copy_function=_copy_no_matter_what) + else: + logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}") + shutil.copyfile(abs_src, abs_dst) + + +def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None: + """Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash. + + Does nothing if `revision` is already a proper `commit_hash` or reference is already cached. + """ + if revision != commit_hash: + ref_path = Path(storage_folder) / "refs" / revision + ref_path.parent.mkdir(parents=True, exist_ok=True) + if not ref_path.exists() or commit_hash != ref_path.read_text(): + # Update ref only if has been updated. Could cause useless error in case + # repo is already cached and user doesn't have write access to cache folder. + # See https://github.com/huggingface/huggingface_hub/issues/1216. + ref_path.write_text(commit_hash) + + +@validate_hf_hub_args +def repo_folder_name(*, repo_id: str, repo_type: str) -> str: + """Return a serialized version of a hf.co repo name and type, safe for disk storage + as a single non-nested folder. + + Example: models--julien-c--EsperBERTo-small + """ + # remove all `/` occurrences to correctly convert repo to directory name + parts = [f"{repo_type}s", *repo_id.split("/")] + return constants.REPO_ID_SEPARATOR.join(parts) + + +def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None: + """Check disk usage and log a warning if there is not enough disk space to download the file. + + Args: + expected_size (`int`): + The expected size of the file in bytes. + target_dir (`str`): + The directory where the file will be stored after downloading. + """ + + target_dir = Path(target_dir) # format as `Path` + for path in [target_dir] + list(target_dir.parents): # first check target_dir, then each parents one by one + try: + target_dir_free = shutil.disk_usage(path).free + if target_dir_free < expected_size: + warnings.warn( + "Not enough free disk space to download the file. " + f"The expected file size is: {expected_size / 1e6:.2f} MB. " + f"The target location {target_dir} only has {target_dir_free / 1e6:.2f} MB free disk space." + ) + return + except OSError: # raise on anything: file does not exist or space disk cannot be checked + pass + + +@validate_hf_hub_args +def hf_hub_download( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + user_agent: Union[Dict, str, None] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + headers: Optional[Dict[str, str]] = None, + endpoint: Optional[str] = None, + resume_download: Optional[bool] = None, + force_filename: Optional[str] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", +) -> str: + """Download a given file if it's not already present in the local cache. + + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + ``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + ``` + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files. While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the model repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded file will be placed under this directory. + user_agent (`dict`, `str`, *optional*): + The user-agent info in the form of a dictionary or a string. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in + the local cache. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + token (`str`, `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If a string, it's used as the authentication token. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + headers (`dict`, *optional*): + Additional headers to be sent with the request. + + Returns: + `str`: Local path of file or if networking is off, last version of file cached on disk. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`~utils.EntryNotFoundError`] + If the file to download cannot be found. + [`~utils.LocalEntryNotFoundError`] + If network is disabled or unavailable and file is not found in cache. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` but the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + If ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If some parameter value is invalid. + + """ + if constants.HF_HUB_ETAG_TIMEOUT != constants.DEFAULT_ETAG_TIMEOUT: + # Respect environment variable above user value + etag_timeout = constants.HF_HUB_ETAG_TIMEOUT + + if force_filename is not None: + warnings.warn( + "The `force_filename` parameter is deprecated as a new caching system, " + "which keeps the filenames as they are on the Hub, is now in place.", + FutureWarning, + ) + if resume_download is not None: + warnings.warn( + "`resume_download` is deprecated and will be removed in version 1.0.0. " + "Downloads always resume when possible. " + "If you want to force a new download, use `force_download=True`.", + FutureWarning, + ) + + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + if revision is None: + revision = constants.DEFAULT_REVISION + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + if isinstance(local_dir, Path): + local_dir = str(local_dir) + + if subfolder == "": + subfolder = None + if subfolder is not None: + # This is used to create a URL, and not a local path, hence the forward slash. + filename = f"{subfolder}/{filename}" + + if repo_type is None: + repo_type = "model" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") + + hf_headers = build_hf_headers( + token=token, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + headers=headers, + ) + + if local_dir is not None: + if local_dir_use_symlinks != "auto": + warnings.warn( + "`local_dir_use_symlinks` parameter is deprecated and will be ignored. " + "The process to download files to a local folder has been updated and do " + "not rely on symlinks anymore. You only need to pass a destination folder " + "as`local_dir`.\n" + "For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder." + ) + + return _hf_hub_download_to_local_dir( + # Destination + local_dir=local_dir, + # File info + repo_id=repo_id, + repo_type=repo_type, + filename=filename, + revision=revision, + # HTTP info + endpoint=endpoint, + etag_timeout=etag_timeout, + headers=hf_headers, + proxies=proxies, + token=token, + # Additional options + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + else: + return _hf_hub_download_to_cache_dir( + # Destination + cache_dir=cache_dir, + # File info + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + # HTTP info + endpoint=endpoint, + etag_timeout=etag_timeout, + headers=hf_headers, + proxies=proxies, + token=token, + # Additional options + local_files_only=local_files_only, + force_download=force_download, + ) + + +def _hf_hub_download_to_cache_dir( + *, + # Destination + cache_dir: str, + # File info + repo_id: str, + filename: str, + repo_type: str, + revision: str, + # HTTP info + endpoint: Optional[str], + etag_timeout: float, + headers: Dict[str, str], + proxies: Optional[Dict], + token: Optional[Union[bool, str]], + # Additional options + local_files_only: bool, + force_download: bool, +) -> str: + """Download a given file to a cache folder, if not already present. + + Method should not be called directly. Please use `hf_hub_download` instead. + """ + locks_dir = os.path.join(cache_dir, ".locks") + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) + + # cross platform transcription of filename, to be used as a local file path. + relative_filename = os.path.join(*filename.split("/")) + if os.name == "nt": + if relative_filename.startswith("..\\") or "\\..\\" in relative_filename: + raise ValueError( + f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository" + " owner to rename this file." + ) + + # if user provides a commit_hash and they already have the file on disk, shortcut everything. + if REGEX_COMMIT_HASH.match(revision): + pointer_path = _get_pointer_path(storage_folder, revision, relative_filename) + if os.path.exists(pointer_path) and not force_download: + return pointer_path + + # Try to get metadata (etag, commit_hash, url, size) from the server. + # If we can't, a HEAD request error is returned. + (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = _get_metadata_or_catch_error( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + proxies=proxies, + etag_timeout=etag_timeout, + headers=headers, + token=token, + local_files_only=local_files_only, + storage_folder=storage_folder, + relative_filename=relative_filename, + ) + + # etag can be None for several reasons: + # 1. we passed local_files_only. + # 2. we don't have a connection + # 3. Hub is down (HTTP 500, 503, 504) + # 4. repo is not found -for example private or gated- and invalid/missing token sent + # 5. Hub is blocked by a firewall or proxy is not set correctly. + # => Try to get the last downloaded one from the specified revision. + # + # If the specified revision is a commit hash, look inside "snapshots". + # If the specified revision is a branch or tag, look inside "refs". + if head_call_error is not None: + # Couldn't make a HEAD call => let's try to find a local file + if not force_download: + commit_hash = None + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + ref_path = os.path.join(storage_folder, "refs", revision) + if os.path.isfile(ref_path): + with open(ref_path) as f: + commit_hash = f.read() + + # Return pointer file if exists + if commit_hash is not None: + pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) + if os.path.exists(pointer_path) and not force_download: + return pointer_path + + # Otherwise, raise appropriate error + _raise_on_head_call_error(head_call_error, force_download, local_files_only) + + # From now on, etag, commit_hash, url and size are not None. + assert etag is not None, "etag must have been retrieved from server" + assert commit_hash is not None, "commit_hash must have been retrieved from server" + assert url_to_download is not None, "file location must have been retrieved from server" + assert expected_size is not None, "expected_size must have been retrieved from server" + blob_path = os.path.join(storage_folder, "blobs", etag) + pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) + + os.makedirs(os.path.dirname(blob_path), exist_ok=True) + os.makedirs(os.path.dirname(pointer_path), exist_ok=True) + + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash) + + # Prevent parallel downloads of the same file with a lock. + # etag could be duplicated across repos, + lock_path = os.path.join(locks_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type), f"{etag}.lock") + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix. + if ( + os.name == "nt" + and len(os.path.abspath(lock_path)) > 255 + and not os.path.abspath(lock_path).startswith("\\\\?\\") + ): + lock_path = "\\\\?\\" + os.path.abspath(lock_path) + + if ( + os.name == "nt" + and len(os.path.abspath(blob_path)) > 255 + and not os.path.abspath(blob_path).startswith("\\\\?\\") + ): + blob_path = "\\\\?\\" + os.path.abspath(blob_path) + + Path(lock_path).parent.mkdir(parents=True, exist_ok=True) + + # pointer already exists -> immediate return + if not force_download and os.path.exists(pointer_path): + return pointer_path + + # Blob exists but pointer must be (safely) created -> take the lock + if not force_download and os.path.exists(blob_path): + with WeakFileLock(lock_path): + if not os.path.exists(pointer_path): + _create_symlink(blob_path, pointer_path, new_blob=False) + return pointer_path + + # Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache) + + with WeakFileLock(lock_path): + _download_to_tmp_and_move( + incomplete_path=Path(blob_path + ".incomplete"), + destination_path=Path(blob_path), + url_to_download=url_to_download, + proxies=proxies, + headers=headers, + expected_size=expected_size, + filename=filename, + force_download=force_download, + etag=etag, + xet_file_data=xet_file_data, + ) + if not os.path.exists(pointer_path): + _create_symlink(blob_path, pointer_path, new_blob=True) + + return pointer_path + + +def _hf_hub_download_to_local_dir( + *, + # Destination + local_dir: Union[str, Path], + # File info + repo_id: str, + repo_type: str, + filename: str, + revision: str, + # HTTP info + endpoint: Optional[str], + etag_timeout: float, + headers: Dict[str, str], + proxies: Optional[Dict], + token: Union[bool, str, None], + # Additional options + cache_dir: str, + force_download: bool, + local_files_only: bool, +) -> str: + """Download a given file to a local folder, if not already present. + + Method should not be called directly. Please use `hf_hub_download` instead. + """ + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix. + if os.name == "nt" and len(os.path.abspath(local_dir)) > 255: + local_dir = "\\\\?\\" + os.path.abspath(local_dir) + local_dir = Path(local_dir) + paths = get_local_download_paths(local_dir=local_dir, filename=filename) + local_metadata = read_download_metadata(local_dir=local_dir, filename=filename) + + # Local file exists + metadata exists + commit_hash matches => return file + if ( + not force_download + and REGEX_COMMIT_HASH.match(revision) + and paths.file_path.is_file() + and local_metadata is not None + and local_metadata.commit_hash == revision + ): + return str(paths.file_path) + + # Local file doesn't exist or commit_hash doesn't match => we need the etag + (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = _get_metadata_or_catch_error( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + proxies=proxies, + etag_timeout=etag_timeout, + headers=headers, + token=token, + local_files_only=local_files_only, + ) + + if head_call_error is not None: + # No HEAD call but local file exists => default to local file + if not force_download and paths.file_path.is_file(): + logger.warning( + f"Couldn't access the Hub to check for update but local file already exists. Defaulting to existing file. (error: {head_call_error})" + ) + return str(paths.file_path) + # Otherwise => raise + _raise_on_head_call_error(head_call_error, force_download, local_files_only) + + # From now on, etag, commit_hash, url and size are not None. + assert etag is not None, "etag must have been retrieved from server" + assert commit_hash is not None, "commit_hash must have been retrieved from server" + assert url_to_download is not None, "file location must have been retrieved from server" + assert expected_size is not None, "expected_size must have been retrieved from server" + + # Local file exists => check if it's up-to-date + if not force_download and paths.file_path.is_file(): + # etag matches => update metadata and return file + if local_metadata is not None and local_metadata.etag == etag: + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + return str(paths.file_path) + + # metadata is outdated + etag is a sha256 + # => means it's an LFS file (large) + # => let's compute local hash and compare + # => if match, update metadata and return file + if local_metadata is None and REGEX_SHA256.match(etag) is not None: + with open(paths.file_path, "rb") as f: + file_hash = sha_fileobj(f).hex() + if file_hash == etag: + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + return str(paths.file_path) + + # Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache) + + # If we are lucky enough, the file is already in the cache => copy it + if not force_download: + cached_path = try_to_load_from_cache( + repo_id=repo_id, + filename=filename, + cache_dir=cache_dir, + revision=commit_hash, + repo_type=repo_type, + ) + if isinstance(cached_path, str): + with WeakFileLock(paths.lock_path): + paths.file_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(cached_path, paths.file_path) + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + return str(paths.file_path) + + # Otherwise, let's download the file! + with WeakFileLock(paths.lock_path): + paths.file_path.unlink(missing_ok=True) # delete outdated file first + _download_to_tmp_and_move( + incomplete_path=paths.incomplete_path(etag), + destination_path=paths.file_path, + url_to_download=url_to_download, + proxies=proxies, + headers=headers, + expected_size=expected_size, + filename=filename, + force_download=force_download, + etag=etag, + xet_file_data=xet_file_data, + ) + + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + return str(paths.file_path) + + +@validate_hf_hub_args +def try_to_load_from_cache( + repo_id: str, + filename: str, + cache_dir: Union[str, Path, None] = None, + revision: Optional[str] = None, + repo_type: Optional[str] = None, +) -> Union[str, _CACHED_NO_EXIST_T, None]: + """ + Explores the cache to return the latest cached file for a given revision if found. + + This function will not raise any exception if the file in not cached. + + Args: + cache_dir (`str` or `os.PathLike`): + The folder where the cached files lie. + repo_id (`str`): + The ID of the repo on huggingface.co. + filename (`str`): + The filename to look for inside `repo_id`. + revision (`str`, *optional*): + The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is + provided either. + repo_type (`str`, *optional*): + The type of the repository. Will default to `"model"`. + + Returns: + `Optional[str]` or `_CACHED_NO_EXIST`: + Will return `None` if the file was not cached. Otherwise: + - The exact path to the cached file if it's found in the cache + - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was + cached. + + Example: + + ```python + from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST + + filepath = try_to_load_from_cache() + if isinstance(filepath, str): + # file exists and is cached + ... + elif filepath is _CACHED_NO_EXIST: + # non-existence of file is cached + ... + else: + # file is not cached + ... + ``` + """ + if revision is None: + revision = "main" + if repo_type is None: + repo_type = "model" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + + object_id = repo_id.replace("/", "--") + repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}") + if not os.path.isdir(repo_cache): + # No cache for this model + return None + + refs_dir = os.path.join(repo_cache, "refs") + snapshots_dir = os.path.join(repo_cache, "snapshots") + no_exist_dir = os.path.join(repo_cache, ".no_exist") + + # Resolve refs (for instance to convert main to the associated commit sha) + if os.path.isdir(refs_dir): + revision_file = os.path.join(refs_dir, revision) + if os.path.isfile(revision_file): + with open(revision_file) as f: + revision = f.read() + + # Check if file is cached as "no_exist" + if os.path.isfile(os.path.join(no_exist_dir, revision, filename)): + return _CACHED_NO_EXIST + + # Check if revision folder exists + if not os.path.exists(snapshots_dir): + return None + cached_shas = os.listdir(snapshots_dir) + if revision not in cached_shas: + # No cache for this revision and we won't try to return a random revision + return None + + # Check if file exists in cache + cached_file = os.path.join(snapshots_dir, revision, filename) + return cached_file if os.path.isfile(cached_file) else None + + +@validate_hf_hub_args +def get_hf_file_metadata( + url: str, + token: Union[bool, str, None] = None, + proxies: Optional[Dict] = None, + timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + headers: Optional[Dict[str, str]] = None, + endpoint: Optional[str] = None, +) -> HfFileMetadata: + """Fetch metadata of a file versioned on the Hub for a given url. + + Args: + url (`str`): + File url, for example returned by [`hf_hub_url`]. + token (`str` or `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If `False` or `None`, no token is provided. + - If a string, it's used as the authentication token. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + timeout (`float`, *optional*, defaults to 10): + How many seconds to wait for the server to send metadata before giving up. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`dict`, `str`, *optional*): + The user-agent info in the form of a dictionary or a string. + headers (`dict`, *optional*): + Additional headers to be sent with the request. + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + + Returns: + A [`HfFileMetadata`] object containing metadata such as location, etag, size and + commit_hash. + """ + hf_headers = build_hf_headers( + token=token, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + headers=headers, + ) + hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file + + # Retrieve metadata + r = _request_wrapper( + method="HEAD", + url=url, + headers=hf_headers, + allow_redirects=False, + follow_relative_redirects=True, + proxies=proxies, + timeout=timeout, + ) + hf_raise_for_status(r) + + # Return + return HfFileMetadata( + commit_hash=r.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT), + # We favor a custom header indicating the etag of the linked resource, and + # we fallback to the regular etag header. + etag=_normalize_etag(r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")), + # Either from response headers (if redirected) or defaults to request url + # Do not use directly `url`, as `_request_wrapper` might have followed relative + # redirects. + location=r.headers.get("Location") or r.request.url, # type: ignore + size=_int_or_none( + r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length") + ), + xet_file_data=parse_xet_file_data_from_response(r, endpoint=endpoint), # type: ignore + ) + + +def _get_metadata_or_catch_error( + *, + repo_id: str, + filename: str, + repo_type: str, + revision: str, + endpoint: Optional[str], + proxies: Optional[Dict], + etag_timeout: Optional[float], + headers: Dict[str, str], # mutated inplace! + token: Union[bool, str, None], + local_files_only: bool, + relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache + storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache +) -> Union[ + # Either an exception is caught and returned + Tuple[None, None, None, None, None, Exception], + # Or the metadata is returned as + # `(url_to_download, etag, commit_hash, expected_size, xet_file_data, None)` + Tuple[str, str, str, int, Optional[XetFileData], None], +]: + """Get metadata for a file on the Hub, safely handling network issues. + + Returns either the etag, commit_hash and expected size of the file, or the error + raised while fetching the metadata. + + NOTE: This function mutates `headers` inplace! It removes the `authorization` header + if the file is a LFS blob and the domain of the url is different from the + domain of the location (typically an S3 bucket). + """ + if local_files_only: + return ( + None, + None, + None, + None, + None, + OfflineModeIsEnabled( + f"Cannot access file since 'local_files_only=True' as been set. (repo_id: {repo_id}, repo_type: {repo_type}, revision: {revision}, filename: {filename})" + ), + ) + + url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint) + url_to_download: str = url + etag: Optional[str] = None + commit_hash: Optional[str] = None + expected_size: Optional[int] = None + head_error_call: Optional[Exception] = None + xet_file_data: Optional[XetFileData] = None + + # Try to get metadata from the server. + # Do not raise yet if the file is not found or not accessible. + if not local_files_only: + try: + try: + metadata = get_hf_file_metadata( + url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token, endpoint=endpoint + ) + except EntryNotFoundError as http_error: + if storage_folder is not None and relative_filename is not None: + # Cache the non-existence of the file + commit_hash = http_error.response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT) + if commit_hash is not None: + no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename + try: + no_exist_file_path.parent.mkdir(parents=True, exist_ok=True) + no_exist_file_path.touch() + except OSError as e: + logger.error( + f"Could not cache non-existence of file. Will ignore error and continue. Error: {e}" + ) + _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash) + raise + + # Commit hash must exist + commit_hash = metadata.commit_hash + if commit_hash is None: + raise FileMetadataError( + "Distant resource does not seem to be on huggingface.co. It is possible that a configuration issue" + " prevents you from downloading resources from https://huggingface.co. Please check your firewall" + " and proxy settings and make sure your SSL certificates are updated." + ) + + # Etag must exist + # If we don't have any of those, raise an error. + etag = metadata.etag + if etag is None: + raise FileMetadataError( + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." + ) + + # Size must exist + expected_size = metadata.size + if expected_size is None: + raise FileMetadataError("Distant resource does not have a Content-Length.") + + xet_file_data = metadata.xet_file_data + + # In case of a redirect, save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + # + # If url domain is different => we are downloading from a CDN => url is signed => don't send auth + # If url domain is the same => redirect due to repo rename AND downloading a regular file => keep auth + if xet_file_data is None and url != metadata.location: + url_to_download = metadata.location + if urlparse(url).netloc != urlparse(metadata.location).netloc: + # Remove authorization header when downloading a LFS blob + headers.pop("authorization", None) + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + OfflineModeIsEnabled, + ) as error: + # Otherwise, our Internet connection is down. + # etag is None + head_error_call = error + except (RevisionNotFoundError, EntryNotFoundError): + # The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted) + raise + except requests.HTTPError as error: + # Multiple reasons for an http error: + # - Repository is private and invalid/missing token sent + # - Repository is gated and invalid/missing token sent + # - Hub is down (error 500 or 504) + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + head_error_call = error + except FileMetadataError as error: + # Multiple reasons for a FileMetadataError: + # - Wrong network configuration (proxy, firewall, SSL certificates) + # - Inconsistency on the Hub + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + head_error_call = error + + if not (local_files_only or etag is not None or head_error_call is not None): + raise RuntimeError("etag is empty due to uncovered problems") + + return (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_error_call) # type: ignore [return-value] + + +def _raise_on_head_call_error(head_call_error: Exception, force_download: bool, local_files_only: bool) -> NoReturn: + """Raise an appropriate error when the HEAD call failed and we cannot locate a local file.""" + # No head call => we cannot force download. + if force_download: + if local_files_only: + raise ValueError("Cannot pass 'force_download=True' and 'local_files_only=True' at the same time.") + elif isinstance(head_call_error, OfflineModeIsEnabled): + raise ValueError("Cannot pass 'force_download=True' when offline mode is enabled.") from head_call_error + else: + raise ValueError("Force download failed due to the above error.") from head_call_error + + # No head call + couldn't find an appropriate file on disk => raise an error. + if local_files_only: + raise LocalEntryNotFoundError( + "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable" + " hf.co look-ups and downloads online, set 'local_files_only' to False." + ) + elif isinstance(head_call_error, (RepositoryNotFoundError, GatedRepoError)) or ( + isinstance(head_call_error, HfHubHTTPError) and head_call_error.response.status_code == 401 + ): + # Repo not found or gated => let's raise the actual error + # Unauthorized => likely a token issue => let's raise the actual error + raise head_call_error + else: + # Otherwise: most likely a connection issue or Hub downtime => let's warn the user + raise LocalEntryNotFoundError( + "An error happened while trying to locate the file on the Hub and we cannot find the requested files" + " in the local cache. Please check your connection and try again or make sure your Internet connection" + " is on." + ) from head_call_error + + +def _download_to_tmp_and_move( + incomplete_path: Path, + destination_path: Path, + url_to_download: str, + proxies: Optional[Dict], + headers: Dict[str, str], + expected_size: Optional[int], + filename: str, + force_download: bool, + etag: Optional[str], + xet_file_data: Optional[XetFileData], +) -> None: + """Download content from a URL to a destination path. + + Internal logic: + - return early if file is already downloaded + - resume download if possible (from incomplete file) + - do not resume download if `force_download=True` or `HF_HUB_ENABLE_HF_TRANSFER=True` + - check disk space before downloading + - download content to a temporary file + - set correct permissions on temporary file + - move the temporary file to the destination path + + Both `incomplete_path` and `destination_path` must be on the same volume to avoid a local copy. + """ + if destination_path.exists() and not force_download: + # Do nothing if already exists (except if force_download=True) + return + + if incomplete_path.exists() and (force_download or (constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies)): + # By default, we will try to resume the download if possible. + # However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should + # not resume the download => delete the incomplete file. + message = f"Removing incomplete file '{incomplete_path}'" + if force_download: + message += " (force_download=True)" + elif constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies: + message += " (hf_transfer=True)" + logger.info(message) + incomplete_path.unlink(missing_ok=True) + + with incomplete_path.open("ab") as f: + resume_size = f.tell() + message = f"Downloading '{filename}' to '{incomplete_path}'" + if resume_size > 0 and expected_size is not None: + message += f" (resume from {resume_size}/{expected_size})" + logger.info(message) + + if expected_size is not None: # might be None if HTTP header not set correctly + # Check disk space in both tmp and destination path + _check_disk_space(expected_size, incomplete_path.parent) + _check_disk_space(expected_size, destination_path.parent) + + if xet_file_data is not None and is_xet_available(): + logger.debug("Xet Storage is enabled for this repo. Downloading file from Xet Storage..") + xet_get( + incomplete_path=incomplete_path, + xet_file_data=xet_file_data, + headers=headers, + expected_size=expected_size, + displayed_filename=filename, + ) + else: + if xet_file_data is not None and not constants.HF_HUB_DISABLE_XET: + logger.warning( + "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. " + "Falling back to regular HTTP download. " + "For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`" + ) + + http_get( + url_to_download, + f, + proxies=proxies, + resume_size=resume_size, + headers=headers, + expected_size=expected_size, + ) + + logger.info(f"Download complete. Moving file to {destination_path}") + _chmod_and_move(incomplete_path, destination_path) + + +def _int_or_none(value: Optional[str]) -> Optional[int]: + try: + return int(value) # type: ignore + except (TypeError, ValueError): + return None + + +def _chmod_and_move(src: Path, dst: Path) -> None: + """Set correct permission before moving a blob from tmp directory to cache dir. + + Do not take into account the `umask` from the process as there is no convenient way + to get it that is thread-safe. + + See: + - About umask: https://docs.python.org/3/library/os.html#os.umask + - Thread-safety: https://stackoverflow.com/a/70343066 + - About solution: https://github.com/huggingface/huggingface_hub/pull/1220#issuecomment-1326211591 + - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1141 + - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1215 + """ + # Get umask by creating a temporary file in the cached repo folder. + tmp_file = dst.parent.parent / f"tmp_{uuid.uuid4()}" + try: + tmp_file.touch() + cache_dir_mode = Path(tmp_file).stat().st_mode + os.chmod(str(src), stat.S_IMODE(cache_dir_mode)) + except OSError as e: + logger.warning( + f"Could not set the permissions on the file '{src}'. Error: {e}.\nContinuing without setting permissions." + ) + finally: + try: + tmp_file.unlink() + except OSError: + # fails if `tmp_file.touch()` failed => do nothing + # See https://github.com/huggingface/huggingface_hub/issues/2359 + pass + + shutil.move(str(src), str(dst), copy_function=_copy_no_matter_what) + + +def _copy_no_matter_what(src: str, dst: str) -> None: + """Copy file from src to dst. + + If `shutil.copy2` fails, fallback to `shutil.copyfile`. + """ + try: + # Copy file with metadata and permission + # Can fail e.g. if dst is an S3 mount + shutil.copy2(src, dst) + except OSError: + # Copy only file content + shutil.copyfile(src, dst) + + +def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str: + # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks + snapshot_path = os.path.join(storage_folder, "snapshots") + pointer_path = os.path.join(snapshot_path, revision, relative_filename) + if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents: + raise ValueError( + "Invalid pointer path: cannot create pointer path in snapshot folder if" + f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and" + f" `relative_filename='{relative_filename}'`." + ) + return pointer_path diff --git a/phivenv/Lib/site-packages/huggingface_hub/hf_api.py b/phivenv/Lib/site-packages/huggingface_hub/hf_api.py new file mode 100644 index 0000000000000000000000000000000000000000..ec41ae26f2f2079ff1a865da12bc4b7a4992e7b5 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/hf_api.py @@ -0,0 +1,10619 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import inspect +import io +import json +import re +import struct +import time +import warnings +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import asdict, dataclass, field +from datetime import datetime +from functools import wraps +from itertools import islice +from pathlib import Path +from textwrap import dedent +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) +from urllib.parse import quote, unquote + +import requests +from requests.exceptions import HTTPError +from tqdm.auto import tqdm as base_tqdm +from tqdm.contrib.concurrent import thread_map + +from . import constants +from ._commit_api import ( + CommitOperation, + CommitOperationAdd, + CommitOperationCopy, + CommitOperationDelete, + _fetch_files_to_copy, + _fetch_upload_modes, + _prepare_commit_payload, + _upload_lfs_files, + _upload_xet_files, + _warn_on_overwriting_operations, +) +from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType +from ._jobs_api import JobInfo +from ._space_api import SpaceHardware, SpaceRuntime, SpaceStorage, SpaceVariable +from ._upload_large_folder import upload_large_folder_internal +from .community import ( + Discussion, + DiscussionComment, + DiscussionStatusChange, + DiscussionTitleChange, + DiscussionWithDetails, + deserialize_event, +) +from .constants import ( + DEFAULT_ETAG_TIMEOUT, # noqa: F401 # kept for backward compatibility + DEFAULT_REQUEST_TIMEOUT, # noqa: F401 # kept for backward compatibility + DEFAULT_REVISION, # noqa: F401 # kept for backward compatibility + DISCUSSION_STATUS, # noqa: F401 # kept for backward compatibility + DISCUSSION_TYPES, # noqa: F401 # kept for backward compatibility + ENDPOINT, # noqa: F401 # kept for backward compatibility + INFERENCE_ENDPOINTS_ENDPOINT, # noqa: F401 # kept for backward compatibility + REGEX_COMMIT_OID, # noqa: F401 # kept for backward compatibility + REPO_TYPE_MODEL, # noqa: F401 # kept for backward compatibility + REPO_TYPES, # noqa: F401 # kept for backward compatibility + REPO_TYPES_MAPPING, # noqa: F401 # kept for backward compatibility + REPO_TYPES_URL_PREFIXES, # noqa: F401 # kept for backward compatibility + SAFETENSORS_INDEX_FILE, # noqa: F401 # kept for backward compatibility + SAFETENSORS_MAX_HEADER_LENGTH, # noqa: F401 # kept for backward compatibility + SAFETENSORS_SINGLE_FILE, # noqa: F401 # kept for backward compatibility + SPACES_SDK_TYPES, # noqa: F401 # kept for backward compatibility + WEBHOOK_DOMAIN_T, # noqa: F401 # kept for backward compatibility + DiscussionStatusFilter, # noqa: F401 # kept for backward compatibility + DiscussionTypeFilter, # noqa: F401 # kept for backward compatibility +) +from .errors import ( + BadRequestError, + EntryNotFoundError, + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from .file_download import HfFileMetadata, get_hf_file_metadata, hf_hub_url +from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData +from .utils import ( + DEFAULT_IGNORE_PATTERNS, + HfFolder, # noqa: F401 # kept for backward compatibility + LocalTokenNotFoundError, + NotASafetensorsRepoError, + SafetensorsFileMetadata, + SafetensorsParsingError, + SafetensorsRepoMetadata, + TensorInfo, + build_hf_headers, + chunk_iterable, + experimental, + filter_repo_objects, + fix_hf_endpoint_in_url, + get_session, + get_token, + hf_raise_for_status, + logging, + paginate, + parse_datetime, + validate_hf_hub_args, +) +from .utils import tqdm as hf_tqdm +from .utils._auth import ( + _get_token_from_environment, + _get_token_from_file, + _get_token_from_google_colab, +) +from .utils._deprecation import _deprecate_method +from .utils._runtime import is_xet_available +from .utils._typing import CallableT +from .utils.endpoint_helpers import _is_emission_within_threshold + + +if TYPE_CHECKING: + from .inference._providers import PROVIDER_T + +R = TypeVar("R") # Return type +CollectionItemType_T = Literal["model", "dataset", "space", "paper", "collection"] + +ExpandModelProperty_T = Literal[ + "author", + "baseModels", + "cardData", + "childrenModelCount", + "config", + "createdAt", + "disabled", + "downloads", + "downloadsAllTime", + "gated", + "gguf", + "inference", + "inferenceProviderMapping", + "lastModified", + "library_name", + "likes", + "mask_token", + "model-index", + "pipeline_tag", + "private", + "resourceGroup", + "safetensors", + "sha", + "siblings", + "spaces", + "tags", + "transformersInfo", + "trendingScore", + "usedStorage", + "widgetData", + "xetEnabled", +] + +ExpandDatasetProperty_T = Literal[ + "author", + "cardData", + "citation", + "createdAt", + "description", + "disabled", + "downloads", + "downloadsAllTime", + "gated", + "lastModified", + "likes", + "paperswithcode_id", + "private", + "resourceGroup", + "sha", + "siblings", + "tags", + "trendingScore", + "usedStorage", + "xetEnabled", +] + +ExpandSpaceProperty_T = Literal[ + "author", + "cardData", + "createdAt", + "datasets", + "disabled", + "lastModified", + "likes", + "models", + "private", + "resourceGroup", + "runtime", + "sdk", + "sha", + "siblings", + "subdomain", + "tags", + "trendingScore", + "usedStorage", + "xetEnabled", +] + +USERNAME_PLACEHOLDER = "hf_user" +_REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$") + +_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = ( + "\nNote: Creating a commit assumes that the repo already exists on the" + " Huggingface Hub. Please use `create_repo` if it's not the case." +) +_AUTH_CHECK_NO_REPO_ERROR_MESSAGE = ( + "\nNote: The repository either does not exist or you do not have access rights." + " Please check the repository ID and your access permissions." + " If this is a private repository, ensure that your token is correct." +) +logger = logging.get_logger(__name__) + + +def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str]: + """ + Returns the repo type and ID from a huggingface.co URL linking to a + repository + + Args: + hf_id (`str`): + An URL or ID of a repository on the HF hub. Accepted values are: + + - https://huggingface.co/// + - https://huggingface.co// + - hf://// + - hf:/// + - // + - / + - + hub_url (`str`, *optional*): + The URL of the HuggingFace Hub, defaults to https://huggingface.co + + Returns: + A tuple with three items: repo_type (`str` or `None`), namespace (`str` or + `None`) and repo_id (`str`). + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If URL cannot be parsed. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `repo_type` is unknown. + """ + input_hf_id = hf_id + + hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else constants.ENDPOINT) + is_hf_url = hub_url in hf_id and "@" not in hf_id + + HFFS_PREFIX = "hf://" + if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists + hf_id = hf_id[len(HFFS_PREFIX) :] + + url_segments = hf_id.split("/") + is_hf_id = len(url_segments) <= 3 + + namespace: Optional[str] + if is_hf_url: + namespace, repo_id = url_segments[-2:] + if namespace == hub_url: + namespace = None + if len(url_segments) > 2 and hub_url not in url_segments[-3]: + repo_type = url_segments[-3] + elif namespace in constants.REPO_TYPES_MAPPING: + # Mean canonical dataset or model + repo_type = constants.REPO_TYPES_MAPPING[namespace] + namespace = None + else: + repo_type = None + elif is_hf_id: + if len(url_segments) == 3: + # Passed // or // + repo_type, namespace, repo_id = url_segments[-3:] + elif len(url_segments) == 2: + if url_segments[0] in constants.REPO_TYPES_MAPPING: + # Passed '' or 'datasets/' for a canonical model or dataset + repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] + namespace = None + repo_id = hf_id.split("/")[-1] + else: + # Passed / or / + namespace, repo_id = hf_id.split("/")[-2:] + repo_type = None + else: + # Passed + repo_id = url_segments[0] + namespace, repo_type = None, None + else: + raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}") + + # Check if repo type is known (mapping "spaces" => "space" + empty value => `None`) + if repo_type in constants.REPO_TYPES_MAPPING: + repo_type = constants.REPO_TYPES_MAPPING[repo_type] + if repo_type == "": + repo_type = None + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')") + + return repo_type, namespace, repo_id + + +@dataclass +class LastCommitInfo(dict): + oid: str + title: str + date: datetime + + def __post_init__(self): # hack to make LastCommitInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class BlobLfsInfo(dict): + size: int + sha256: str + pointer_size: int + + def __post_init__(self): # hack to make BlobLfsInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class BlobSecurityInfo(dict): + safe: bool # duplicate information with "status" field, keeping it for backward compatibility + status: str + av_scan: Optional[Dict] + pickle_import_scan: Optional[Dict] + + def __post_init__(self): # hack to make BlogSecurityInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class TransformersInfo(dict): + auto_model: str + custom_class: Optional[str] = None + # possible `pipeline_tag` values: https://github.com/huggingface/huggingface.js/blob/3ee32554b8620644a6287e786b2a83bf5caf559c/packages/tasks/src/pipelines.ts#L72 + pipeline_tag: Optional[str] = None + processor: Optional[str] = None + + def __post_init__(self): # hack to make TransformersInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class SafeTensorsInfo(dict): + parameters: Dict[str, int] + total: int + + def __post_init__(self): # hack to make SafeTensorsInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class CommitInfo(str): + """Data structure containing information about a newly created commit. + + Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`], + [`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific + to `str` is deprecated. + + Attributes: + commit_url (`str`): + Url where to find the commit. + + commit_message (`str`): + The summary (first line) of the commit that has been created. + + commit_description (`str`): + Description of the commit that has been created. Can be empty. + + oid (`str`): + Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. + + pr_url (`str`, *optional*): + Url to the PR that has been created, if any. Populated when `create_pr=True` + is passed. + + pr_revision (`str`, *optional*): + Revision of the PR that has been created, if any. Populated when + `create_pr=True` is passed. Example: `"refs/pr/1"`. + + pr_num (`int`, *optional*): + Number of the PR discussion that has been created, if any. Populated when + `create_pr=True` is passed. Can be passed as `discussion_num` in + [`get_discussion_details`]. Example: `1`. + + repo_url (`RepoUrl`): + Repo URL of the commit containing info like repo_id, repo_type, etc. + + _url (`str`, *optional*): + Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by + [`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on + the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this + attribute. Please use `commit_url` instead. + """ + + commit_url: str + commit_message: str + commit_description: str + oid: str + pr_url: Optional[str] = None + + # Computed from `commit_url` in `__post_init__` + repo_url: RepoUrl = field(init=False) + + # Computed from `pr_url` in `__post_init__` + pr_revision: Optional[str] = field(init=False) + pr_num: Optional[str] = field(init=False) + + # legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.) + _url: str = field(repr=False, default=None) # type: ignore # defaults to `commit_url` + + def __new__(cls, *args, commit_url: str, _url: Optional[str] = None, **kwargs): + return str.__new__(cls, _url or commit_url) + + def __post_init__(self): + """Populate pr-related fields after initialization. + + See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing. + """ + # Repo info + self.repo_url = RepoUrl(self.commit_url.split("/commit/")[0]) + + # PR info + if self.pr_url is not None: + self.pr_revision = _parse_revision_from_pr_url(self.pr_url) + self.pr_num = int(self.pr_revision.split("/")[-1]) + else: + self.pr_revision = None + self.pr_num = None + + +@dataclass +class AccessRequest: + """Data structure containing information about a user access request. + + Attributes: + username (`str`): + Username of the user who requested access. + fullname (`str`): + Fullname of the user who requested access. + email (`Optional[str]`): + Email of the user who requested access. + Can only be `None` in the /accepted list if the user was granted access manually. + timestamp (`datetime`): + Timestamp of the request. + status (`Literal["pending", "accepted", "rejected"]`): + Status of the request. Can be one of `["pending", "accepted", "rejected"]`. + fields (`Dict[str, Any]`, *optional*): + Additional fields filled by the user in the gate form. + """ + + username: str + fullname: str + email: Optional[str] + timestamp: datetime + status: Literal["pending", "accepted", "rejected"] + + # Additional fields filled by the user in the gate form + fields: Optional[Dict[str, Any]] = None + + +@dataclass +class WebhookWatchedItem: + """Data structure containing information about the items watched by a webhook. + + Attributes: + type (`Literal["dataset", "model", "org", "space", "user"]`): + Type of the item to be watched. Can be one of `["dataset", "model", "org", "space", "user"]`. + name (`str`): + Name of the item to be watched. Can be the username, organization name, model name, dataset name or space name. + """ + + type: Literal["dataset", "model", "org", "space", "user"] + name: str + + +@dataclass +class WebhookInfo: + """Data structure containing information about a webhook. + + Attributes: + id (`str`): + ID of the webhook. + url (`str`): + URL of the webhook. + watched (`List[WebhookWatchedItem]`): + List of items watched by the webhook, see [`WebhookWatchedItem`]. + domains (`List[WEBHOOK_DOMAIN_T]`): + List of domains the webhook is watching. Can be one of `["repo", "discussions"]`. + secret (`str`, *optional*): + Secret of the webhook. + disabled (`bool`): + Whether the webhook is disabled or not. + """ + + id: str + url: str + watched: List[WebhookWatchedItem] + domains: List[constants.WEBHOOK_DOMAIN_T] + secret: Optional[str] + disabled: bool + + +class RepoUrl(str): + """Subclass of `str` describing a repo URL on the Hub. + + `RepoUrl` is returned by `HfApi.create_repo`. It inherits from `str` for backward + compatibility. At initialization, the URL is parsed to populate properties: + - endpoint (`str`) + - namespace (`Optional[str]`) + - repo_name (`str`) + - repo_id (`str`) + - repo_type (`Literal["model", "dataset", "space"]`) + - url (`str`) + + Args: + url (`Any`): + String value of the repo url. + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + + Example: + ```py + >>> RepoUrl('https://huggingface.co/gpt2') + RepoUrl('https://huggingface.co/gpt2', endpoint='https://huggingface.co', repo_type='model', repo_id='gpt2') + + >>> RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co') + RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co', repo_type='dataset', repo_id='dummy_user/dummy_dataset') + + >>> RepoUrl('hf://datasets/my-user/my-dataset') + RepoUrl('hf://datasets/my-user/my-dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='user/dataset') + + >>> HfApi.create_repo("dummy_model") + RepoUrl('https://huggingface.co/Wauplin/dummy_model', endpoint='https://huggingface.co', repo_type='model', repo_id='Wauplin/dummy_model') + ``` + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If URL cannot be parsed. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `repo_type` is unknown. + """ + + def __new__(cls, url: Any, endpoint: Optional[str] = None): + url = fix_hf_endpoint_in_url(url, endpoint=endpoint) + return super(RepoUrl, cls).__new__(cls, url) + + def __init__(self, url: Any, endpoint: Optional[str] = None) -> None: + super().__init__() + # Parse URL + self.endpoint = endpoint or constants.ENDPOINT + repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(self, hub_url=self.endpoint) + + # Populate fields + self.namespace = namespace + self.repo_name = repo_name + self.repo_id = repo_name if namespace is None else f"{namespace}/{repo_name}" + self.repo_type = repo_type or constants.REPO_TYPE_MODEL + self.url = str(self) # just in case it's needed + + def __repr__(self) -> str: + return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')" + + +@dataclass +class RepoSibling: + """ + Contains basic information about a repo file inside a repo on the Hub. + + + + All attributes of this class are optional except `rfilename`. This is because only the file names are returned when + listing repositories on the Hub (with [`list_models`], [`list_datasets`] or [`list_spaces`]). If you need more + information like file size, blob id or lfs details, you must request them specifically from one repo at a time + (using [`model_info`], [`dataset_info`] or [`space_info`]) as it adds more constraints on the backend server to + retrieve these. + + + + Attributes: + rfilename (str): + file name, relative to the repo root. + size (`int`, *optional*): + The file's size, in bytes. This attribute is defined when `files_metadata` argument of [`repo_info`] is set + to `True`. It's `None` otherwise. + blob_id (`str`, *optional*): + The file's git OID. This attribute is defined when `files_metadata` argument of [`repo_info`] is set to + `True`. It's `None` otherwise. + lfs (`BlobLfsInfo`, *optional*): + The file's LFS metadata. This attribute is defined when`files_metadata` argument of [`repo_info`] is set to + `True` and the file is stored with Git LFS. It's `None` otherwise. + """ + + rfilename: str + size: Optional[int] = None + blob_id: Optional[str] = None + lfs: Optional[BlobLfsInfo] = None + + +@dataclass +class RepoFile: + """ + Contains information about a file on the Hub. + + Attributes: + path (str): + file path relative to the repo root. + size (`int`): + The file's size, in bytes. + blob_id (`str`): + The file's git OID. + lfs (`BlobLfsInfo`): + The file's LFS metadata. + last_commit (`LastCommitInfo`, *optional*): + The file's last commit metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + security (`BlobSecurityInfo`, *optional*): + The file's security scan metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + """ + + path: str + size: int + blob_id: str + lfs: Optional[BlobLfsInfo] = None + last_commit: Optional[LastCommitInfo] = None + security: Optional[BlobSecurityInfo] = None + + def __init__(self, **kwargs): + self.path = kwargs.pop("path") + self.size = kwargs.pop("size") + self.blob_id = kwargs.pop("oid") + lfs = kwargs.pop("lfs", None) + if lfs is not None: + lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"]) + self.lfs = lfs + last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) + if last_commit is not None: + last_commit = LastCommitInfo( + oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) + ) + self.last_commit = last_commit + security = kwargs.pop("securityFileStatus", None) + if security is not None: + safe = security["status"] == "safe" + security = BlobSecurityInfo( + safe=safe, + status=security["status"], + av_scan=security["avScan"], + pickle_import_scan=security["pickleImportScan"], + ) + self.security = security + + # backwards compatibility + self.rfilename = self.path + self.lastCommit = self.last_commit + + +@dataclass +class RepoFolder: + """ + Contains information about a folder on the Hub. + + Attributes: + path (str): + folder path relative to the repo root. + tree_id (`str`): + The folder's git OID. + last_commit (`LastCommitInfo`, *optional*): + The folder's last commit metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + """ + + path: str + tree_id: str + last_commit: Optional[LastCommitInfo] = None + + def __init__(self, **kwargs): + self.path = kwargs.pop("path") + self.tree_id = kwargs.pop("oid") + last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) + if last_commit is not None: + last_commit = LastCommitInfo( + oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) + ) + self.last_commit = last_commit + + +@dataclass +class InferenceProviderMapping: + provider: "PROVIDER_T" # Provider name + hf_model_id: str # ID of the model on the Hugging Face Hub + provider_id: str # ID of the model on the provider's side + status: Literal["error", "live", "staging"] + task: str + + adapter: Optional[str] = None + adapter_weights_path: Optional[str] = None + type: Optional[Literal["single-model", "tag-filter"]] = None + + def __init__(self, **kwargs): + self.provider = kwargs.pop("provider") + self.hf_model_id = kwargs.pop("hf_model_id") + self.provider_id = kwargs.pop("providerId") + self.status = kwargs.pop("status") + self.task = kwargs.pop("task") + + self.adapter = kwargs.pop("adapter", None) + self.adapter_weights_path = kwargs.pop("adapterWeightsPath", None) + self.type = kwargs.pop("type", None) + self.__dict__.update(**kwargs) + + +@dataclass +class ModelInfo: + """ + Contains information about a model on the Hub. This object is returned by [`model_info`] and [`list_models`]. + + + + Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + In general, the more specific the query, the more information is returned. On the contrary, when listing models + using [`list_models`] only a subset of the attributes are returned. + + + + Attributes: + id (`str`): + ID of model. + author (`str`, *optional*): + Author of the model. + sha (`str`, *optional*): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + disabled (`bool`, *optional*): + Is the repo disabled. + downloads (`int`): + Number of downloads of the model over the last 30 days. + downloads_all_time (`int`): + Cumulated number of downloads of the model since its creation. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + gguf (`Dict`, *optional*): + GGUF information of the model. + inference (`Literal["warm"]`, *optional*): + Status of the model on Inference Providers. Warm if the model is served by at least one provider. + inference_provider_mapping (`List[InferenceProviderMapping]`, *optional*): + A list of [`InferenceProviderMapping`] ordered after the user's provider order. + likes (`int`): + Number of likes of the model. + library_name (`str`, *optional*): + Library associated with the model. + tags (`List[str]`): + List of tags of the model. Compared to `card_data.tags`, contains extra tags computed by the Hub + (e.g. supported libraries, model's arXiv). + pipeline_tag (`str`, *optional*): + Pipeline tag associated with the model. + mask_token (`str`, *optional*): + Mask token used by the model. + widget_data (`Any`, *optional*): + Widget data associated with the model. + model_index (`Dict`, *optional*): + Model index for evaluation. + config (`Dict`, *optional*): + Model configuration. + transformers_info (`TransformersInfo`, *optional*): + Transformers-specific info (auto class, processor, etc.) associated with the model. + trending_score (`int`, *optional*): + Trending score of the model. + card_data (`ModelCardData`, *optional*): + Model Card Metadata as a [`huggingface_hub.repocard_data.ModelCardData`] object. + siblings (`List[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the model. + spaces (`List[str]`, *optional*): + List of spaces using the model. + safetensors (`SafeTensorsInfo`, *optional*): + Model's safetensors information. + security_repo_status (`Dict`, *optional*): + Model's security scan status. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: Optional[bool] + disabled: Optional[bool] + downloads: Optional[int] + downloads_all_time: Optional[int] + gated: Optional[Literal["auto", "manual", False]] + gguf: Optional[Dict] + inference: Optional[Literal["warm"]] + inference_provider_mapping: Optional[List[InferenceProviderMapping]] + likes: Optional[int] + library_name: Optional[str] + tags: Optional[List[str]] + pipeline_tag: Optional[str] + mask_token: Optional[str] + card_data: Optional[ModelCardData] + widget_data: Optional[Any] + model_index: Optional[Dict] + config: Optional[Dict] + transformers_info: Optional[TransformersInfo] + trending_score: Optional[int] + siblings: Optional[List[RepoSibling]] + spaces: Optional[List[str]] + safetensors: Optional[SafeTensorsInfo] + security_repo_status: Optional[Dict] + xet_enabled: Optional[bool] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + self.private = kwargs.pop("private", None) + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.downloads = kwargs.pop("downloads", None) + self.downloads_all_time = kwargs.pop("downloadsAllTime", None) + self.likes = kwargs.pop("likes", None) + self.library_name = kwargs.pop("library_name", None) + self.gguf = kwargs.pop("gguf", None) + + self.inference = kwargs.pop("inference", None) + + # little hack to simplify Inference Providers logic and make it backward and forward compatible + # right now, API returns a dict on model_info and a list on list_models. Let's harmonize to list. + mapping = kwargs.pop("inferenceProviderMapping", None) + if isinstance(mapping, list): + self.inference_provider_mapping = [ + InferenceProviderMapping(**{**value, "hf_model_id": self.id}) for value in mapping + ] + elif isinstance(mapping, dict): + self.inference_provider_mapping = [ + InferenceProviderMapping(**{**value, "hf_model_id": self.id, "provider": provider}) + for provider, value in mapping.items() + ] + elif mapping is None: + self.inference_provider_mapping = None + else: + raise ValueError( + f"Unexpected type for `inferenceProviderMapping`. Expecting `dict` or `list`. Got {mapping}." + ) + + self.tags = kwargs.pop("tags", None) + self.pipeline_tag = kwargs.pop("pipeline_tag", None) + self.mask_token = kwargs.pop("mask_token", None) + self.trending_score = kwargs.pop("trendingScore", None) + + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + ModelCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + + self.widget_data = kwargs.pop("widgetData", None) + self.model_index = kwargs.pop("model-index", None) or kwargs.pop("model_index", None) + self.config = kwargs.pop("config", None) + transformers_info = kwargs.pop("transformersInfo", None) or kwargs.pop("transformers_info", None) + self.transformers_info = TransformersInfo(**transformers_info) if transformers_info else None + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings is not None + else None + ) + self.spaces = kwargs.pop("spaces", None) + safetensors = kwargs.pop("safetensors", None) + self.safetensors = ( + SafeTensorsInfo( + parameters=safetensors["parameters"], + total=safetensors["total"], + ) + if safetensors + else None + ) + self.security_repo_status = kwargs.pop("securityRepoStatus", None) + self.xet_enabled = kwargs.pop("xetEnabled", None) + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.transformersInfo = self.transformers_info + self.__dict__.update(**kwargs) + + +@dataclass +class DatasetInfo: + """ + Contains information about a dataset on the Hub. This object is returned by [`dataset_info`] and [`list_datasets`]. + + + + Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + In general, the more specific the query, the more information is returned. On the contrary, when listing datasets + using [`list_datasets`] only a subset of the attributes are returned. + + + + Attributes: + id (`str`): + ID of dataset. + author (`str`): + Author of the dataset. + sha (`str`): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + disabled (`bool`, *optional*): + Is the repo disabled. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + downloads (`int`): + Number of downloads of the dataset over the last 30 days. + downloads_all_time (`int`): + Cumulated number of downloads of the model since its creation. + likes (`int`): + Number of likes of the dataset. + tags (`List[str]`): + List of tags of the dataset. + card_data (`DatasetCardData`, *optional*): + Model Card Metadata as a [`huggingface_hub.repocard_data.DatasetCardData`] object. + siblings (`List[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the dataset. + paperswithcode_id (`str`, *optional*): + Papers with code ID of the dataset. + trending_score (`int`, *optional*): + Trending score of the dataset. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: Optional[bool] + gated: Optional[Literal["auto", "manual", False]] + disabled: Optional[bool] + downloads: Optional[int] + downloads_all_time: Optional[int] + likes: Optional[int] + paperswithcode_id: Optional[str] + tags: Optional[List[str]] + trending_score: Optional[int] + card_data: Optional[DatasetCardData] + siblings: Optional[List[RepoSibling]] + xet_enabled: Optional[bool] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + self.private = kwargs.pop("private", None) + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.downloads = kwargs.pop("downloads", None) + self.downloads_all_time = kwargs.pop("downloadsAllTime", None) + self.likes = kwargs.pop("likes", None) + self.paperswithcode_id = kwargs.pop("paperswithcode_id", None) + self.tags = kwargs.pop("tags", None) + self.trending_score = kwargs.pop("trendingScore", None) + + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + DatasetCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings is not None + else None + ) + self.xet_enabled = kwargs.pop("xetEnabled", None) + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.__dict__.update(**kwargs) + + +@dataclass +class SpaceInfo: + """ + Contains information about a Space on the Hub. This object is returned by [`space_info`] and [`list_spaces`]. + + + + Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + In general, the more specific the query, the more information is returned. On the contrary, when listing spaces + using [`list_spaces`] only a subset of the attributes are returned. + + + + Attributes: + id (`str`): + ID of the Space. + author (`str`, *optional*): + Author of the Space. + sha (`str`, *optional*): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + disabled (`bool`, *optional*): + Is the Space disabled. + host (`str`, *optional*): + Host URL of the Space. + subdomain (`str`, *optional*): + Subdomain of the Space. + likes (`int`): + Number of likes of the Space. + tags (`List[str]`): + List of tags of the Space. + siblings (`List[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the Space. + card_data (`SpaceCardData`, *optional*): + Space Card Metadata as a [`huggingface_hub.repocard_data.SpaceCardData`] object. + runtime (`SpaceRuntime`, *optional*): + Space runtime information as a [`huggingface_hub.hf_api.SpaceRuntime`] object. + sdk (`str`, *optional*): + SDK used by the Space. + models (`List[str]`, *optional*): + List of models used by the Space. + datasets (`List[str]`, *optional*): + List of datasets used by the Space. + trending_score (`int`, *optional*): + Trending score of the Space. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: Optional[bool] + gated: Optional[Literal["auto", "manual", False]] + disabled: Optional[bool] + host: Optional[str] + subdomain: Optional[str] + likes: Optional[int] + sdk: Optional[str] + tags: Optional[List[str]] + siblings: Optional[List[RepoSibling]] + trending_score: Optional[int] + card_data: Optional[SpaceCardData] + runtime: Optional[SpaceRuntime] + models: Optional[List[str]] + datasets: Optional[List[str]] + xet_enabled: Optional[bool] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + self.private = kwargs.pop("private", None) + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.host = kwargs.pop("host", None) + self.subdomain = kwargs.pop("subdomain", None) + self.likes = kwargs.pop("likes", None) + self.sdk = kwargs.pop("sdk", None) + self.tags = kwargs.pop("tags", None) + self.trending_score = kwargs.pop("trendingScore", None) + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + SpaceCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings is not None + else None + ) + runtime = kwargs.pop("runtime", None) + self.runtime = SpaceRuntime(runtime) if runtime else None + self.models = kwargs.pop("models", None) + self.datasets = kwargs.pop("datasets", None) + self.xet_enabled = kwargs.pop("xetEnabled", None) + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.__dict__.update(**kwargs) + + +@dataclass +class CollectionItem: + """ + Contains information about an item of a Collection (model, dataset, Space, paper or collection). + + Attributes: + item_object_id (`str`): + Unique ID of the item in the collection. + item_id (`str`): + ID of the underlying object on the Hub. Can be either a repo_id, a paper id or a collection slug. + e.g. `"jbilcke-hf/ai-comic-factory"`, `"2307.09288"`, `"celinah/cerebras-function-calling-682607169c35fbfa98b30b9a"`. + item_type (`str`): + Type of the underlying object. Can be one of `"model"`, `"dataset"`, `"space"`, `"paper"` or `"collection"`. + position (`int`): + Position of the item in the collection. + note (`str`, *optional*): + Note associated with the item, as plain text. + """ + + item_object_id: str # id in database + item_id: str # repo_id or paper id + item_type: str + position: int + note: Optional[str] = None + + def __init__( + self, + _id: str, + id: str, + type: CollectionItemType_T, + position: int, + note: Optional[Dict] = None, + **kwargs, + ) -> None: + self.item_object_id: str = _id # id in database + self.item_id: str = id # repo_id or paper id + # if the item is a collection, override item_id with the slug + slug = kwargs.get("slug") + if slug is not None: + self.item_id = slug # collection slug + self.item_type: CollectionItemType_T = type + self.position: int = position + self.note: str = note["text"] if note is not None else None + + +@dataclass +class Collection: + """ + Contains information about a Collection on the Hub. + + Attributes: + slug (`str`): + Slug of the collection. E.g. `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection. E.g. `"Recent models"`. + owner (`str`): + Owner of the collection. E.g. `"TheBloke"`. + items (`List[CollectionItem]`): + List of items in the collection. + last_updated (`datetime`): + Date of the last update of the collection. + position (`int`): + Position of the collection in the list of collections of the owner. + private (`bool`): + Whether the collection is private or not. + theme (`str`): + Theme of the collection. E.g. `"green"`. + upvotes (`int`): + Number of upvotes of the collection. + description (`str`, *optional*): + Description of the collection, as plain text. + url (`str`): + (property) URL of the collection on the Hub. + """ + + slug: str + title: str + owner: str + items: List[CollectionItem] + last_updated: datetime + position: int + private: bool + theme: str + upvotes: int + description: Optional[str] = None + + def __init__(self, **kwargs) -> None: + self.slug = kwargs.pop("slug") + self.title = kwargs.pop("title") + self.owner = kwargs.pop("owner") + self.items = [CollectionItem(**item) for item in kwargs.pop("items")] + self.last_updated = parse_datetime(kwargs.pop("lastUpdated")) + self.position = kwargs.pop("position") + self.private = kwargs.pop("private") + self.theme = kwargs.pop("theme") + self.upvotes = kwargs.pop("upvotes") + self.description = kwargs.pop("description", None) + endpoint = kwargs.pop("endpoint", None) + if endpoint is None: + endpoint = constants.ENDPOINT + self._url = f"{endpoint}/collections/{self.slug}" + + @property + def url(self) -> str: + """Returns the URL of the collection on the Hub.""" + return self._url + + +@dataclass +class GitRefInfo: + """ + Contains information about a git reference for a repo on the Hub. + + Attributes: + name (`str`): + Name of the reference (e.g. tag name or branch name). + ref (`str`): + Full git ref on the Hub (e.g. `"refs/heads/main"` or `"refs/tags/v1.0"`). + target_commit (`str`): + OID of the target commit for the ref (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) + """ + + name: str + ref: str + target_commit: str + + +@dataclass +class GitRefs: + """ + Contains information about all git references for a repo on the Hub. + + Object is returned by [`list_repo_refs`]. + + Attributes: + branches (`List[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about branches on the repo. + converts (`List[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about "convert" refs on the repo. + Converts are refs used (internally) to push preprocessed data in Dataset repos. + tags (`List[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about tags on the repo. + pull_requests (`List[GitRefInfo]`, *optional*): + A list of [`GitRefInfo`] containing information about pull requests on the repo. + Only returned if `include_prs=True` is set. + """ + + branches: List[GitRefInfo] + converts: List[GitRefInfo] + tags: List[GitRefInfo] + pull_requests: Optional[List[GitRefInfo]] = None + + +@dataclass +class GitCommitInfo: + """ + Contains information about a git commit for a repo on the Hub. Check out [`list_repo_commits`] for more details. + + Attributes: + commit_id (`str`): + OID of the commit (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) + authors (`List[str]`): + List of authors of the commit. + created_at (`datetime`): + Datetime when the commit was created. + title (`str`): + Title of the commit. This is a free-text value entered by the authors. + message (`str`): + Description of the commit. This is a free-text value entered by the authors. + formatted_title (`str`): + Title of the commit formatted as HTML. Only returned if `formatted=True` is set. + formatted_message (`str`): + Description of the commit formatted as HTML. Only returned if `formatted=True` is set. + """ + + commit_id: str + + authors: List[str] + created_at: datetime + title: str + message: str + + formatted_title: Optional[str] + formatted_message: Optional[str] + + +@dataclass +class UserLikes: + """ + Contains information about a user likes on the Hub. + + Attributes: + user (`str`): + Name of the user for which we fetched the likes. + total (`int`): + Total number of likes. + datasets (`List[str]`): + List of datasets liked by the user (as repo_ids). + models (`List[str]`): + List of models liked by the user (as repo_ids). + spaces (`List[str]`): + List of spaces liked by the user (as repo_ids). + """ + + # Metadata + user: str + total: int + + # User likes + datasets: List[str] + models: List[str] + spaces: List[str] + + +@dataclass +class Organization: + """ + Contains information about an organization on the Hub. + + Attributes: + avatar_url (`str`): + URL of the organization's avatar. + name (`str`): + Name of the organization on the Hub (unique). + fullname (`str`): + Organization's full name. + """ + + avatar_url: str + name: str + fullname: str + + def __init__(self, **kwargs) -> None: + self.avatar_url = kwargs.pop("avatarUrl", "") + self.name = kwargs.pop("name", "") + self.fullname = kwargs.pop("fullname", "") + + # forward compatibility + self.__dict__.update(**kwargs) + + +@dataclass +class User: + """ + Contains information about a user on the Hub. + + Attributes: + username (`str`): + Name of the user on the Hub (unique). + fullname (`str`): + User's full name. + avatar_url (`str`): + URL of the user's avatar. + details (`str`, *optional*): + User's details. + is_following (`bool`, *optional*): + Whether the authenticated user is following this user. + is_pro (`bool`, *optional*): + Whether the user is a pro user. + num_models (`int`, *optional*): + Number of models created by the user. + num_datasets (`int`, *optional*): + Number of datasets created by the user. + num_spaces (`int`, *optional*): + Number of spaces created by the user. + num_discussions (`int`, *optional*): + Number of discussions initiated by the user. + num_papers (`int`, *optional*): + Number of papers authored by the user. + num_upvotes (`int`, *optional*): + Number of upvotes received by the user. + num_likes (`int`, *optional*): + Number of likes given by the user. + num_following (`int`, *optional*): + Number of users this user is following. + num_followers (`int`, *optional*): + Number of users following this user. + orgs (list of [`Organization`]): + List of organizations the user is part of. + """ + + # Metadata + username: str + fullname: str + avatar_url: str + details: Optional[str] = None + is_following: Optional[bool] = None + is_pro: Optional[bool] = None + num_models: Optional[int] = None + num_datasets: Optional[int] = None + num_spaces: Optional[int] = None + num_discussions: Optional[int] = None + num_papers: Optional[int] = None + num_upvotes: Optional[int] = None + num_likes: Optional[int] = None + num_following: Optional[int] = None + num_followers: Optional[int] = None + orgs: List[Organization] = field(default_factory=list) + + def __init__(self, **kwargs) -> None: + self.username = kwargs.pop("user", "") + self.fullname = kwargs.pop("fullname", "") + self.avatar_url = kwargs.pop("avatarUrl", "") + self.is_following = kwargs.pop("isFollowing", None) + self.is_pro = kwargs.pop("isPro", None) + self.details = kwargs.pop("details", None) + self.num_models = kwargs.pop("numModels", None) + self.num_datasets = kwargs.pop("numDatasets", None) + self.num_spaces = kwargs.pop("numSpaces", None) + self.num_discussions = kwargs.pop("numDiscussions", None) + self.num_papers = kwargs.pop("numPapers", None) + self.num_upvotes = kwargs.pop("numUpvotes", None) + self.num_likes = kwargs.pop("numLikes", None) + self.num_following = kwargs.pop("numFollowing", None) + self.num_followers = kwargs.pop("numFollowers", None) + self.user_type = kwargs.pop("type", None) + self.orgs = [Organization(**org) for org in kwargs.pop("orgs", [])] + + # forward compatibility + self.__dict__.update(**kwargs) + + +@dataclass +class PaperInfo: + """ + Contains information about a paper on the Hub. + + Attributes: + id (`str`): + arXiv paper ID. + authors (`List[str]`, **optional**): + Names of paper authors + published_at (`datetime`, **optional**): + Date paper published. + title (`str`, **optional**): + Title of the paper. + summary (`str`, **optional**): + Summary of the paper. + upvotes (`int`, **optional**): + Number of upvotes for the paper on the Hub. + discussion_id (`str`, **optional**): + Discussion ID for the paper on the Hub. + source (`str`, **optional**): + Source of the paper. + comments (`int`, **optional**): + Number of comments for the paper on the Hub. + submitted_at (`datetime`, **optional**): + Date paper appeared in daily papers on the Hub. + submitted_by (`User`, **optional**): + Information about who submitted the daily paper. + """ + + id: str + authors: Optional[List[str]] + published_at: Optional[datetime] + title: Optional[str] + summary: Optional[str] + upvotes: Optional[int] + discussion_id: Optional[str] + source: Optional[str] + comments: Optional[int] + submitted_at: Optional[datetime] + submitted_by: Optional[User] + + def __init__(self, **kwargs) -> None: + paper = kwargs.pop("paper", {}) + self.id = kwargs.pop("id", None) or paper.pop("id", None) + authors = paper.pop("authors", None) or kwargs.pop("authors", None) + self.authors = [author.pop("name", None) for author in authors] if authors else None + published_at = paper.pop("publishedAt", None) or kwargs.pop("publishedAt", None) + self.published_at = parse_datetime(published_at) if published_at else None + self.title = kwargs.pop("title", None) + self.source = kwargs.pop("source", None) + self.summary = paper.pop("summary", None) or kwargs.pop("summary", None) + self.upvotes = paper.pop("upvotes", None) or kwargs.pop("upvotes", None) + self.discussion_id = paper.pop("discussionId", None) or kwargs.pop("discussionId", None) + self.comments = kwargs.pop("numComments", 0) + submitted_at = kwargs.pop("publishedAt", None) or kwargs.pop("submittedOnDailyAt", None) + self.submitted_at = parse_datetime(submitted_at) if submitted_at else None + submitted_by = kwargs.pop("submittedBy", None) or kwargs.pop("submittedOnDailyBy", None) + self.submitted_by = User(**submitted_by) if submitted_by else None + + # forward compatibility + self.__dict__.update(**kwargs) + + +@dataclass +class LFSFileInfo: + """ + Contains information about a file stored as LFS on a repo on the Hub. + + Used in the context of listing and permanently deleting LFS files from a repo to free-up space. + See [`list_lfs_files`] and [`permanently_delete_lfs_files`] for more details. + + Git LFS files are tracked using SHA-256 object IDs, rather than file paths, to optimize performance + This approach is necessary because a single object can be referenced by multiple paths across different commits, + making it impractical to search and resolve these connections. Check out [our documentation](https://huggingface.co/docs/hub/storage-limits#advanced-track-lfs-file-references) + to learn how to know which filename(s) is(are) associated with each SHA. + + Attributes: + file_oid (`str`): + SHA-256 object ID of the file. This is the identifier to pass when permanently deleting the file. + filename (`str`): + Possible filename for the LFS object. See the note above for more information. + oid (`str`): + OID of the LFS object. + pushed_at (`datetime`): + Date the LFS object was pushed to the repo. + ref (`str`, *optional*): + Ref where the LFS object has been pushed (if any). + size (`int`): + Size of the LFS object. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> lfs_files = api.list_lfs_files("username/my-cool-repo") + + # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. + # e.g. select only LFS files in the "checkpoints" folder + >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) + + # Permanently delete LFS files + >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) + ``` + """ + + file_oid: str + filename: str + oid: str + pushed_at: datetime + ref: Optional[str] + size: int + + def __init__(self, **kwargs) -> None: + self.file_oid = kwargs.pop("fileOid") + self.filename = kwargs.pop("filename") + self.oid = kwargs.pop("oid") + self.pushed_at = parse_datetime(kwargs.pop("pushedAt")) + self.ref = kwargs.pop("ref", None) + self.size = kwargs.pop("size") + + # forward compatibility + self.__dict__.update(**kwargs) + + +def future_compatible(fn: CallableT) -> CallableT: + """Wrap a method of `HfApi` to handle `run_as_future=True`. + + A method flagged as "future_compatible" will be called in a thread if `run_as_future=True` and return a + `concurrent.futures.Future` instance. Otherwise, it will be called normally and return the result. + """ + sig = inspect.signature(fn) + args_params = list(sig.parameters)[1:] # remove "self" from list + + @wraps(fn) + def _inner(self, *args, **kwargs): + # Get `run_as_future` value if provided (default to False) + if "run_as_future" in kwargs: + run_as_future = kwargs["run_as_future"] + kwargs["run_as_future"] = False # avoid recursion error + else: + run_as_future = False + for param, value in zip(args_params, args): + if param == "run_as_future": + run_as_future = value + break + + # Call the function in a thread if `run_as_future=True` + if run_as_future: + return self.run_as_future(fn, self, *args, **kwargs) + + # Otherwise, call the function normally + return fn(self, *args, **kwargs) + + _inner.is_future_compatible = True # type: ignore + return _inner # type: ignore + + +class HfApi: + """ + Client to interact with the Hugging Face Hub via HTTP. + + The client is initialized with some high-level settings used in all requests + made to the Hub (HF endpoint, authentication, user agents...). Using the `HfApi` + client is preferred but not mandatory as all of its public methods are exposed + directly at the root of `huggingface_hub`. + + Args: + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. Will be added to + the user-agent header. Example: `"transformers"`. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. Will be added + to the user-agent header. Example: `"4.24.0"`. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. It will + be completed with information about the installed packages. + headers (`dict`, *optional*): + Additional headers to be sent with each request. Example: `{"X-My-Header": "value"}`. + Headers passed here are taking precedence over the default headers. + """ + + def __init__( + self, + endpoint: Optional[str] = None, + token: Union[str, bool, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + headers: Optional[Dict[str, str]] = None, + ) -> None: + self.endpoint = endpoint if endpoint is not None else constants.ENDPOINT + self.token = token + self.library_name = library_name + self.library_version = library_version + self.user_agent = user_agent + self.headers = headers + self._thread_pool: Optional[ThreadPoolExecutor] = None + + def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: + """ + Run a method in the background and return a Future instance. + + The main goal is to run methods without blocking the main thread (e.g. to push data during a training). + Background jobs are queued to preserve order but are not ran in parallel. If you need to speed-up your scripts + by parallelizing lots of call to the API, you must setup and use your own [ThreadPoolExecutor](https://docs.python.org/3/library/concurrent.futures.html#threadpoolexecutor). + + Note: Most-used methods like [`upload_file`], [`upload_folder`] and [`create_commit`] have a `run_as_future: bool` + argument to directly call them in the background. This is equivalent to calling `api.run_as_future(...)` on them + but less verbose. + + Args: + fn (`Callable`): + The method to run in the background. + *args, **kwargs: + Arguments with which the method will be called. + + Return: + `Future`: a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) instance to + get the result of the task. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> future = api.run_as_future(api.whoami) # instant + >>> future.done() + False + >>> future.result() # wait until complete and return result + (...) + >>> future.done() + True + ``` + """ + if self._thread_pool is None: + self._thread_pool = ThreadPoolExecutor(max_workers=1) + self._thread_pool + return self._thread_pool.submit(fn, *args, **kwargs) + + @validate_hf_hub_args + def whoami(self, token: Union[bool, str, None] = None) -> Dict: + """ + Call HF API to know "whoami". + + Args: + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + # Get the effective token using the helper function get_token + effective_token = token or self.token or get_token() or True + r = get_session().get( + f"{self.endpoint}/api/whoami-v2", + headers=self._build_hf_headers(token=effective_token), + ) + try: + hf_raise_for_status(r) + except HTTPError as e: + if e.response.status_code == 401: + error_message = "Invalid user token." + # Check which token is the effective one and generate the error message accordingly + if effective_token == _get_token_from_google_colab(): + error_message += " The token from Google Colab vault is invalid. Please update it from the UI." + elif effective_token == _get_token_from_environment(): + error_message += ( + " The token from HF_TOKEN environment variable is invalid. " + "Note that HF_TOKEN takes precedence over `hf auth login`." + ) + elif effective_token == _get_token_from_file(): + error_message += " The token stored is invalid. Please run `hf auth login` to update it." + raise HTTPError(error_message, request=e.request, response=e.response) from e + raise + return r.json() + + @_deprecate_method( + version="1.0", + message=( + "Permissions are more complex than when `get_token_permission` was first introduced. " + "OAuth and fine-grain tokens allows for more detailed permissions. " + "If you need to know the permissions associated with a token, please use `whoami` and check the `'auth'` key." + ), + ) + def get_token_permission( + self, token: Union[bool, str, None] = None + ) -> Literal["read", "write", "fineGrained", None]: + """ + Check if a given `token` is valid and return its permissions. + + + + This method is deprecated and will be removed in version 1.0. Permissions are more complex than when + `get_token_permission` was first introduced. OAuth and fine-grain tokens allows for more detailed permissions. + If you need to know the permissions associated with a token, please use `whoami` and check the `'auth'` key. + + + + For more details about tokens, please refer to https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens. + + Args: + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Literal["read", "write", "fineGrained", None]`: Permission granted by the token ("read" or "write"). Returns `None` if no + token passed, if token is invalid or if role is not returned by the server. This typically happens when the token is an OAuth token. + """ + try: + return self.whoami(token=token)["auth"]["accessToken"]["role"] + except (LocalTokenNotFoundError, HTTPError, KeyError): + return None + + def get_model_tags(self) -> Dict: + """ + List all valid model tags as a nested namespace object + """ + path = f"{self.endpoint}/api/models-tags-by-type" + r = get_session().get(path) + hf_raise_for_status(r) + return r.json() + + def get_dataset_tags(self) -> Dict: + """ + List all valid dataset tags as a nested namespace object. + """ + path = f"{self.endpoint}/api/datasets-tags-by-type" + r = get_session().get(path) + hf_raise_for_status(r) + return r.json() + + @validate_hf_hub_args + def list_models( + self, + *, + # Search-query parameter + filter: Union[str, Iterable[str], None] = None, + author: Optional[str] = None, + gated: Optional[bool] = None, + inference: Optional[Literal["warm"]] = None, + inference_provider: Optional[Union[Literal["all"], "PROVIDER_T", List["PROVIDER_T"]]] = None, + library: Optional[Union[str, List[str]]] = None, + language: Optional[Union[str, List[str]]] = None, + model_name: Optional[str] = None, + task: Optional[Union[str, List[str]]] = None, + trained_dataset: Optional[Union[str, List[str]]] = None, + tags: Optional[Union[str, List[str]]] = None, + search: Optional[str] = None, + pipeline_tag: Optional[str] = None, + emissions_thresholds: Optional[Tuple[float, float]] = None, + # Sorting and pagination parameters + sort: Union[Literal["last_modified"], str, None] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + # Additional data to fetch + expand: Optional[List[ExpandModelProperty_T]] = None, + full: Optional[bool] = None, + cardData: bool = False, + fetch_config: bool = False, + token: Union[bool, str, None] = None, + ) -> Iterable[ModelInfo]: + """ + List models hosted on the Huggingface Hub, given some filters. + + Args: + filter (`str` or `Iterable[str]`, *optional*): + A string or list of string to filter models on the Hub. + author (`str`, *optional*): + A string which identify the author (user or organization) of the + returned models. + gated (`bool`, *optional*): + A boolean to filter models on the Hub that are gated or not. By default, all models are returned. + If `gated=True` is passed, only gated models are returned. + If `gated=False` is passed, only non-gated models are returned. + inference (`Literal["warm"]`, *optional*): + If "warm", filter models on the Hub currently served by at least one provider. + inference_provider (`Literal["all"]` or `str`, *optional*): + A string to filter models on the Hub that are served by a specific provider. + Pass `"all"` to get all models served by at least one provider. + library (`str` or `List`, *optional*): + A string or list of strings of foundational libraries models were + originally trained from, such as pytorch, tensorflow, or allennlp. + language (`str` or `List`, *optional*): + A string or list of strings of languages, both by name and country + code, such as "en" or "English" + model_name (`str`, *optional*): + A string that contain complete or partial names for models on the + Hub, such as "bert" or "bert-base-cased" + task (`str` or `List`, *optional*): + A string or list of strings of tasks models were designed for, such + as: "fill-mask" or "automatic-speech-recognition" + trained_dataset (`str` or `List`, *optional*): + A string tag or a list of string tags of the trained dataset for a + model on the Hub. + tags (`str` or `List`, *optional*): + A string tag or a list of tags to filter models on the Hub by, such + as `text-generation` or `spacy`. + search (`str`, *optional*): + A string that will be contained in the returned model ids. + pipeline_tag (`str`, *optional*): + A string pipeline tag to filter models on the Hub by, such as `summarization`. + emissions_thresholds (`Tuple`, *optional*): + A tuple of two ints or floats representing a minimum and maximum + carbon footprint to filter the resulting models with in grams. + sort (`Literal["last_modified"]` or `str`, *optional*): + The key with which to sort the resulting models. Possible values are "last_modified", "trending_score", + "created_at", "downloads" and "likes". + direction (`Literal[-1]` or `int`, *optional*): + Direction in which to sort. The value `-1` sorts by descending + order while all other values sort by ascending order. + limit (`int`, *optional*): + The limit on the number of models fetched. Leaving this option + to `None` fetches all models. + expand (`List[ExpandModelProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `full`, `cardData` or `fetch_config` are passed. + Possible values are `"author"`, `"cardData"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, `"resourceGroup"` and `"xetEnabled"`. + full (`bool`, *optional*): + Whether to fetch all model data, including the `last_modified`, + the `sha`, the files and the `tags`. This is set to `True` by + default when using a filter. + cardData (`bool`, *optional*): + Whether to grab the metadata for the model as well. Can contain + useful information such as carbon emissions, metrics, and + datasets trained on. + fetch_config (`bool`, *optional*): + Whether to fetch the model configs as well. This is not included + in `full` due to its size. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + + Returns: + `Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects. + + Example: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all models + >>> api.list_models() + + # List text classification models + >>> api.list_models(filter="text-classification") + + # List models from the KerasHub library + >>> api.list_models(filter="keras-hub") + + # List models served by Cohere + >>> api.list_models(inference_provider="cohere") + + # List models with "bert" in their name + >>> api.list_models(search="bert") + + # List models with "bert" in their name and pushed by google + >>> api.list_models(search="bert", author="google") + ``` + """ + if expand and (full or cardData or fetch_config): + raise ValueError("`expand` cannot be used if `full`, `cardData` or `fetch_config` are passed.") + + if emissions_thresholds is not None and cardData is None: + raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.") + + path = f"{self.endpoint}/api/models" + headers = self._build_hf_headers(token=token) + params: Dict[str, Any] = {} + + # Build the filter list + filter_list: List[str] = [] + if filter: + filter_list.extend([filter] if isinstance(filter, str) else filter) + if library: + filter_list.extend([library] if isinstance(library, str) else library) + if task: + filter_list.extend([task] if isinstance(task, str) else task) + if trained_dataset: + if isinstance(trained_dataset, str): + trained_dataset = [trained_dataset] + for dataset in trained_dataset: + if not dataset.startswith("dataset:"): + dataset = f"dataset:{dataset}" + filter_list.append(dataset) + if language: + filter_list.extend([language] if isinstance(language, str) else language) + if tags: + filter_list.extend([tags] if isinstance(tags, str) else tags) + if len(filter_list) > 0: + params["filter"] = filter_list + + # Handle other query params + if author: + params["author"] = author + if gated is not None: + params["gated"] = gated + if inference is not None: + params["inference"] = inference + if inference_provider is not None: + params["inference_provider"] = inference_provider + if pipeline_tag: + params["pipeline_tag"] = pipeline_tag + search_list = [] + if model_name: + search_list.append(model_name) + if search: + search_list.append(search) + if len(search_list) > 0: + params["search"] = search_list + if sort is not None: + params["sort"] = ( + "lastModified" + if sort == "last_modified" + else "trendingScore" + if sort == "trending_score" + else "createdAt" + if sort == "created_at" + else sort + ) + if direction is not None: + params["direction"] = direction + if limit is not None: + params["limit"] = limit + + # Request additional data + if full: + params["full"] = True + if fetch_config: + params["config"] = True + if cardData: + params["cardData"] = True + if expand: + params["expand"] = expand + + # `items` is a generator + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + model_info = ModelInfo(**item) + if emissions_thresholds is None or _is_emission_within_threshold(model_info, *emissions_thresholds): + yield model_info + + @validate_hf_hub_args + def list_datasets( + self, + *, + # Search-query parameter + filter: Union[str, Iterable[str], None] = None, + author: Optional[str] = None, + benchmark: Optional[Union[str, List[str]]] = None, + dataset_name: Optional[str] = None, + gated: Optional[bool] = None, + language_creators: Optional[Union[str, List[str]]] = None, + language: Optional[Union[str, List[str]]] = None, + multilinguality: Optional[Union[str, List[str]]] = None, + size_categories: Optional[Union[str, List[str]]] = None, + tags: Optional[Union[str, List[str]]] = None, + task_categories: Optional[Union[str, List[str]]] = None, + task_ids: Optional[Union[str, List[str]]] = None, + search: Optional[str] = None, + # Sorting and pagination parameters + sort: Optional[Union[Literal["last_modified"], str]] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + # Additional data to fetch + expand: Optional[List[ExpandDatasetProperty_T]] = None, + full: Optional[bool] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[DatasetInfo]: + """ + List datasets hosted on the Huggingface Hub, given some filters. + + Args: + filter (`str` or `Iterable[str]`, *optional*): + A string or list of string to filter datasets on the hub. + author (`str`, *optional*): + A string which identify the author of the returned datasets. + benchmark (`str` or `List`, *optional*): + A string or list of strings that can be used to identify datasets on + the Hub by their official benchmark. + dataset_name (`str`, *optional*): + A string or list of strings that can be used to identify datasets on + the Hub by its name, such as `SQAC` or `wikineural` + gated (`bool`, *optional*): + A boolean to filter datasets on the Hub that are gated or not. By default, all datasets are returned. + If `gated=True` is passed, only gated datasets are returned. + If `gated=False` is passed, only non-gated datasets are returned. + language_creators (`str` or `List`, *optional*): + A string or list of strings that can be used to identify datasets on + the Hub with how the data was curated, such as `crowdsourced` or + `machine_generated`. + language (`str` or `List`, *optional*): + A string or list of strings representing a two-character language to + filter datasets by on the Hub. + multilinguality (`str` or `List`, *optional*): + A string or list of strings representing a filter for datasets that + contain multiple languages. + size_categories (`str` or `List`, *optional*): + A string or list of strings that can be used to identify datasets on + the Hub by the size of the dataset such as `100K>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all datasets + >>> api.list_datasets() + + + # List only the text classification datasets + >>> api.list_datasets(filter="task_categories:text-classification") + + + # List only the datasets in russian for language modeling + >>> api.list_datasets( + ... filter=("language:ru", "task_ids:language-modeling") + ... ) + + # List FiftyOne datasets (identified by the tag "fiftyone" in dataset card) + >>> api.list_datasets(tags="fiftyone") + ``` + + Example usage with the `search` argument: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all datasets with "text" in their name + >>> api.list_datasets(search="text") + + # List all datasets with "text" in their name made by google + >>> api.list_datasets(search="text", author="google") + ``` + """ + if expand and full: + raise ValueError("`expand` cannot be used if `full` is passed.") + + path = f"{self.endpoint}/api/datasets" + headers = self._build_hf_headers(token=token) + params: Dict[str, Any] = {} + + # Build `filter` list + filter_list = [] + if filter is not None: + if isinstance(filter, str): + filter_list.append(filter) + else: + filter_list.extend(filter) + for key, value in ( + ("benchmark", benchmark), + ("language_creators", language_creators), + ("language", language), + ("multilinguality", multilinguality), + ("size_categories", size_categories), + ("task_categories", task_categories), + ("task_ids", task_ids), + ): + if value: + if isinstance(value, str): + value = [value] + for value_item in value: + if not value_item.startswith(f"{key}:"): + data = f"{key}:{value_item}" + filter_list.append(data) + if tags is not None: + filter_list.extend([tags] if isinstance(tags, str) else tags) + if len(filter_list) > 0: + params["filter"] = filter_list + + # Handle other query params + if author: + params["author"] = author + if gated is not None: + params["gated"] = gated + search_list = [] + if dataset_name: + search_list.append(dataset_name) + if search: + search_list.append(search) + if len(search_list) > 0: + params["search"] = search_list + if sort is not None: + params["sort"] = ( + "lastModified" + if sort == "last_modified" + else "trendingScore" + if sort == "trending_score" + else "createdAt" + if sort == "created_at" + else sort + ) + if direction is not None: + params["direction"] = direction + if limit is not None: + params["limit"] = limit + + # Request additional data + if expand: + params["expand"] = expand + if full: + params["full"] = True + + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + yield DatasetInfo(**item) + + @validate_hf_hub_args + def list_spaces( + self, + *, + # Search-query parameter + filter: Union[str, Iterable[str], None] = None, + author: Optional[str] = None, + search: Optional[str] = None, + datasets: Union[str, Iterable[str], None] = None, + models: Union[str, Iterable[str], None] = None, + linked: bool = False, + # Sorting and pagination parameters + sort: Union[Literal["last_modified"], str, None] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + # Additional data to fetch + expand: Optional[List[ExpandSpaceProperty_T]] = None, + full: Optional[bool] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[SpaceInfo]: + """ + List spaces hosted on the Huggingface Hub, given some filters. + + Args: + filter (`str` or `Iterable`, *optional*): + A string tag or list of tags that can be used to identify Spaces on the Hub. + author (`str`, *optional*): + A string which identify the author of the returned Spaces. + search (`str`, *optional*): + A string that will be contained in the returned Spaces. + datasets (`str` or `Iterable`, *optional*): + Whether to return Spaces that make use of a dataset. + The name of a specific dataset can be passed as a string. + models (`str` or `Iterable`, *optional*): + Whether to return Spaces that make use of a model. + The name of a specific model can be passed as a string. + linked (`bool`, *optional*): + Whether to return Spaces that make use of either a model or a dataset. + sort (`Literal["last_modified"]` or `str`, *optional*): + The key with which to sort the resulting models. Possible values are "last_modified", "trending_score", + "created_at" and "likes". + direction (`Literal[-1]` or `int`, *optional*): + Direction in which to sort. The value `-1` sorts by descending + order while all other values sort by ascending order. + limit (`int`, *optional*): + The limit on the number of Spaces fetched. Leaving this option + to `None` fetches all Spaces. + expand (`List[ExpandSpaceProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `full` is passed. + Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. + full (`bool`, *optional*): + Whether to fetch all Spaces data, including the `last_modified`, `siblings` + and `card_data` fields. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[SpaceInfo]`: an iterable of [`huggingface_hub.hf_api.SpaceInfo`] objects. + """ + if expand and full: + raise ValueError("`expand` cannot be used if `full` is passed.") + + path = f"{self.endpoint}/api/spaces" + headers = self._build_hf_headers(token=token) + params: Dict[str, Any] = {} + if filter is not None: + params["filter"] = filter + if author is not None: + params["author"] = author + if search is not None: + params["search"] = search + if sort is not None: + params["sort"] = ( + "lastModified" + if sort == "last_modified" + else "trendingScore" + if sort == "trending_score" + else "createdAt" + if sort == "created_at" + else sort + ) + if direction is not None: + params["direction"] = direction + if limit is not None: + params["limit"] = limit + if linked: + params["linked"] = True + if datasets is not None: + params["datasets"] = datasets + if models is not None: + params["models"] = models + + # Request additional data + if expand: + params["expand"] = expand + if full: + params["full"] = True + + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + yield SpaceInfo(**item) + + @validate_hf_hub_args + def unlike( + self, + repo_id: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Unlike a given repo on the Hub (e.g. remove from favorite list). + + To prevent spam usage, it is not possible to `like` a repository from a script. + + See also [`list_liked_repos`]. + + Args: + repo_id (`str`): + The repository to unlike. Example: `"user/my-cool-model"`. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if unliking a dataset or space, `None` or + `"model"` if unliking a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + + Example: + ```python + >>> from huggingface_hub import list_liked_repos, unlike + >>> "gpt2" in list_liked_repos().models # we assume you have already liked gpt2 + True + >>> unlike("gpt2") + >>> "gpt2" in list_liked_repos().models + False + ``` + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + response = get_session().delete( + url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def list_liked_repos( + self, + user: Optional[str] = None, + *, + token: Union[bool, str, None] = None, + ) -> UserLikes: + """ + List all public repos liked by a user on huggingface.co. + + This list is public so token is optional. If `user` is not passed, it defaults to + the logged in user. + + See also [`unlike`]. + + Args: + user (`str`, *optional*): + Name of the user for which you want to fetch the likes. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`UserLikes`]: object containing the user name and 3 lists of repo ids (1 for + models, 1 for datasets and 1 for Spaces). + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `user` is not passed and no token found (either from argument or from machine). + + Example: + ```python + >>> from huggingface_hub import list_liked_repos + + >>> likes = list_liked_repos("julien-c") + + >>> likes.user + "julien-c" + + >>> likes.models + ["osanseviero/streamlit_1.15", "Xhaheen/ChatGPT_HF", ...] + ``` + """ + # User is either provided explicitly or retrieved from current token. + if user is None: + me = self.whoami(token=token) + if me["type"] == "user": + user = me["name"] + else: + raise ValueError( + "Cannot list liked repos. You must provide a 'user' as input or be logged in as a user." + ) + + path = f"{self.endpoint}/api/users/{user}/likes" + headers = self._build_hf_headers(token=token) + + likes = list(paginate(path, params={}, headers=headers)) + # Looping over a list of items similar to: + # { + # 'createdAt': '2021-09-09T21:53:27.000Z', + # 'repo': { + # 'name': 'PaddlePaddle/PaddleOCR', + # 'type': 'space' + # } + # } + # Let's loop 3 times over the received list. Less efficient but more straightforward to read. + return UserLikes( + user=user, + total=len(likes), + models=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "model"], + datasets=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "dataset"], + spaces=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "space"], + ) + + @validate_hf_hub_args + def list_repo_likers( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[User]: + """ + List all users who liked a given repo on the hugging Face Hub. + + See also [`list_liked_repos`]. + + Args: + repo_id (`str`): + The repository to retrieve . Example: `"user/my-cool-model"`. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: + `Iterable[User]`: an iterable of [`huggingface_hub.hf_api.User`] objects. + """ + + # Construct the API endpoint + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/likers" + for liker in paginate(path, params={}, headers=self._build_hf_headers(token=token)): + yield User(username=liker["user"], fullname=liker["fullname"], avatar_url=liker["avatarUrl"]) + + @validate_hf_hub_args + def model_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + securityStatus: Optional[bool] = None, + files_metadata: bool = False, + expand: Optional[List[ExpandModelProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> ModelInfo: + """ + Get info on one specific model on huggingface.co + + Model can be private if you pass an acceptable token or are logged in. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the model repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + securityStatus (`bool`, *optional*): + Whether to retrieve the security status from the model + repository as well. The security status will be returned in the `security_repo_status` field. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + expand (`List[ExpandModelProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `securityStatus` or `files_metadata` are passed. + Possible values are `"author"`, `"baseModels"`, `"cardData"`, `"childrenModelCount"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`huggingface_hub.hf_api.ModelInfo`]: The model repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + if expand and (securityStatus or files_metadata): + raise ValueError("`expand` cannot be used if `securityStatus` or `files_metadata` are set.") + + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/models/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}") + ) + params: Dict = {} + if securityStatus: + params["securityStatus"] = True + if files_metadata: + params["blobs"] = True + if expand: + params["expand"] = expand + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return ModelInfo(**data) + + @validate_hf_hub_args + def dataset_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + expand: Optional[List[ExpandDatasetProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> DatasetInfo: + """ + Get info on one specific dataset on huggingface.co. + + Dataset can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the dataset repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + expand (`List[ExpandDatasetProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `files_metadata` is passed. + Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"`, `"tags"`, `"trendingScore"`,`"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`hf_api.DatasetInfo`]: The dataset repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + if expand and files_metadata: + raise ValueError("`expand` cannot be used if `files_metadata` is set.") + + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/datasets/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") + ) + params: Dict = {} + if files_metadata: + params["blobs"] = True + if expand: + params["expand"] = expand + + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return DatasetInfo(**data) + + @validate_hf_hub_args + def space_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + expand: Optional[List[ExpandSpaceProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> SpaceInfo: + """ + Get info on one specific Space on huggingface.co. + + Space can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the space repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + expand (`List[ExpandSpaceProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `full` is passed. + Possible values are `"author"`, `"cardData"`, `"createdAt"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`~hf_api.SpaceInfo`]: The space repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + if expand and files_metadata: + raise ValueError("`expand` cannot be used if `files_metadata` is set.") + + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/spaces/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") + ) + params: Dict = {} + if files_metadata: + params["blobs"] = True + if expand: + params["expand"] = expand + + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return SpaceInfo(**data) + + @validate_hf_hub_args + def repo_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + expand: Optional[Union[ExpandModelProperty_T, ExpandDatasetProperty_T, ExpandSpaceProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> Union[ModelInfo, DatasetInfo, SpaceInfo]: + """ + Get the info object for a given repo of a given type. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the repository from which to get the + information. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + expand (`ExpandModelProperty_T` or `ExpandDatasetProperty_T` or `ExpandSpaceProperty_T`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `files_metadata` is passed. + For an exhaustive list of available properties, check out [`model_info`], [`dataset_info`] or [`space_info`]. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Union[SpaceInfo, DatasetInfo, ModelInfo]`: The repository information, as a + [`huggingface_hub.hf_api.DatasetInfo`], [`huggingface_hub.hf_api.ModelInfo`] + or [`huggingface_hub.hf_api.SpaceInfo`] object. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + if repo_type is None or repo_type == "model": + method = self.model_info + elif repo_type == "dataset": + method = self.dataset_info # type: ignore + elif repo_type == "space": + method = self.space_info # type: ignore + else: + raise ValueError("Unsupported repo type.") + return method( + repo_id, + revision=revision, + token=token, + timeout=timeout, + expand=expand, # type: ignore[arg-type] + files_metadata=files_metadata, + ) + + @validate_hf_hub_args + def repo_exists( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> bool: + """ + Checks if a repository exists on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + True if the repository exists, False otherwise. + + Examples: + ```py + >>> from huggingface_hub import repo_exists + >>> repo_exists("google/gemma-7b") + True + >>> repo_exists("google/not-a-repo") + False + ``` + """ + try: + self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) + return True + except GatedRepoError: + return True # we don't have access but it exists + except RepositoryNotFoundError: + return False + + @validate_hf_hub_args + def revision_exists( + self, + repo_id: str, + revision: str, + *, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> bool: + """ + Checks if a specific revision exists on a repo on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`): + The revision of the repository to check. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + True if the repository and the revision exists, False otherwise. + + Examples: + ```py + >>> from huggingface_hub import revision_exists + >>> revision_exists("google/gemma-7b", "float16") + True + >>> revision_exists("google/gemma-7b", "not-a-revision") + False + ``` + """ + try: + self.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) + return True + except RevisionNotFoundError: + return False + except RepositoryNotFoundError: + return False + + @validate_hf_hub_args + def file_exists( + self, + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> bool: + """ + Checks if a file exists in a repository on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + filename (`str`): + The name of the file to check, for example: + `"config.json"` + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + True if the file exists, False otherwise. + + Examples: + ```py + >>> from huggingface_hub import file_exists + >>> file_exists("bigcode/starcoder", "config.json") + True + >>> file_exists("bigcode/starcoder", "not-a-file") + False + >>> file_exists("bigcode/not-a-repo", "config.json") + False + ``` + """ + url = hf_hub_url( + repo_id=repo_id, repo_type=repo_type, revision=revision, filename=filename, endpoint=self.endpoint + ) + try: + if token is None: + token = self.token + get_hf_file_metadata(url, token=token) + return True + except GatedRepoError: # raise specifically on gated repo + raise + except (RepositoryNotFoundError, EntryNotFoundError, RevisionNotFoundError): + return False + + @validate_hf_hub_args + def list_repo_files( + self, + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> List[str]: + """ + Get the list of files in a given repo. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + revision (`str`, *optional*): + The revision of the repository from which to get the information. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to + a model. Default is `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[str]`: the list of files in a given repository. + """ + return [ + f.rfilename + for f in self.list_repo_tree( + repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type, token=token + ) + if isinstance(f, RepoFile) + ] + + @validate_hf_hub_args + def list_repo_tree( + self, + repo_id: str, + path_in_repo: Optional[str] = None, + *, + recursive: bool = False, + expand: bool = False, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> Iterable[Union[RepoFile, RepoFolder]]: + """ + List a repo tree's files and folders and get information about them. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + path_in_repo (`str`, *optional*): + Relative path of the tree (folder) in the repo, for example: + `"checkpoints/1fec34a/results"`. Will default to the root tree (folder) of the repository. + recursive (`bool`, *optional*, defaults to `False`): + Whether to list tree's files and folders recursively. + expand (`bool`, *optional*, defaults to `False`): + Whether to fetch more information about the tree's files and folders (e.g. last commit and files' security scan results). This + operation is more expensive for the server so only 50 results are returned per page (instead of 1000). + As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it + takes to get the results. + revision (`str`, *optional*): + The revision of the repository from which to get the tree. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the tree (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[Union[RepoFile, RepoFolder]]`: + The information about the tree's files and folders, as an iterable of [`RepoFile`] and [`RepoFolder`] objects. The order of the files and folders is + not guaranteed. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + [`~utils.EntryNotFoundError`]: + If the tree (folder) does not exist (error 404) on the repo. + + Examples: + + Get information about a repo's tree. + ```py + >>> from huggingface_hub import list_repo_tree + >>> repo_tree = list_repo_tree("lysandre/arxiv-nlp") + >>> repo_tree + + >>> list(repo_tree) + [ + RepoFile(path='.gitattributes', size=391, blob_id='ae8c63daedbd4206d7d40126955d4e6ab1c80f8f', lfs=None, last_commit=None, security=None), + RepoFile(path='README.md', size=391, blob_id='43bd404b159de6fba7c2f4d3264347668d43af25', lfs=None, last_commit=None, security=None), + RepoFile(path='config.json', size=554, blob_id='2f9618c3a19b9a61add74f70bfb121335aeef666', lfs=None, last_commit=None, security=None), + RepoFile( + path='flax_model.msgpack', size=497764107, blob_id='8095a62ccb4d806da7666fcda07467e2d150218e', + lfs={'size': 497764107, 'sha256': 'd88b0d6a6ff9c3f8151f9d3228f57092aaea997f09af009eefd7373a77b5abb9', 'pointer_size': 134}, last_commit=None, security=None + ), + RepoFile(path='merges.txt', size=456318, blob_id='226b0752cac7789c48f0cb3ec53eda48b7be36cc', lfs=None, last_commit=None, security=None), + RepoFile( + path='pytorch_model.bin', size=548123560, blob_id='64eaa9c526867e404b68f2c5d66fd78e27026523', + lfs={'size': 548123560, 'sha256': '9be78edb5b928eba33aa88f431551348f7466ba9f5ef3daf1d552398722a5436', 'pointer_size': 134}, last_commit=None, security=None + ), + RepoFile(path='vocab.json', size=898669, blob_id='b00361fece0387ca34b4b8b8539ed830d644dbeb', lfs=None, last_commit=None, security=None)] + ] + ``` + + Get even more information about a repo's tree (last commit and files' security scan results) + ```py + >>> from huggingface_hub import list_repo_tree + >>> repo_tree = list_repo_tree("prompthero/openjourney-v4", expand=True) + >>> list(repo_tree) + [ + RepoFolder( + path='feature_extractor', + tree_id='aa536c4ea18073388b5b0bc791057a7296a00398', + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) + } + ), + RepoFolder( + path='safety_checker', + tree_id='65aef9d787e5557373fdf714d6c34d4fcdd70440', + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) + } + ), + RepoFile( + path='model_index.json', + size=582, + blob_id='d3d7c1e8c3e78eeb1640b8e2041ee256e24c9ee1', + lfs=None, + last_commit={ + 'oid': 'b195ed2d503f3eb29637050a886d77bd81d35f0e', + 'title': 'Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`. (#54)', + 'date': datetime.datetime(2023, 5, 15, 21, 41, 59, tzinfo=datetime.timezone.utc) + }, + security={ + 'safe': True, + 'av_scan': {'virusFound': False, 'virusNames': None}, + 'pickle_import_scan': None + } + ) + ... + ] + ``` + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + encoded_path_in_repo = "/" + quote(path_in_repo, safe="") if path_in_repo else "" + tree_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tree/{revision}{encoded_path_in_repo}" + for path_info in paginate(path=tree_url, headers=headers, params={"recursive": recursive, "expand": expand}): + yield (RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info)) + + @validate_hf_hub_args + def list_repo_refs( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + include_pull_requests: bool = False, + token: Union[str, bool, None] = None, + ) -> GitRefs: + """ + Get the list of refs of a given repo (both tags and branches). + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing refs from a dataset or a Space, + `None` or `"model"` if listing from a model. Default is `None`. + include_pull_requests (`bool`, *optional*): + Whether to include refs from pull requests in the list. Defaults to `False`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> api.list_repo_refs("gpt2") + GitRefs(branches=[GitRefInfo(name='main', ref='refs/heads/main', target_commit='e7da7f221d5bf496a48136c0cd264e630fe9fcc8')], converts=[], tags=[]) + + >>> api.list_repo_refs("bigcode/the-stack", repo_type='dataset') + GitRefs( + branches=[ + GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), + GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') + ], + converts=[], + tags=[ + GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') + ] + ) + ``` + + Returns: + [`GitRefs`]: object containing all information about branches and tags for a + repo on the Hub. + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + response = get_session().get( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs", + headers=self._build_hf_headers(token=token), + params={"include_prs": 1} if include_pull_requests else {}, + ) + hf_raise_for_status(response) + data = response.json() + + def _format_as_git_ref_info(item: Dict) -> GitRefInfo: + return GitRefInfo(name=item["name"], ref=item["ref"], target_commit=item["targetCommit"]) + + return GitRefs( + branches=[_format_as_git_ref_info(item) for item in data["branches"]], + converts=[_format_as_git_ref_info(item) for item in data["converts"]], + tags=[_format_as_git_ref_info(item) for item in data["tags"]], + pull_requests=[_format_as_git_ref_info(item) for item in data["pullRequests"]] + if include_pull_requests + else None, + ) + + @validate_hf_hub_args + def list_repo_commits( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + revision: Optional[str] = None, + formatted: bool = False, + ) -> List[GitCommitInfo]: + """ + Get the list of commits of a given revision for a repo on the Hub. + + Commits are sorted by date (last commit first). + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if + listing from a model. Default is `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + formatted (`bool`): + Whether to return the HTML-formatted title and description of the commits. Defaults to False. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + + # Commits are sorted by date (last commit first) + >>> initial_commit = api.list_repo_commits("gpt2")[-1] + + # Initial commit is always a system commit containing the `.gitattributes` file. + >>> initial_commit + GitCommitInfo( + commit_id='9b865efde13a30c13e0a33e536cf3e4a5a9d71d8', + authors=['system'], + created_at=datetime.datetime(2019, 2, 18, 10, 36, 15, tzinfo=datetime.timezone.utc), + title='initial commit', + message='', + formatted_title=None, + formatted_message=None + ) + + # Create an empty branch by deriving from initial commit + >>> api.create_branch("gpt2", "new_empty_branch", revision=initial_commit.commit_id) + ``` + + Returns: + List[[`GitCommitInfo`]]: list of objects containing information about the commits for a repo on the Hub. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + + # Paginate over results and return the list of commits. + return [ + GitCommitInfo( + commit_id=item["id"], + authors=[author["user"] for author in item["authors"]], + created_at=parse_datetime(item["date"]), + title=item["title"], + message=item["message"], + formatted_title=item.get("formatted", {}).get("title"), + formatted_message=item.get("formatted", {}).get("message"), + ) + for item in paginate( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/commits/{revision}", + headers=self._build_hf_headers(token=token), + params={"expand[]": "formatted"} if formatted else {}, + ) + ] + + @validate_hf_hub_args + def get_paths_info( + self, + repo_id: str, + paths: Union[List[str], str], + *, + expand: bool = False, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> List[Union[RepoFile, RepoFolder]]: + """ + Get information about a repo's paths. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + paths (`Union[List[str], str]`, *optional*): + The paths to get information about. If a path do not exist, it is ignored without raising + an exception. + expand (`bool`, *optional*, defaults to `False`): + Whether to fetch more information about the paths (e.g. last commit and files' security scan results). This + operation is more expensive for the server so only 50 results are returned per page (instead of 1000). + As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it + takes to get the results. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the information (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[Union[RepoFile, RepoFolder]]`: + The information about the paths, as a list of [`RepoFile`] and [`RepoFolder`] objects. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + + Example: + ```py + >>> from huggingface_hub import get_paths_info + >>> paths_info = get_paths_info("allenai/c4", ["README.md", "en"], repo_type="dataset") + >>> paths_info + [ + RepoFile(path='README.md', size=2379, blob_id='f84cb4c97182890fc1dbdeaf1a6a468fd27b4fff', lfs=None, last_commit=None, security=None), + RepoFolder(path='en', tree_id='dc943c4c40f53d02b31ced1defa7e5f438d5862e', last_commit=None) + ] + ``` + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + response = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}", + data={ + "paths": paths if isinstance(paths, list) else [paths], + "expand": expand, + }, + headers=headers, + ) + hf_raise_for_status(response) + paths_info = response.json() + return [ + RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info) + for path_info in paths_info + ] + + @validate_hf_hub_args + def super_squash_history( + self, + repo_id: str, + *, + branch: Optional[str] = None, + commit_message: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> None: + """Squash commit history on a branch for a repo on the Hub. + + Squashing the repo history is useful when you know you'll make hundreds of commits and you don't want to + clutter the history. Squashing commits can only be performed from the head of a branch. + + + + Once squashed, the commit history cannot be retrieved. This is a non-revertible operation. + + + + + + Once the history of a branch has been squashed, it is not possible to merge it back into another branch since + their history will have diverged. + + + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + branch (`str`, *optional*): + The branch to squash. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The commit message to use for the squashed commit. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if + listing from a model. Default is `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If the branch to squash cannot be found. + [`~utils.BadRequestError`]: + If invalid reference for a branch. You cannot squash history on tags. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + + # Create repo + >>> repo_id = api.create_repo("test-squash").repo_id + + # Make a lot of commits. + >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"content") + >>> api.upload_file(repo_id=repo_id, path_in_repo="lfs.bin", path_or_fileobj=b"content") + >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"another_content") + + # Squash history + >>> api.super_squash_history(repo_id=repo_id) + ``` + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + if branch is None: + branch = constants.DEFAULT_REVISION + + # Prepare request + url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/super-squash/{quote(branch, safe='')}" + headers = self._build_hf_headers(token=token) + commit_message = commit_message or f"Super-squash branch '{branch}' using huggingface_hub" + + # Super-squash + response = get_session().post(url=url, headers=headers, json={"message": commit_message}) + hf_raise_for_status(response) + + @validate_hf_hub_args + def list_lfs_files( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[LFSFileInfo]: + """ + List all LFS files in a repo on the Hub. + + This is primarily useful to count how much storage a repo is using and to eventually clean up large files + with [`permanently_delete_lfs_files`]. Note that this would be a permanent action that will affect all commits + referencing this deleted files and that cannot be undone. + + Args: + repo_id (`str`): + The repository for which you are listing LFS files. + repo_type (`str`, *optional*): + Type of repository. Set to `"dataset"` or `"space"` if listing from a dataset or space, `None` or + `"model"` if listing from a model. Default is `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[LFSFileInfo]`: An iterator of [`LFSFileInfo`] objects. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> lfs_files = api.list_lfs_files("username/my-cool-repo") + + # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. + # e.g. select only LFS files in the "checkpoints" folder + >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) + + # Permanently delete LFS files + >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) + ``` + """ + # Prepare request + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/lfs-files" + headers = self._build_hf_headers(token=token) + + # Paginate over LFS items + for item in paginate(url, params={}, headers=headers): + yield LFSFileInfo(**item) + + @validate_hf_hub_args + def permanently_delete_lfs_files( + self, + repo_id: str, + lfs_files: Iterable[LFSFileInfo], + *, + rewrite_history: bool = True, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Permanently delete LFS files from a repo on the Hub. + + + + This is a permanent action that will affect all commits referencing the deleted files and might corrupt your + repository. This is a non-revertible operation. Use it only if you know what you are doing. + + + + Args: + repo_id (`str`): + The repository for which you are listing LFS files. + lfs_files (`Iterable[LFSFileInfo]`): + An iterable of [`LFSFileInfo`] items to permanently delete from the repo. Use [`list_lfs_files`] to list + all LFS files from a repo. + rewrite_history (`bool`, *optional*, default to `True`): + Whether to rewrite repository history to remove file pointers referencing the deleted LFS files (recommended). + repo_type (`str`, *optional*): + Type of repository. Set to `"dataset"` or `"space"` if listing from a dataset or space, `None` or + `"model"` if listing from a model. Default is `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> lfs_files = api.list_lfs_files("username/my-cool-repo") + + # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. + # e.g. select only LFS files in the "checkpoints" folder + >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) + + # Permanently delete LFS files + >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) + ``` + """ + # Prepare request + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/lfs-files/batch" + headers = self._build_hf_headers(token=token) + + # Delete LFS items by batches of 1000 + for batch in chunk_iterable(lfs_files, 1000): + shas = [item.file_oid for item in batch] + if len(shas) == 0: + return + payload = { + "deletions": { + "sha": shas, + "rewriteHistory": rewrite_history, + } + } + response = get_session().post(url, headers=headers, json=payload) + hf_raise_for_status(response) + + @validate_hf_hub_args + def create_repo( + self, + repo_id: str, + *, + token: Union[str, bool, None] = None, + private: Optional[bool] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + resource_group_id: Optional[str] = None, + space_sdk: Optional[str] = None, + space_hardware: Optional[SpaceHardware] = None, + space_storage: Optional[SpaceStorage] = None, + space_sleep_time: Optional[int] = None, + space_secrets: Optional[List[Dict[str, str]]] = None, + space_variables: Optional[List[Dict[str, str]]] = None, + ) -> RepoUrl: + """Create an empty repo on the HuggingFace Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo already exists. + resource_group_id (`str`, *optional*): + Resource group in which to create the repo. Resource groups is only available for Enterprise Hub organizations and + allow to define which members of the organization can access the resource. The ID of a resource group + can be found in the URL of the resource's page on the Hub (e.g. `"66670e5163145ca562cb1988"`). + To learn more about resource groups, see https://huggingface.co/docs/hub/en/security-resource-groups. + space_sdk (`str`, *optional*): + Choice of SDK to use if repo_type is "space". Can be "streamlit", "gradio", "docker", or "static". + space_hardware (`SpaceHardware` or `str`, *optional*): + Choice of Hardware if repo_type is "space". See [`SpaceHardware`] for a complete list. + space_storage (`SpaceStorage` or `str`, *optional*): + Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. + space_sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + space_secrets (`List[Dict[str, str]]`, *optional*): + A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + space_variables (`List[Dict[str, str]]`, *optional*): + A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. + + Returns: + [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing + attributes like `endpoint`, `repo_type` and `repo_id`. + """ + organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) + + path = f"{self.endpoint}/api/repos/create" + + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + + json: Dict[str, Any] = {"name": name, "organization": organization} + if private is not None: + json["private"] = private + if repo_type is not None: + json["type"] = repo_type + if repo_type == "space": + if space_sdk is None: + raise ValueError( + "No space_sdk provided. `create_repo` expects space_sdk to be one" + f" of {constants.SPACES_SDK_TYPES} when repo_type is 'space'`" + ) + if space_sdk not in constants.SPACES_SDK_TYPES: + raise ValueError(f"Invalid space_sdk. Please choose one of {constants.SPACES_SDK_TYPES}.") + json["sdk"] = space_sdk + + if space_sdk is not None and repo_type != "space": + warnings.warn("Ignoring provided space_sdk because repo_type is not 'space'.") + + function_args = [ + "space_hardware", + "space_storage", + "space_sleep_time", + "space_secrets", + "space_variables", + ] + json_keys = ["hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] + values = [space_hardware, space_storage, space_sleep_time, space_secrets, space_variables] + + if repo_type == "space": + json.update({k: v for k, v in zip(json_keys, values) if v is not None}) + else: + provided_space_args = [key for key, value in zip(function_args, values) if value is not None] + + if provided_space_args: + warnings.warn(f"Ignoring provided {', '.join(provided_space_args)} because repo_type is not 'space'.") + + if getattr(self, "_lfsmultipartthresh", None): + # Testing purposes only. + # See https://github.com/huggingface/huggingface_hub/pull/733/files#r820604472 + json["lfsmultipartthresh"] = self._lfsmultipartthresh # type: ignore + + if resource_group_id is not None: + json["resourceGroupId"] = resource_group_id + + headers = self._build_hf_headers(token=token) + while True: + r = get_session().post(path, headers=headers, json=json) + if r.status_code == 409 and "Cannot create repo: another conflicting operation is in progress" in r.text: + # Since https://github.com/huggingface/moon-landing/pull/7272 (private repo), it is not possible to + # concurrently create repos on the Hub for a same user. This is rarely an issue, except when running + # tests. To avoid any inconvenience, we retry to create the repo for this specific error. + # NOTE: This could have being fixed directly in the tests but adding it here should fixed CIs for all + # dependent libraries. + # NOTE: If a fix is implemented server-side, we should be able to remove this retry mechanism. + logger.debug("Create repo failed due to a concurrency issue. Retrying...") + continue + break + + try: + hf_raise_for_status(r) + except HTTPError as err: + if exist_ok and err.response.status_code == 409: + # Repo already exists and `exist_ok=True` + pass + elif exist_ok and err.response.status_code == 403: + # No write permission on the namespace but repo might already exist + try: + self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) + if repo_type is None or repo_type == constants.REPO_TYPE_MODEL: + return RepoUrl(f"{self.endpoint}/{repo_id}") + return RepoUrl(f"{self.endpoint}/{repo_type}/{repo_id}") + except HfHubHTTPError: + raise err + else: + raise + + d = r.json() + return RepoUrl(d["url"], endpoint=self.endpoint) + + @validate_hf_hub_args + def delete_repo( + self, + repo_id: str, + *, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + missing_ok: bool = False, + ) -> None: + """ + Delete a repo from the HuggingFace Hub. CAUTION: this is irreversible. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. + missing_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo does not exist. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to delete from cannot be found and `missing_ok` is set to False (default). + """ + organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) + + path = f"{self.endpoint}/api/repos/delete" + + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + + json = {"name": name, "organization": organization} + if repo_type is not None: + json["type"] = repo_type + + headers = self._build_hf_headers(token=token) + r = get_session().delete(path, headers=headers, json=json) + try: + hf_raise_for_status(r) + except RepositoryNotFoundError: + if not missing_ok: + raise + + @_deprecate_method(version="0.32", message="Please use `update_repo_settings` instead.") + @validate_hf_hub_args + def update_repo_visibility( + self, + repo_id: str, + private: bool = False, + *, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + ) -> Dict[str, bool]: + """Update the visibility setting of a repository. + + Deprecated. Use `update_repo_settings` instead. + + Args: + repo_id (`str`, *optional*): + A namespace (user or an organization) and a repo name separated by a `/`. + private (`bool`, *optional*, defaults to `False`): + Whether the repository should be private. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: + The HTTP response in json. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL # default repo type + + r = get_session().put( + url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", + headers=self._build_hf_headers(token=token), + json={"private": private}, + ) + hf_raise_for_status(r) + return r.json() + + @validate_hf_hub_args + def update_repo_settings( + self, + repo_id: str, + *, + gated: Optional[Literal["auto", "manual", False]] = None, + private: Optional[bool] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + xet_enabled: Optional[bool] = None, + ) -> None: + """ + Update the settings of a repository, including gated access and visibility. + + To give more control over how repos are used, the Hub allows repo authors to enable + access requests for their repos, and also to set the visibility of the repo to private. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a /. + gated (`Literal["auto", "manual", False]`, *optional*): + The gated status for the repository. If set to `None` (default), the `gated` setting of the repository won't be updated. + * "auto": The repository is gated, and access requests are automatically approved or denied based on predefined criteria. + * "manual": The repository is gated, and access requests require manual approval. + * False : The repository is not gated, and anyone can access it. + private (`bool`, *optional*): + Whether the repository should be private. + token (`Union[str, bool, None]`, *optional*): + A valid user access token (string). Defaults to the locally saved token, + which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass False. + repo_type (`str`, *optional*): + The type of the repository to update settings from (`"model"`, `"dataset"` or `"space"`). + Defaults to `"model"`. + xet_enabled (`bool`, *optional*): + Whether the repository should be enabled for Xet Storage. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If gated is not one of "auto", "manual", or False. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If repo_type is not one of the values in constants.REPO_TYPES. + [`~utils.HfHubHTTPError`]: + If the request to the Hugging Face Hub API fails. + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + """ + + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL # default repo type + + # Prepare the JSON payload for the PUT request + payload: Dict = {} + + if gated is not None: + if gated not in ["auto", "manual", False]: + raise ValueError(f"Invalid gated status, must be one of 'auto', 'manual', or False. Got '{gated}'.") + payload["gated"] = gated + + if private is not None: + payload["private"] = private + + if xet_enabled is not None: + payload["xetEnabled"] = xet_enabled + + if len(payload) == 0: + raise ValueError("At least one setting must be updated.") + + # Build headers + headers = self._build_hf_headers(token=token) + + r = get_session().put( + url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", + headers=headers, + json=payload, + ) + hf_raise_for_status(r) + + def move_repo( + self, + from_id: str, + to_id: str, + *, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ): + """ + Moving a repository from namespace1/repo_name1 to namespace2/repo_name2 + + Note there are certain limitations. For more information about moving + repositories, please see + https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo. + + Args: + from_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. Original repository identifier. + to_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. Final repository identifier. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if len(from_id.split("/")) != 2: + raise ValueError(f"Invalid repo_id: {from_id}. It should have a namespace (:namespace:/:repo_name:)") + + if len(to_id.split("/")) != 2: + raise ValueError(f"Invalid repo_id: {to_id}. It should have a namespace (:namespace:/:repo_name:)") + + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL # Hub won't accept `None`. + + json = {"fromRepo": from_id, "toRepo": to_id, "type": repo_type} + + path = f"{self.endpoint}/api/repos/move" + headers = self._build_hf_headers(token=token) + r = get_session().post(path, headers=headers, json=json) + try: + hf_raise_for_status(r) + except HfHubHTTPError as e: + e.append_to_message( + "\nFor additional documentation please see" + " https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo." + ) + raise + + @overload + def create_commit( # type: ignore + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: ... + + @overload + def create_commit( + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: ... + + @validate_hf_hub_args + @future_compatible + def create_commit( + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Creates a commit in the given repo, deleting & uploading files as needed. + + + + The input list of `CommitOperation` will be mutated during the commit process. Do not reuse the same objects + for multiple commits. + + + + + + `create_commit` assumes that the repo already exists on the Hub. If you get a + Client error 404, please make sure you are authenticated and that `repo_id` and + `repo_type` are set correctly. If repo does not exist, create it first using + [`~hf_api.create_repo`]. + + + + + + `create_commit` is limited to 25k LFS files and a 1GB payload for regular files. + + + + Args: + repo_id (`str`): + The repository in which the commit will be created, for example: + `"username/custom_transformers"` + + operations (`Iterable` of [`~hf_api.CommitOperation`]): + An iterable of operations to include in the commit, either: + + - [`~hf_api.CommitOperationAdd`] to upload a file + - [`~hf_api.CommitOperationDelete`] to delete a file + - [`~hf_api.CommitOperationCopy`] to copy a file + + Operation objects will be mutated to include information relative to the upload. Do not reuse the + same objects for multiple commits. + + commit_message (`str`): + The summary (first line) of the commit that will be created. + + commit_description (`str`, *optional*): + The description of the commit that will be created + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + + num_threads (`int`, *optional*): + Number of concurrent threads for uploading files. Defaults to 5. + Setting it to 2 means at most 2 files will be uploaded concurrently. + + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. + Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, + the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` + is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` + ensures the repo has not changed before committing the changes, and can be especially useful + if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If commit message is empty. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If parent commit is not a valid commit OID. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If a README.md file with an invalid metadata section is committed. In this case, the commit will fail + early, before trying to upload any file. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `create_pr` is `True` and revision is neither `None` nor `"main"`. + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + """ + if parent_commit is not None and not constants.REGEX_COMMIT_OID.fullmatch(parent_commit): + raise ValueError( + f"`parent_commit` is not a valid commit OID. It must match the following regex: {constants.REGEX_COMMIT_OID}" + ) + + if commit_message is None or len(commit_message) == 0: + raise ValueError("`commit_message` can't be empty, please pass a value.") + + commit_description = commit_description if commit_description is not None else "" + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + unquoted_revision = revision or constants.DEFAULT_REVISION + revision = quote(unquoted_revision, safe="") + create_pr = create_pr if create_pr is not None else False + + headers = self._build_hf_headers(token=token) + + operations = list(operations) + additions = [op for op in operations if isinstance(op, CommitOperationAdd)] + copies = [op for op in operations if isinstance(op, CommitOperationCopy)] + nb_additions = len(additions) + nb_copies = len(copies) + nb_deletions = len(operations) - nb_additions - nb_copies + + for addition in additions: + if addition._is_committed: + raise ValueError( + f"CommitOperationAdd {addition} has already being committed and cannot be reused. Please create a" + " new CommitOperationAdd object if you want to create a new commit." + ) + + if repo_type != "dataset": + for addition in additions: + if addition.path_in_repo.endswith((".arrow", ".parquet")): + warnings.warn( + f"It seems that you are about to commit a data file ({addition.path_in_repo}) to a {repo_type}" + " repository. You are sure this is intended? If you are trying to upload a dataset, please" + " set `repo_type='dataset'` or `--repo-type=dataset` in a CLI." + ) + + logger.debug( + f"About to commit to the hub: {len(additions)} addition(s), {len(copies)} copie(s) and" + f" {nb_deletions} deletion(s)." + ) + + # If updating a README.md file, make sure the metadata format is valid + # It's better to fail early than to fail after all the files have been uploaded. + for addition in additions: + if addition.path_in_repo == "README.md": + with addition.as_file() as file: + content = file.read().decode() + self._validate_yaml(content, repo_type=repo_type, token=token) + # Skip other additions after `README.md` has been processed + break + + # If updating twice the same file or update then delete a file in a single commit + _warn_on_overwriting_operations(operations) + + self.preupload_lfs_files( + repo_id=repo_id, + additions=additions, + token=token, + repo_type=repo_type, + revision=unquoted_revision, # first-class methods take unquoted revision + create_pr=create_pr, + num_threads=num_threads, + free_memory=False, # do not remove `CommitOperationAdd.path_or_fileobj` on LFS files for "normal" users + ) + + files_to_copy = _fetch_files_to_copy( + copies=copies, + repo_type=repo_type, + repo_id=repo_id, + headers=headers, + revision=unquoted_revision, + endpoint=self.endpoint, + ) + # Remove no-op operations (files that have not changed) + operations_without_no_op = [] + for operation in operations: + if ( + isinstance(operation, CommitOperationAdd) + and operation._remote_oid is not None + and operation._remote_oid == operation._local_oid + ): + # File already exists on the Hub and has not changed: we can skip it. + logger.debug(f"Skipping upload for '{operation.path_in_repo}' as the file has not changed.") + continue + if ( + isinstance(operation, CommitOperationCopy) + and operation._dest_oid is not None + and operation._dest_oid == operation._src_oid + ): + # Source and destination files are identical - skip + logger.debug( + f"Skipping copy for '{operation.src_path_in_repo}' -> '{operation.path_in_repo}' as the content of the source file is the same as the destination file." + ) + continue + operations_without_no_op.append(operation) + if len(operations) != len(operations_without_no_op): + logger.info( + f"Removing {len(operations) - len(operations_without_no_op)} file(s) from commit that have not changed." + ) + + # Return early if empty commit + if len(operations_without_no_op) == 0: + logger.warning("No files have been modified since last commit. Skipping to prevent empty commit.") + + # Get latest commit info + try: + info = self.repo_info(repo_id=repo_id, repo_type=repo_type, revision=unquoted_revision, token=token) + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + + # Return commit info based on latest commit + url_prefix = self.endpoint + if repo_type is not None and repo_type != constants.REPO_TYPE_MODEL: + url_prefix = f"{url_prefix}/{repo_type}s" + return CommitInfo( + commit_url=f"{url_prefix}/{repo_id}/commit/{info.sha}", + commit_message=commit_message, + commit_description=commit_description, + oid=info.sha, # type: ignore[arg-type] + ) + + commit_payload = _prepare_commit_payload( + operations=operations, + files_to_copy=files_to_copy, + commit_message=commit_message, + commit_description=commit_description, + parent_commit=parent_commit, + ) + commit_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/commit/{revision}" + + def _payload_as_ndjson() -> Iterable[bytes]: + for item in commit_payload: + yield json.dumps(item).encode() + yield b"\n" + + headers = { + # See https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073 + "Content-Type": "application/x-ndjson", + **headers, + } + data = b"".join(_payload_as_ndjson()) + params = {"create_pr": "1"} if create_pr else None + + try: + commit_resp = get_session().post(url=commit_url, headers=headers, data=data, params=params) + hf_raise_for_status(commit_resp, endpoint_name="commit") + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + except EntryNotFoundError as e: + if nb_deletions > 0 and "A file with this name doesn't exist" in str(e): + e.append_to_message( + "\nMake sure to differentiate file and folder paths in delete" + " operations with a trailing '/' or using `is_folder=True/False`." + ) + raise + + # Mark additions as committed (cannot be reused in another commit) + for addition in additions: + addition._is_committed = True + + commit_data = commit_resp.json() + return CommitInfo( + commit_url=commit_data["commitUrl"], + commit_message=commit_message, + commit_description=commit_description, + oid=commit_data["commitOid"], + pr_url=commit_data["pullRequestUrl"] if create_pr else None, + ) + + def preupload_lfs_files( + self, + repo_id: str, + additions: Iterable[CommitOperationAdd], + *, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + free_memory: bool = True, + gitignore_content: Optional[str] = None, + ): + """Pre-upload LFS files to S3 in preparation on a future commit. + + This method is useful if you are generating the files to upload on-the-fly and you don't want to store them + in memory before uploading them all at once. + + + + This is a power-user method. You shouldn't need to call it directly to make a normal commit. + Use [`create_commit`] directly instead. + + + + + + Commit operations will be mutated during the process. In particular, the attached `path_or_fileobj` will be + removed after the upload to save memory (and replaced by an empty `bytes` object). Do not reuse the same + objects except to pass them to [`create_commit`]. If you don't want to remove the attached content from the + commit operation object, pass `free_memory=False`. + + + + Args: + repo_id (`str`): + The repository in which you will commit the files, for example: `"username/custom_transformers"`. + + operations (`Iterable` of [`CommitOperationAdd`]): + The list of files to upload. Warning: the objects in this list will be mutated to include information + relative to the upload. Do not reuse the same objects for multiple commits. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + The type of repository to upload to (e.g. `"model"` -default-, `"dataset"` or `"space"`). + + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + + create_pr (`boolean`, *optional*): + Whether or not you plan to create a Pull Request with that commit. Defaults to `False`. + + num_threads (`int`, *optional*): + Number of concurrent threads for uploading files. Defaults to 5. + Setting it to 2 means at most 2 files will be uploaded concurrently. + + gitignore_content (`str`, *optional*): + The content of the `.gitignore` file to know which files should be ignored. The order of priority + is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present + in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub + (if any). + + Example: + ```py + >>> from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit, create_repo + + >>> repo_id = create_repo("test_preupload").repo_id + + # Generate and preupload LFS files one by one + >>> operations = [] # List of all `CommitOperationAdd` objects that will be generated + >>> for i in range(5): + ... content = ... # generate binary content + ... addition = CommitOperationAdd(path_in_repo=f"shard_{i}_of_5.bin", path_or_fileobj=content) + ... preupload_lfs_files(repo_id, additions=[addition]) # upload + free memory + ... operations.append(addition) + + # Create commit + >>> create_commit(repo_id, operations=operations, commit_message="Commit all shards") + ``` + """ + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + create_pr = create_pr if create_pr is not None else False + headers = self._build_hf_headers(token=token) + + # Check if a `gitignore` file is being committed to the Hub. + additions = list(additions) + if gitignore_content is None: + for addition in additions: + if addition.path_in_repo == ".gitignore": + with addition.as_file() as f: + gitignore_content = f.read().decode() + break + + # Filter out already uploaded files + new_additions = [addition for addition in additions if not addition._is_uploaded] + + # Check which new files are LFS + # For some items, we might have already fetched the upload mode (in case of upload_large_folder) + additions_no_upload_mode = [addition for addition in new_additions if addition._upload_mode is None] + if len(additions_no_upload_mode) > 0: + try: + _fetch_upload_modes( + additions=additions_no_upload_mode, + repo_type=repo_type, + repo_id=repo_id, + headers=headers, + revision=revision, + endpoint=self.endpoint, + create_pr=create_pr or False, + gitignore_content=gitignore_content, + ) + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + + # Filter out regular files + new_lfs_additions = [addition for addition in new_additions if addition._upload_mode == "lfs"] + + # Filter out files listed in .gitignore + new_lfs_additions_to_upload = [] + for addition in new_lfs_additions: + if addition._should_ignore: + logger.debug(f"Skipping upload for LFS file '{addition.path_in_repo}' (ignored by gitignore file).") + else: + new_lfs_additions_to_upload.append(addition) + if len(new_lfs_additions) != len(new_lfs_additions_to_upload): + logger.info( + f"Skipped upload for {len(new_lfs_additions) - len(new_lfs_additions_to_upload)} LFS file(s) " + "(ignored by gitignore file)." + ) + # Prepare upload parameters + upload_kwargs = { + "additions": new_lfs_additions_to_upload, + "repo_type": repo_type, + "repo_id": repo_id, + "headers": headers, + "endpoint": self.endpoint, + # If `create_pr`, we don't want to check user permission on the revision as users with read permission + # should still be able to create PRs even if they don't have write permission on the target branch of the + # PR (i.e. `revision`). + "revision": revision if not create_pr else None, + } + # Upload files using Xet protocol if all of the following are true: + # - xet is enabled for the repo, + # - the files are provided as str or paths objects, + # - the library is installed. + # Otherwise, default back to LFS. + xet_enabled = self.repo_info( + repo_id=repo_id, + repo_type=repo_type, + revision=unquote(revision) if revision is not None else revision, + expand="xetEnabled", + token=token, + ).xet_enabled + has_buffered_io_data = any( + isinstance(addition.path_or_fileobj, io.BufferedIOBase) for addition in new_lfs_additions_to_upload + ) + if xet_enabled and not has_buffered_io_data and is_xet_available(): + logger.debug("Uploading files using Xet Storage..") + _upload_xet_files(**upload_kwargs, create_pr=create_pr) # type: ignore [arg-type] + else: + if xet_enabled and is_xet_available(): + if has_buffered_io_data: + logger.warning( + "Uploading files as a binary IO buffer is not supported by Xet Storage. " + "Falling back to HTTP upload." + ) + _upload_lfs_files(**upload_kwargs, num_threads=num_threads) # type: ignore [arg-type] + for addition in new_lfs_additions_to_upload: + addition._is_uploaded = True + if free_memory: + addition.path_or_fileobj = b"" + + @overload + def upload_file( # type: ignore + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: ... + + @overload + def upload_file( + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: ... + + @validate_hf_hub_args + @future_compatible + def upload_file( + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Upload a local file (up to 50 GB) to the given repo. The upload is done + through a HTTP post request, and doesn't require git or git-lfs to be + installed. + + Args: + path_or_fileobj (`str`, `Path`, `bytes`, or `IO`): + Path to a file on the local machine or binary data stream / + fileobj / buffer. + path_in_repo (`str`): + Relative filepath in the repo, for example: + `"checkpoints/1fec34a/weights.bin"` + repo_id (`str`): + The repository to which the file will be uploaded, for example: + `"username/custom_transformers"` + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit + commit_description (`str` *optional*) + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + + + + `upload_file` assumes that the repo already exists on the Hub. If you get a + Client error 404, please make sure you are authenticated and that `repo_id` and + `repo_type` are set correctly. If repo does not exist, create it first using + [`~hf_api.create_repo`]. + + + + Example: + + ```python + >>> from huggingface_hub import upload_file + + >>> with open("./local/filepath", "rb") as fobj: + ... upload_file( + ... path_or_fileobj=fileobj, + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-dataset", + ... repo_type="dataset", + ... token="my_token", + ... ) + "https://huggingface.co/datasets/username/my-dataset/blob/main/remote/file/path.h5" + + >>> upload_file( + ... path_or_fileobj=".\\\\local\\\\file\\\\path", + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-model", + ... token="my_token", + ... ) + "https://huggingface.co/username/my-model/blob/main/remote/file/path.h5" + + >>> upload_file( + ... path_or_fileobj=".\\\\local\\\\file\\\\path", + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-model", + ... token="my_token", + ... create_pr=True, + ... ) + "https://huggingface.co/username/my-model/blob/refs%2Fpr%2F1/remote/file/path.h5" + ``` + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + + commit_message = ( + commit_message if commit_message is not None else f"Upload {path_in_repo} with huggingface_hub" + ) + operation = CommitOperationAdd( + path_or_fileobj=path_or_fileobj, + path_in_repo=path_in_repo, + ) + + commit_info = self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + operations=[operation], + commit_message=commit_message, + commit_description=commit_description, + token=token, + revision=revision, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + if commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + revision = revision if revision is not None else constants.DEFAULT_REVISION + + return CommitInfo( + commit_url=commit_info.commit_url, + commit_message=commit_info.commit_message, + commit_description=commit_info.commit_description, + oid=commit_info.oid, + pr_url=commit_info.pr_url, + # Similar to `hf_hub_url` but it's "blob" instead of "resolve" + # TODO: remove this in v1.0 + _url=f"{self.endpoint}/{repo_id}/blob/{revision}/{path_in_repo}", + ) + + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: ... + + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: ... + + @validate_hf_hub_args + @future_compatible + def upload_folder( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Upload a local folder to the given repo. The upload is done through a HTTP requests, and doesn't require git or + git-lfs to be installed. + + The structure of the folder will be preserved. Files with the same name already present in the repository will + be overwritten. Others will be left untouched. + + Use the `allow_patterns` and `ignore_patterns` arguments to specify which files to upload. These parameters + accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as + documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). If both `allow_patterns` and + `ignore_patterns` are provided, both constraints apply. By default, all files from the folder are uploaded. + + Use the `delete_patterns` argument to specify remote files you want to delete. Input type is the same as for + `allow_patterns` (see above). If `path_in_repo` is also provided, the patterns are matched against paths + relative to this folder. For example, `upload_folder(..., path_in_repo="experiment", delete_patterns="logs/*")` + will delete any remote file under `./experiment/logs/`. Note that the `.gitattributes` file will not be deleted + even if it matches the patterns. + + Any `.git/` folder present in any subdirectory will be ignored. However, please be aware that the `.gitignore` + file is not taken into account. + + Uses `HfApi.create_commit` under the hood. + + Args: + repo_id (`str`): + The repository to which the file will be uploaded, for example: + `"username/custom_transformers"` + folder_path (`str` or `Path`): + Path to the folder to upload on the local file system + path_in_repo (`str`, *optional*): + Relative path of the directory in the repo, for example: + `"checkpoints/1fec34a/results"`. Will default to the root folder of the repository. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to: + `f"Upload {path_in_repo} with huggingface_hub"` + commit_description (`str` *optional*): + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not + set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened + against this branch. If `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo while committing + new files. This is useful if you don't know which files have already been uploaded. + Note: to avoid discrepancies the `.gitattributes` file is not deleted even if it matches the pattern. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + + + + + `upload_folder` assumes that the repo already exists on the Hub. If you get a Client error 404, please make + sure you are authenticated and that `repo_id` and `repo_type` are set correctly. If repo does not exist, create + it first using [`~hf_api.create_repo`]. + + + + + + When dealing with a large folder (thousands of files or hundreds of GB), we recommend using [`~hf_api.upload_large_folder`] instead. + + + + Example: + + ```python + # Upload checkpoints folder except the log files + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... ignore_patterns="**/logs/*.txt", + ... ) + # "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" + + # Upload checkpoints folder including logs while deleting existing logs from the repo + # Useful if you don't know exactly which log files have already being pushed + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... delete_patterns="**/logs/*.txt", + ... ) + "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" + + # Upload checkpoints folder while creating a PR + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... create_pr=True, + ... ) + "https://huggingface.co/datasets/username/my-dataset/tree/refs%2Fpr%2F1/remote/experiment/checkpoints" + + ``` + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + + # By default, upload folder to the root directory in repo. + if path_in_repo is None: + path_in_repo = "" + + # Do not upload .git folder + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + ignore_patterns += DEFAULT_IGNORE_PATTERNS + + delete_operations = self._prepare_folder_deletions( + repo_id=repo_id, + repo_type=repo_type, + revision=constants.DEFAULT_REVISION if create_pr else revision, + token=token, + path_in_repo=path_in_repo, + delete_patterns=delete_patterns, + ) + add_operations = self._prepare_upload_folder_additions( + folder_path, + path_in_repo, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + token=token, + repo_type=repo_type, + ) + + # Optimize operations: if some files will be overwritten, we don't need to delete them first + if len(add_operations) > 0: + added_paths = set(op.path_in_repo for op in add_operations) + delete_operations = [ + delete_op for delete_op in delete_operations if delete_op.path_in_repo not in added_paths + ] + commit_operations = delete_operations + add_operations + + commit_message = commit_message or "Upload folder using huggingface_hub" + + commit_info = self.create_commit( + repo_type=repo_type, + repo_id=repo_id, + operations=commit_operations, + commit_message=commit_message, + commit_description=commit_description, + token=token, + revision=revision, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + # Create url to uploaded folder (for legacy return value) + if create_pr and commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + revision = revision if revision is not None else constants.DEFAULT_REVISION + + return CommitInfo( + commit_url=commit_info.commit_url, + commit_message=commit_info.commit_message, + commit_description=commit_info.commit_description, + oid=commit_info.oid, + pr_url=commit_info.pr_url, + # Similar to `hf_hub_url` but it's "tree" instead of "resolve" + # TODO: remove this in v1.0 + _url=f"{self.endpoint}/{repo_id}/tree/{revision}/{path_in_repo}", + ) + + @validate_hf_hub_args + def delete_file( + self, + path_in_repo: str, + repo_id: str, + *, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Deletes a file in the given repo. + + Args: + path_in_repo (`str`): + Relative filepath in the repo, for example: + `"checkpoints/1fec34a/weights.bin"` + repo_id (`str`): + The repository from which the file will be deleted, for example: + `"username/custom_transformers"` + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or + space, `None` or `"model"` if in a model. Default is `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Delete {path_in_repo} with huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + - [`~utils.EntryNotFoundError`] + If the file to download cannot be found. + + + + """ + commit_message = ( + commit_message if commit_message is not None else f"Delete {path_in_repo} with huggingface_hub" + ) + + operations = [CommitOperationDelete(path_in_repo=path_in_repo)] + + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=operations, + revision=revision, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @validate_hf_hub_args + def delete_files( + self, + repo_id: str, + delete_patterns: List[str], + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Delete files from a repository on the Hub. + + If a folder path is provided, the entire folder is deleted as well as + all files it contained. + + Args: + repo_id (`str`): + The repository from which the folder will be deleted, for example: + `"username/custom_transformers"` + delete_patterns (`List[str]`): + List of files or folders to delete. Each string can either be + a file path, a folder path or a Unix shell-style wildcard. + E.g. `["file.txt", "folder/", "data/*.parquet"]` + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + to the stored token. + repo_type (`str`, *optional*): + Type of the repo to delete files from. Can be `"model"`, + `"dataset"` or `"space"`. Defaults to `"model"`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary (first line) of the generated commit. Defaults to + `f"Delete files using huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + """ + operations = self._prepare_folder_deletions( + repo_id=repo_id, repo_type=repo_type, delete_patterns=delete_patterns, path_in_repo="", revision=revision + ) + + if commit_message is None: + commit_message = f"Delete files {' '.join(delete_patterns)} with huggingface_hub" + + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=operations, + revision=revision, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @validate_hf_hub_args + def delete_folder( + self, + path_in_repo: str, + repo_id: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Deletes a folder in the given repo. + + Simple wrapper around [`create_commit`] method. + + Args: + path_in_repo (`str`): + Relative folder path in the repo, for example: `"checkpoints/1fec34a"`. + repo_id (`str`): + The repository from which the folder will be deleted, for example: + `"username/custom_transformers"` + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + to the stored token. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the folder is in a dataset or + space, `None` or `"model"` if in a model. Default is `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Delete folder {path_in_repo} with huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + """ + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=[CommitOperationDelete(path_in_repo=path_in_repo, is_folder=True)], + revision=revision, + commit_message=( + commit_message if commit_message is not None else f"Delete folder {path_in_repo} with huggingface_hub" + ), + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + def upload_large_folder( + self, + repo_id: str, + folder_path: Union[str, Path], + *, + repo_type: str, # Repo type is required! + revision: Optional[str] = None, + private: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + num_workers: Optional[int] = None, + print_report: bool = True, + print_report_every: int = 60, + ) -> None: + """Upload a large folder to the Hub in the most resilient way possible. + + Several workers are started to upload files in an optimized way. Before being committed to a repo, files must be + hashed and be pre-uploaded if they are LFS files. Workers will perform these tasks for each file in the folder. + At each step, some metadata information about the upload process is saved in the folder under `.cache/.huggingface/` + to be able to resume the process if interrupted. The whole process might result in several commits. + + Args: + repo_id (`str`): + The repository to which the file will be uploaded. + E.g. `"HuggingFaceTB/smollm-corpus"`. + folder_path (`str` or `Path`): + Path to the folder to upload on the local file system. + repo_type (`str`): + Type of the repository. Must be one of `"model"`, `"dataset"` or `"space"`. + Unlike in all other `HfApi` methods, `repo_type` is explicitly required here. This is to avoid + any mistake when uploading a large folder to the Hub, and therefore prevent from having to re-upload + everything. + revision (`str`, `optional`): + The branch to commit to. If not provided, the `main` branch will be used. + private (`bool`, `optional`): + Whether the repository should be private. + If `None` (default), the repo will be public unless the organization's default is private. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + num_workers (`int`, *optional*): + Number of workers to start. Defaults to `os.cpu_count() - 2` (minimum 2). + A higher number of workers may speed up the process if your machine allows it. However, on machines with a + slower connection, it is recommended to keep the number of workers low to ensure better resumability. + Indeed, partially uploaded files will have to be completely re-uploaded if the process is interrupted. + print_report (`bool`, *optional*): + Whether to print a report of the upload progress. Defaults to True. + Report is printed to `sys.stdout` every X seconds (60 by defaults) and overwrites the previous report. + print_report_every (`int`, *optional*): + Frequency at which the report is printed. Defaults to 60 seconds. + + + + A few things to keep in mind: + - Repository limits still apply: https://huggingface.co/docs/hub/repositories-recommendations + - Do not start several processes in parallel. + - You can interrupt and resume the process at any time. + - Do not upload the same folder to several repositories. If you need to do so, you must delete the local `.cache/.huggingface/` folder first. + + + + + + While being much more robust to upload large folders, `upload_large_folder` is more limited than [`upload_folder`] feature-wise. In practice: + - you cannot set a custom `path_in_repo`. If you want to upload to a subfolder, you need to set the proper structure locally. + - you cannot set a custom `commit_message` and `commit_description` since multiple commits are created. + - you cannot delete from the repo while uploading. Please make a separate commit first. + - you cannot create a PR directly. Please create a PR first (from the UI or using [`create_pull_request`]) and then commit to it by passing `revision`. + + + + **Technical details:** + + `upload_large_folder` process is as follow: + 1. (Check parameters and setup.) + 2. Create repo if missing. + 3. List local files to upload. + 4. Run validation checks and display warnings if repository limits might be exceeded: + - Warns if the total number of files exceeds 100k (recommended limit). + - Warns if any folder contains more than 10k files (recommended limit). + - Warns about files larger than 20GB (recommended) or 50GB (hard limit). + 5. Start workers. Workers can perform the following tasks: + - Hash a file. + - Get upload mode (regular or LFS) for a list of files. + - Pre-upload an LFS file. + - Commit a bunch of files. + Once a worker finishes a task, it will move on to the next task based on the priority list (see below) until + all files are uploaded and committed. + 6. While workers are up, regularly print a report to sys.stdout. + + Order of priority: + 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file). + 2. Commit if at least 150 files are ready to commit. + 3. Get upload mode if at least 10 files have been hashed. + 4. Pre-upload LFS file if at least 1 file and no worker is pre-uploading. + 5. Hash file if at least 1 file and no worker is hashing. + 6. Get upload mode if at least 1 file and no worker is getting upload mode. + 7. Pre-upload LFS file if at least 1 file (exception: if hf_transfer is enabled, only 1 worker can preupload LFS at a time). + 8. Hash file if at least 1 file to hash. + 9. Get upload mode if at least 1 file to get upload mode. + 10. Commit if at least 1 file to commit and at least 1 min since last commit attempt. + 11. Commit if at least 1 file to commit and all other queues are empty. + + Special rules: + - If `hf_transfer` is enabled, only 1 LFS uploader at a time. Otherwise the CPU would be bloated by `hf_transfer`. + - Only one worker can commit at a time. + - If no tasks are available, the worker waits for 10 seconds before checking again. + """ + return upload_large_folder_internal( + self, + repo_id=repo_id, + folder_path=folder_path, + repo_type=repo_type, + revision=revision, + private=private, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + num_workers=num_workers, + print_report=print_report, + print_report_every=print_report_every, + ) + + @validate_hf_hub_args + def get_hf_file_metadata( + self, + *, + url: str, + token: Union[bool, str, None] = None, + proxies: Optional[Dict] = None, + timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, + ) -> HfFileMetadata: + """Fetch metadata of a file versioned on the Hub for a given url. + + Args: + url (`str`): + File url, for example returned by [`hf_hub_url`]. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. + timeout (`float`, *optional*, defaults to 10): + How many seconds to wait for the server to send metadata before giving up. + + Returns: + A [`HfFileMetadata`] object containing metadata such as location, etag, size and commit_hash. + """ + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return get_hf_file_metadata( + url=url, + token=token, + proxies=proxies, + timeout=timeout, + library_name=self.library_name, + library_version=self.library_version, + user_agent=self.user_agent, + endpoint=self.endpoint, + ) + + @validate_hf_hub_args + def hf_hub_download( + self, + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + # Deprecated args + resume_download: Optional[bool] = None, + force_filename: Optional[str] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + ) -> str: + """Download a given file if it's not already present in the local cache. + + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + ``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + ``` + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files. While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the repository. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded file will be placed under this directory. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in + the local cache. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + + Returns: + `str`: Local path of file or if networking is off, last version of file cached on disk. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`~utils.EntryNotFoundError`] + If the file to download cannot be found. + [`~utils.LocalEntryNotFoundError`] + If network is disabled or unavailable and file is not found in cache. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` but the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + If ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If some parameter value is invalid. + """ + from .file_download import hf_hub_download + + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + library_name=self.library_name, + library_version=self.library_version, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent=self.user_agent, + force_download=force_download, + force_filename=force_filename, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + token=token, + headers=self.headers, + local_files_only=local_files_only, + ) + + @validate_hf_hub_args + def snapshot_download( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + proxies: Optional[Dict] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[Type[base_tqdm]] = None, + # Deprecated args + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + resume_download: Optional[bool] = None, + ) -> str: + """Download repo files. + + Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from + a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. You can also filter which files to download using + `allow_patterns` and `ignore_patterns`. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files.While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly + configured. It is also not possible to filter which files to download when cloning a repository using git. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded files will be placed under this directory. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in the local cache. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are downloaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not downloaded. + max_workers (`int`, *optional*): + Number of concurrent threads to download files (1 thread = 1 file download). + Defaults to 8. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Note that the `tqdm_class` is not passed to each individual download. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + + Returns: + `str`: folder path of the repo snapshot. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` and the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid. + """ + from ._snapshot_download import snapshot_download + + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return snapshot_download( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + library_name=self.library_name, + library_version=self.library_version, + user_agent=self.user_agent, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + force_download=force_download, + token=token, + local_files_only=local_files_only, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + max_workers=max_workers, + tqdm_class=tqdm_class, + ) + + def get_safetensors_metadata( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> SafetensorsRepoMetadata: + """ + Parse metadata for a safetensors repo on the Hub. + + We first check if the repo has a single safetensors file or a sharded safetensors repo. If it's a single + safetensors file, we parse the metadata from this file. If it's a sharded safetensors repo, we parse the + metadata from the index file and then parse the metadata from each shard. + + To parse metadata from a single safetensors file, use [`parse_safetensors_file_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a + model. Default is `None`. + revision (`str`, *optional*): + The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the + head of the `"main"` branch. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`SafetensorsRepoMetadata`]: information related to safetensors repo. + + Raises: + [`NotASafetensorsRepoError`] + If the repo is not a safetensors repo i.e. doesn't have either a + `model.safetensors` or a `model.safetensors.index.json` file. + [`SafetensorsParsingError`] + If a safetensors file header couldn't be parsed correctly. + + Example: + ```py + # Parse repo with single weights file + >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m") + >>> metadata + SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...}, + files_metadata={'model.safetensors': SafetensorsFileMetadata(...)} + ) + >>> metadata.files_metadata["model.safetensors"].metadata + {'format': 'pt'} + + # Parse repo with sharded model + >>> metadata = get_safetensors_metadata("bigscience/bloom") + Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s] + >>> metadata + SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...}) + >>> len(metadata.files_metadata) + 72 # All safetensors files have been fetched + + # Parse repo with sharded model + >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5") + NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files. + ``` + """ + if self.file_exists( # Single safetensors file => non-sharded model + repo_id=repo_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ): + file_metadata = self.parse_safetensors_file_metadata( + repo_id=repo_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ) + return SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={ + tensor_name: constants.SAFETENSORS_SINGLE_FILE for tensor_name in file_metadata.tensors.keys() + }, + files_metadata={constants.SAFETENSORS_SINGLE_FILE: file_metadata}, + ) + elif self.file_exists( # Multiple safetensors files => sharded with index + repo_id=repo_id, + filename=constants.SAFETENSORS_INDEX_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ): + # Fetch index + index_file = self.hf_hub_download( + repo_id=repo_id, + filename=constants.SAFETENSORS_INDEX_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ) + with open(index_file) as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + + # Fetch metadata per shard + files_metadata = {} + + def _parse(filename: str) -> None: + files_metadata[filename] = self.parse_safetensors_file_metadata( + repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, token=token + ) + + thread_map( + _parse, + set(weight_map.values()), + desc="Parse safetensors files", + tqdm_class=hf_tqdm, + ) + + return SafetensorsRepoMetadata( + metadata=index.get("metadata", None), + sharded=True, + weight_map=weight_map, + files_metadata=files_metadata, + ) + else: + # Not a safetensors repo + raise NotASafetensorsRepoError( + f"'{repo_id}' is not a safetensors repo. Couldn't find '{constants.SAFETENSORS_INDEX_FILE}' or '{constants.SAFETENSORS_SINGLE_FILE}' files." + ) + + def parse_safetensors_file_metadata( + self, + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> SafetensorsFileMetadata: + """ + Parse metadata from a safetensors file on the Hub. + + To parse metadata from all safetensors files in a repo at once, use [`get_safetensors_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a + model. Default is `None`. + revision (`str`, *optional*): + The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the + head of the `"main"` branch. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`SafetensorsFileMetadata`]: information related to a safetensors file. + + Raises: + [`NotASafetensorsRepoError`]: + If the repo is not a safetensors repo i.e. doesn't have either a + `model.safetensors` or a `model.safetensors.index.json` file. + [`SafetensorsParsingError`]: + If a safetensors file header couldn't be parsed correctly. + """ + url = hf_hub_url( + repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, endpoint=self.endpoint + ) + _headers = self._build_hf_headers(token=token) + + # 1. Fetch first 100kb + # Empirically, 97% of safetensors files have a metadata size < 100kb (over the top 1000 models on the Hub). + # We assume fetching 100kb is faster than making 2 GET requests. Therefore we always fetch the first 100kb to + # avoid the 2nd GET in most cases. + # See https://github.com/huggingface/huggingface_hub/pull/1855#discussion_r1404286419. + response = get_session().get(url, headers={**_headers, "range": "bytes=0-100000"}) + hf_raise_for_status(response) + + # 2. Parse metadata size + metadata_size = struct.unpack(" constants.SAFETENSORS_MAX_HEADER_LENGTH: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or constants.DEFAULT_REVISION}'): safetensors header is too big. Maximum supported size is " + f"{constants.SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size})." + ) + + # 3.a. Get metadata from payload + if metadata_size <= 100000: + metadata_as_bytes = response.content[8 : 8 + metadata_size] + else: # 3.b. Request full metadata + response = get_session().get(url, headers={**_headers, "range": f"bytes=8-{metadata_size + 7}"}) + hf_raise_for_status(response) + metadata_as_bytes = response.content + + # 4. Parse json header + try: + metadata_as_dict = json.loads(metadata_as_bytes.decode(errors="ignore")) + except json.JSONDecodeError as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or constants.DEFAULT_REVISION}'): header is not json-encoded string. Please make sure this is a " + "correctly formatted safetensors file." + ) from e + + try: + return SafetensorsFileMetadata( + metadata=metadata_as_dict.get("__metadata__", {}), + tensors={ + key: TensorInfo( + dtype=tensor["dtype"], + shape=tensor["shape"], + data_offsets=tuple(tensor["data_offsets"]), # type: ignore + ) + for key, tensor in metadata_as_dict.items() + if key != "__metadata__" + }, + ) + except (KeyError, IndexError) as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or constants.DEFAULT_REVISION}'): header format not recognized. Please make sure this is a correctly" + " formatted safetensors file." + ) from e + + @validate_hf_hub_args + def create_branch( + self, + repo_id: str, + *, + branch: str, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Create a new branch for a repo on the Hub, starting from the specified revision (defaults to `main`). + To find a revision suiting your needs, you can use [`list_repo_refs`] or [`list_repo_commits`]. + + Args: + repo_id (`str`): + The repository in which the branch will be created. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to create. + + revision (`str`, *optional*): + The git revision to create the branch from. It can be a branch name or + the OID/SHA of a commit, as a hexadecimal string. Defaults to the head + of the `"main"` branch. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if branch already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.BadRequestError`]: + If invalid reference for a branch. Ex: `refs/pr/5` or 'refs/foo/bar'. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token) + payload = {} + if revision is not None: + payload["startingPoint"] = revision + + # Create branch + response = get_session().post(url=branch_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if exist_ok and e.response.status_code == 409: + return + elif exist_ok and e.response.status_code == 403: + # No write permission on the namespace but branch might already exist + try: + refs = self.list_repo_refs(repo_id=repo_id, repo_type=repo_type, token=token) + for branch_ref in refs.branches: + if branch_ref.name == branch: + return # Branch already exists => do not raise + except HfHubHTTPError: + pass # We raise the original error if the branch does not exist + raise + + @validate_hf_hub_args + def delete_branch( + self, + repo_id: str, + *, + branch: str, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a branch from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a branch will be deleted. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to delete. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.HfHubHTTPError`]: + If trying to delete a protected branch. Ex: `main` cannot be deleted. + [`~utils.HfHubHTTPError`]: + If trying to delete a branch that does not exist. + + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token) + + # Delete branch + response = get_session().delete(url=branch_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def create_tag( + self, + repo_id: str, + *, + tag: str, + tag_message: Optional[str] = None, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Tag a given commit of a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a commit will be tagged. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to create. + + tag_message (`str`, *optional*): + The description of the tag to create. + + revision (`str`, *optional*): + The git revision to tag. It can be a branch name or the OID/SHA of a + commit, as a hexadecimal string. Shorthands (7 first characters) are + also supported. Defaults to the head of the `"main"` branch. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or + space, `None` or `"model"` if tagging a model. Default is + `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if tag already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{revision}" + headers = self._build_hf_headers(token=token) + payload = {"tag": tag} + if tag_message is not None: + payload["message"] = tag_message + + # Tag + response = get_session().post(url=tag_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if not (e.response.status_code == 409 and exist_ok): + raise + + @validate_hf_hub_args + def delete_tag( + self, + repo_id: str, + *, + tag: str, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a tag from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a tag will be deleted. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to delete. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or space, `None` or + `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If tag is not found. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + tag = quote(tag, safe="") + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{tag}" + headers = self._build_hf_headers(token=token) + + # Un-tag + response = get_session().delete(url=tag_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def get_full_repo_name( + self, + model_id: str, + *, + organization: Optional[str] = None, + token: Union[bool, str, None] = None, + ): + """ + Returns the repository name for a given model ID and optional + organization. + + Args: + model_id (`str`): + The name of the model. + organization (`str`, *optional*): + If passed, the repository name will be in the organization + namespace instead of the user namespace. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `str`: The repository name in the user's namespace + ({username}/{model_id}) if no organization is passed, and under the + organization namespace ({organization}/{model_id}) otherwise. + """ + if organization is None: + if "/" in model_id: + username = model_id.split("/")[0] + else: + username = self.whoami(token=token)["name"] # type: ignore + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + @validate_hf_hub_args + def get_repo_discussions( + self, + repo_id: str, + *, + author: Optional[str] = None, + discussion_type: Optional[constants.DiscussionTypeFilter] = None, + discussion_status: Optional[constants.DiscussionStatusFilter] = None, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterator[Discussion]: + """ + Fetches Discussions and Pull Requests for the given repo. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + author (`str`, *optional*): + Pass a value to filter by discussion author. `None` means no filter. + Default is `None`. + discussion_type (`str`, *optional*): + Set to `"pull_request"` to fetch only pull requests, `"discussion"` + to fetch only discussions. Set to `"all"` or `None` to fetch both. + Default is `None`. + discussion_status (`str`, *optional*): + Set to `"open"` (respectively `"closed"`) to fetch only open + (respectively closed) discussions. Set to `"all"` or `None` + to fetch both. + Default is `None`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if fetching from a dataset or + space, `None` or `"model"` if fetching from a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterator[Discussion]`: An iterator of [`Discussion`] objects. + + Example: + Collecting all discussions of a repo in a list: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) + ``` + + Iterating over discussions of a repo: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> for discussion in get_repo_discussions(repo_id="bert-base-uncased"): + ... print(discussion.num, discussion.title) + ``` + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if discussion_type is not None and discussion_type not in constants.DISCUSSION_TYPES: + raise ValueError(f"Invalid discussion_type, must be one of {constants.DISCUSSION_TYPES}") + if discussion_status is not None and discussion_status not in constants.DISCUSSION_STATUS: + raise ValueError(f"Invalid discussion_status, must be one of {constants.DISCUSSION_STATUS}") + + headers = self._build_hf_headers(token=token) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions" + + params: Dict[str, Union[str, int]] = {} + if discussion_type is not None: + params["type"] = discussion_type + if discussion_status is not None: + params["status"] = discussion_status + if author is not None: + params["author"] = author + + def _fetch_discussion_page(page_index: int): + params["p"] = page_index + resp = get_session().get(path, headers=headers, params=params) + hf_raise_for_status(resp) + paginated_discussions = resp.json() + total = paginated_discussions["count"] + start = paginated_discussions["start"] + discussions = paginated_discussions["discussions"] + has_next = (start + len(discussions)) < total + return discussions, has_next + + has_next, page_index = True, 0 + + while has_next: + discussions, has_next = _fetch_discussion_page(page_index=page_index) + for discussion in discussions: + yield Discussion( + title=discussion["title"], + num=discussion["num"], + author=discussion.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion["createdAt"]), + status=discussion["status"], + repo_id=discussion["repo"]["name"], + repo_type=discussion["repo"]["type"], + is_pull_request=discussion["isPullRequest"], + endpoint=self.endpoint, + ) + page_index = page_index + 1 + + @validate_hf_hub_args + def get_discussion_details( + self, + repo_id: str, + discussion_num: int, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> DiscussionWithDetails: + """Fetches a Discussion's / Pull Request 's details from the Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`DiscussionWithDetails`] + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" + headers = self._build_hf_headers(token=token) + resp = get_session().get(path, params={"diff": "1"}, headers=headers) + hf_raise_for_status(resp) + + discussion_details = resp.json() + is_pull_request = discussion_details["isPullRequest"] + + target_branch = discussion_details["changes"]["base"] if is_pull_request else None + conflicting_files = discussion_details["filesWithConflicts"] if is_pull_request else None + merge_commit_oid = discussion_details["changes"].get("mergeCommitId", None) if is_pull_request else None + + return DiscussionWithDetails( + title=discussion_details["title"], + num=discussion_details["num"], + author=discussion_details.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion_details["createdAt"]), + status=discussion_details["status"], + repo_id=discussion_details["repo"]["name"], + repo_type=discussion_details["repo"]["type"], + is_pull_request=discussion_details["isPullRequest"], + events=[deserialize_event(evt) for evt in discussion_details["events"]], + conflicting_files=conflicting_files, + target_branch=target_branch, + merge_commit_oid=merge_commit_oid, + diff=discussion_details.get("diff"), + endpoint=self.endpoint, + ) + + @validate_hf_hub_args + def create_discussion( + self, + repo_id: str, + title: str, + *, + token: Union[bool, str, None] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + pull_request: bool = False, + ) -> DiscussionWithDetails: + """Creates a Discussion or Pull Request. + + Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + pull_request (`bool`, *optional*): + Whether to create a Pull Request or discussion. If `True`, creates a Pull Request. + If `False`, creates a discussion. Defaults to `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + if description is not None: + description = description.strip() + description = ( + description + if description + else ( + f"{'Pull Request' if pull_request else 'Discussion'} opened with the" + " [huggingface_hub Python" + " library](https://huggingface.co/docs/huggingface_hub)" + ) + ) + + headers = self._build_hf_headers(token=token) + resp = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions", + json={ + "title": title.strip(), + "description": description, + "pullRequest": pull_request, + }, + headers=headers, + ) + hf_raise_for_status(resp) + num = resp.json()["num"] + return self.get_discussion_details( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=num, + token=token, + ) + + @validate_hf_hub_args + def create_pull_request( + self, + repo_id: str, + title: str, + *, + token: Union[bool, str, None] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionWithDetails: + """Creates a Pull Request . Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]; + + This is a wrapper around [`HfApi.create_discussion`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + """ + return self.create_discussion( + repo_id=repo_id, + title=title, + token=token, + description=description, + repo_type=repo_type, + pull_request=True, + ) + + def _post_discussion_changes( + self, + *, + repo_id: str, + discussion_num: int, + resource: str, + body: Optional[dict] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> requests.Response: + """Internal utility to POST changes to a Discussion or Pull Request""" + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + repo_id = f"{repo_type}s/{repo_id}" + + path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" + + headers = self._build_hf_headers(token=token) + resp = requests.post(path, headers=headers, json=body) + hf_raise_for_status(resp) + return resp + + @validate_hf_hub_args + def comment_discussion( + self, + repo_id: str, + discussion_num: int, + comment: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Creates a new comment on the given Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`): + The content of the comment to create. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the newly created comment + + + Examples: + ```python + + >>> comment = \"\"\" + ... Hello @otheruser! + ... + ... # This is a title + ... + ... **This is bold**, *this is italic* and ~this is strikethrough~ + ... And [this](http://url) is a link + ... \"\"\" + + >>> HfApi().comment_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... comment=comment + ... ) + # DiscussionComment(id='deadbeef0000000', type='comment', ...) + + ``` + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="comment", + body={"comment": comment}, + ) + return deserialize_event(resp.json()["newMessage"]) # type: ignore + + @validate_hf_hub_args + def rename_discussion( + self, + repo_id: str, + discussion_num: int, + new_title: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionTitleChange: + """Renames a Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_title (`str`): + The new title for the discussion + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionTitleChange`]: the title change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionTitleChange(id='deadbeef0000000', type='title-change', ...) + + ``` + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="title", + body={"title": new_title}, + ) + return deserialize_event(resp.json()["newTitle"]) # type: ignore + + @validate_hf_hub_args + def change_discussion_status( + self, + repo_id: str, + discussion_num: int, + new_status: Literal["open", "closed"], + *, + token: Union[bool, str, None] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionStatusChange: + """Closes or re-opens a Discussion or Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_status (`str`): + The new status for the discussion, either `"open"` or `"closed"`. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionStatusChange`]: the status change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionStatusChange(id='deadbeef0000000', type='status-change', ...) + + ``` + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if new_status not in ["open", "closed"]: + raise ValueError("Invalid status, valid statuses are: 'open' and 'closed'") + body: Dict[str, str] = {"status": new_status} + if comment and comment.strip(): + body["comment"] = comment.strip() + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="status", + body=body, + ) + return deserialize_event(resp.json()["newStatus"]) # type: ignore + + @validate_hf_hub_args + def merge_pull_request( + self, + repo_id: str, + discussion_num: int, + *, + token: Union[bool, str, None] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ): + """Merges a Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionStatusChange`]: the status change event + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="merge", + body={"comment": comment.strip()} if comment and comment.strip() else None, + ) + + @validate_hf_hub_args + def edit_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + new_content: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Edits a comment on a Discussion / Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + new_content (`str`): + The new content of the comment. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the edited comment + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/edit", + body={"content": new_content}, + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def hide_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Hides a comment on a Discussion / Pull Request. + + + Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible. + + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the hidden comment + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + warnings.warn( + "Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible.", + UserWarning, + ) + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/hide", + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def add_space_secret( + self, + repo_id: str, + key: str, + value: str, + *, + description: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """Adds or updates a secret in a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"` + value (`str`): + Secret value. Example: `"your_github_api_key"`. + description (`str`, *optional*): + Secret description. Example: `"Github API key to access the Github API"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def delete_space_secret(self, repo_id: str, key: str, *, token: Union[bool, str, None] = None) -> None: + """Deletes a secret from a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def get_space_variables(self, repo_id: str, *, token: Union[bool, str, None] = None) -> Dict[str, SpaceVariable]: + """Gets all variables from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to query. Example: `"bigcode/in-the-stack"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def add_space_variable( + self, + repo_id: str, + key: str, + value: str, + *, + description: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Dict[str, SpaceVariable]: + """Adds or updates a variable in a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + value (`str`): + Variable value. Example: `"the_model_repo_id"`. + description (`str`): + Description of the variable. Example: `"Model Repo ID of the implemented model"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def delete_space_variable( + self, repo_id: str, key: str, *, token: Union[bool, str, None] = None + ) -> Dict[str, SpaceVariable]: + """Deletes a variable from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def get_space_runtime(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: + """Gets runtime information about a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/runtime", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def request_space_hardware( + self, + repo_id: str, + hardware: SpaceHardware, + *, + token: Union[bool, str, None] = None, + sleep_time: Optional[int] = None, + ) -> SpaceRuntime: + """Request new hardware for a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + hardware (`str` or [`SpaceHardware`]): + Hardware on which to run the Space. Example: `"t4-medium"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + + + It is also possible to request hardware directly when creating the Space repo! See [`create_repo`] for details. + + + """ + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + payload: Dict[str, Any] = {"flavor": hardware} + if sleep_time is not None: + payload["sleepTimeSeconds"] = sleep_time + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/hardware", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def set_space_sleep_time( + self, repo_id: str, sleep_time: int, *, token: Union[bool, str, None] = None + ) -> SpaceRuntime: + """Set a custom sleep time for a Space running on upgraded hardware.. + + Your Space will go to sleep after X seconds of inactivity. You are not billed when your Space is in "sleep" + mode. If a new visitor lands on your Space, it will "wake it up". Only upgraded hardware can have a + configurable sleep time. To know more about the sleep stage, please refer to + https://huggingface.co/docs/hub/spaces-gpus#sleep-time. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to pause (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + + + It is also possible to set a custom sleep time when requesting hardware with [`request_space_hardware`]. + + + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/sleeptime", + headers=self._build_hf_headers(token=token), + json={"seconds": sleep_time}, + ) + hf_raise_for_status(r) + runtime = SpaceRuntime(r.json()) + + hardware = runtime.requested_hardware or runtime.hardware + if hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + return runtime + + @validate_hf_hub_args + def pause_space(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: + """Pause your Space. + + A paused Space stops executing until manually restarted by its owner. This is different from the sleeping + state in which free Spaces go after 48h of inactivity. Paused time is not billed to your account, no matter the + hardware you've selected. To restart your Space, use [`restart_space`] and go to your Space settings page. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to pause. Example: `"Salesforce/BLIP2"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space including `stage=PAUSED` and requested hardware. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can pause it. If you want to manage a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/pause", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def restart_space( + self, repo_id: str, *, token: Union[bool, str, None] = None, factory_reboot: bool = False + ) -> SpaceRuntime: + """Restart your Space. + + This is the only way to programmatically restart a Space if you've put it on Pause (see [`pause_space`]). You + must be the owner of the Space to restart it. If you are using an upgraded hardware, your account will be + billed as soon as the Space is restarted. You can trigger a restart no matter the current state of a Space. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to restart. Example: `"Salesforce/BLIP2"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + factory_reboot (`bool`, *optional*): + If `True`, the Space will be rebuilt from scratch without caching any requirements. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can restart it. If you want to restart a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + params = {} + if factory_reboot: + params["factory"] = "true" + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/restart", headers=self._build_hf_headers(token=token), params=params + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def duplicate_space( + self, + from_id: str, + to_id: Optional[str] = None, + *, + private: Optional[bool] = None, + token: Union[bool, str, None] = None, + exist_ok: bool = False, + hardware: Optional[SpaceHardware] = None, + storage: Optional[SpaceStorage] = None, + sleep_time: Optional[int] = None, + secrets: Optional[List[Dict[str, str]]] = None, + variables: Optional[List[Dict[str, str]]] = None, + ) -> RepoUrl: + """Duplicate a Space. + + Programmatically duplicate a Space. The new Space will be created in your account and will be in the same state + as the original Space (running or paused). You can duplicate a Space no matter the current state of a Space. + + Args: + from_id (`str`): + ID of the Space to duplicate. Example: `"pharma/CLIP-Interrogator"`. + to_id (`str`, *optional*): + ID of the new Space. Example: `"dog/CLIP-Interrogator"`. If not provided, the new Space will have the same + name as the original Space, but in your account. + private (`bool`, *optional*): + Whether the new Space should be private or not. Defaults to the same privacy as the original Space. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo already exists. + hardware (`SpaceHardware` or `str`, *optional*): + Choice of Hardware. Example: `"t4-medium"`. See [`SpaceHardware`] for a complete list. + storage (`SpaceStorage` or `str`, *optional*): + Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + secrets (`List[Dict[str, str]]`, *optional*): + A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + variables (`List[Dict[str, str]]`, *optional*): + A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. + + Returns: + [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing + attributes like `endpoint`, `repo_type` and `repo_id`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If one of `from_id` or `to_id` cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + If the HuggingFace API returned an error + + Example: + ```python + >>> from huggingface_hub import duplicate_space + + # Duplicate a Space to your account + >>> duplicate_space("multimodalart/dreambooth-training") + RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) + + # Can set custom destination id and visibility flag. + >>> duplicate_space("multimodalart/dreambooth-training", to_id="my-dreambooth", private=True) + RepoUrl('https://huggingface.co/spaces/nateraw/my-dreambooth',...) + ``` + """ + # Parse to_id if provided + parsed_to_id = RepoUrl(to_id) if to_id is not None else None + + # Infer target repo_id + to_namespace = ( # set namespace manually or default to username + parsed_to_id.namespace + if parsed_to_id is not None and parsed_to_id.namespace is not None + else self.whoami(token)["name"] + ) + to_repo_name = parsed_to_id.repo_name if to_id is not None else RepoUrl(from_id).repo_name # type: ignore + + # repository must be a valid repo_id (namespace/repo_name). + payload: Dict[str, Any] = {"repository": f"{to_namespace}/{to_repo_name}"} + + keys = ["private", "hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] + values = [private, hardware, storage, sleep_time, secrets, variables] + payload.update({k: v for k, v in zip(keys, values) if v is not None}) + + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + + r = get_session().post( + f"{self.endpoint}/api/spaces/{from_id}/duplicate", + headers=self._build_hf_headers(token=token), + json=payload, + ) + + try: + hf_raise_for_status(r) + except HTTPError as err: + if exist_ok and err.response.status_code == 409: + # Repo already exists and `exist_ok=True` + pass + else: + raise + + return RepoUrl(r.json()["url"], endpoint=self.endpoint) + + @validate_hf_hub_args + def request_space_storage( + self, + repo_id: str, + storage: SpaceStorage, + *, + token: Union[bool, str, None] = None, + ) -> SpaceRuntime: + """Request persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. + storage (`str` or [`SpaceStorage`]): + Storage tier. Either 'small', 'medium', or 'large'. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + + + It is not possible to decrease persistent storage after its granted. To do so, you must delete it + via [`delete_space_storage`]. + + + """ + payload: Dict[str, SpaceStorage] = {"tier": storage} + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def delete_space_storage( + self, + repo_id: str, + *, + token: Union[bool, str, None] = None, + ) -> SpaceRuntime: + """Delete persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + Raises: + [`BadRequestError`] + If space has no persistent storage. + + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + ####################### + # Inference Endpoints # + ####################### + + def list_inference_endpoints( + self, namespace: Optional[str] = None, *, token: Union[bool, str, None] = None + ) -> List[InferenceEndpoint]: + """Lists all inference endpoints for the given namespace. + + Args: + namespace (`str`, *optional*): + The namespace to list endpoints for. Defaults to the current user. Set to `"*"` to list all endpoints + from all namespaces (i.e. personal namespace and all orgs the user belongs to). + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + List[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> api.list_inference_endpoints() + [InferenceEndpoint(name='my-endpoint', ...), ...] + ``` + """ + # Special case: list all endpoints for all namespaces the user has access to + if namespace == "*": + user = self.whoami(token=token) + + # List personal endpoints first + endpoints: List[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token)) + + # Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access) + for org in user.get("orgs", []): + try: + endpoints += list_inference_endpoints(namespace=org["name"], token=token) + except HfHubHTTPError as error: + if error.response.status_code == 401: # Either no billing or user don't have access) + logger.debug("Cannot list Inference Endpoints for org '%s': %s", org["name"], error) + pass + + return endpoints + + # Normal case: list endpoints for a specific namespace + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return [ + InferenceEndpoint.from_raw(endpoint, namespace=namespace, token=token) + for endpoint in response.json()["items"] + ] + + def create_inference_endpoint( + self, + name: str, + *, + repository: str, + framework: str, + accelerator: str, + instance_size: str, + instance_type: str, + region: str, + vendor: str, + account_id: Optional[str] = None, + min_replica: int = 1, + max_replica: int = 1, + scale_to_zero_timeout: Optional[int] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[Dict] = None, + env: Optional[Dict[str, str]] = None, + secrets: Optional[Dict[str, str]] = None, + type: InferenceEndpointType = InferenceEndpointType.PROTECTED, + domain: Optional[str] = None, + path: Optional[str] = None, + cache_http_responses: Optional[bool] = None, + tags: Optional[List[str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Create a new Inference Endpoint. + + Args: + name (`str`): + The unique name for the new Inference Endpoint. + repository (`str`): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`): + The machine learning framework used for the model (e.g. `"custom"`). + accelerator (`str`): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + region (`str`): + The cloud region in which the Inference Endpoint will be created (e.g. `"us-east-1"`). + vendor (`str`): + The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. `"aws"`). + account_id (`str`, *optional*): + The account ID used to link a VPC to a private Inference Endpoint (if applicable). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. To enable + scaling to zero, set this value to 0 and adjust `scale_to_zero_timeout` accordingly. Defaults to 1. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero, or no scaling to zero if + set to None and `min_replica` is not 0. Defaults to None. + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`Dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + env (`Dict[str, str]`, *optional*): + Non-secret environment variables to inject in the container environment. + secrets (`Dict[str, str]`, *optional*): + Secret values to inject in the container environment. + type ([`InferenceEndpointType]`, *optional*): + The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`. + domain (`str`, *optional*): + The custom domain for the Inference Endpoint deployment, if setup the inference endpoint will be available at this domain (e.g. `"my-new-domain.cool-website.woof"`). + path (`str`, *optional*): + The custom path to the deployed model, should start with a `/` (e.g. `"/models/google-bert/bert-base-uncased"`). + cache_http_responses (`bool`, *optional*): + Whether to cache HTTP responses from the Inference Endpoint. Defaults to `False`. + tags (`List[str]`, *optional*): + A list of tags to associate with the Inference Endpoint. + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "my-endpoint-name", + ... repository="gpt2", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="cpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x2", + ... instance_type="intel-icl", + ... ) + >>> endpoint + InferenceEndpoint(name='my-endpoint-name', status="pending",...) + + # Run inference on the endpoint + >>> endpoint.client.text_generation(...) + "..." + ``` + + ```python + # Start an Inference Endpoint running Zephyr-7b-beta on TGI + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "aws-zephyr-7b-beta-0486", + ... repository="HuggingFaceH4/zephyr-7b-beta", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="gpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x1", + ... instance_type="nvidia-a10g", + ... env={ + ... "MAX_BATCH_PREFILL_TOKENS": "2048", + ... "MAX_INPUT_LENGTH": "1024", + ... "MAX_TOTAL_TOKENS": "1512", + ... "MODEL_ID": "/repository" + ... }, + ... custom_image={ + ... "health_route": "/health", + ... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", + ... }, + ... secrets={"MY_SECRET_KEY": "secret_value"}, + ... tags=["dev", "text-generation"], + ... ) + ``` + + ```python + # Start an Inference Endpoint running ProsusAI/finbert while scaling to zero in 15 minutes + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "finbert-classifier", + ... repository="ProsusAI/finbert", + ... framework="pytorch", + ... task="text-classification", + ... min_replica=0, + ... scale_to_zero_timeout=15, + ... accelerator="cpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x2", + ... instance_type="intel-icl", + ... ) + >>> endpoint.wait(timeout=300) + # Run inference on the endpoint + >>> endpoint.client.text_generation(...) + TextClassificationOutputElement(label='positive', score=0.8983615040779114) + ``` + + """ + namespace = namespace or self._get_namespace(token=token) + + if custom_image is not None: + image = ( + custom_image + if next(iter(custom_image)) in constants.INFERENCE_ENDPOINT_IMAGE_KEYS + else {"custom": custom_image} + ) + else: + image = {"huggingface": {}} + + payload: Dict = { + "accountId": account_id, + "compute": { + "accelerator": accelerator, + "instanceSize": instance_size, + "instanceType": instance_type, + "scaling": { + "maxReplica": max_replica, + "minReplica": min_replica, + "scaleToZeroTimeout": scale_to_zero_timeout, + }, + }, + "model": { + "framework": framework, + "repository": repository, + "revision": revision, + "task": task, + "image": image, + }, + "name": name, + "provider": { + "region": region, + "vendor": vendor, + }, + "type": type, + } + if env: + payload["model"]["env"] = env + if secrets: + payload["model"]["secrets"] = secrets + if domain is not None or path is not None: + payload["route"] = {} + if domain is not None: + payload["route"]["domain"] = domain + if path is not None: + payload["route"]["path"] = path + if cache_http_responses is not None: + payload["cacheHttpResponses"] = cache_http_responses + if tags is not None: + payload["tags"] = tags + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + @experimental + @validate_hf_hub_args + def create_inference_endpoint_from_catalog( + self, + repo_id: str, + *, + name: Optional[str] = None, + token: Union[bool, str, None] = None, + namespace: Optional[str] = None, + ) -> InferenceEndpoint: + """Create a new Inference Endpoint from a model in the Hugging Face Inference Catalog. + + The goal of the Inference Catalog is to provide a curated list of models that are optimized for inference + and for which default configurations have been tested. See https://endpoints.huggingface.co/catalog for a list + of available models in the catalog. + + Args: + repo_id (`str`): + The ID of the model in the catalog to deploy as an Inference Endpoint. + name (`str`, *optional*): + The unique name for the new Inference Endpoint. If not provided, a random name will be generated. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. + + Returns: + [`InferenceEndpoint`]: information about the new Inference Endpoint. + + + + `create_inference_endpoint_from_catalog` is experimental. Its API is subject to change in the future. Please provide feedback + if you have any suggestions or requests. + + + """ + token = token or self.token or get_token() + payload: Dict = { + "namespace": namespace or self._get_namespace(token=token), + "repoId": repo_id, + } + if name is not None: + payload["endpointName"] = name + + response = get_session().post( + f"{constants.INFERENCE_CATALOG_ENDPOINT}/deploy", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + data = response.json()["endpoint"] + return InferenceEndpoint.from_raw(data, namespace=data["name"], token=token) + + @experimental + @validate_hf_hub_args + def list_inference_catalog(self, *, token: Union[bool, str, None] = None) -> List[str]: + """List models available in the Hugging Face Inference Catalog. + + The goal of the Inference Catalog is to provide a curated list of models that are optimized for inference + and for which default configurations have been tested. See https://endpoints.huggingface.co/catalog for a list + of available models in the catalog. + + Use [`create_inference_endpoint_from_catalog`] to deploy a model from the catalog. + + Args: + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + + Returns: + List[`str`]: A list of model IDs available in the catalog. + + + `list_inference_catalog` is experimental. Its API is subject to change in the future. Please provide feedback + if you have any suggestions or requests. + + + """ + response = get_session().get( + f"{constants.INFERENCE_CATALOG_ENDPOINT}/repo-list", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + return response.json()["models"] + + def get_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Get information about an Inference Endpoint. + + Args: + name (`str`): + The name of the Inference Endpoint to retrieve information about. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the requested Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.get_inference_endpoint("my-text-to-image") + >>> endpoint + InferenceEndpoint(name='my-text-to-image', ...) + + # Get status + >>> endpoint.status + 'running' + >>> endpoint.url + 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' + + # Run inference + >>> endpoint.client.text_to_image(...) + ``` + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def update_inference_endpoint( + self, + name: str, + *, + # Compute update + accelerator: Optional[str] = None, + instance_size: Optional[str] = None, + instance_type: Optional[str] = None, + min_replica: Optional[int] = None, + max_replica: Optional[int] = None, + scale_to_zero_timeout: Optional[int] = None, + # Model update + repository: Optional[str] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[Dict] = None, + env: Optional[Dict[str, str]] = None, + secrets: Optional[Dict[str, str]] = None, + # Route update + domain: Optional[str] = None, + path: Optional[str] = None, + # Other + cache_http_responses: Optional[bool] = None, + tags: Optional[List[str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Update an Inference Endpoint. + + This method allows the update of either the compute configuration, the deployed model, the route, or any combination. + All arguments are optional but at least one must be provided. + + For convenience, you can also update an Inference Endpoint using [`InferenceEndpoint.update`]. + + Args: + name (`str`): + The name of the Inference Endpoint to update. + + accelerator (`str`, *optional*): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`, *optional*): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`, *optional*): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero. + + repository (`str`, *optional*): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`, *optional*): + The machine learning framework used for the model (e.g. `"custom"`). + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`Dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + env (`Dict[str, str]`, *optional*): + Non-secret environment variables to inject in the container environment + secrets (`Dict[str, str]`, *optional*): + Secret values to inject in the container environment. + + domain (`str`, *optional*): + The custom domain for the Inference Endpoint deployment, if setup the inference endpoint will be available at this domain (e.g. `"my-new-domain.cool-website.woof"`). + path (`str`, *optional*): + The custom path to the deployed model, should start with a `/` (e.g. `"/models/google-bert/bert-base-uncased"`). + + cache_http_responses (`bool`, *optional*): + Whether to cache HTTP responses from the Inference Endpoint. + tags (`List[str]`, *optional*): + A list of tags to associate with the Inference Endpoint. + + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + # Populate only the fields that are not None + payload: Dict = defaultdict(lambda: defaultdict(dict)) + if accelerator is not None: + payload["compute"]["accelerator"] = accelerator + if instance_size is not None: + payload["compute"]["instanceSize"] = instance_size + if instance_type is not None: + payload["compute"]["instanceType"] = instance_type + if max_replica is not None: + payload["compute"]["scaling"]["maxReplica"] = max_replica + if min_replica is not None: + payload["compute"]["scaling"]["minReplica"] = min_replica + if scale_to_zero_timeout is not None: + payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout + if repository is not None: + payload["model"]["repository"] = repository + if framework is not None: + payload["model"]["framework"] = framework + if revision is not None: + payload["model"]["revision"] = revision + if task is not None: + payload["model"]["task"] = task + if custom_image is not None: + payload["model"]["image"] = {"custom": custom_image} + if env is not None: + payload["model"]["env"] = env + if secrets is not None: + payload["model"]["secrets"] = secrets + if domain is not None: + payload["route"]["domain"] = domain + if path is not None: + payload["route"]["path"] = path + if cache_http_responses is not None: + payload["cacheHttpResponses"] = cache_http_responses + if tags is not None: + payload["tags"] = tags + + response = get_session().put( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def delete_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """Delete an Inference Endpoint. + + This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable + to pause it with [`pause_inference_endpoint`] or scale it to zero with [`scale_to_zero_inference_endpoint`]. + + For convenience, you can also delete an Inference Endpoint using [`InferenceEndpoint.delete`]. + + Args: + name (`str`): + The name of the Inference Endpoint to delete. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + namespace = namespace or self._get_namespace(token=token) + response = get_session().delete( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + def pause_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Pause an Inference Endpoint. + + A paused Inference Endpoint will not be charged. It can be resumed at any time using [`resume_inference_endpoint`]. + This is different than scaling the Inference Endpoint to zero with [`scale_to_zero_inference_endpoint`], which + would be automatically restarted when a request is made to it. + + For convenience, you can also pause an Inference Endpoint using [`pause_inference_endpoint`]. + + Args: + name (`str`): + The name of the Inference Endpoint to pause. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the paused Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def resume_inference_endpoint( + self, + name: str, + *, + namespace: Optional[str] = None, + running_ok: bool = True, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Resume an Inference Endpoint. + + For convenience, you can also resume an Inference Endpoint using [`InferenceEndpoint.resume`]. + + Args: + name (`str`): + The name of the Inference Endpoint to resume. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the resumed Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(response) + except HfHubHTTPError as error: + # If already running (and it's ok), then fetch current status and return + if running_ok and error.response.status_code == 400 and "already running" in error.response.text: + return self.get_inference_endpoint(name, namespace=namespace, token=token) + # Otherwise, raise the error + raise + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def scale_to_zero_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Scale Inference Endpoint to zero. + + An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a + cold start delay. This is different than pausing the Inference Endpoint with [`pause_inference_endpoint`], which + would require a manual resume with [`resume_inference_endpoint`]. + + For convenience, you can also scale an Inference Endpoint to zero using [`InferenceEndpoint.scale_to_zero`]. + + Args: + name (`str`): + The name of the Inference Endpoint to scale to zero. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the scaled-to-zero Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def _get_namespace(self, token: Union[bool, str, None] = None) -> str: + """Get the default namespace for the current user.""" + me = self.whoami(token=token) + if me["type"] == "user": + return me["name"] + else: + raise ValueError( + "Cannot determine default namespace. You must provide a 'namespace' as input or be logged in as a" + " user." + ) + + ######################## + # Collection Endpoints # + ######################## + @validate_hf_hub_args + def list_collections( + self, + *, + owner: Union[List[str], str, None] = None, + item: Union[List[str], str, None] = None, + sort: Optional[Literal["lastModified", "trending", "upvotes"]] = None, + limit: Optional[int] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[Collection]: + """List collections on the Huggingface Hub, given some filters. + + + + When listing collections, the item list per collection is truncated to 4 items maximum. To retrieve all items + from a collection, you must use [`get_collection`]. + + + + Args: + owner (`List[str]` or `str`, *optional*): + Filter by owner's username. + item (`List[str]` or `str`, *optional*): + Filter collections containing a particular items. Example: `"models/teknium/OpenHermes-2.5-Mistral-7B"`, `"datasets/squad"` or `"papers/2311.12983"`. + sort (`Literal["lastModified", "trending", "upvotes"]`, *optional*): + Sort collections by last modified, trending or upvotes. + limit (`int`, *optional*): + Maximum number of collections to be returned. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[Collection]`: an iterable of [`Collection`] objects. + """ + # Construct the API endpoint + path = f"{self.endpoint}/api/collections" + headers = self._build_hf_headers(token=token) + params: Dict = {} + if owner is not None: + params.update({"owner": owner}) + if item is not None: + params.update({"item": item}) + if sort is not None: + params.update({"sort": sort}) + if limit is not None: + params.update({"limit": limit}) + + # Paginate over the results until limit is reached + items = paginate(path, headers=headers, params=params) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + + # Parse as Collection and return + for position, collection_data in enumerate(items): + yield Collection(position=position, **collection_data) + + def get_collection(self, collection_slug: str, *, token: Union[bool, str, None] = None) -> Collection: + """Gets information about a Collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection of the Hub. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import get_collection + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + >>> collection.title + 'Recent models' + >>> len(collection.items) + 37 + >>> collection.items[0] + CollectionItem( + item_object_id='651446103cd773a050bf64c2', + item_id='TheBloke/U-Amethyst-20B-AWQ', + item_type='model', + position=88, + note=None + ) + ``` + """ + r = get_session().get( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def create_collection( + self, + title: str, + *, + namespace: Optional[str] = None, + description: Optional[str] = None, + private: bool = False, + exists_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> Collection: + """Create a new Collection on the Hub. + + Args: + title (`str`): + Title of the collection to create. Example: `"Recent models"`. + namespace (`str`, *optional*): + Namespace of the collection to create (username or org). Will default to the owner name. + description (`str`, *optional*): + Description of the collection to create. + private (`bool`, *optional*): + Whether the collection should be private or not. Defaults to `False` (i.e. public collection). + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if collection already exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import create_collection + >>> collection = create_collection( + ... title="ICCV 2023", + ... description="Portfolio of models, papers and demos I presented at ICCV 2023", + ... ) + >>> collection.slug + "username/iccv-2023-64f9a55bb3115b4f513ec026" + ``` + """ + if namespace is None: + namespace = self.whoami(token)["name"] + + payload = { + "title": title, + "namespace": namespace, + "private": private, + } + if description is not None: + payload["description"] = description + + r = get_session().post( + f"{self.endpoint}/api/collections", headers=self._build_hf_headers(token=token), json=payload + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Collection already exists and `exists_ok=True` + slug = r.json()["slug"] + return self.get_collection(slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_metadata( + self, + collection_slug: str, + *, + title: Optional[str] = None, + description: Optional[str] = None, + position: Optional[int] = None, + private: Optional[bool] = None, + theme: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Collection: + """Update metadata of a collection on the Hub. + + All arguments are optional. Only provided metadata will be updated. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection to update. + description (`str`, *optional*): + Description of the collection to update. + position (`int`, *optional*): + New position of the collection in the list of collections of the user. + private (`bool`, *optional*): + Whether the collection should be private or not. + theme (`str`, *optional*): + Theme of the collection on the Hub. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import update_collection_metadata + >>> collection = update_collection_metadata( + ... collection_slug="username/iccv-2023-64f9a55bb3115b4f513ec026", + ... title="ICCV Oct. 2023" + ... description="Portfolio of models, datasets, papers and demos I presented at ICCV Oct. 2023", + ... private=False, + ... theme="pink", + ... ) + >>> collection.slug + "username/iccv-oct-2023-64f9a55bb3115b4f513ec026" + # ^collection slug got updated but not the trailing ID + ``` + """ + payload = { + "position": position, + "private": private, + "theme": theme, + "title": title, + "description": description, + } + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + return Collection(**{**r.json()["data"], "endpoint": self.endpoint}) + + def delete_collection( + self, collection_slug: str, *, missing_ok: bool = False, token: Union[bool, str, None] = None + ) -> None: + """Delete a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to delete. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if collection doesn't exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import delete_collection + >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) + ``` + + + + This is a non-revertible action. A deleted collection cannot be restored. + + + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Collection doesn't exists and `missing_ok=True` + return + else: + raise + + def add_collection_item( + self, + collection_slug: str, + item_id: str, + item_type: CollectionItemType_T, + *, + note: Optional[str] = None, + exists_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> Collection: + """Add an item to a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_id (`str`): + ID of the item to add to the collection. It can be the ID of a repo on the Hub (e.g. `"facebook/bart-large-mnli"`) + or a paper id (e.g. `"2307.09288"`). + item_type (`str`): + Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if item already exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the item you try to add to the collection does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 409 if the item you try to add to the collection is already in the collection (and exists_ok=False) + + Example: + + ```py + >>> from huggingface_hub import add_collection_item + >>> collection = add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="pierre-loic/climate-news-articles", + ... item_type="dataset" + ... ) + >>> collection.items[-1].item_id + "pierre-loic/climate-news-articles" + # ^item got added to the collection on last position + + # Add item with a note + >>> add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="datasets/climate_fever", + ... item_type="dataset" + ... note="This dataset adopts the FEVER methodology that consists of 1,535 real-world claims regarding climate-change collected on the internet." + ... ) + (...) + ``` + """ + payload: Dict[str, Any] = {"item": {"id": item_id, "type": item_type}} + if note is not None: + payload["note"] = note + r = get_session().post( + f"{self.endpoint}/api/collections/{collection_slug}/items", + headers=self._build_hf_headers(token=token), + json=payload, + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Item already exists and `exists_ok=True` + return self.get_collection(collection_slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + note: Optional[str] = None, + position: Optional[int] = None, + token: Union[bool, str, None] = None, + ) -> None: + """Update an item in a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + position (`int`, *optional*): + New position of the item in the collection. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import get_collection, update_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Update item based on its ID (add note + update position) + >>> update_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... note="Newly updated model!" + ... position=0, + ... ) + ``` + """ + payload = {"position": position, "note": note} + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + + def delete_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + missing_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> None: + """Delete an item from a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if item doesn't exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import get_collection, delete_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Delete item based on its ID + >>> delete_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... ) + ``` + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Item already deleted and `missing_ok=True` + return + else: + raise + + ########################## + # Manage access requests # + ########################## + + @validate_hf_hub_args + def list_pending_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> List[AccessRequest]: + """ + Get pending access requests for a given gated repo. + + A pending request means the user has requested access to the repo but the request has not been processed yet. + If the approval mode is automatic, this list should be empty. Pending requests can be accepted or rejected + using [`accept_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_pending_access_requests, accept_access_request + + # List pending requests + >>> requests = list_pending_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='pending', + fields=None, + ), + ... + ] + + # Accept Clem's request + >>> accept_access_request("meta-llama/Llama-2-7b", "clem") + ``` + """ + return self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_accepted_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> List[AccessRequest]: + """ + Get accepted access requests for a given gated repo. + + An accepted request means the user has requested access to the repo and the request has been accepted. The user + can download any file of the repo. If the approval mode is automatic, this list should contains by default all + requests. Accepted requests can be cancelled or rejected at any time using [`cancel_access_request`] and + [`reject_access_request`]. A cancelled request will go back to the pending list while a rejected request will + go to the rejected list. In both cases, the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_accepted_access_requests + + >>> requests = list_accepted_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='accepted', + fields=None, + ), + ... + ] + ``` + """ + return self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_rejected_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> List[AccessRequest]: + """ + Get rejected access requests for a given gated repo. + + A rejected request means the user has requested access to the repo and the request has been explicitly rejected + by a repo owner (either you or another user from your organization). The user cannot download any file of the + repo. Rejected requests can be accepted or cancelled at any time using [`accept_access_request`] and + [`cancel_access_request`]. A cancelled request will go back to the pending list while an accepted request will + go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_rejected_access_requests + + >>> requests = list_rejected_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='rejected', + fields=None, + ), + ... + ] + ``` + """ + return self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token) + + def _list_access_requests( + self, + repo_id: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> List[AccessRequest]: + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + response = get_session().get( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + return [ + AccessRequest( + username=request["user"]["user"], + fullname=request["user"]["fullname"], + email=request["user"].get("email"), + status=request["status"], + timestamp=parse_datetime(request["timestamp"]), + fields=request.get("fields"), # only if custom fields in form + ) + for request in response.json() + ] + + @validate_hf_hub_args + def cancel_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Cancel an access request from a user for a given gated repo. + + A cancelled request will go back to the pending list and the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to cancel access request for. + user (`str`): + The username of the user which access request should be cancelled. + repo_type (`str`, *optional*): + The type of the repo to cancel access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request cannot be found. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request is already in the pending list. + """ + self._handle_access_request(repo_id, user, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def accept_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Accept an access request from a user for a given gated repo. + + Once the request is accepted, the user will be able to download any file of the repo and access the community + tab. If the approval mode is automatic, you don't have to accept requests manually. An accepted request can be + cancelled or rejected at any time using [`cancel_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to accept access request for. + user (`str`): + The username of the user which access request should be accepted. + repo_type (`str`, *optional*): + The type of the repo to accept access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request cannot be found. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request is already in the accepted list. + """ + self._handle_access_request(repo_id, user, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def reject_access_request( + self, + repo_id: str, + user: str, + *, + repo_type: Optional[str] = None, + rejection_reason: Optional[str], + token: Union[bool, str, None] = None, + ) -> None: + """ + Reject an access request from a user for a given gated repo. + + A rejected request will go to the rejected list. The user cannot download any file of the repo. Rejected + requests can be accepted or cancelled at any time using [`accept_access_request`] and [`cancel_access_request`]. + A cancelled request will go back to the pending list while an accepted request will go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to reject access request for. + user (`str`): + The username of the user which access request should be rejected. + repo_type (`str`, *optional*): + The type of the repo to reject access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + rejection_reason (`str`, *optional*): + Optional rejection reason that will be visible to the user (max 200 characters). + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request cannot be found. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request is already in the rejected list. + """ + self._handle_access_request( + repo_id, user, "rejected", repo_type=repo_type, rejection_reason=rejection_reason, token=token + ) + + @validate_hf_hub_args + def _handle_access_request( + self, + repo_id: str, + user: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + rejection_reason: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + payload = {"user": user, "status": status} + + if rejection_reason is not None: + if status != "rejected": + raise ValueError("`rejection_reason` can only be passed when rejecting an access request.") + payload["rejectionReason"] = rejection_reason + + response = get_session().post( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/handle", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def grant_access( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Grant access to a user for a given gated repo. + + Granting access don't require for the user to send an access request by themselves. The user is automatically + added to the accepted list meaning they can download the files You can revoke the granted access at any time + using [`cancel_access_request`] or [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to grant access to. + user (`str`): + The username of the user to grant access. + repo_type (`str`, *optional*): + The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the user already has access to the repo. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + response = get_session().post( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/grant", + headers=self._build_hf_headers(token=token), + json={"user": user}, + ) + hf_raise_for_status(response) + return response.json() + + ################### + # Manage webhooks # + ################### + + @validate_hf_hub_args + def get_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Get a webhook by its id. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to get. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the webhook. + + Example: + ```python + >>> from huggingface_hub import get_webhook + >>> webhook = get_webhook("654bbbc16f2ec14d77f109cc") + >>> print(webhook) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + secret="my-secret", + domains=["repo", "discussion"], + disabled=False, + ) + ``` + """ + response = get_session().get( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def list_webhooks(self, *, token: Union[bool, str, None] = None) -> List[WebhookInfo]: + """List all configured webhooks. + + Args: + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[WebhookInfo]`: + List of webhook info objects. + + Example: + ```python + >>> from huggingface_hub import list_webhooks + >>> webhooks = list_webhooks() + >>> len(webhooks) + 2 + >>> webhooks[0] + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + secret="my-secret", + domains=["repo", "discussion"], + disabled=False, + ) + ``` + """ + response = get_session().get( + f"{constants.ENDPOINT}/api/settings/webhooks", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhooks_data = response.json() + + return [ + WebhookInfo( + id=webhook["id"], + url=webhook["url"], + watched=[WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook["watched"]], + domains=webhook["domains"], + secret=webhook.get("secret"), + disabled=webhook["disabled"], + ) + for webhook in webhooks_data + ] + + @validate_hf_hub_args + def create_webhook( + self, + *, + url: str, + watched: List[Union[Dict, WebhookWatchedItem]], + domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, + secret: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> WebhookInfo: + """Create a new webhook. + + Args: + url (`str`): + URL to send the payload to. + watched (`List[WebhookWatchedItem]`): + List of [`WebhookWatchedItem`] to be watched by the webhook. It can be users, orgs, models, datasets or spaces. + Watched items can also be provided as plain dictionaries. + domains (`List[Literal["repo", "discussion"]]`, optional): + List of domains to watch. It can be "repo", "discussion" or both. + secret (`str`, optional): + A secret to sign the payload with. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the newly created webhook. + + Example: + ```python + >>> from huggingface_hub import create_webhook + >>> payload = create_webhook( + ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], + ... url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + ... domains=["repo", "discussion"], + ... secret="my-secret", + ... ) + >>> print(payload) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=False, + ) + ``` + """ + watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] + + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks", + json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def update_webhook( + self, + webhook_id: str, + *, + url: Optional[str] = None, + watched: Optional[List[Union[Dict, WebhookWatchedItem]]] = None, + domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, + secret: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> WebhookInfo: + """Update an existing webhook. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to be updated. + url (`str`, optional): + The URL to which the payload will be sent. + watched (`List[WebhookWatchedItem]`, optional): + List of items to watch. It can be users, orgs, models, datasets, or spaces. + Refer to [`WebhookWatchedItem`] for more details. Watched items can also be provided as plain dictionaries. + domains (`List[Literal["repo", "discussion"]]`, optional): + The domains to watch. This can include "repo", "discussion", or both. + secret (`str`, optional): + A secret to sign the payload with, providing an additional layer of security. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the updated webhook. + + Example: + ```python + >>> from huggingface_hub import update_webhook + >>> updated_payload = update_webhook( + ... webhook_id="654bbbc16f2ec14d77f109cc", + ... url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], + ... domains=["repo"], + ... secret="my-secret", + ... ) + >>> print(updated_payload) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo"], + secret="my-secret", + disabled=False, + ``` + """ + if watched is None: + watched = [] + watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] + + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def enable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Enable a webhook (makes it "active"). + + Args: + webhook_id (`str`): + The unique identifier of the webhook to enable. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the enabled webhook. + + Example: + ```python + >>> from huggingface_hub import enable_webhook + >>> enabled_webhook = enable_webhook("654bbbc16f2ec14d77f109cc") + >>> enabled_webhook + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=False, + ) + ``` + """ + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/enable", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def disable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Disable a webhook (makes it "disabled"). + + Args: + webhook_id (`str`): + The unique identifier of the webhook to disable. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the disabled webhook. + + Example: + ```python + >>> from huggingface_hub import disable_webhook + >>> disabled_webhook = disable_webhook("654bbbc16f2ec14d77f109cc") + >>> disabled_webhook + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=True, + ) + ``` + """ + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/disable", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def delete_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> None: + """Delete a webhook. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to delete. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `None` + + Example: + ```python + >>> from huggingface_hub import delete_webhook + >>> delete_webhook("654bbbc16f2ec14d77f109cc") + ``` + """ + response = get_session().delete( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + ############# + # Internals # + ############# + + def _build_hf_headers( + self, + token: Union[bool, str, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + ) -> Dict[str, str]: + """ + Alias for [`build_hf_headers`] that uses the token from [`HfApi`] client + when `token` is not provided. + """ + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + return build_hf_headers( + token=token, + library_name=library_name or self.library_name, + library_version=library_version or self.library_version, + user_agent=user_agent or self.user_agent, + headers=self.headers, + ) + + def _prepare_folder_deletions( + self, + repo_id: str, + repo_type: Optional[str], + revision: Optional[str], + path_in_repo: str, + delete_patterns: Optional[Union[List[str], str]], + token: Union[bool, str, None] = None, + ) -> List[CommitOperationDelete]: + """Generate the list of Delete operations for a commit to delete files from a repo. + + List remote files and match them against the `delete_patterns` constraints. Returns a list of [`CommitOperationDelete`] + with the matching items. + + Note: `.gitattributes` file is essential to make a repo work properly on the Hub. This file will always be + kept even if it matches the `delete_patterns` constraints. + """ + if delete_patterns is None: + # If no delete patterns, no need to list and filter remote files + return [] + + # List remote files + filenames = self.list_repo_files(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) + + # Compute relative path in repo + if path_in_repo and path_in_repo not in (".", "./"): + path_in_repo = path_in_repo.strip("/") + "/" # harmonize + relpath_to_abspath = { + file[len(path_in_repo) :]: file for file in filenames if file.startswith(path_in_repo) + } + else: + relpath_to_abspath = {file: file for file in filenames} + + # Apply filter on relative paths and return + return [ + CommitOperationDelete(path_in_repo=relpath_to_abspath[relpath], is_folder=False) + for relpath in filter_repo_objects(relpath_to_abspath.keys(), allow_patterns=delete_patterns) + if relpath_to_abspath[relpath] != ".gitattributes" + ] + + def _prepare_upload_folder_additions( + self, + folder_path: Union[str, Path], + path_in_repo: str, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> List[CommitOperationAdd]: + """Generate the list of Add operations for a commit to upload a folder. + + Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist) + constraints are discarded. + """ + + folder_path = Path(folder_path).expanduser().resolve() + if not folder_path.is_dir(): + raise ValueError(f"Provided path: '{folder_path}' is not a directory") + + # List files from folder + relpath_to_abspath = { + path.relative_to(folder_path).as_posix(): path + for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic + if path.is_file() + } + + # Filter files + # Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering. + filtered_repo_objects = list( + filter_repo_objects( + relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns + ) + ) + + prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else "" + + # If updating a README.md file, make sure the metadata format is valid + # It's better to fail early than to fail after all the files have been hashed. + if "README.md" in filtered_repo_objects: + self._validate_yaml( + content=relpath_to_abspath["README.md"].read_text(encoding="utf8"), + repo_type=repo_type, + token=token, + ) + if len(filtered_repo_objects) > 30: + log = logger.warning if len(filtered_repo_objects) > 200 else logger.info + log( + "It seems you are trying to upload a large folder at once. This might take some time and then fail if " + "the folder is too large. For such cases, it is recommended to upload in smaller batches or to use " + "`HfApi().upload_large_folder(...)`/`hf upload-large-folder` instead. For more details, " + "check out https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#upload-a-large-folder." + ) + + logger.info(f"Start hashing {len(filtered_repo_objects)} files.") + operations = [ + CommitOperationAdd( + path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk + path_in_repo=prefix + relpath, # "absolute" path in repo + ) + for relpath in filtered_repo_objects + ] + logger.info(f"Finished hashing {len(filtered_repo_objects)} files.") + return operations + + def _validate_yaml(self, content: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None): + """ + Validate YAML from `README.md`, used before file hashing and upload. + + Args: + content (`str`): + Content of `README.md` to validate. + repo_type (`str`, *optional*): + The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if YAML is invalid + """ + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + headers = self._build_hf_headers(token=token) + + response = get_session().post( + f"{self.endpoint}/api/validate-yaml", + json={"content": content, "repoType": repo_type}, + headers=headers, + ) + # Handle warnings (example: empty metadata) + response_content = response.json() + message = "\n".join([f"- {warning.get('message')}" for warning in response_content.get("warnings", [])]) + if message: + warnings.warn(f"Warnings while validating metadata in README.md:\n{message}") + + # Raise on errors + try: + hf_raise_for_status(response) + except BadRequestError as e: + errors = response_content.get("errors", []) + message = "\n".join([f"- {error.get('message')}" for error in errors]) + raise ValueError(f"Invalid metadata in README.md.\n{message}") from e + + def get_user_overview(self, username: str, token: Union[bool, str, None] = None) -> User: + """ + Get an overview of a user on the Hub. + + Args: + username (`str`): + Username of the user to get an overview of. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `User`: A [`User`] object with the user's overview. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the user does not exist on the Hub. + """ + r = get_session().get( + f"{constants.ENDPOINT}/api/users/{username}/overview", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return User(**r.json()) + + def list_organization_members(self, organization: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + List of members of an organization on the Hub. + + Args: + organization (`str`): + Name of the organization to get the members of. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the members of the organization. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the organization does not exist on the Hub. + + """ + for member in paginate( + path=f"{constants.ENDPOINT}/api/organizations/{organization}/members", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**member) + + def list_user_followers(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + Get the list of followers of a user on the Hub. + + Args: + username (`str`): + Username of the user to get the followers of. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the followers of the user. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the user does not exist on the Hub. + + """ + for follower in paginate( + path=f"{constants.ENDPOINT}/api/users/{username}/followers", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**follower) + + def list_user_following(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + Get the list of users followed by a user on the Hub. + + Args: + username (`str`): + Username of the user to get the users followed by. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the users followed by the user. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the user does not exist on the Hub. + + """ + for followed_user in paginate( + path=f"{constants.ENDPOINT}/api/users/{username}/following", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**followed_user) + + def list_papers( + self, + *, + query: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[PaperInfo]: + """ + List daily papers on the Hugging Face Hub given a search query. + + Args: + query (`str`, *optional*): + A search query string to find papers. + If provided, returns papers that match the query. + token (Union[bool, str, None], *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects. + + Example: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all papers with "attention" in their title + >>> api.list_papers(query="attention") + ``` + """ + path = f"{self.endpoint}/api/papers/search" + params = {} + if query: + params["q"] = query + r = get_session().get( + path, + params=params, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + for paper in r.json(): + yield PaperInfo(**paper) + + def paper_info(self, id: str) -> PaperInfo: + """ + Get information for a paper on the Hub. + + Args: + id (`str`, **optional**): + ArXiv id of the paper. + + Returns: + `PaperInfo`: A `PaperInfo` object. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the paper does not exist on the Hub. + """ + path = f"{self.endpoint}/api/papers/{id}" + r = get_session().get(path) + hf_raise_for_status(r) + return PaperInfo(**r.json()) + + def auth_check( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Check if the provided user token has access to a specific repository on the Hugging Face Hub. + + This method verifies whether the user, authenticated via the provided token, has access to the specified + repository. If the repository is not found or if the user lacks the required permissions to access it, + the method raises an appropriate exception. + + Args: + repo_id (`str`): + The repository to check for access. Format should be `"user/repo_name"`. + Example: `"user/my-cool-model"`. + + repo_type (`str`, *optional*): + The type of the repository. Should be one of `"model"`, `"dataset"`, or `"space"`. + If not specified, the default is `"model"`. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Raises: + [`~utils.RepositoryNotFoundError`]: + Raised if the repository does not exist, is private, or the user does not have access. This can + occur if the `repo_id` or `repo_type` is incorrect or if the repository is private but the user + is not authenticated. + + [`~utils.GatedRepoError`]: + Raised if the repository exists but is gated and the user is not authorized to access it. + + Example: + Check if the user has access to a repository: + + ```python + >>> from huggingface_hub import auth_check + >>> from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + + try: + auth_check("user/my-cool-model") + except GatedRepoError: + # Handle gated repository error + print("You do not have permission to access this gated repository.") + except RepositoryNotFoundError: + # Handle repository not found error + print("The repository was not found or you do not have access.") + ``` + + In this example: + - If the user has access, the method completes successfully. + - If the repository is gated or does not exist, appropriate exceptions are raised, allowing the user + to handle them accordingly. + """ + headers = self._build_hf_headers(token=token) + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/auth-check" + r = get_session().get(path, headers=headers) + hf_raise_for_status(r) + + def run_job( + self, + *, + image: str, + command: List[str], + env: Optional[Dict[str, Any]] = None, + secrets: Optional[Dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> JobInfo: + """ + Run compute Jobs on Hugging Face infrastructure. + + Args: + image (`str`): + The Docker image to use. + Examples: `"ubuntu"`, `"python:3.12"`, `"pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"`. + Example with an image from a Space: `"hf.co/spaces/lhoestq/duckdb"`. + + command (`List[str]`): + The command to run. Example: `["echo", "hello"]`. + + env (`Dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`Dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + Run your first Job: + + ```python + >>> from huggingface_hub import run_job + >>> run_job("python:3.12", ["python", "-c" ,"print('Hello from HF compute!')"]) + ``` + + Run a GPU Job: + + ```python + >>> from huggingface_hub import run_job + >>> image = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel" + >>> command = ["python", "-c", "import torch; print(f"This code ran with the following GPU: {torch.cuda.get_device_name()}")"] + >>> run_job(image, command, flavor="a10g-small") + ``` + + """ + if flavor is None: + flavor = SpaceHardware.CPU_BASIC + + # prepare payload to send to HF Jobs API + input_json: Dict[str, Any] = { + "command": command, + "arguments": [], + "environment": env or {}, + "flavor": flavor, + } + # secrets are optional + if secrets: + input_json["secrets"] = secrets + # timeout is optional + if timeout: + time_units_factors = {"s": 1, "m": 60, "h": 3600, "d": 3600 * 24} + if isinstance(timeout, str) and timeout[-1] in time_units_factors: + input_json["timeoutSeconds"] = int(float(timeout[:-1]) * time_units_factors[timeout[-1]]) + else: + input_json["timeoutSeconds"] = int(timeout) + # input is either from docker hub or from HF spaces + for prefix in ( + "https://huggingface.co/spaces/", + "https://hf.co/spaces/", + "huggingface.co/spaces/", + "hf.co/spaces/", + ): + if image.startswith(prefix): + input_json["spaceId"] = image[len(prefix) :] + break + else: + input_json["dockerImage"] = image + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().post( + f"https://huggingface.co/api/jobs/{namespace}", + json=input_json, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + job_info = response.json() + return JobInfo(**job_info, endpoint=self.endpoint) + + def fetch_job_logs( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[str]: + """ + Fetch all the logs from a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import fetch_job_logs, run_job + >>> job = run_job("python:3.12", ["python", "-c" ,"print('Hello from HF compute!')"]) + >>> for log in fetch_job_logs(job.job_id): + ... print(log) + Hello from HF compute! + ``` + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + logging_finished = logging_started = False + job_finished = False + # - We need to retry because sometimes the /logs doesn't return logs when the job just started. + # (for example it can return only two lines: one for "Job started" and one empty line) + # - Timeouts can happen in case of build errors + # - ChunkedEncodingError can happen in case of stopped logging in the middle of streaming + # - Infinite empty log stream can happen in case of build error + # (the logs stream is infinite and empty except for the Job started message) + # - there is a ": keep-alive" every 30 seconds + + # We don't use http_backoff since we need to check ourselves if ConnectionError.__context__ is a TimeoutError + max_retries = 5 + min_wait_time = 1 + max_wait_time = 10 + sleep_time = 0 + for _ in range(max_retries): + time.sleep(sleep_time) + sleep_time = min(max_wait_time, max(min_wait_time, sleep_time * 2)) + try: + resp = get_session().get( + f"https://huggingface.co/api/jobs/{namespace}/{job_id}/logs", + headers=self._build_hf_headers(token=token), + stream=True, + timeout=120, + ) + log = None + for line in resp.iter_lines(chunk_size=1): + line = line.decode("utf-8") + if line and line.startswith("data: {"): + data = json.loads(line[len("data: ") :]) + # timestamp = data["timestamp"] + if not data["data"].startswith("===== Job started"): + logging_started = True + log = data["data"] + yield log + logging_finished = logging_started + except requests.exceptions.ChunkedEncodingError: + # Response ended prematurely + break + except KeyboardInterrupt: + break + except requests.exceptions.ConnectionError as err: + is_timeout = err.__context__ and isinstance(getattr(err.__context__, "__cause__", None), TimeoutError) + if logging_started or not is_timeout: + raise + if logging_finished or job_finished: + break + job_status = ( + get_session() + .get( + f"https://huggingface.co/api/jobs/{namespace}/{job_id}", + headers=self._build_hf_headers(token=token), + ) + .json() + ) + if "status" in job_status and job_status["status"]["stage"] not in ("RUNNING", "UPDATING"): + job_finished = True + + def list_jobs( + self, + *, + timeout: Optional[int] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> List[JobInfo]: + """ + List compute Jobs on Hugging Face infrastructure. + + Args: + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + + namespace (`str`, *optional*): + The namespace from where it lists the jobs. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/jobs/{namespace}", + headers=self._build_hf_headers(token=token), + timeout=timeout, + ) + response.raise_for_status() + return [JobInfo(**job_info, endpoint=self.endpoint) for job_info in response.json()] + + def inspect_job( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> JobInfo: + """ + Inspect a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import inspect_job, run_job + >>> job = run_job("python:3.12", ["python", "-c" ,"print('Hello from HF compute!')"]) + >>> inspect_job(job.job_id) + JobInfo( + id='68780d00bbe36d38803f645f', + created_at=datetime.datetime(2025, 7, 16, 20, 35, 12, 808000, tzinfo=datetime.timezone.utc), + docker_image='python:3.12', + space_id=None, + command=['python', '-c', "print('Hello from HF compute!')"], + arguments=[], + environment={}, + secrets={}, + flavor='cpu-basic', + status=JobStatus(stage='RUNNING', message=None) + ) + ``` + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/jobs/{namespace}/{job_id}", + headers=self._build_hf_headers(token=token), + ) + response.raise_for_status() + return JobInfo(**response.json(), endpoint=self.endpoint) + + def cancel_job( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Cancel a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + get_session().post( + f"{self.endpoint}/api/jobs/{namespace}/{job_id}/cancel", + headers=self._build_hf_headers(token=token), + ).raise_for_status() + + @experimental + def run_uv_job( + self, + script: str, + *, + script_args: Optional[List[str]] = None, + dependencies: Optional[List[str]] = None, + python: Optional[str] = None, + image: Optional[str] = None, + env: Optional[Dict[str, Any]] = None, + secrets: Optional[Dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + _repo: Optional[str] = None, + ) -> JobInfo: + """ + Run a UV script Job on Hugging Face infrastructure. + + Args: + script (`str`): + Path or URL of the UV script. + + script_args (`List[str]`, *optional*) + Arguments to pass to the script. + + dependencies (`List[str]`, *optional*) + Dependencies to use to run the UV script. + + python (`str`, *optional*) + Use a specific Python version. Default is 3.12. + + image (`str`, *optional*, defaults to "ghcr.io/astral-sh/uv:python3.12-bookworm"): + Use a custom Docker image with `uv` installed. + + env (`Dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`Dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import run_uv_job + >>> script = "https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/trl/scripts/sft.py" + >>> run_uv_job(script, dependencies=["trl"], flavor="a10g-small") + ``` + """ + image = image or "ghcr.io/astral-sh/uv:python3.12-bookworm" + env = env or {} + secrets = secrets or {} + + # Build command + uv_args = [] + if dependencies: + for dependency in dependencies: + uv_args += ["--with", dependency] + if python: + uv_args += ["--python", python] + script_args = script_args or [] + + if namespace is None: + namespace = self.whoami(token=token)["name"] + + if script.startswith("http://") or script.startswith("https://"): + # Direct URL execution - no upload needed + command = ["uv", "run"] + uv_args + [script] + script_args + else: + # Local file - upload to HF + script_path = Path(script) + filename = script_path.name + # Parse repo + if _repo: + repo_id = _repo + if "/" not in repo_id: + repo_id = f"{namespace}/{repo_id}" + repo_id = _repo + else: + repo_id = f"{namespace}/hf-cli-jobs-uv-run-scripts" + + # Create repo if needed + try: + self.repo_info(repo_id, repo_type="dataset") + logger.debug(f"Using existing repository: {repo_id}") + except RepositoryNotFoundError: + logger.info(f"Creating repository: {repo_id}") + create_repo(repo_id, repo_type="dataset", private=True, exist_ok=True) + + # Upload script + logger.info(f"Uploading {script_path.name} to {repo_id}...") + with open(script_path, "r") as f: + script_content = f.read() + + self.upload_file( + path_or_fileobj=script_content.encode(), + path_in_repo=filename, + repo_id=repo_id, + repo_type="dataset", + ) + + script_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}" + repo_url = f"https://huggingface.co/datasets/{repo_id}" + + logger.debug(f"✓ Script uploaded to: {repo_url}/blob/main/{filename}") + + # Create and upload minimal README + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC") + readme_content = dedent( + f""" + --- + tags: + - hf-cli-jobs-uv-script + - ephemeral + viewer: false + --- + + # UV Script: {filename} + + Executed via `hf jobs uv run` on {timestamp} + + ## Run this script + + ```bash + hf jobs uv run {filename} + ``` + + --- + *Created with [hf jobs](https://huggingface.co/docs/huggingface_hub/main/en/guides/jobs)* + """ + ) + self.upload_file( + path_or_fileobj=readme_content.encode(), + path_in_repo="README.md", + repo_id=repo_id, + repo_type="dataset", + ) + + secrets["UV_SCRIPT_HF_TOKEN"] = token or self.token or get_token() + env["UV_SCRIPT_URL"] = script_url + + pre_command = ( + dedent( + """ + import urllib.request + import os + from pathlib import Path + o = urllib.request.build_opener() + o.addheaders = [("Authorization", "Bearer " + os.environ["UV_SCRIPT_HF_TOKEN"])] + Path("/tmp/script.py").write_bytes(o.open(os.environ["UV_SCRIPT_URL"]).read()) + """ + ) + .strip() + .replace('"', r"\"") + .split("\n") + ) + pre_command = ["python", "-c", '"' + "; ".join(pre_command) + '"'] + command = ["uv", "run"] + uv_args + ["/tmp/script.py"] + script_args + command = ["bash", "-c", " ".join(pre_command) + " && " + " ".join(command)] + + # Create RunCommand args + return self.run_job( + image=image, + command=command, + env=env, + secrets=secrets, + flavor=flavor, + timeout=timeout, + namespace=namespace, + token=token, + ) + + +def _parse_revision_from_pr_url(pr_url: str) -> str: + """Safely parse revision number from a PR url. + + Example: + ```py + >>> _parse_revision_from_pr_url("https://huggingface.co/bigscience/bloom/discussions/2") + "refs/pr/2" + ``` + """ + re_match = re.match(_REGEX_DISCUSSION_URL, pr_url) + if re_match is None: + raise RuntimeError(f"Unexpected response from the hub, expected a Pull Request URL but got: '{pr_url}'") + return f"refs/pr/{re_match[1]}" + + +api = HfApi() + +whoami = api.whoami +auth_check = api.auth_check +get_token_permission = api.get_token_permission + +list_models = api.list_models +model_info = api.model_info + +list_datasets = api.list_datasets +dataset_info = api.dataset_info + +list_spaces = api.list_spaces +space_info = api.space_info + +list_papers = api.list_papers +paper_info = api.paper_info + +repo_exists = api.repo_exists +revision_exists = api.revision_exists +file_exists = api.file_exists +repo_info = api.repo_info +list_repo_files = api.list_repo_files +list_repo_refs = api.list_repo_refs +list_repo_commits = api.list_repo_commits +list_repo_tree = api.list_repo_tree +get_paths_info = api.get_paths_info + +get_model_tags = api.get_model_tags +get_dataset_tags = api.get_dataset_tags + +create_commit = api.create_commit +create_repo = api.create_repo +delete_repo = api.delete_repo +update_repo_visibility = api.update_repo_visibility +update_repo_settings = api.update_repo_settings +move_repo = api.move_repo +upload_file = api.upload_file +upload_folder = api.upload_folder +delete_file = api.delete_file +delete_folder = api.delete_folder +delete_files = api.delete_files +upload_large_folder = api.upload_large_folder +preupload_lfs_files = api.preupload_lfs_files +create_branch = api.create_branch +delete_branch = api.delete_branch +create_tag = api.create_tag +delete_tag = api.delete_tag +get_full_repo_name = api.get_full_repo_name + +# Danger-zone API +super_squash_history = api.super_squash_history +list_lfs_files = api.list_lfs_files +permanently_delete_lfs_files = api.permanently_delete_lfs_files + +# Safetensors helpers +get_safetensors_metadata = api.get_safetensors_metadata +parse_safetensors_file_metadata = api.parse_safetensors_file_metadata + +# Background jobs +run_as_future = api.run_as_future + +# Activity API +list_liked_repos = api.list_liked_repos +list_repo_likers = api.list_repo_likers +unlike = api.unlike + +# Community API +get_discussion_details = api.get_discussion_details +get_repo_discussions = api.get_repo_discussions +create_discussion = api.create_discussion +create_pull_request = api.create_pull_request +change_discussion_status = api.change_discussion_status +comment_discussion = api.comment_discussion +edit_discussion_comment = api.edit_discussion_comment +rename_discussion = api.rename_discussion +merge_pull_request = api.merge_pull_request + +# Space API +add_space_secret = api.add_space_secret +delete_space_secret = api.delete_space_secret +get_space_variables = api.get_space_variables +add_space_variable = api.add_space_variable +delete_space_variable = api.delete_space_variable +get_space_runtime = api.get_space_runtime +request_space_hardware = api.request_space_hardware +set_space_sleep_time = api.set_space_sleep_time +pause_space = api.pause_space +restart_space = api.restart_space +duplicate_space = api.duplicate_space +request_space_storage = api.request_space_storage +delete_space_storage = api.delete_space_storage + +# Inference Endpoint API +list_inference_endpoints = api.list_inference_endpoints +create_inference_endpoint = api.create_inference_endpoint +get_inference_endpoint = api.get_inference_endpoint +update_inference_endpoint = api.update_inference_endpoint +delete_inference_endpoint = api.delete_inference_endpoint +pause_inference_endpoint = api.pause_inference_endpoint +resume_inference_endpoint = api.resume_inference_endpoint +scale_to_zero_inference_endpoint = api.scale_to_zero_inference_endpoint +create_inference_endpoint_from_catalog = api.create_inference_endpoint_from_catalog +list_inference_catalog = api.list_inference_catalog + +# Collections API +get_collection = api.get_collection +list_collections = api.list_collections +create_collection = api.create_collection +update_collection_metadata = api.update_collection_metadata +delete_collection = api.delete_collection +add_collection_item = api.add_collection_item +update_collection_item = api.update_collection_item +delete_collection_item = api.delete_collection_item +delete_collection_item = api.delete_collection_item + +# Access requests API +list_pending_access_requests = api.list_pending_access_requests +list_accepted_access_requests = api.list_accepted_access_requests +list_rejected_access_requests = api.list_rejected_access_requests +cancel_access_request = api.cancel_access_request +accept_access_request = api.accept_access_request +reject_access_request = api.reject_access_request +grant_access = api.grant_access + +# Webhooks API +create_webhook = api.create_webhook +disable_webhook = api.disable_webhook +delete_webhook = api.delete_webhook +enable_webhook = api.enable_webhook +get_webhook = api.get_webhook +list_webhooks = api.list_webhooks +update_webhook = api.update_webhook + + +# User API +get_user_overview = api.get_user_overview +list_organization_members = api.list_organization_members +list_user_followers = api.list_user_followers +list_user_following = api.list_user_following + +# Jobs API +run_job = api.run_job +fetch_job_logs = api.fetch_job_logs +list_jobs = api.list_jobs +inspect_job = api.inspect_job +cancel_job = api.cancel_job +run_uv_job = api.run_uv_job diff --git a/phivenv/Lib/site-packages/huggingface_hub/hf_file_system.py b/phivenv/Lib/site-packages/huggingface_hub/hf_file_system.py new file mode 100644 index 0000000000000000000000000000000000000000..050a1d9bf9c0dc150420dd144a44c87e6701785f --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/hf_file_system.py @@ -0,0 +1,1134 @@ +import os +import re +import tempfile +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from itertools import chain +from pathlib import Path +from typing import Any, Dict, Iterator, List, NoReturn, Optional, Tuple, Union +from urllib.parse import quote, unquote + +import fsspec +from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback +from fsspec.utils import isfilelike +from requests import Response + +from . import constants +from ._commit_api import CommitOperationCopy, CommitOperationDelete +from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from .file_download import hf_hub_url, http_get +from .hf_api import HfApi, LastCommitInfo, RepoFile +from .utils import HFValidationError, hf_raise_for_status, http_backoff + + +# Regex used to match special revisions with "/" in them (see #1710) +SPECIAL_REFS_REVISION_REGEX = re.compile( + r""" + (^refs\/convert\/\w+) # `refs/convert/parquet` revisions + | + (^refs\/pr\/\d+) # PR revisions + """, + re.VERBOSE, +) + + +@dataclass +class HfFileSystemResolvedPath: + """Data structure containing information about a resolved Hugging Face file system path.""" + + repo_type: str + repo_id: str + revision: str + path_in_repo: str + # The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision. + # Used to reconstruct the unresolved path to return to the user. + _raw_revision: Optional[str] = field(default=None, repr=False) + + def unresolve(self) -> str: + repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id + if self._raw_revision: + return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/") + elif self.revision != constants.DEFAULT_REVISION: + return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/") + else: + return f"{repo_path}/{self.path_in_repo}".rstrip("/") + + +class HfFileSystem(fsspec.AbstractFileSystem): + """ + Access a remote Hugging Face Hub repository as if were a local file system. + + + + [`HfFileSystem`] provides fsspec compatibility, which is useful for libraries that require it (e.g., reading + Hugging Face datasets directly with `pandas`). However, it introduces additional overhead due to this compatibility + layer. For better performance and reliability, it's recommended to use `HfApi` methods when possible. + + + + Args: + token (`str` or `bool`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + Usage: + + ```python + >>> from huggingface_hub import HfFileSystem + + >>> fs = HfFileSystem() + + >>> # List files + >>> fs.glob("my-username/my-model/*.bin") + ['my-username/my-model/pytorch_model.bin'] + >>> fs.ls("datasets/my-username/my-dataset", detail=False) + ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] + + >>> # Read/write files + >>> with fs.open("my-username/my-model/pytorch_model.bin") as f: + ... data = f.read() + >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f: + ... f.write(data) + ``` + """ + + root_marker = "" + protocol = "hf" + + def __init__( + self, + *args, + endpoint: Optional[str] = None, + token: Union[bool, str, None] = None, + **storage_options, + ): + super().__init__(*args, **storage_options) + self.endpoint = endpoint or constants.ENDPOINT + self.token = token + self._api = HfApi(endpoint=endpoint, token=token) + # Maps (repo_type, repo_id, revision) to a 2-tuple with: + # * the 1st element indicating whether the repositoy and the revision exist + # * the 2nd element being the exception raised if the repository or revision doesn't exist + self._repo_and_revision_exists_cache: Dict[ + Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] + ] = {} + + def _repo_and_revision_exist( + self, repo_type: str, repo_id: str, revision: Optional[str] + ) -> Tuple[bool, Optional[Exception]]: + if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: + try: + self._api.repo_info( + repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT + ) + except (RepositoryNotFoundError, HFValidationError) as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e + except RevisionNotFoundError as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + else: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] + + def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: + """ + Resolve a Hugging Face file system path into its components. + + Args: + path (`str`): + Path to resolve. + revision (`str`, *optional*): + The revision of the repo to resolve. Defaults to the revision specified in the path. + + Returns: + [`HfFileSystemResolvedPath`]: Resolved path information containing `repo_type`, `repo_id`, `revision` and `path_in_repo`. + + Raises: + `ValueError`: + If path contains conflicting revision information. + `NotImplementedError`: + If trying to list repositories. + """ + + def _align_revision_in_path_with_revision( + revision_in_path: Optional[str], revision: Optional[str] + ) -> Optional[str]: + if revision is not None: + if revision_in_path is not None and revision_in_path != revision: + raise ValueError( + f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' + " are not the same." + ) + else: + revision = revision_in_path + return revision + + path = self._strip_protocol(path) + if not path: + # can't list repositories at root + raise NotImplementedError("Access to repositories lists is not implemented.") + elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values(): + if "/" not in path: + # can't list repositories at the repository type level + raise NotImplementedError("Access to repositories lists is not implemented.") + repo_type, path = path.split("/", 1) + repo_type = constants.REPO_TYPES_MAPPING[repo_type] + else: + repo_type = constants.REPO_TYPE_MODEL + if path.count("/") > 0: + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + if "/" in revision_in_path: + match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path) + if match is not None and revision in (None, match.group()): + # Handle `refs/convert/parquet` and PR revisions separately + path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/") + revision_in_path = match.group() + else: + revision_in_path, path_in_repo = revision_in_path.split("/", 1) + else: + path_in_repo = "" + revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + _raise_file_not_found(path, err) + else: + revision_in_path = None + repo_id_with_namespace = "/".join(path.split("/")[:2]) + path_in_repo_with_namespace = "/".join(path.split("/")[2:]) + repo_id_without_namespace = path.split("/")[0] + path_in_repo_without_namespace = "/".join(path.split("/")[1:]) + repo_id = repo_id_with_namespace + path_in_repo = path_in_repo_with_namespace + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + if isinstance(err, (RepositoryNotFoundError, HFValidationError)): + repo_id = repo_id_without_namespace + path_in_repo = path_in_repo_without_namespace + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + _raise_file_not_found(path, err) + else: + _raise_file_not_found(path, err) + else: + repo_id = path + path_in_repo = "" + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) + else: + revision_in_path = None + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise NotImplementedError("Access to repositories lists is not implemented.") + + revision = revision if revision is not None else constants.DEFAULT_REVISION + return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path) + + def invalidate_cache(self, path: Optional[str] = None) -> None: + """ + Clear the cache for a given path. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.invalidate_cache). + + Args: + path (`str`, *optional*): + Path to clear from cache. If not provided, clear the entire cache. + + """ + if not path: + self.dircache.clear() + self._repo_and_revision_exists_cache.clear() + else: + resolved_path = self.resolve_path(path) + path = resolved_path.unresolve() + while path: + self.dircache.pop(path, None) + path = self._parent(path) + + # Only clear repo cache if path is to repo root + if not resolved_path.path_in_repo: + self._repo_and_revision_exists_cache.pop((resolved_path.repo_type, resolved_path.repo_id, None), None) + self._repo_and_revision_exists_cache.pop( + (resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision), None + ) + + def _open( + self, + path: str, + mode: str = "rb", + revision: Optional[str] = None, + block_size: Optional[int] = None, + **kwargs, + ) -> "HfFileSystemFile": + if "a" in mode: + raise NotImplementedError("Appending to remote files is not yet supported.") + if block_size == 0: + return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs) + else: + return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs) + + def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path = self.resolve_path(path, revision=revision) + self._api.delete_file( + path_in_repo=resolved_path.path_in_repo, + repo_id=resolved_path.repo_id, + token=self.token, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message"), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def rm( + self, + path: str, + recursive: bool = False, + maxdepth: Optional[int] = None, + revision: Optional[str] = None, + **kwargs, + ) -> None: + """ + Delete files from a repository. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.rm). + + + + Note: When possible, use `HfApi.delete_file()` for better performance. + + + + Args: + path (`str`): + Path to delete. + recursive (`bool`, *optional*): + If True, delete directory and all its contents. Defaults to False. + maxdepth (`int`, *optional*): + Maximum number of subdirectories to visit when deleting recursively. + revision (`str`, *optional*): + The git revision to delete from. + + """ + resolved_path = self.resolve_path(path, revision=revision) + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision) + paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)] + operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] + commit_message = f"Delete {path} " + commit_message += "recursively " if recursive else "" + commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" + # TODO: use `commit_description` to list all the deleted paths? + self._api.create_commit( + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + token=self.token, + operations=operations, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def ls( + self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs + ) -> List[Union[str, Dict[str, Any]]]: + """ + List the contents of a directory. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.ls). + + + + Note: When possible, use `HfApi.list_repo_tree()` for better performance. + + + + Args: + path (`str`): + Path to the directory. + detail (`bool`, *optional*): + If True, returns a list of dictionaries containing file information. If False, + returns a list of file paths. Defaults to True. + refresh (`bool`, *optional*): + If True, bypass the cache and fetch the latest data. Defaults to False. + revision (`str`, *optional*): + The git revision to list from. + + Returns: + `List[Union[str, Dict[str, Any]]]`: List of file paths (if detail=False) or list of file information + dictionaries (if detail=True). + """ + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + try: + out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs) + except EntryNotFoundError: + # Path could be a file + if not resolved_path.path_in_repo: + _raise_file_not_found(path, None) + out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs) + out = [o for o in out if o["name"] == path] + if len(out) == 0: + _raise_file_not_found(path, None) + return out if detail else [o["name"] for o in out] + + def _ls_tree( + self, + path: str, + recursive: bool = False, + refresh: bool = False, + revision: Optional[str] = None, + expand_info: bool = False, + ): + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + root_path = HfFileSystemResolvedPath( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + path_in_repo="", + _raw_revision=resolved_path._raw_revision, + ).unresolve() + + out = [] + if path in self.dircache and not refresh: + cached_path_infos = self.dircache[path] + out.extend(cached_path_infos) + dirs_not_in_dircache = [] + if recursive: + # Use BFS to traverse the cache and build the "recursive "output + # (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same) + dirs_to_visit = deque( + [path_info for path_info in cached_path_infos if path_info["type"] == "directory"] + ) + while dirs_to_visit: + dir_info = dirs_to_visit.popleft() + if dir_info["name"] not in self.dircache: + dirs_not_in_dircache.append(dir_info["name"]) + else: + cached_path_infos = self.dircache[dir_info["name"]] + out.extend(cached_path_infos) + dirs_to_visit.extend( + [path_info for path_info in cached_path_infos if path_info["type"] == "directory"] + ) + + dirs_not_expanded = [] + if expand_info: + # Check if there are directories with non-expanded entries + dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None] + + if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded): + # If the dircache is incomplete, find the common path of the missing and non-expanded entries + # and extend the output with the result of `_ls_tree(common_path, recursive=True)` + common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded) + # Get the parent directory if the common prefix itself is not a directory + common_path = ( + common_prefix.rstrip("/") + if common_prefix.endswith("/") + or common_prefix == root_path + or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded) + else self._parent(common_prefix) + ) + out = [o for o in out if not o["name"].startswith(common_path + "/")] + for cached_path in self.dircache: + if cached_path.startswith(common_path + "/"): + self.dircache.pop(cached_path, None) + self.dircache.pop(common_path, None) + out.extend( + self._ls_tree( + common_path, + recursive=recursive, + refresh=True, + revision=revision, + expand_info=expand_info, + ) + ) + else: + tree = self._api.list_repo_tree( + resolved_path.repo_id, + resolved_path.path_in_repo, + recursive=recursive, + expand=expand_info, + revision=resolved_path.revision, + repo_type=resolved_path.repo_type, + ) + for path_info in tree: + if isinstance(path_info, RepoFile): + cache_path_info = { + "name": root_path + "/" + path_info.path, + "size": path_info.size, + "type": "file", + "blob_id": path_info.blob_id, + "lfs": path_info.lfs, + "last_commit": path_info.last_commit, + "security": path_info.security, + } + else: + cache_path_info = { + "name": root_path + "/" + path_info.path, + "size": 0, + "type": "directory", + "tree_id": path_info.tree_id, + "last_commit": path_info.last_commit, + } + parent_path = self._parent(cache_path_info["name"]) + self.dircache.setdefault(parent_path, []).append(cache_path_info) + out.append(cache_path_info) + return out + + def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]: + """ + Return all files below the given path. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.walk). + + Args: + path (`str`): + Root path to list files from. + + Returns: + `Iterator[Tuple[str, List[str], List[str]]]`: An iterator of (path, list of directory names, list of file names) tuples. + """ + path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() + yield from super().walk(path, *args, **kwargs) + + def glob(self, path: str, **kwargs) -> List[str]: + """ + Find files by glob-matching. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob). + + Args: + path (`str`): + Path pattern to match. + + Returns: + `List[str]`: List of paths matching the pattern. + """ + path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() + return super().glob(path, **kwargs) + + def find( + self, + path: str, + maxdepth: Optional[int] = None, + withdirs: bool = False, + detail: bool = False, + refresh: bool = False, + revision: Optional[str] = None, + **kwargs, + ) -> Union[List[str], Dict[str, Dict[str, Any]]]: + """ + List all files below path. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.find). + + Args: + path (`str`): + Root path to list files from. + maxdepth (`int`, *optional*): + Maximum depth to descend into subdirectories. + withdirs (`bool`, *optional*): + Include directory paths in the output. Defaults to False. + detail (`bool`, *optional*): + If True, returns a dict mapping paths to file information. Defaults to False. + refresh (`bool`, *optional*): + If True, bypass the cache and fetch the latest data. Defaults to False. + revision (`str`, *optional*): + The git revision to list from. + + Returns: + `Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information. + """ + if maxdepth: + return super().find( + path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs + ) + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + try: + out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs) + except EntryNotFoundError: + # Path could be a file + if self.info(path, revision=revision, **kwargs)["type"] == "file": + out = {path: {}} + else: + out = {} + else: + if not withdirs: + out = [o for o in out if o["type"] != "directory"] + else: + # If `withdirs=True`, include the directory itself to be consistent with the spec + path_info = self.info(path, revision=resolved_path.revision, **kwargs) + out = [path_info] + out if path_info["type"] == "directory" else out + out = {o["name"]: o for o in out} + names = sorted(out) + if not detail: + return names + else: + return {name: out[name] for name in names} + + def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: + """ + Copy a file within or between repositories. + + + + Note: When possible, use `HfApi.upload_file()` for better performance. + + + + Args: + path1 (`str`): + Source path to copy from. + path2 (`str`): + Destination path to copy to. + revision (`str`, *optional*): + The git revision to copy from. + + """ + resolved_path1 = self.resolve_path(path1, revision=revision) + resolved_path2 = self.resolve_path(path2, revision=revision) + + same_repo = ( + resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id + ) + + if same_repo: + commit_message = f"Copy {path1} to {path2}" + self._api.create_commit( + repo_id=resolved_path1.repo_id, + repo_type=resolved_path1.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description", ""), + operations=[ + CommitOperationCopy( + src_path_in_repo=resolved_path1.path_in_repo, + path_in_repo=resolved_path2.path_in_repo, + src_revision=resolved_path1.revision, + ) + ], + ) + else: + with self.open(path1, "rb", revision=resolved_path1.revision) as f: + content = f.read() + commit_message = f"Copy {path1} to {path2}" + self._api.upload_file( + path_or_fileobj=content, + path_in_repo=resolved_path2.path_in_repo, + repo_id=resolved_path2.repo_id, + token=self.token, + repo_type=resolved_path2.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path1.unresolve()) + self.invalidate_cache(path=resolved_path2.unresolve()) + + def modified(self, path: str, **kwargs) -> datetime: + """ + Get the last modified time of a file. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.modified). + + Args: + path (`str`): + Path to the file. + + Returns: + `datetime`: Last commit date of the file. + """ + info = self.info(path, **{**kwargs, "expand_info": True}) + return info["last_commit"]["date"] + + def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]: + """ + Get information about a file or directory. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.info). + + + + Note: When possible, use `HfApi.get_paths_info()` or `HfApi.repo_info()` for better performance. + + + + Args: + path (`str`): + Path to get info for. + refresh (`bool`, *optional*): + If True, bypass the cache and fetch the latest data. Defaults to False. + revision (`str`, *optional*): + The git revision to get info from. + + Returns: + `Dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.). + + """ + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + expand_info = kwargs.get( + "expand_info", False + ) # don't expose it as a parameter in the public API to follow the spec + if not resolved_path.path_in_repo: + # Path is the root directory + out = { + "name": path, + "size": 0, + "type": "directory", + "last_commit": None, + } + if expand_info: + last_commit = self._api.list_repo_commits( + resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision + )[-1] + out = { + **out, + "tree_id": None, # TODO: tree_id of the root directory? + "last_commit": LastCommitInfo( + oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at + ), + } + else: + out = None + parent_path = self._parent(path) + if not expand_info and parent_path not in self.dircache: + # Fill the cache with cheap call + self.ls(parent_path) + if parent_path in self.dircache: + # Check if the path is in the cache + out1 = [o for o in self.dircache[parent_path] if o["name"] == path] + if not out1: + _raise_file_not_found(path, None) + out = out1[0] + if refresh or out is None or (expand_info and out and out["last_commit"] is None): + paths_info = self._api.get_paths_info( + resolved_path.repo_id, + resolved_path.path_in_repo, + expand=expand_info, + revision=resolved_path.revision, + repo_type=resolved_path.repo_type, + ) + if not paths_info: + _raise_file_not_found(path, None) + path_info = paths_info[0] + root_path = HfFileSystemResolvedPath( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + path_in_repo="", + _raw_revision=resolved_path._raw_revision, + ).unresolve() + if isinstance(path_info, RepoFile): + out = { + "name": root_path + "/" + path_info.path, + "size": path_info.size, + "type": "file", + "blob_id": path_info.blob_id, + "lfs": path_info.lfs, + "last_commit": path_info.last_commit, + "security": path_info.security, + } + else: + out = { + "name": root_path + "/" + path_info.path, + "size": 0, + "type": "directory", + "tree_id": path_info.tree_id, + "last_commit": path_info.last_commit, + } + if not expand_info: + out = {k: out[k] for k in ["name", "size", "type"]} + assert out is not None + return out + + def exists(self, path, **kwargs): + """ + Check if a file exists. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists). + + + + Note: When possible, use `HfApi.file_exists()` for better performance. + + + + Args: + path (`str`): + Path to check. + + Returns: + `bool`: True if file exists, False otherwise. + """ + try: + if kwargs.get("refresh", False): + self.invalidate_cache(path) + + self.info(path, **kwargs) + return True + except: # noqa: E722 + return False + + def isdir(self, path): + """ + Check if a path is a directory. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isdir). + + Args: + path (`str`): + Path to check. + + Returns: + `bool`: True if path is a directory, False otherwise. + """ + try: + return self.info(path)["type"] == "directory" + except OSError: + return False + + def isfile(self, path): + """ + Check if a path is a file. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isfile). + + Args: + path (`str`): + Path to check. + + Returns: + `bool`: True if path is a file, False otherwise. + """ + try: + return self.info(path)["type"] == "file" + except: # noqa: E722 + return False + + def url(self, path: str) -> str: + """ + Get the HTTP URL of the given path. + + Args: + path (`str`): + Path to get URL for. + + Returns: + `str`: HTTP URL to access the file or directory on the Hub. + """ + resolved_path = self.resolve_path(path) + url = hf_hub_url( + resolved_path.repo_id, + resolved_path.path_in_repo, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + endpoint=self.endpoint, + ) + if self.isdir(path): + url = url.replace("/resolve/", "/tree/", 1) + return url + + def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None: + """ + Copy single remote file to local. + + + + Note: When possible, use `HfApi.hf_hub_download()` for better performance. + + + + Args: + rpath (`str`): + Remote path to download from. + lpath (`str`): + Local path to download to. + callback (`Callback`, *optional*): + Optional callback to track download progress. Defaults to no callback. + outfile (`IO`, *optional*): + Optional file-like object to write to. If provided, `lpath` is ignored. + + """ + revision = kwargs.get("revision") + unhandled_kwargs = set(kwargs.keys()) - {"revision"} + if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0: + # for now, let's not handle custom callbacks + # and let's not handle custom kwargs + return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs) + + # Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883 + if isfilelike(lpath): + outfile = lpath + elif self.isdir(rpath): + os.makedirs(lpath, exist_ok=True) + return None + + if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object + os.makedirs(os.path.dirname(lpath), exist_ok=True) + + # Open file if not already open + close_file = False + if outfile is None: + outfile = open(lpath, "wb") + close_file = True + initial_pos = outfile.tell() + + # Custom implementation of `get_file` to use `http_get`. + resolve_remote_path = self.resolve_path(rpath, revision=revision) + expected_size = self.info(rpath, revision=revision)["size"] + callback.set_size(expected_size) + try: + http_get( + url=hf_hub_url( + repo_id=resolve_remote_path.repo_id, + revision=resolve_remote_path.revision, + filename=resolve_remote_path.path_in_repo, + repo_type=resolve_remote_path.repo_type, + endpoint=self.endpoint, + ), + temp_file=outfile, + displayed_filename=rpath, + expected_size=expected_size, + resume_size=0, + headers=self._api._build_hf_headers(), + _tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None, + ) + outfile.seek(initial_pos) + finally: + # Close file only if we opened it ourselves + if close_file: + outfile.close() + + @property + def transaction(self): + """A context within which files are committed together upon exit + + Requires the file class to implement `.commit()` and `.discard()` + for the normal and exception cases. + """ + # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231 + # See https://github.com/huggingface/huggingface_hub/issues/1733 + raise NotImplementedError("Transactional commits are not supported.") + + def start_transaction(self): + """Begin write transaction for deferring files, non-context version""" + # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241 + # See https://github.com/huggingface/huggingface_hub/issues/1733 + raise NotImplementedError("Transactional commits are not supported.") + + +class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): + def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): + try: + self.resolved_path = fs.resolve_path(path, revision=revision) + except FileNotFoundError as e: + if "w" in kwargs.get("mode", ""): + raise FileNotFoundError( + f"{e}.\nMake sure the repository and revision exist before writing data." + ) from e + raise + super().__init__(fs, self.resolved_path.unresolve(), **kwargs) + self.fs: HfFileSystem + + def __del__(self): + if not hasattr(self, "resolved_path"): + # Means that the constructor failed. Nothing to do. + return + return super().__del__() + + def _fetch_range(self, start: int, end: int) -> bytes: + headers = { + "range": f"bytes={start}-{end - 1}", + **self.fs._api._build_hf_headers(), + } + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + r = http_backoff( + "GET", + url, + headers=headers, + retry_on_status_codes=(500, 502, 503, 504), + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + hf_raise_for_status(r) + return r.content + + def _initiate_upload(self) -> None: + self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) + + def _upload_chunk(self, final: bool = False) -> None: + self.buffer.seek(0) + block = self.buffer.read() + self.temp_file.write(block) + if final: + self.temp_file.close() + self.fs._api.upload_file( + path_or_fileobj=self.temp_file.name, + path_in_repo=self.resolved_path.path_in_repo, + repo_id=self.resolved_path.repo_id, + token=self.fs.token, + repo_type=self.resolved_path.repo_type, + revision=self.resolved_path.revision, + commit_message=self.kwargs.get("commit_message"), + commit_description=self.kwargs.get("commit_description"), + ) + os.remove(self.temp_file.name) + self.fs.invalidate_cache( + path=self.resolved_path.unresolve(), + ) + + def read(self, length=-1): + """Read remote file. + + If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if + `hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a + temporary file and read from there. + """ + if self.mode == "rb" and (length is None or length == -1) and self.loc == 0: + with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming + out = f.read() + self.loc += len(out) + return out + return super().read(length) + + def url(self) -> str: + return self.fs.url(self.path) + + +class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile): + def __init__( + self, + fs: HfFileSystem, + path: str, + mode: str = "rb", + revision: Optional[str] = None, + block_size: int = 0, + cache_type: str = "none", + **kwargs, + ): + if block_size != 0: + raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}") + if cache_type != "none": + raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}") + if "w" in mode: + raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'") + try: + self.resolved_path = fs.resolve_path(path, revision=revision) + except FileNotFoundError as e: + if "w" in kwargs.get("mode", ""): + raise FileNotFoundError( + f"{e}.\nMake sure the repository and revision exist before writing data." + ) from e + # avoid an unnecessary .info() call to instantiate .details + self.details = {"name": self.resolved_path.unresolve(), "size": None} + super().__init__( + fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs + ) + self.response: Optional[Response] = None + self.fs: HfFileSystem + + def seek(self, loc: int, whence: int = 0): + if loc == 0 and whence == 1: + return + if loc == self.loc and whence == 0: + return + raise ValueError("Cannot seek streaming HF file") + + def read(self, length: int = -1): + read_args = (length,) if length >= 0 else () + if self.response is None: + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + self.response = http_backoff( + "GET", + url, + headers=self.fs._api._build_hf_headers(), + retry_on_status_codes=(500, 502, 503, 504), + stream=True, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + hf_raise_for_status(self.response) + try: + out = self.response.raw.read(*read_args) + except Exception: + self.response.close() + + # Retry by recreating the connection + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + self.response = http_backoff( + "GET", + url, + headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()}, + retry_on_status_codes=(500, 502, 503, 504), + stream=True, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + hf_raise_for_status(self.response) + try: + out = self.response.raw.read(*read_args) + except Exception: + self.response.close() + raise + self.loc += len(out) + return out + + def url(self) -> str: + return self.fs.url(self.path) + + def __del__(self): + if not hasattr(self, "resolved_path"): + # Means that the constructor failed. Nothing to do. + return + return super().__del__() + + def __reduce__(self): + return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name) + + +def safe_revision(revision: str) -> str: + return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision) + + +def safe_quote(s: str) -> str: + return quote(s, safe="") + + +def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn: + msg = path + if isinstance(err, RepositoryNotFoundError): + msg = f"{path} (repository not found)" + elif isinstance(err, RevisionNotFoundError): + msg = f"{path} (revision not found)" + elif isinstance(err, HFValidationError): + msg = f"{path} (invalid repository id)" + raise FileNotFoundError(msg) from err + + +def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str): + return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type) diff --git a/phivenv/Lib/site-packages/huggingface_hub/hub_mixin.py b/phivenv/Lib/site-packages/huggingface_hub/hub_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..e095dfc86541ecb98b7e463ee8c07e6f5922c011 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/hub_mixin.py @@ -0,0 +1,851 @@ +import inspect +import json +import os +from dataclasses import Field, asdict, dataclass, is_dataclass +from pathlib import Path +from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union + +import packaging.version + +from . import constants +from .errors import EntryNotFoundError, HfHubHTTPError +from .file_download import hf_hub_download +from .hf_api import HfApi +from .repocard import ModelCard, ModelCardData +from .utils import ( + SoftTemporaryDirectory, + is_jsonable, + is_safetensors_available, + is_simple_optional_type, + is_torch_available, + logging, + unwrap_simple_optional_type, + validate_hf_hub_args, +) + + +if is_torch_available(): + import torch # type: ignore + +if is_safetensors_available(): + import safetensors + from safetensors.torch import load_model as load_model_as_safetensor + from safetensors.torch import save_model as save_model_as_safetensor + + +logger = logging.get_logger(__name__) + + +# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349 +class DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[Dict[str, Field]] + + +# Generic variable that is either ModelHubMixin or a subclass thereof +T = TypeVar("T", bound="ModelHubMixin") +# Generic variable to represent an args type +ARGS_T = TypeVar("ARGS_T") +ENCODER_T = Callable[[ARGS_T], Any] +DECODER_T = Callable[[Any], ARGS_T] +CODER_T = Tuple[ENCODER_T, DECODER_T] + + +DEFAULT_MODEL_CARD = """ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + +This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration: +- Code: {{ repo_url | default("[More Information Needed]", true) }} +- Paper: {{ paper_url | default("[More Information Needed]", true) }} +- Docs: {{ docs_url | default("[More Information Needed]", true) }} +""" + + +@dataclass +class MixinInfo: + model_card_template: str + model_card_data: ModelCardData + docs_url: Optional[str] = None + paper_url: Optional[str] = None + repo_url: Optional[str] = None + + +class ModelHubMixin: + """ + A generic mixin to integrate ANY machine learning framework with the Hub. + + To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models + have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example + of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. + + When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to + `__init__` but to the class definition itself. This is useful to define metadata about the library integrating + [`ModelHubMixin`]. + + For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations). + + Args: + repo_url (`str`, *optional*): + URL of the library repository. Used to generate model card. + paper_url (`str`, *optional*): + URL of the library paper. Used to generate model card. + docs_url (`str`, *optional*): + URL of the library documentation. Used to generate model card. + model_card_template (`str`, *optional*): + Template of the model card. Used to generate model card. Defaults to a generic template. + language (`str` or `List[str]`, *optional*): + Language supported by the library. Used to generate model card. + library_name (`str`, *optional*): + Name of the library integrating ModelHubMixin. Used to generate model card. + license (`str`, *optional*): + License of the library integrating ModelHubMixin. Used to generate model card. + E.g: "apache-2.0" + license_name (`str`, *optional*): + Name of the library integrating ModelHubMixin. Used to generate model card. + Only used if `license` is set to `other`. + E.g: "coqui-public-model-license". + license_link (`str`, *optional*): + URL to the license of the library integrating ModelHubMixin. Used to generate model card. + Only used if `license` is set to `other` and `license_name` is set. + E.g: "https://coqui.ai/cpml". + pipeline_tag (`str`, *optional*): + Tag of the pipeline. Used to generate model card. E.g. "text-classification". + tags (`List[str]`, *optional*): + Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"] + coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*): + Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not + jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc. + + Example: + + ```python + >>> from huggingface_hub import ModelHubMixin + + # Inherit from ModelHubMixin + >>> class MyCustomModel( + ... ModelHubMixin, + ... library_name="my-library", + ... tags=["computer-vision"], + ... repo_url="https://github.com/huggingface/my-cool-library", + ... paper_url="https://arxiv.org/abs/2304.12244", + ... docs_url="https://huggingface.co/docs/my-cool-library", + ... # ^ optional metadata to generate model card + ... ): + ... def __init__(self, size: int = 512, device: str = "cpu"): + ... # define how to initialize your model + ... super().__init__() + ... ... + ... + ... def _save_pretrained(self, save_directory: Path) -> None: + ... # define how to serialize your model + ... ... + ... + ... @classmethod + ... def from_pretrained( + ... cls: Type[T], + ... pretrained_model_name_or_path: Union[str, Path], + ... *, + ... force_download: bool = False, + ... resume_download: Optional[bool] = None, + ... proxies: Optional[Dict] = None, + ... token: Optional[Union[str, bool]] = None, + ... cache_dir: Optional[Union[str, Path]] = None, + ... local_files_only: bool = False, + ... revision: Optional[str] = None, + ... **model_kwargs, + ... ) -> T: + ... # define how to deserialize your model + ... ... + + >>> model = MyCustomModel(size=256, device="gpu") + + # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + + # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + + # Download and initialize weights from the Hub + >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model") + >>> reloaded_model.size + 256 + + # Model card has been correctly populated + >>> from huggingface_hub import ModelCard + >>> card = ModelCard.load("username/my-awesome-model") + >>> card.data.tags + ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"] + >>> card.data.library_name + "my-library" + ``` + """ + + _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None + # ^ optional config attribute automatically set in `from_pretrained` + _hub_mixin_info: MixinInfo + # ^ information about the library integrating ModelHubMixin (used to generate model card) + _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not + _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters + _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters + _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded + _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types + # ^ internal values to handle config + + def __init_subclass__( + cls, + *, + # Generic info for model card + repo_url: Optional[str] = None, + paper_url: Optional[str] = None, + docs_url: Optional[str] = None, + # Model card template + model_card_template: str = DEFAULT_MODEL_CARD, + # Model card metadata + language: Optional[List[str]] = None, + library_name: Optional[str] = None, + license: Optional[str] = None, + license_name: Optional[str] = None, + license_link: Optional[str] = None, + pipeline_tag: Optional[str] = None, + tags: Optional[List[str]] = None, + # How to encode/decode arguments with custom type into a JSON config? + coders: Optional[ + Dict[Type, CODER_T] + # Key is a type. + # Value is a tuple (encoder, decoder). + # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))} + ] = None, + ) -> None: + """Inspect __init__ signature only once when subclassing + handle modelcard.""" + super().__init_subclass__() + + # Will be reused when creating modelcard + tags = tags or [] + tags.append("model_hub_mixin") + + # Initialize MixinInfo if not existent + info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData()) + + # If parent class has a MixinInfo, inherit from it as a copy + if hasattr(cls, "_hub_mixin_info"): + # Inherit model card template from parent class if not explicitly set + if model_card_template == DEFAULT_MODEL_CARD: + info.model_card_template = cls._hub_mixin_info.model_card_template + + # Inherit from parent model card data + info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict()) + + # Inherit other info + info.docs_url = cls._hub_mixin_info.docs_url + info.paper_url = cls._hub_mixin_info.paper_url + info.repo_url = cls._hub_mixin_info.repo_url + cls._hub_mixin_info = info + + # Update MixinInfo with metadata + if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD: + info.model_card_template = model_card_template + if repo_url is not None: + info.repo_url = repo_url + if paper_url is not None: + info.paper_url = paper_url + if docs_url is not None: + info.docs_url = docs_url + if language is not None: + info.model_card_data.language = language + if library_name is not None: + info.model_card_data.library_name = library_name + if license is not None: + info.model_card_data.license = license + if license_name is not None: + info.model_card_data.license_name = license_name + if license_link is not None: + info.model_card_data.license_link = license_link + if pipeline_tag is not None: + info.model_card_data.pipeline_tag = pipeline_tag + if tags is not None: + if info.model_card_data.tags is not None: + info.model_card_data.tags.extend(tags) + else: + info.model_card_data.tags = tags + + info.model_card_data.tags = sorted(set(info.model_card_data.tags)) + + # Handle encoders/decoders for args + cls._hub_mixin_coders = coders or {} + cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys()) + + # Inspect __init__ signature to handle config + cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters) + cls._hub_mixin_jsonable_default_values = { + param.name: cls._encode_arg(param.default) + for param in cls._hub_mixin_init_parameters.values() + if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default) + } + cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters + + def __new__(cls: Type[T], *args, **kwargs) -> T: + """Create a new instance of the class and handle config. + + 3 cases: + - If `self._hub_mixin_config` is already set, do nothing. + - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`. + - Otherwise, build `self._hub_mixin_config` from default values and passed values. + """ + instance = super().__new__(cls) + + # If `config` is already set, return early + if instance._hub_mixin_config is not None: + return instance + + # Infer passed values + passed_values = { + **{ + key: value + for key, value in zip( + # [1:] to skip `self` parameter + list(cls._hub_mixin_init_parameters)[1:], + args, + ) + }, + **kwargs, + } + + # If config passed as dataclass => set it and return early + if is_dataclass(passed_values.get("config")): + instance._hub_mixin_config = passed_values["config"] + return instance + + # Otherwise, build config from default + passed values + init_config = { + # default values + **cls._hub_mixin_jsonable_default_values, + # passed values + **{ + key: cls._encode_arg(value) # Encode custom types as jsonable value + for key, value in passed_values.items() + if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder + }, + } + passed_config = init_config.pop("config", {}) + + # Populate `init_config` with provided config + if isinstance(passed_config, dict): + init_config.update(passed_config) + + # Set `config` attribute and return + if init_config != {}: + instance._hub_mixin_config = init_config + return instance + + @classmethod + def _is_jsonable(cls, value: Any) -> bool: + """Check if a value is JSON serializable.""" + if is_dataclass(value): + return True + if isinstance(value, cls._hub_mixin_jsonable_custom_types): + return True + return is_jsonable(value) + + @classmethod + def _encode_arg(cls, arg: Any) -> Any: + """Encode an argument into a JSON serializable format.""" + if is_dataclass(arg): + return asdict(arg) # type: ignore[arg-type] + for type_, (encoder, _) in cls._hub_mixin_coders.items(): + if isinstance(arg, type_): + if arg is None: + return None + return encoder(arg) + return arg + + @classmethod + def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]: + """Decode a JSON serializable value into an argument.""" + if is_simple_optional_type(expected_type): + if value is None: + return None + expected_type = unwrap_simple_optional_type(expected_type) + # Dataclass => handle it + if is_dataclass(expected_type): + return _load_dataclass(expected_type, value) # type: ignore[return-value] + # Otherwise => check custom decoders + for type_, (_, decoder) in cls._hub_mixin_coders.items(): + if inspect.isclass(expected_type) and issubclass(expected_type, type_): + return decoder(value) + # Otherwise => don't decode + return value + + def save_pretrained( + self, + save_directory: Union[str, Path], + *, + config: Optional[Union[dict, DataclassInstance]] = None, + repo_id: Optional[str] = None, + push_to_hub: bool = False, + model_card_kwargs: Optional[Dict[str, Any]] = None, + **push_to_hub_kwargs, + ) -> Optional[str]: + """ + Save weights in local directory. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + config (`dict` or `DataclassInstance`, *optional*): + Model configuration specified as a key/value dictionary or a dataclass instance. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Huggingface Hub after saving it. + repo_id (`str`, *optional*): + ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if + not provided. + model_card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + push_to_hub_kwargs: + Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. + Returns: + `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. + """ + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json + # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite + # an existing config.json if it was not saved by `_save_pretrained`. + config_path = save_directory / constants.CONFIG_NAME + config_path.unlink(missing_ok=True) + + # save model weights/files (framework-specific) + self._save_pretrained(save_directory) + + # save config (if provided and if not serialized yet in `_save_pretrained`) + if config is None: + config = self._hub_mixin_config + if config is not None: + if is_dataclass(config): + config = asdict(config) # type: ignore[arg-type] + if not config_path.exists(): + config_str = json.dumps(config, sort_keys=True, indent=2) + config_path.write_text(config_str) + + # save model card + model_card_path = save_directory / "README.md" + model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {} + if not model_card_path.exists(): # do not overwrite if already exists + self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md") + + # push to the Hub if required + if push_to_hub: + kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input + if config is not None: # kwarg for `push_to_hub` + kwargs["config"] = config + if repo_id is None: + repo_id = save_directory.name # Defaults to `save_directory` name + return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs) + return None + + def _save_pretrained(self, save_directory: Path) -> None: + """ + Overwrite this method in subclass to define how to save your model. + Check out our [integration guide](../guides/integrations) for instructions. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + """ + raise NotImplementedError + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls: Type[T], + pretrained_model_name_or_path: Union[str, Path], + *, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[Dict] = None, + token: Optional[Union[str, bool]] = None, + cache_dir: Optional[Union[str, Path]] = None, + local_files_only: bool = False, + revision: Optional[str] = None, + **model_kwargs, + ) -> T: + """ + Download a model from the Huggingface Hub and instantiate it. + + Args: + pretrained_model_name_or_path (`str`, `Path`): + - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. + - Or a path to a `directory` containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. + Defaults to the latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `hf auth login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs (`Dict`, *optional*): + Additional kwargs to pass to the model during initialization. + """ + model_id = str(pretrained_model_name_or_path) + config_file: Optional[str] = None + if os.path.isdir(model_id): + if constants.CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, constants.CONFIG_NAME) + else: + logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=constants.CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") + + # Read config + config = None + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + + # Decode custom types in config + for key, value in config.items(): + if key in cls._hub_mixin_init_parameters: + expected_type = cls._hub_mixin_init_parameters[key].annotation + if expected_type is not inspect.Parameter.empty: + config[key] = cls._decode_arg(expected_type, value) + + # Populate model_kwargs from config + for param in cls._hub_mixin_init_parameters.values(): + if param.name not in model_kwargs and param.name in config: + model_kwargs[param.name] = config[param.name] + + # Check if `config` argument was passed at init + if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs: + # Decode `config` argument if it was passed + config_annotation = cls._hub_mixin_init_parameters["config"].annotation + config = cls._decode_arg(config_annotation, config) + + # Forward config to model initialization + model_kwargs["config"] = config + + # Inject config if `**kwargs` are expected + if is_dataclass(cls): + for key in cls.__dataclass_fields__: + if key not in model_kwargs and key in config: + model_kwargs[key] = config[key] + elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): + for key, value in config.items(): + if key not in model_kwargs: + model_kwargs[key] = value + + # Finally, also inject if `_from_pretrained` expects it + if cls._hub_mixin_inject_config and "config" not in model_kwargs: + model_kwargs["config"] = config + + instance = cls._from_pretrained( + model_id=str(model_id), + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + **model_kwargs, + ) + + # Implicitly set the config as instance attribute if not already set by the class + # This way `config` will be available when calling `save_pretrained` or `push_to_hub`. + if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})): + instance._hub_mixin_config = config + + return instance + + @classmethod + def _from_pretrained( + cls: Type[T], + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Optional[Union[str, bool]], + **model_kwargs, + ) -> T: + """Overwrite this method in subclass to define how to load your model from pretrained. + + Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most + args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this + method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location` + parameter to set on which device the model should be loaded. + + Check out our [integration guide](../guides/integrations) for more instructions. + + Args: + model_id (`str`): + ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the + latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `hf auth login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs: + Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. + """ + raise NotImplementedError + + @validate_hf_hub_args + def push_to_hub( + self, + repo_id: str, + *, + config: Optional[Union[dict, DataclassInstance]] = None, + commit_message: str = "Push model using huggingface_hub.", + private: Optional[bool] = None, + token: Optional[str] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + model_card_kwargs: Optional[Dict[str, Any]] = None, + ) -> str: + """ + Upload model checkpoint to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + repo_id (`str`): + ID of the repository to push to (example: `"username/my-model"`). + config (`dict` or `DataclassInstance`, *optional*): + Model configuration specified as a key/value dictionary or a dataclass instance. + commit_message (`str`, *optional*): + Message to commit while pushing. + private (`bool`, *optional*): + Whether the repository created should be private. + If `None` (default), the repo will be public unless the organization's default is private. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `hf auth login`. + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + model_card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + + Returns: + The url of the commit of your model in the given repository. + """ + api = HfApi(token=token) + repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs) + return api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) + + def generate_model_card(self, *args, **kwargs) -> ModelCard: + card = ModelCard.from_template( + card_data=self._hub_mixin_info.model_card_data, + template_str=self._hub_mixin_info.model_card_template, + repo_url=self._hub_mixin_info.repo_url, + paper_url=self._hub_mixin_info.paper_url, + docs_url=self._hub_mixin_info.docs_url, + **kwargs, + ) + return card + + +class PyTorchModelHubMixin(ModelHubMixin): + """ + Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model + is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model, + you should first set it back in training mode with `model.train()`. + + See [`ModelHubMixin`] for more details on how to use the mixin. + + Example: + + ```python + >>> import torch + >>> import torch.nn as nn + >>> from huggingface_hub import PyTorchModelHubMixin + + >>> class MyModel( + ... nn.Module, + ... PyTorchModelHubMixin, + ... library_name="keras-nlp", + ... repo_url="https://github.com/keras-team/keras-nlp", + ... paper_url="https://arxiv.org/abs/2304.12244", + ... docs_url="https://keras.io/keras_nlp/", + ... # ^ optional metadata to generate model card + ... ): + ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4): + ... super().__init__() + ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size)) + ... self.linear = nn.Linear(output_size, vocab_size) + + ... def forward(self, x): + ... return self.linear(x + self.param) + >>> model = MyModel(hidden_size=256) + + # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + + # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + + # Download and initialize weights from the Hub + >>> model = MyModel.from_pretrained("username/my-awesome-model") + >>> model.hidden_size + 256 + ``` + """ + + def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None: + tags = tags or [] + tags.append("pytorch_model_hub_mixin") + kwargs["tags"] = tags + return super().__init_subclass__(*args, **kwargs) + + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights from a Pytorch model to a local directory.""" + model_to_save = self.module if hasattr(self, "module") else self # type: ignore + save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) # type: ignore [arg-type] + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + model = cls(**model_kwargs) + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE) + return cls._load_as_safetensor(model, model_file, map_location, strict) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_safetensor(model, model_file, map_location, strict) + except EntryNotFoundError: + model_file = hf_hub_download( + repo_id=model_id, + filename=constants.PYTORCH_WEIGHTS_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_pickle(model, model_file, map_location, strict) + + @classmethod + def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True) + model.load_state_dict(state_dict, strict=strict) # type: ignore + model.eval() # type: ignore + return model + + @classmethod + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined] + load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] + if map_location != "cpu": + logger.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) # type: ignore [attr-defined] + else: + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type] + return model + + +def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance: + """Load a dataclass instance from a dictionary. + + Fields not expected by the dataclass are ignored. + """ + return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__}) diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference_api.py b/phivenv/Lib/site-packages/huggingface_hub/inference_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f895fcc61c3867838b013ecd3f6789cbc010b5b3 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference_api.py @@ -0,0 +1,217 @@ +import io +from typing import Any, Dict, List, Optional, Union + +from . import constants +from .hf_api import HfApi +from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args +from .utils._deprecation import _deprecate_method + + +logger = logging.get_logger(__name__) + + +ALL_TASKS = [ + # NLP + "text-classification", + "token-classification", + "table-question-answering", + "question-answering", + "zero-shot-classification", + "translation", + "summarization", + "conversational", + "feature-extraction", + "text-generation", + "text2text-generation", + "fill-mask", + "sentence-similarity", + # Audio + "text-to-speech", + "automatic-speech-recognition", + "audio-to-audio", + "audio-classification", + "voice-activity-detection", + # Computer vision + "image-classification", + "object-detection", + "image-segmentation", + "text-to-image", + "image-to-image", + # Others + "tabular-classification", + "tabular-regression", +] + + +class InferenceApi: + """Client to configure requests and make calls to the HuggingFace Inference API. + + Example: + + ```python + >>> from huggingface_hub.inference_api import InferenceApi + + >>> # Mask-fill example + >>> inference = InferenceApi("bert-base-uncased") + >>> inference(inputs="The goal of life is [MASK].") + [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] + + >>> # Question Answering example + >>> inference = InferenceApi("deepset/roberta-base-squad2") + >>> inputs = { + ... "question": "What's my name?", + ... "context": "My name is Clara and I live in Berkeley.", + ... } + >>> inference(inputs) + {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} + + >>> # Zero-shot example + >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli") + >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" + >>> params = {"candidate_labels": ["refund", "legal", "faq"]} + >>> inference(inputs, params) + {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} + + >>> # Overriding configured task + >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction") + + >>> # Text-to-image + >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1") + >>> inference("cat") + + + >>> # Return as raw response to parse the output yourself + >>> inference = InferenceApi("mio/amadeus") + >>> response = inference("hello world", raw_response=True) + >>> response.headers + {"Content-Type": "audio/flac", ...} + >>> response.content # raw bytes from server + b'(...)' + ``` + """ + + @validate_hf_hub_args + @_deprecate_method( + version="1.0", + message=( + "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out" + " this guide to learn how to convert your script to use it:" + " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client." + ), + ) + def __init__( + self, + repo_id: str, + task: Optional[str] = None, + token: Optional[str] = None, + gpu: bool = False, + ): + """Inits headers and API call information. + + Args: + repo_id (``str``): + Id of repository (e.g. `user/bert-base-uncased`). + task (``str``, `optional`, defaults ``None``): + Whether to force a task instead of using task specified in the + repository. + token (`str`, `optional`): + The API token to use as HTTP bearer authorization. This is not + the authentication token. You can find the token in + https://huggingface.co/settings/token. Alternatively, you can + find both your organizations and personal API tokens using + `HfApi().whoami(token)`. + gpu (`bool`, `optional`, defaults `False`): + Whether to use GPU instead of CPU for inference(requires Startup + plan at least). + """ + self.options = {"wait_for_model": True, "use_gpu": gpu} + self.headers = build_hf_headers(token=token) + + # Configure task + model_info = HfApi(token=token).model_info(repo_id=repo_id) + if not model_info.pipeline_tag and not task: + raise ValueError( + "Task not specified in the repository. Please add it to the model card" + " using pipeline_tag" + " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)" + ) + + if task and task != model_info.pipeline_tag: + if task not in ALL_TASKS: + raise ValueError(f"Invalid task {task}. Make sure it's valid.") + + logger.warning( + "You're using a different task than the one specified in the" + " repository. Be sure to know what you're doing :)" + ) + self.task = task + else: + assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None" + self.task = model_info.pipeline_tag + + self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" + + def __repr__(self): + # Do not add headers to repr to avoid leaking token. + return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})" + + def __call__( + self, + inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, + params: Optional[Dict] = None, + data: Optional[bytes] = None, + raw_response: bool = False, + ) -> Any: + """Make a call to the Inference API. + + Args: + inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*): + Inputs for the prediction. + params (`Dict`, *optional*): + Additional parameters for the models. Will be sent as `parameters` in the + payload. + data (`bytes`, *optional*): + Bytes content of the request. In this case, leave `inputs` and `params` empty. + raw_response (`bool`, defaults to `False`): + If `True`, the raw `Response` object is returned. You can parse its content + as preferred. By default, the content is parsed into a more practical format + (json dictionary or PIL Image for example). + """ + # Build payload + payload: Dict[str, Any] = { + "options": self.options, + } + if inputs: + payload["inputs"] = inputs + if params: + payload["parameters"] = params + + # Make API call + response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data) + + # Let the user handle the response + if raw_response: + return response + + # By default, parse the response for the user. + content_type = response.headers.get("Content-Type") or "" + if content_type.startswith("image"): + if not is_pillow_available(): + raise ImportError( + f"Task '{self.task}' returned as image but Pillow is not installed." + " Please install it (`pip install Pillow`) or pass" + " `raw_response=True` to get the raw `Response` object and parse" + " the image by yourself." + ) + + from PIL import Image + + return Image.open(io.BytesIO(response.content)) + elif content_type == "application/json": + return response.json() + else: + raise NotImplementedError( + f"{content_type} output type is not implemented yet. You can pass" + " `raw_response=True` to get the raw `Response` object and parse the" + " output by yourself." + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/keras_mixin.py b/phivenv/Lib/site-packages/huggingface_hub/keras_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..45d0eaf8a7737a7c5dac029f16607871e4bfccd4 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/keras_mixin.py @@ -0,0 +1,500 @@ +import collections.abc as collections +import json +import os +import warnings +from functools import wraps +from pathlib import Path +from shutil import copytree +from typing import Any, Dict, List, Optional, Union + +from huggingface_hub import ModelHubMixin, snapshot_download +from huggingface_hub.utils import ( + get_tf_version, + is_graphviz_available, + is_pydot_available, + is_tf_available, + yaml_dump, +) + +from . import constants +from .hf_api import HfApi +from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args +from .utils._typing import CallableT + + +logger = logging.get_logger(__name__) + +keras = None +if is_tf_available(): + # Depending on which version of TensorFlow is installed, we need to import + # keras from the correct location. + # See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1. + # Note: saving a keras model only works with Keras<3.0. + try: + import tf_keras as keras # type: ignore + except ImportError: + import tensorflow as tf # type: ignore + + keras = tf.keras + + +def _requires_keras_2_model(fn: CallableT) -> CallableT: + # Wrapper to raise if user tries to save a Keras 3.x model + @wraps(fn) + def _inner(model, *args, **kwargs): + if not hasattr(model, "history"): # hacky way to check if model is Keras 2.x + raise NotImplementedError( + f"Cannot use '{fn.__name__}': Keras 3.x is not supported." + " Please save models manually and upload them using `upload_folder` or `hf upload`." + ) + return fn(model, *args, **kwargs) + + return _inner # type: ignore [return-value] + + +def _flatten_dict(dictionary, parent_key=""): + """Flatten a nested dictionary. + Reference: https://stackoverflow.com/a/6027615/10319735 + + Args: + dictionary (`dict`): + The nested dictionary to be flattened. + parent_key (`str`): + The parent key to be prefixed to the children keys. + Necessary for recursing over the nested dictionary. + + Returns: + The flattened dictionary. + """ + items = [] + for key, value in dictionary.items(): + new_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, collections.MutableMapping): + items.extend( + _flatten_dict( + value, + new_key, + ).items() + ) + else: + items.append((new_key, value)) + return dict(items) + + +def _create_hyperparameter_table(model): + """Parse hyperparameter dictionary into a markdown table.""" + table = None + if model.optimizer is not None: + optimizer_params = model.optimizer.get_config() + # flatten the configuration + optimizer_params = _flatten_dict(optimizer_params) + optimizer_params["training_precision"] = keras.mixed_precision.global_policy().name + table = "| Hyperparameters | Value |\n| :-- | :-- |\n" + for key, value in optimizer_params.items(): + table += f"| {key} | {value} |\n" + return table + + +def _plot_network(model, save_directory): + keras.utils.plot_model( + model, + to_file=f"{save_directory}/model.png", + show_shapes=False, + show_dtype=False, + show_layer_names=True, + rankdir="TB", + expand_nested=False, + dpi=96, + layer_range=None, + ) + + +def _create_model_card( + model, + repo_dir: Path, + plot_model: bool = True, + metadata: Optional[dict] = None, +): + """ + Creates a model card for the repository. + + Do not overwrite an existing README.md file. + """ + readme_path = repo_dir / "README.md" + if readme_path.exists(): + return + + hyperparameters = _create_hyperparameter_table(model) + if plot_model and is_graphviz_available() and is_pydot_available(): + _plot_network(model, repo_dir) + if metadata is None: + metadata = {} + metadata["library_name"] = "keras" + model_card: str = "---\n" + model_card += yaml_dump(metadata, default_flow_style=False) + model_card += "---\n" + model_card += "\n## Model description\n\nMore information needed\n" + model_card += "\n## Intended uses & limitations\n\nMore information needed\n" + model_card += "\n## Training and evaluation data\n\nMore information needed\n" + if hyperparameters is not None: + model_card += "\n## Training procedure\n" + model_card += "\n### Training hyperparameters\n" + model_card += "\nThe following hyperparameters were used during training:\n\n" + model_card += hyperparameters + model_card += "\n" + if plot_model and os.path.exists(f"{repo_dir}/model.png"): + model_card += "\n ## Model Plot\n" + model_card += "\n
" + model_card += "\nView Model Plot\n" + path_to_plot = "./model.png" + model_card += f"\n![Model Image]({path_to_plot})\n" + model_card += "\n
" + + readme_path.write_text(model_card) + + +@_requires_keras_2_model +def save_pretrained_keras( + model, + save_directory: Union[str, Path], + config: Optional[Dict[str, Any]] = None, + include_optimizer: bool = False, + plot_model: bool = True, + tags: Optional[Union[list, str]] = None, + **model_save_kwargs, +): + """ + Saves a Keras model to save_directory in SavedModel format. Use this if + you're using the Functional or Sequential APIs. + + Args: + model (`Keras.Model`): + The [Keras + model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) + you'd like to save. The model must be compiled and built. + save_directory (`str` or `Path`): + Specify directory in which you want to save the Keras model. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + include_optimizer(`bool`, *optional*, defaults to `False`): + Whether or not to include optimizer in serialization. + plot_model (`bool`, *optional*, defaults to `True`): + Setting this to `True` will plot the model and put it in the model + card. Requires graphviz and pydot to be installed. + tags (Union[`str`,`list`], *optional*): + List of tags that are related to model or string of a single tag. See example tags + [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). + model_save_kwargs(`dict`, *optional*): + model_save_kwargs will be passed to + [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). + """ + if keras is None: + raise ImportError("Called a Tensorflow-specific function but could not import it.") + + if not model.built: + raise ValueError("Model should be built before trying to save") + + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # saving config + if config: + if not isinstance(config, dict): + raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'") + + with (save_directory / constants.CONFIG_NAME).open("w") as f: + json.dump(config, f) + + metadata = {} + if isinstance(tags, list): + metadata["tags"] = tags + elif isinstance(tags, str): + metadata["tags"] = [tags] + + task_name = model_save_kwargs.pop("task_name", None) + if task_name is not None: + warnings.warn( + "`task_name` input argument is deprecated. Pass `tags` instead.", + FutureWarning, + ) + if "tags" in metadata: + metadata["tags"].append(task_name) + else: + metadata["tags"] = [task_name] + + if model.history is not None: + if model.history.history != {}: + path = save_directory / "history.json" + if path.exists(): + warnings.warn( + "`history.json` file already exists, it will be overwritten by the history of this version.", + UserWarning, + ) + with path.open("w", encoding="utf-8") as f: + json.dump(model.history.history, f, indent=2, sort_keys=True) + + _create_model_card(model, save_directory, plot_model, metadata) + keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs) + + +def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin": + r""" + Instantiate a pretrained Keras model from a pre-trained model from the Hub. + The model is expected to be in `SavedModel` format. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + - A string, the `model id` of a pretrained model hosted inside a + model repo on huggingface.co. Valid model ids can be located + at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like + `dbmdz/bert-base-german-cased`. + - You can add `revision` by appending `@` at the end of model_id + simply like this: `dbmdz/bert-base-german-cased@main` Revision + is the specific model version to use. It can be a branch name, + a tag name, or a commit id, since we use a git-based system + for storing models and other artifacts on huggingface.co, so + `revision` can be any identifier allowed by git. + - A path to a `directory` containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., + `./my_model_directory/`. + - `None` if you are both providing the configuration and state + dictionary (resp. with keyword arguments `config` and + `state_dict`). + force_download (`bool`, *optional*, defaults to `False`): + Whether to force the (re-)download of the model weights and + configuration files, overriding the cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., + `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The + proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If + `True`, will use the token generated when running `transformers-cli + login` (stored in `~/.huggingface`). + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model + configuration should be cached if the standard cache should not be + used. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only look at local files (i.e., do not try to download + the model). + model_kwargs (`Dict`, *optional*): + model_kwargs will be passed to the model during initialization + + + + Passing `token=True` is required when you want to use a private + model. + + + """ + return KerasModelHubMixin.from_pretrained(*args, **kwargs) + + +@validate_hf_hub_args +@_requires_keras_2_model +def push_to_hub_keras( + model, + repo_id: str, + *, + config: Optional[dict] = None, + commit_message: str = "Push Keras model using huggingface_hub.", + private: Optional[bool] = None, + api_endpoint: Optional[str] = None, + token: Optional[str] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + log_dir: Optional[str] = None, + include_optimizer: bool = False, + tags: Optional[Union[list, str]] = None, + plot_model: bool = True, + **model_save_kwargs, +): + """ + Upload model checkpoint to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + model (`Keras.Model`): + The [Keras model](`https://www.tensorflow.org/api_docs/python/tf/keras/Model`) you'd like to push to the + Hub. The model must be compiled and built. + repo_id (`str`): + ID of the repository to push to (example: `"username/my-model"`). + commit_message (`str`, *optional*, defaults to "Add Keras model"): + Message to commit while pushing. + private (`bool`, *optional*): + Whether the repository created should be private. + If `None` (default), the repo will be public unless the organization's default is private. + api_endpoint (`str`, *optional*): + The API endpoint to use when pushing the model to the hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If + not set, will use the token set when logging in with + `hf auth login` (stored in `~/.huggingface`). + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to + the default branch as specified in your repository, which + defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. + Defaults to `False`. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + log_dir (`str`, *optional*): + TensorBoard logging directory to be pushed. The Hub automatically + hosts and displays a TensorBoard instance if log files are included + in the repository. + include_optimizer (`bool`, *optional*, defaults to `False`): + Whether or not to include optimizer during serialization. + tags (Union[`list`, `str`], *optional*): + List of tags that are related to model or string of a single tag. See example tags + [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). + plot_model (`bool`, *optional*, defaults to `True`): + Setting this to `True` will plot the model and put it in the model + card. Requires graphviz and pydot to be installed. + model_save_kwargs(`dict`, *optional*): + model_save_kwargs will be passed to + [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). + + Returns: + The url of the commit of your model in the given repository. + """ + api = HfApi(endpoint=api_endpoint) + repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + save_pretrained_keras( + model, + saved_path, + config=config, + include_optimizer=include_optimizer, + tags=tags, + plot_model=plot_model, + **model_save_kwargs, + ) + + # If `log_dir` provided, delete remote logs and upload new ones + if log_dir is not None: + delete_patterns = ( + [] + if delete_patterns is None + else ( + [delete_patterns] # convert `delete_patterns` to a list + if isinstance(delete_patterns, str) + else delete_patterns + ) + ) + delete_patterns.append("logs/*") + copytree(log_dir, saved_path / "logs") + + return api.upload_folder( + repo_type="model", + repo_id=repo_id, + folder_path=saved_path, + commit_message=commit_message, + token=token, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) + + +class KerasModelHubMixin(ModelHubMixin): + """ + Implementation of [`ModelHubMixin`] to provide model Hub upload/download + capabilities to Keras models. + + + ```python + >>> import tensorflow as tf + >>> from huggingface_hub import KerasModelHubMixin + + + >>> class MyModel(tf.keras.Model, KerasModelHubMixin): + ... def __init__(self, **kwargs): + ... super().__init__() + ... self.config = kwargs.pop("config", None) + ... self.dummy_inputs = ... + ... self.layer = ... + + ... def call(self, *args): + ... return ... + + + >>> # Initialize and compile the model as you normally would + >>> model = MyModel() + >>> model.compile(...) + >>> # Build the graph by training it or passing dummy inputs + >>> _ = model(model.dummy_inputs) + >>> # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + >>> # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + >>> # Download and initialize weights from the Hub + >>> model = MyModel.from_pretrained("username/super-cool-model") + ``` + """ + + def _save_pretrained(self, save_directory): + save_pretrained_keras(self, save_directory) + + @classmethod + def _from_pretrained( + cls, + model_id, + revision, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + token, + config: Optional[Dict[str, Any]] = None, + **model_kwargs, + ): + """Here we just call [`from_pretrained_keras`] function so both the mixin and + functional APIs stay in sync. + + TODO - Some args above aren't used since we are calling + snapshot_download instead of hf_hub_download. + """ + if keras is None: + raise ImportError("Called a TensorFlow-specific function but could not import it.") + + # Root is either a local filepath matching model_id or a cached snapshot + if not os.path.isdir(model_id): + storage_folder = snapshot_download( + repo_id=model_id, + revision=revision, + cache_dir=cache_dir, + library_name="keras", + library_version=get_tf_version(), + ) + else: + storage_folder = model_id + + # TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here... + model = keras.models.load_model(storage_folder) + + # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir. + model.config = config + + return model diff --git a/phivenv/Lib/site-packages/huggingface_hub/lfs.py b/phivenv/Lib/site-packages/huggingface_hub/lfs.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d4f36829dfe941ce60c8b711c1cc912e8c324a --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/lfs.py @@ -0,0 +1,460 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Git LFS related type definitions and utilities""" + +import inspect +import io +import re +import warnings +from dataclasses import dataclass +from math import ceil +from os.path import getsize +from pathlib import Path +from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tuple, TypedDict +from urllib.parse import unquote + +from huggingface_hub import constants + +from .utils import ( + build_hf_headers, + fix_hf_endpoint_in_url, + get_session, + hf_raise_for_status, + http_backoff, + logging, + tqdm, + validate_hf_hub_args, +) +from .utils._lfs import SliceFileObj +from .utils.sha import sha256, sha_fileobj +from .utils.tqdm import is_tqdm_disabled + + +if TYPE_CHECKING: + from ._commit_api import CommitOperationAdd + +logger = logging.get_logger(__name__) + +OID_REGEX = re.compile(r"^[0-9a-f]{40}$") + +LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload" + +LFS_HEADERS = { + "Accept": "application/vnd.git-lfs+json", + "Content-Type": "application/vnd.git-lfs+json", +} + + +@dataclass +class UploadInfo: + """ + Dataclass holding required information to determine whether a blob + should be uploaded to the hub using the LFS protocol or the regular protocol + + Args: + sha256 (`bytes`): + SHA256 hash of the blob + size (`int`): + Size in bytes of the blob + sample (`bytes`): + First 512 bytes of the blob + """ + + sha256: bytes + size: int + sample: bytes + + @classmethod + def from_path(cls, path: str): + size = getsize(path) + with io.open(path, "rb") as file: + sample = file.peek(512)[:512] + sha = sha_fileobj(file) + return cls(size=size, sha256=sha, sample=sample) + + @classmethod + def from_bytes(cls, data: bytes): + sha = sha256(data).digest() + return cls(size=len(data), sample=data[:512], sha256=sha) + + @classmethod + def from_fileobj(cls, fileobj: BinaryIO): + sample = fileobj.read(512) + fileobj.seek(0, io.SEEK_SET) + sha = sha_fileobj(fileobj) + size = fileobj.tell() + fileobj.seek(0, io.SEEK_SET) + return cls(size=size, sha256=sha, sample=sample) + + +@validate_hf_hub_args +def post_lfs_batch_info( + upload_infos: Iterable[UploadInfo], + token: Optional[str], + repo_type: str, + repo_id: str, + revision: Optional[str] = None, + endpoint: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> Tuple[List[dict], List[dict]]: + """ + Requests the LFS batch endpoint to retrieve upload instructions + + Learn more: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md + + Args: + upload_infos (`Iterable` of `UploadInfo`): + `UploadInfo` for the files that are being uploaded, typically obtained + from `CommitOperationAdd.upload_info` + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The git revision to upload to. + headers (`dict`, *optional*): + Additional headers to include in the request + + Returns: + `LfsBatchInfo`: 2-tuple: + - First element is the list of upload instructions from the server + - Second element is an list of errors, if any + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If an argument is invalid or the server response is malformed. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the server returned an error. + """ + endpoint = endpoint if endpoint is not None else constants.ENDPOINT + url_prefix = "" + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + url_prefix = constants.REPO_TYPES_URL_PREFIXES[repo_type] + batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch" + payload: Dict = { + "operation": "upload", + "transfers": ["basic", "multipart"], + "objects": [ + { + "oid": upload.sha256.hex(), + "size": upload.size, + } + for upload in upload_infos + ], + "hash_algo": "sha256", + } + if revision is not None: + payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted' + + headers = { + **LFS_HEADERS, + **build_hf_headers(token=token), + **(headers or {}), + } + resp = get_session().post(batch_url, headers=headers, json=payload) + hf_raise_for_status(resp) + batch_info = resp.json() + + objects = batch_info.get("objects", None) + if not isinstance(objects, list): + raise ValueError("Malformed response from server") + + return ( + [_validate_batch_actions(obj) for obj in objects if "error" not in obj], + [_validate_batch_error(obj) for obj in objects if "error" in obj], + ) + + +class PayloadPartT(TypedDict): + partNumber: int + etag: str + + +class CompletionPayloadT(TypedDict): + """Payload that will be sent to the Hub when uploading multi-part.""" + + oid: str + parts: List[PayloadPartT] + + +def lfs_upload( + operation: "CommitOperationAdd", + lfs_batch_action: Dict, + token: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + endpoint: Optional[str] = None, +) -> None: + """ + Handles uploading a given object to the Hub with the LFS protocol. + + Can be a No-op if the content of the file is already present on the hub large file storage. + + Args: + operation (`CommitOperationAdd`): + The add operation triggering this upload. + lfs_batch_action (`dict`): + Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for + more details. + headers (`dict`, *optional*): + Headers to include in the request, including authentication and user agent headers. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `lfs_batch_action` is improperly formatted + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the upload resulted in an error + """ + # 0. If LFS file is already present, skip upload + _validate_batch_actions(lfs_batch_action) + actions = lfs_batch_action.get("actions") + if actions is None: + # The file was already uploaded + logger.debug(f"Content of file {operation.path_in_repo} is already present upstream - skipping upload") + return + + # 1. Validate server response (check required keys in dict) + upload_action = lfs_batch_action["actions"]["upload"] + _validate_lfs_action(upload_action) + verify_action = lfs_batch_action["actions"].get("verify") + if verify_action is not None: + _validate_lfs_action(verify_action) + + # 2. Upload file (either single part or multi-part) + header = upload_action.get("header", {}) + chunk_size = header.get("chunk_size") + upload_url = fix_hf_endpoint_in_url(upload_action["href"], endpoint=endpoint) + if chunk_size is not None: + try: + chunk_size = int(chunk_size) + except (ValueError, TypeError): + raise ValueError( + f"Malformed response from LFS batch endpoint: `chunk_size` should be an integer. Got '{chunk_size}'." + ) + _upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_url) + else: + _upload_single_part(operation=operation, upload_url=upload_url) + + # 3. Verify upload went well + if verify_action is not None: + _validate_lfs_action(verify_action) + verify_url = fix_hf_endpoint_in_url(verify_action["href"], endpoint) + verify_resp = get_session().post( + verify_url, + headers=build_hf_headers(token=token, headers=headers), + json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size}, + ) + hf_raise_for_status(verify_resp) + logger.debug(f"{operation.path_in_repo}: Upload successful") + + +def _validate_lfs_action(lfs_action: dict): + """validates response from the LFS batch endpoint""" + if not ( + isinstance(lfs_action.get("href"), str) + and (lfs_action.get("header") is None or isinstance(lfs_action.get("header"), dict)) + ): + raise ValueError("lfs_action is improperly formatted") + return lfs_action + + +def _validate_batch_actions(lfs_batch_actions: dict): + """validates response from the LFS batch endpoint""" + if not (isinstance(lfs_batch_actions.get("oid"), str) and isinstance(lfs_batch_actions.get("size"), int)): + raise ValueError("lfs_batch_actions is improperly formatted") + + upload_action = lfs_batch_actions.get("actions", {}).get("upload") + verify_action = lfs_batch_actions.get("actions", {}).get("verify") + if upload_action is not None: + _validate_lfs_action(upload_action) + if verify_action is not None: + _validate_lfs_action(verify_action) + return lfs_batch_actions + + +def _validate_batch_error(lfs_batch_error: dict): + """validates response from the LFS batch endpoint""" + if not (isinstance(lfs_batch_error.get("oid"), str) and isinstance(lfs_batch_error.get("size"), int)): + raise ValueError("lfs_batch_error is improperly formatted") + error_info = lfs_batch_error.get("error") + if not ( + isinstance(error_info, dict) + and isinstance(error_info.get("message"), str) + and isinstance(error_info.get("code"), int) + ): + raise ValueError("lfs_batch_error is improperly formatted") + return lfs_batch_error + + +def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> None: + """ + Uploads `fileobj` as a single PUT HTTP request (basic LFS transfer protocol) + + Args: + upload_url (`str`): + The URL to PUT the file to. + fileobj: + The file-like object holding the data to upload. + + Returns: `requests.Response` + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the upload resulted in an error. + """ + with operation.as_file(with_tqdm=True) as fileobj: + # S3 might raise a transient 500 error -> let's retry if that happens + response = http_backoff("PUT", upload_url, data=fileobj, retry_on_status_codes=(500, 502, 503, 504)) + hf_raise_for_status(response) + + +def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size: int, upload_url: str) -> None: + """ + Uploads file using HF multipart LFS transfer protocol. + """ + # 1. Get upload URLs for each part + sorted_parts_urls = _get_sorted_parts_urls(header=header, upload_info=operation.upload_info, chunk_size=chunk_size) + + # 2. Upload parts (either with hf_transfer or in pure Python) + use_hf_transfer = constants.HF_HUB_ENABLE_HF_TRANSFER + if ( + constants.HF_HUB_ENABLE_HF_TRANSFER + and not isinstance(operation.path_or_fileobj, str) + and not isinstance(operation.path_or_fileobj, Path) + ): + warnings.warn( + "hf_transfer is enabled but does not support uploading from bytes or BinaryIO, falling back to regular" + " upload" + ) + use_hf_transfer = False + + response_headers = ( + _upload_parts_hf_transfer(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size) + if use_hf_transfer + else _upload_parts_iteratively(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size) + ) + + # 3. Send completion request + completion_res = get_session().post( + upload_url, + json=_get_completion_payload(response_headers, operation.upload_info.sha256.hex()), + headers=LFS_HEADERS, + ) + hf_raise_for_status(completion_res) + + +def _get_sorted_parts_urls(header: Dict, upload_info: UploadInfo, chunk_size: int) -> List[str]: + sorted_part_upload_urls = [ + upload_url + for _, upload_url in sorted( + [ + (int(part_num, 10), upload_url) + for part_num, upload_url in header.items() + if part_num.isdigit() and len(part_num) > 0 + ], + key=lambda t: t[0], + ) + ] + num_parts = len(sorted_part_upload_urls) + if num_parts != ceil(upload_info.size / chunk_size): + raise ValueError("Invalid server response to upload large LFS file") + return sorted_part_upload_urls + + +def _get_completion_payload(response_headers: List[Dict], oid: str) -> CompletionPayloadT: + parts: List[PayloadPartT] = [] + for part_number, header in enumerate(response_headers): + etag = header.get("etag") + if etag is None or etag == "": + raise ValueError(f"Invalid etag (`{etag}`) returned for part {part_number + 1}") + parts.append( + { + "partNumber": part_number + 1, + "etag": etag, + } + ) + return {"oid": oid, "parts": parts} + + +def _upload_parts_iteratively( + operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int +) -> List[Dict]: + headers = [] + with operation.as_file(with_tqdm=True) as fileobj: + for part_idx, part_upload_url in enumerate(sorted_parts_urls): + with SliceFileObj( + fileobj, + seek_from=chunk_size * part_idx, + read_limit=chunk_size, + ) as fileobj_slice: + # S3 might raise a transient 500 error -> let's retry if that happens + part_upload_res = http_backoff( + "PUT", part_upload_url, data=fileobj_slice, retry_on_status_codes=(500, 502, 503, 504) + ) + hf_raise_for_status(part_upload_res) + headers.append(part_upload_res.headers) + return headers # type: ignore + + +def _upload_parts_hf_transfer( + operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int +) -> List[Dict]: + # Upload file using an external Rust-based package. Upload is faster but support less features (no progress bars). + try: + from hf_transfer import multipart_upload + except ImportError: + raise ValueError( + "Fast uploading using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is" + " not available in your environment. Try `pip install hf_transfer`." + ) + + supports_callback = "callback" in inspect.signature(multipart_upload).parameters + if not supports_callback: + warnings.warn( + "You are using an outdated version of `hf_transfer`. Consider upgrading to latest version to enable progress bars using `pip install -U hf_transfer`." + ) + + total = operation.upload_info.size + desc = operation.path_in_repo + if len(desc) > 40: + desc = f"(…){desc[-40:]}" + + with tqdm( + unit="B", + unit_scale=True, + total=total, + initial=0, + desc=desc, + disable=is_tqdm_disabled(logger.getEffectiveLevel()), + name="huggingface_hub.lfs_upload", + ) as progress: + try: + output = multipart_upload( + file_path=operation.path_or_fileobj, + parts_urls=sorted_parts_urls, + chunk_size=chunk_size, + max_files=128, + parallel_failures=127, # could be removed + max_retries=5, + **({"callback": progress.update} if supports_callback else {}), + ) + except Exception as e: + raise RuntimeError( + "An error occurred while uploading using `hf_transfer`. Consider disabling HF_HUB_ENABLE_HF_TRANSFER for" + " better error handling." + ) from e + if not supports_callback: + progress.update(total) + return output diff --git a/phivenv/Lib/site-packages/huggingface_hub/py.typed b/phivenv/Lib/site-packages/huggingface_hub/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/huggingface_hub/repocard.py b/phivenv/Lib/site-packages/huggingface_hub/repocard.py new file mode 100644 index 0000000000000000000000000000000000000000..df4fb81aa1b78095913a7e25c3a8c529283cebee --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/repocard.py @@ -0,0 +1,830 @@ +import os +import re +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Type, Union + +import requests +import yaml + +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.hf_api import upload_file +from huggingface_hub.repocard_data import ( + CardData, + DatasetCardData, + EvalResult, + ModelCardData, + SpaceCardData, + eval_results_to_model_index, + model_index_to_eval_results, +) +from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump + +from . import constants +from .errors import EntryNotFoundError +from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args + + +logger = logging.get_logger(__name__) + + +TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md" +TEMPLATE_DATASETCARD_PATH = Path(__file__).parent / "templates" / "datasetcard_template.md" + +# exact same regex as in the Hub server. Please keep in sync. +# See https://github.com/huggingface/moon-landing/blob/main/server/lib/ViewMarkdown.ts#L18 +REGEX_YAML_BLOCK = re.compile(r"^(\s*---[\r\n]+)([\S\s]*?)([\r\n]+---(\r\n|\n|$))") + + +class RepoCard: + card_data_class = CardData + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "model" + + def __init__(self, content: str, ignore_metadata_errors: bool = False): + """Initialize a RepoCard from string content. The content should be a + Markdown file with a YAML block at the beginning and a Markdown body. + + Args: + content (`str`): The content of the Markdown file. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> text = ''' + ... --- + ... language: en + ... license: mit + ... --- + ... + ... # My repo + ... ''' + >>> card = RepoCard(text) + >>> card.data.to_dict() + {'language': 'en', 'license': 'mit'} + >>> card.text + '\\n# My repo\\n' + + ``` + + Raises the following error: + + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + when the content of the repo card metadata is not a dictionary. + + + """ + + # Set the content of the RepoCard, as well as underlying .data and .text attributes. + # See the `content` property setter for more details. + self.ignore_metadata_errors = ignore_metadata_errors + self.content = content + + @property + def content(self): + """The content of the RepoCard, including the YAML block and the Markdown body.""" + line_break = _detect_line_ending(self._content) or "\n" + return f"---{line_break}{self.data.to_yaml(line_break=line_break, original_order=self._original_order)}{line_break}---{line_break}{self.text}" + + @content.setter + def content(self, content: str): + """Set the content of the RepoCard.""" + self._content = content + + match = REGEX_YAML_BLOCK.search(content) + if match: + # Metadata found in the YAML block + yaml_block = match.group(2) + self.text = content[match.end() :] + data_dict = yaml.safe_load(yaml_block) + + if data_dict is None: + data_dict = {} + + # The YAML block's data should be a dictionary + if not isinstance(data_dict, dict): + raise ValueError("repo card metadata block should be a dict") + else: + # Model card without metadata... create empty metadata + logger.warning("Repo card metadata block was not found. Setting CardData to empty.") + data_dict = {} + self.text = content + + self.data = self.card_data_class(**data_dict, ignore_metadata_errors=self.ignore_metadata_errors) + self._original_order = list(data_dict.keys()) + + def __str__(self): + return self.content + + def save(self, filepath: Union[Path, str]): + r"""Save a RepoCard to a file. + + Args: + filepath (`Union[Path, str]`): Filepath to the markdown file to save. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> card = RepoCard("---\nlanguage: en\n---\n# This is a test repo card") + >>> card.save("/tmp/test.md") + + ``` + """ + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + # Preserve newlines as in the existing file. + with open(filepath, mode="w", newline="", encoding="utf-8") as f: + f.write(str(self)) + + @classmethod + def load( + cls, + repo_id_or_path: Union[str, Path], + repo_type: Optional[str] = None, + token: Optional[str] = None, + ignore_metadata_errors: bool = False, + ): + """Initialize a RepoCard from a Hugging Face Hub repo's README.md or a local filepath. + + Args: + repo_id_or_path (`Union[str, Path]`): + The repo ID associated with a Hugging Face Hub repo or a local filepath. + repo_type (`str`, *optional*): + The type of Hugging Face repo to push to. Defaults to None, which will use use "model". Other options + are "dataset" and "space". Not used when loading from a local filepath. If this is called from a child + class, the default value will be the child class's `repo_type`. + token (`str`, *optional*): + Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + + Returns: + [`huggingface_hub.repocard.RepoCard`]: The RepoCard (or subclass) initialized from the repo's + README.md file or filepath. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> card = RepoCard.load("nateraw/food") + >>> assert card.data.tags == ["generated_from_trainer", "image-classification", "pytorch"] + + ``` + """ + + if Path(repo_id_or_path).is_file(): + card_path = Path(repo_id_or_path) + elif isinstance(repo_id_or_path, str): + card_path = Path( + hf_hub_download( + repo_id_or_path, + constants.REPOCARD_NAME, + repo_type=repo_type or cls.repo_type, + token=token, + ) + ) + else: + raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).") + + # Preserve newlines in the existing file. + with card_path.open(mode="r", newline="", encoding="utf-8") as f: + return cls(f.read(), ignore_metadata_errors=ignore_metadata_errors) + + def validate(self, repo_type: Optional[str] = None): + """Validates card against Hugging Face Hub's card validation logic. + Using this function requires access to the internet, so it is only called + internally by [`huggingface_hub.repocard.RepoCard.push_to_hub`]. + + Args: + repo_type (`str`, *optional*, defaults to "model"): + The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". + If this function is called from a child class, the default will be the child class's `repo_type`. + + + Raises the following errors: + + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if the card fails validation checks. + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the request to the Hub API fails for any other reason. + + + """ + + # If repo type is provided, otherwise, use the repo type of the card. + repo_type = repo_type or self.repo_type + + body = { + "repoType": repo_type, + "content": str(self), + } + headers = {"Accept": "text/plain"} + + try: + r = get_session().post("https://huggingface.co/api/validate-yaml", body, headers=headers) + r.raise_for_status() + except requests.exceptions.HTTPError as exc: + if r.status_code == 400: + raise ValueError(r.text) + else: + raise exc + + def push_to_hub( + self, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ): + """Push a RepoCard to a Hugging Face Hub repo. + + Args: + repo_id (`str`): + The repo ID of the Hugging Face Hub repo to push to. Example: "nateraw/food". + token (`str`, *optional*): + Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to + the stored token. + repo_type (`str`, *optional*, defaults to "model"): + The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this + function is called by a child class, it will default to the child class's `repo_type`. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. + commit_description (`str`, *optional*) + The description of the generated commit. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + create_pr (`bool`, *optional*): + Whether or not to create a Pull Request with this commit. Defaults to `False`. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + Returns: + `str`: URL of the commit which updated the card metadata. + """ + + # If repo type is provided, otherwise, use the repo type of the card. + repo_type = repo_type or self.repo_type + + # Validate card before pushing to hub + self.validate(repo_type=repo_type) + + with SoftTemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) / constants.REPOCARD_NAME + tmp_path.write_text(str(self), encoding="utf-8") + url = upload_file( + path_or_fileobj=str(tmp_path), + path_in_repo=constants.REPOCARD_NAME, + repo_id=repo_id, + token=token, + repo_type=repo_type, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + revision=revision, + parent_commit=parent_commit, + ) + return url + + @classmethod + def from_template( + cls, + card_data: CardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a RepoCard from a template. By default, it uses the default template. + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.CardData`): + A huggingface_hub.CardData instance containing the metadata you want to include in the YAML + header of the repo card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.repocard.RepoCard`]: A RepoCard instance with the specified card data and content from the + template. + """ + if is_jinja_available(): + import jinja2 + else: + raise ImportError( + "Using RepoCard.from_template requires Jinja2 to be installed. Please" + " install it with `pip install Jinja2`." + ) + + kwargs = card_data.to_dict().copy() + kwargs.update(template_kwargs) # Template_kwargs have priority + + if template_path is not None: + template_str = Path(template_path).read_text() + if template_str is None: + template_str = Path(cls.default_template_path).read_text() + template = jinja2.Template(template_str) + content = template.render(card_data=card_data.to_yaml(), **kwargs) + return cls(content) + + +class ModelCard(RepoCard): + card_data_class = ModelCardData + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "model" + + @classmethod + def from_template( # type: ignore # violates Liskov property but easier to use + cls, + card_data: ModelCardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a ModelCard from a template. By default, it uses the default template, which can be found here: + https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.ModelCardData`): + A huggingface_hub.ModelCardData instance containing the metadata you want to include in the YAML + header of the model card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.ModelCard`]: A ModelCard instance with the specified card data and content from the + template. + + Example: + ```python + >>> from huggingface_hub import ModelCard, ModelCardData, EvalResult + + >>> # Using the Default Template + >>> card_data = ModelCardData( + ... language='en', + ... license='mit', + ... library_name='timm', + ... tags=['image-classification', 'resnet'], + ... datasets=['beans'], + ... metrics=['accuracy'], + ... ) + >>> card = ModelCard.from_template( + ... card_data, + ... model_description='This model does x + y...' + ... ) + + >>> # Including Evaluation Results + >>> card_data = ModelCardData( + ... language='en', + ... tags=['image-classification', 'resnet'], + ... eval_results=[ + ... EvalResult( + ... task_type='image-classification', + ... dataset_type='beans', + ... dataset_name='Beans', + ... metric_type='accuracy', + ... metric_value=0.9, + ... ), + ... ], + ... model_name='my-cool-model', + ... ) + >>> card = ModelCard.from_template(card_data) + + >>> # Using a Custom Template + >>> card_data = ModelCardData( + ... language='en', + ... tags=['image-classification', 'resnet'] + ... ) + >>> card = ModelCard.from_template( + ... card_data=card_data, + ... template_path='./src/huggingface_hub/templates/modelcard_template.md', + ... custom_template_var='custom value', # will be replaced in template if it exists + ... ) + + ``` + """ + return super().from_template(card_data, template_path, template_str, **template_kwargs) + + +class DatasetCard(RepoCard): + card_data_class = DatasetCardData + default_template_path = TEMPLATE_DATASETCARD_PATH + repo_type = "dataset" + + @classmethod + def from_template( # type: ignore # violates Liskov property but easier to use + cls, + card_data: DatasetCardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here: + https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.DatasetCardData`): + A huggingface_hub.DatasetCardData instance containing the metadata you want to include in the YAML + header of the dataset card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.DatasetCard`]: A DatasetCard instance with the specified card data and content from the + template. + + Example: + ```python + >>> from huggingface_hub import DatasetCard, DatasetCardData + + >>> # Using the Default Template + >>> card_data = DatasetCardData( + ... language='en', + ... license='mit', + ... annotations_creators='crowdsourced', + ... task_categories=['text-classification'], + ... task_ids=['sentiment-classification', 'text-scoring'], + ... multilinguality='monolingual', + ... pretty_name='My Text Classification Dataset', + ... ) + >>> card = DatasetCard.from_template( + ... card_data, + ... pretty_name=card_data.pretty_name, + ... ) + + >>> # Using a Custom Template + >>> card_data = DatasetCardData( + ... language='en', + ... license='mit', + ... ) + >>> card = DatasetCard.from_template( + ... card_data=card_data, + ... template_path='./src/huggingface_hub/templates/datasetcard_template.md', + ... custom_template_var='custom value', # will be replaced in template if it exists + ... ) + + ``` + """ + return super().from_template(card_data, template_path, template_str, **template_kwargs) + + +class SpaceCard(RepoCard): + card_data_class = SpaceCardData + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "space" + + +def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722 + """Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines. + + Uses same implementation as in Hub server, keep it in sync. + + Returns: + str: The detected line ending of the string. + """ + cr = content.count("\r") + lf = content.count("\n") + crlf = content.count("\r\n") + if cr + lf == 0: + return None + if crlf == cr and crlf == lf: + return "\r\n" + if cr > lf: + return "\r" + else: + return "\n" + + +def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]: + content = Path(local_path).read_text() + match = REGEX_YAML_BLOCK.search(content) + if match: + yaml_block = match.group(2) + data = yaml.safe_load(yaml_block) + if data is None or isinstance(data, dict): + return data + raise ValueError("repo card metadata block should be a dict") + else: + return None + + +def metadata_save(local_path: Union[str, Path], data: Dict) -> None: + """ + Save the metadata dict in the upper YAML part Trying to preserve newlines as + in the existing file. Docs about open() with newline="" parameter: + https://docs.python.org/3/library/functions.html?highlight=open#open Does + not work with "^M" linebreaks, which are replaced by \n + """ + line_break = "\n" + content = "" + # try to detect existing newline character + if os.path.exists(local_path): + with open(local_path, "r", newline="", encoding="utf8") as readme: + content = readme.read() + if isinstance(readme.newlines, tuple): + line_break = readme.newlines[0] + elif isinstance(readme.newlines, str): + line_break = readme.newlines + + # creates a new file if it not + with open(local_path, "w", newline="", encoding="utf8") as readme: + data_yaml = yaml_dump(data, sort_keys=False, line_break=line_break) + # sort_keys: keep dict order + match = REGEX_YAML_BLOCK.search(content) + if match: + output = content[: match.start()] + f"---{line_break}{data_yaml}---{line_break}" + content[match.end() :] + else: + output = f"---{line_break}{data_yaml}---{line_break}{content}" + + readme.write(output) + readme.close() + + +def metadata_eval_result( + *, + model_pretty_name: str, + task_pretty_name: str, + task_id: str, + metrics_pretty_name: str, + metrics_id: str, + metrics_value: Any, + dataset_pretty_name: str, + dataset_id: str, + metrics_config: Optional[str] = None, + metrics_verified: bool = False, + dataset_config: Optional[str] = None, + dataset_split: Optional[str] = None, + dataset_revision: Optional[str] = None, + metrics_verification_token: Optional[str] = None, +) -> Dict: + """ + Creates a metadata dict with the result from a model evaluated on a dataset. + + Args: + model_pretty_name (`str`): + The name of the model in natural language. + task_pretty_name (`str`): + The name of a task in natural language. + task_id (`str`): + Example: automatic-speech-recognition. A task id. + metrics_pretty_name (`str`): + A name for the metric in natural language. Example: Test WER. + metrics_id (`str`): + Example: wer. A metric id from https://hf.co/metrics. + metrics_value (`Any`): + The value from the metric. Example: 20.0 or "20.0 ± 1.2". + dataset_pretty_name (`str`): + The name of the dataset in natural language. + dataset_id (`str`): + Example: common_voice. A dataset id from https://hf.co/datasets. + metrics_config (`str`, *optional*): + The name of the metric configuration used in `load_metric()`. + Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + metrics_verified (`bool`, *optional*, defaults to `False`): + Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + dataset_config (`str`, *optional*): + Example: fr. The name of the dataset configuration used in `load_dataset()`. + dataset_split (`str`, *optional*): + Example: test. The name of the dataset split used in `load_dataset()`. + dataset_revision (`str`, *optional*): + Example: 5503434ddd753f426f4b38109466949a1217c2bb. The name of the dataset dataset revision + used in `load_dataset()`. + metrics_verification_token (`bool`, *optional*): + A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + + Returns: + `dict`: a metadata dict with the result from a model evaluated on a dataset. + + Example: + ```python + >>> from huggingface_hub import metadata_eval_result + >>> results = metadata_eval_result( + ... model_pretty_name="RoBERTa fine-tuned on ReactionGIF", + ... task_pretty_name="Text Classification", + ... task_id="text-classification", + ... metrics_pretty_name="Accuracy", + ... metrics_id="accuracy", + ... metrics_value=0.2662102282047272, + ... dataset_pretty_name="ReactionJPEG", + ... dataset_id="julien-c/reactionjpeg", + ... dataset_config="default", + ... dataset_split="test", + ... ) + >>> results == { + ... 'model-index': [ + ... { + ... 'name': 'RoBERTa fine-tuned on ReactionGIF', + ... 'results': [ + ... { + ... 'task': { + ... 'type': 'text-classification', + ... 'name': 'Text Classification' + ... }, + ... 'dataset': { + ... 'name': 'ReactionJPEG', + ... 'type': 'julien-c/reactionjpeg', + ... 'config': 'default', + ... 'split': 'test' + ... }, + ... 'metrics': [ + ... { + ... 'type': 'accuracy', + ... 'value': 0.2662102282047272, + ... 'name': 'Accuracy', + ... 'verified': False + ... } + ... ] + ... } + ... ] + ... } + ... ] + ... } + True + + ``` + """ + + return { + "model-index": eval_results_to_model_index( + model_name=model_pretty_name, + eval_results=[ + EvalResult( + task_name=task_pretty_name, + task_type=task_id, + metric_name=metrics_pretty_name, + metric_type=metrics_id, + metric_value=metrics_value, + dataset_name=dataset_pretty_name, + dataset_type=dataset_id, + metric_config=metrics_config, + verified=metrics_verified, + verify_token=metrics_verification_token, + dataset_config=dataset_config, + dataset_split=dataset_split, + dataset_revision=dataset_revision, + ) + ], + ) + } + + +@validate_hf_hub_args +def metadata_update( + repo_id: str, + metadata: Dict, + *, + repo_type: Optional[str] = None, + overwrite: bool = False, + token: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + revision: Optional[str] = None, + create_pr: bool = False, + parent_commit: Optional[str] = None, +) -> str: + """ + Updates the metadata in the README.md of a repository on the Hugging Face Hub. + If the README.md file doesn't exist yet, a new one is created with metadata and an + the default ModelCard or DatasetCard template. For `space` repo, an error is thrown + as a Space cannot exist without a `README.md` file. + + Args: + repo_id (`str`): + The name of the repository. + metadata (`dict`): + A dictionary containing the metadata to be updated. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if updating to a dataset or space, + `None` or `"model"` if updating to a model. Default is `None`. + overwrite (`bool`, *optional*, defaults to `False`): + If set to `True` an existing field can be overwritten, otherwise + attempting to overwrite an existing field will cause an error. + token (`str`, *optional*): + The Hugging Face authentication token. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Update metadata with huggingface_hub"` + commit_description (`str` *optional*) + The description of the generated commit + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the + `"main"` branch. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `revision` with that commit. + Defaults to `False`. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + Returns: + `str`: URL of the commit which updated the card metadata. + + Example: + ```python + >>> from huggingface_hub import metadata_update + >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF', + ... 'results': [{'dataset': {'name': 'ReactionGIF', + ... 'type': 'julien-c/reactiongif'}, + ... 'metrics': [{'name': 'Recall', + ... 'type': 'recall', + ... 'value': 0.7762102282047272}], + ... 'task': {'name': 'Text Classification', + ... 'type': 'text-classification'}}]}]} + >>> url = metadata_update("hf-internal-testing/reactiongif-roberta-card", metadata) + + ``` + """ + commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub" + + # Card class given repo_type + card_class: Type[RepoCard] + if repo_type is None or repo_type == "model": + card_class = ModelCard + elif repo_type == "dataset": + card_class = DatasetCard + elif repo_type == "space": + card_class = RepoCard + else: + raise ValueError(f"Unknown repo_type: {repo_type}") + + # Either load repo_card from the Hub or create an empty one. + # NOTE: Will not create the repo if it doesn't exist. + try: + card = card_class.load(repo_id, token=token, repo_type=repo_type) + except EntryNotFoundError: + if repo_type == "space": + raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.") + + # Initialize a ModelCard or DatasetCard from default template and no data. + card = card_class.from_template(CardData()) + + for key, value in metadata.items(): + if key == "model-index": + # if the new metadata doesn't include a name, either use existing one or repo name + if "name" not in value[0]: + value[0]["name"] = getattr(card, "model_name", repo_id) + model_name, new_results = model_index_to_eval_results(value) + if card.data.eval_results is None: + card.data.eval_results = new_results + card.data.model_name = model_name + else: + existing_results = card.data.eval_results + + # Iterate over new results + # Iterate over existing results + # If both results describe the same metric but value is different: + # If overwrite=True: overwrite the metric value + # Else: raise ValueError + # Else: append new result to existing ones. + for new_result in new_results: + result_found = False + for existing_result in existing_results: + if new_result.is_equal_except_value(existing_result): + if new_result != existing_result and not overwrite: + raise ValueError( + "You passed a new value for the existing metric" + f" 'name: {new_result.metric_name}, type: " + f"{new_result.metric_type}'. Set `overwrite=True`" + " to overwrite existing metrics." + ) + result_found = True + existing_result.metric_value = new_result.metric_value + if existing_result.verified is True: + existing_result.verify_token = new_result.verify_token + if not result_found: + card.data.eval_results.append(new_result) + else: + # Any metadata that is not a result metric + if card.data.get(key) is not None and not overwrite and card.data.get(key) != value: + raise ValueError( + f"You passed a new value for the existing meta data field '{key}'." + " Set `overwrite=True` to overwrite existing metadata." + ) + else: + card.data[key] = value + + return card.push_to_hub( + repo_id, + token=token, + repo_type=repo_type, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + revision=revision, + parent_commit=parent_commit, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/repocard_data.py b/phivenv/Lib/site-packages/huggingface_hub/repocard_data.py new file mode 100644 index 0000000000000000000000000000000000000000..62215f2274e482d4ed69a1d6deeafdf34fc5a6a4 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/repocard_data.py @@ -0,0 +1,770 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +from huggingface_hub.utils import logging, yaml_dump + + +logger = logging.get_logger(__name__) + + +@dataclass +class EvalResult: + """ + Flattened representation of individual evaluation results found in model-index of Model Cards. + + For more information on the model-index spec, see https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1. + + Args: + task_type (`str`): + The task identifier. Example: "image-classification". + dataset_type (`str`): + The dataset identifier. Example: "common_voice". Use dataset id from https://hf.co/datasets. + dataset_name (`str`): + A pretty name for the dataset. Example: "Common Voice (French)". + metric_type (`str`): + The metric identifier. Example: "wer". Use metric id from https://hf.co/metrics. + metric_value (`Any`): + The metric value. Example: 0.9 or "20.0 ± 1.2". + task_name (`str`, *optional*): + A pretty name for the task. Example: "Speech Recognition". + dataset_config (`str`, *optional*): + The name of the dataset configuration used in `load_dataset()`. + Example: fr in `load_dataset("common_voice", "fr")`. See the `datasets` docs for more info: + https://hf.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name + dataset_split (`str`, *optional*): + The split used in `load_dataset()`. Example: "test". + dataset_revision (`str`, *optional*): + The revision (AKA Git Sha) of the dataset used in `load_dataset()`. + Example: 5503434ddd753f426f4b38109466949a1217c2bb + dataset_args (`Dict[str, Any]`, *optional*): + The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}` + metric_name (`str`, *optional*): + A pretty name for the metric. Example: "Test WER". + metric_config (`str`, *optional*): + The name of the metric configuration used in `load_metric()`. + Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations + metric_args (`Dict[str, Any]`, *optional*): + The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4 + verified (`bool`, *optional*): + Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + verify_token (`str`, *optional*): + A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + source_name (`str`, *optional*): + The name of the source of the evaluation result. Example: "Open LLM Leaderboard". + source_url (`str`, *optional*): + The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard". + """ + + # Required + + # The task identifier + # Example: automatic-speech-recognition + task_type: str + + # The dataset identifier + # Example: common_voice. Use dataset id from https://hf.co/datasets + dataset_type: str + + # A pretty name for the dataset. + # Example: Common Voice (French) + dataset_name: str + + # The metric identifier + # Example: wer. Use metric id from https://hf.co/metrics + metric_type: str + + # Value of the metric. + # Example: 20.0 or "20.0 ± 1.2" + metric_value: Any + + # Optional + + # A pretty name for the task. + # Example: Speech Recognition + task_name: Optional[str] = None + + # The name of the dataset configuration used in `load_dataset()`. + # Example: fr in `load_dataset("common_voice", "fr")`. + # See the `datasets` docs for more info: + # https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name + dataset_config: Optional[str] = None + + # The split used in `load_dataset()`. + # Example: test + dataset_split: Optional[str] = None + + # The revision (AKA Git Sha) of the dataset used in `load_dataset()`. + # Example: 5503434ddd753f426f4b38109466949a1217c2bb + dataset_revision: Optional[str] = None + + # The arguments passed during `Metric.compute()`. + # Example for `bleu`: max_order: 4 + dataset_args: Optional[Dict[str, Any]] = None + + # A pretty name for the metric. + # Example: Test WER + metric_name: Optional[str] = None + + # The name of the metric configuration used in `load_metric()`. + # Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + # See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations + metric_config: Optional[str] = None + + # The arguments passed during `Metric.compute()`. + # Example for `bleu`: max_order: 4 + metric_args: Optional[Dict[str, Any]] = None + + # Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + verified: Optional[bool] = None + + # A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + verify_token: Optional[str] = None + + # The name of the source of the evaluation result. + # Example: Open LLM Leaderboard + source_name: Optional[str] = None + + # The URL of the source of the evaluation result. + # Example: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard + source_url: Optional[str] = None + + @property + def unique_identifier(self) -> tuple: + """Returns a tuple that uniquely identifies this evaluation.""" + return ( + self.task_type, + self.dataset_type, + self.dataset_config, + self.dataset_split, + self.dataset_revision, + ) + + def is_equal_except_value(self, other: "EvalResult") -> bool: + """ + Return True if `self` and `other` describe exactly the same metric but with a + different value. + """ + for key, _ in self.__dict__.items(): + if key == "metric_value": + continue + # For metrics computed by Hugging Face's evaluation service, `verify_token` is derived from `metric_value`, + # so we exclude it here in the comparison. + if key != "verify_token" and getattr(self, key) != getattr(other, key): + return False + return True + + def __post_init__(self) -> None: + if self.source_name is not None and self.source_url is None: + raise ValueError("If `source_name` is provided, `source_url` must also be provided.") + + +@dataclass +class CardData: + """Structure containing metadata from a RepoCard. + + [`CardData`] is the parent class of [`ModelCardData`] and [`DatasetCardData`]. + + Metadata can be exported as a dictionary or YAML. Export can be customized to alter the representation of the data + (example: flatten evaluation results). `CardData` behaves as a dictionary (can get, pop, set values) but do not + inherit from `dict` to allow this export step. + """ + + def __init__(self, ignore_metadata_errors: bool = False, **kwargs): + self.__dict__.update(kwargs) + + def to_dict(self): + """Converts CardData to a dict. + + Returns: + `dict`: CardData represented as a dictionary ready to be dumped to a YAML + block for inclusion in a README.md file. + """ + + data_dict = copy.deepcopy(self.__dict__) + self._to_dict(data_dict) + return {key: value for key, value in data_dict.items() if value is not None} + + def _to_dict(self, data_dict): + """Use this method in child classes to alter the dict representation of the data. Alter the dict in-place. + + Args: + data_dict (`dict`): The raw dict representation of the card data. + """ + pass + + def to_yaml(self, line_break=None, original_order: Optional[List[str]] = None) -> str: + """Dumps CardData to a YAML block for inclusion in a README.md file. + + Args: + line_break (str, *optional*): + The line break to use when dumping to yaml. + + Returns: + `str`: CardData represented as a YAML block. + """ + if original_order: + self.__dict__ = { + k: self.__dict__[k] + for k in original_order + list(set(self.__dict__.keys()) - set(original_order)) + if k in self.__dict__ + } + return yaml_dump(self.to_dict(), sort_keys=False, line_break=line_break).strip() + + def __repr__(self): + return repr(self.__dict__) + + def __str__(self): + return self.to_yaml() + + def get(self, key: str, default: Any = None) -> Any: + """Get value for a given metadata key.""" + value = self.__dict__.get(key) + return default if value is None else value + + def pop(self, key: str, default: Any = None) -> Any: + """Pop value for a given metadata key.""" + return self.__dict__.pop(key, default) + + def __getitem__(self, key: str) -> Any: + """Get value for a given metadata key.""" + return self.__dict__[key] + + def __setitem__(self, key: str, value: Any) -> None: + """Set value for a given metadata key.""" + self.__dict__[key] = value + + def __contains__(self, key: str) -> bool: + """Check if a given metadata key is set.""" + return key in self.__dict__ + + def __len__(self) -> int: + """Return the number of metadata keys set.""" + return len(self.__dict__) + + +def _validate_eval_results( + eval_results: Optional[Union[EvalResult, List[EvalResult]]], + model_name: Optional[str], +) -> List[EvalResult]: + if eval_results is None: + return [] + if isinstance(eval_results, EvalResult): + eval_results = [eval_results] + if not isinstance(eval_results, list) or not all(isinstance(r, EvalResult) for r in eval_results): + raise ValueError( + f"`eval_results` should be of type `EvalResult` or a list of `EvalResult`, got {type(eval_results)}." + ) + if model_name is None: + raise ValueError("Passing `eval_results` requires `model_name` to be set.") + return eval_results + + +class ModelCardData(CardData): + """Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + Args: + base_model (`str` or `List[str]`, *optional*): + The identifier of the base model from which the model derives. This is applicable for example if your model is a + fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs + if your model derives from multiple models). Defaults to None. + datasets (`Union[str, List[str]]`, *optional*): + Dataset or list of datasets that were used to train this model. Should be a dataset ID + found on https://hf.co/datasets. Defaults to None. + eval_results (`Union[List[EvalResult], EvalResult]`, *optional*): + List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided, + `model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`. + language (`Union[str, List[str]]`, *optional*): + Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or + 639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`. + library_name (`str`, *optional*): + Name of library used by this model. Example: keras or any library from + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts. + Defaults to None. + license (`str`, *optional*): + License of this model. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. Defaults to None. + license_name (`str`, *optional*): + Name of the license of this model. Defaults to None. To be used in conjunction with `license_link`. + Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a name. In that case, use `license` instead. + license_link (`str`, *optional*): + Link to the license of this model. Defaults to None. To be used in conjunction with `license_name`. + Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a link. In that case, use `license` instead. + metrics (`List[str]`, *optional*): + List of metrics used to evaluate this model. Should be a metric name that can be found + at https://hf.co/metrics. Example: 'accuracy'. Defaults to None. + model_name (`str`, *optional*): + A name for this model. It is used along with + `eval_results` to construct the `model-index` within the card's metadata. The name + you supply here is what will be used on PapersWithCode's leaderboards. If None is provided + then the repo name is used as a default. Defaults to None. + pipeline_tag (`str`, *optional*): + The pipeline tag associated with the model. Example: "text-classification". + tags (`List[str]`, *optional*): + List of tags to add to your model that can be used when filtering on the Hugging + Face Hub. Defaults to None. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + kwargs (`dict`, *optional*): + Additional metadata that will be added to the model card. Defaults to None. + + Example: + ```python + >>> from huggingface_hub import ModelCardData + >>> card_data = ModelCardData( + ... language="en", + ... license="mit", + ... library_name="timm", + ... tags=['image-classification', 'resnet'], + ... ) + >>> card_data.to_dict() + {'language': 'en', 'license': 'mit', 'library_name': 'timm', 'tags': ['image-classification', 'resnet']} + + ``` + """ + + def __init__( + self, + *, + base_model: Optional[Union[str, List[str]]] = None, + datasets: Optional[Union[str, List[str]]] = None, + eval_results: Optional[List[EvalResult]] = None, + language: Optional[Union[str, List[str]]] = None, + library_name: Optional[str] = None, + license: Optional[str] = None, + license_name: Optional[str] = None, + license_link: Optional[str] = None, + metrics: Optional[List[str]] = None, + model_name: Optional[str] = None, + pipeline_tag: Optional[str] = None, + tags: Optional[List[str]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.base_model = base_model + self.datasets = datasets + self.eval_results = eval_results + self.language = language + self.library_name = library_name + self.license = license + self.license_name = license_name + self.license_link = license_link + self.metrics = metrics + self.model_name = model_name + self.pipeline_tag = pipeline_tag + self.tags = _to_unique_list(tags) + + model_index = kwargs.pop("model-index", None) + if model_index: + try: + model_name, eval_results = model_index_to_eval_results(model_index) + self.model_name = model_name + self.eval_results = eval_results + except (KeyError, TypeError) as error: + if ignore_metadata_errors: + logger.warning("Invalid model-index. Not loading eval results into CardData.") + else: + raise ValueError( + f"Invalid `model_index` in metadata cannot be parsed: {error.__class__} {error}. Pass" + " `ignore_metadata_errors=True` to ignore this error while loading a Model Card. Warning:" + " some information will be lost. Use it at your own risk." + ) + + super().__init__(**kwargs) + + if self.eval_results: + try: + self.eval_results = _validate_eval_results(self.eval_results, self.model_name) + except Exception as e: + if ignore_metadata_errors: + logger.warning(f"Failed to validate eval_results: {e}. Not loading eval results into CardData.") + else: + raise ValueError(f"Failed to validate eval_results: {e}") from e + + def _to_dict(self, data_dict): + """Format the internal data dict. In this case, we convert eval results to a valid model index""" + if self.eval_results is not None: + data_dict["model-index"] = eval_results_to_model_index(self.model_name, self.eval_results) + del data_dict["eval_results"], data_dict["model_name"] + + +class DatasetCardData(CardData): + """Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + Args: + language (`List[str]`, *optional*): + Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or + 639-3 code (two/three letters), or a special value like "code", "multilingual". + license (`Union[str, List[str]]`, *optional*): + License(s) of this dataset. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. + annotations_creators (`Union[str, List[str]]`, *optional*): + How the annotations for the dataset were created. + Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'. + language_creators (`Union[str, List[str]]`, *optional*): + How the text-based data in the dataset was created. + Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other' + multilinguality (`Union[str, List[str]]`, *optional*): + Whether the dataset is multilingual. + Options are: 'monolingual', 'multilingual', 'translation', 'other'. + size_categories (`Union[str, List[str]]`, *optional*): + The number of examples in the dataset. Options are: 'n<1K', '1K1T', and 'other'. + source_datasets (`List[str]]`, *optional*): + Indicates whether the dataset is an original dataset or extended from another existing dataset. + Options are: 'original' and 'extended'. + task_categories (`Union[str, List[str]]`, *optional*): + What categories of task does the dataset support? + task_ids (`Union[str, List[str]]`, *optional*): + What specific tasks does the dataset support? + paperswithcode_id (`str`, *optional*): + ID of the dataset on PapersWithCode. + pretty_name (`str`, *optional*): + A more human-readable name for the dataset. (ex. "Cats vs. Dogs") + train_eval_index (`Dict`, *optional*): + A dictionary that describes the necessary spec for doing evaluation on the Hub. + If not provided, it will be gathered from the 'train-eval-index' key of the kwargs. + config_names (`Union[str, List[str]]`, *optional*): + A list of the available dataset configs for the dataset. + """ + + def __init__( + self, + *, + language: Optional[Union[str, List[str]]] = None, + license: Optional[Union[str, List[str]]] = None, + annotations_creators: Optional[Union[str, List[str]]] = None, + language_creators: Optional[Union[str, List[str]]] = None, + multilinguality: Optional[Union[str, List[str]]] = None, + size_categories: Optional[Union[str, List[str]]] = None, + source_datasets: Optional[List[str]] = None, + task_categories: Optional[Union[str, List[str]]] = None, + task_ids: Optional[Union[str, List[str]]] = None, + paperswithcode_id: Optional[str] = None, + pretty_name: Optional[str] = None, + train_eval_index: Optional[Dict] = None, + config_names: Optional[Union[str, List[str]]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.annotations_creators = annotations_creators + self.language_creators = language_creators + self.language = language + self.license = license + self.multilinguality = multilinguality + self.size_categories = size_categories + self.source_datasets = source_datasets + self.task_categories = task_categories + self.task_ids = task_ids + self.paperswithcode_id = paperswithcode_id + self.pretty_name = pretty_name + self.config_names = config_names + + # TODO - maybe handle this similarly to EvalResult? + self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None) + super().__init__(**kwargs) + + def _to_dict(self, data_dict): + data_dict["train-eval-index"] = data_dict.pop("train_eval_index") + + +class SpaceCardData(CardData): + """Space Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + To get an exhaustive reference of Spaces configuration, please visit https://huggingface.co/docs/hub/spaces-config-reference#spaces-configuration-reference. + + Args: + title (`str`, *optional*) + Title of the Space. + sdk (`str`, *optional*) + SDK of the Space (one of `gradio`, `streamlit`, `docker`, or `static`). + sdk_version (`str`, *optional*) + Version of the used SDK (if Gradio/Streamlit sdk). + python_version (`str`, *optional*) + Python version used in the Space (if Gradio/Streamlit sdk). + app_file (`str`, *optional*) + Path to your main application file (which contains either gradio or streamlit Python code, or static html code). + Path is relative to the root of the repository. + app_port (`str`, *optional*) + Port on which your application is running. Used only if sdk is `docker`. + license (`str`, *optional*) + License of this model. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. + duplicated_from (`str`, *optional*) + ID of the original Space if this is a duplicated Space. + models (List[`str`], *optional*) + List of models related to this Space. Should be a dataset ID found on https://hf.co/models. + datasets (`List[str]`, *optional*) + List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets. + tags (`List[str]`, *optional*) + List of tags to add to your Space that can be used when filtering on the Hub. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + kwargs (`dict`, *optional*): + Additional metadata that will be added to the space card. + + Example: + ```python + >>> from huggingface_hub import SpaceCardData + >>> card_data = SpaceCardData( + ... title="Dreambooth Training", + ... license="mit", + ... sdk="gradio", + ... duplicated_from="multimodalart/dreambooth-training" + ... ) + >>> card_data.to_dict() + {'title': 'Dreambooth Training', 'sdk': 'gradio', 'license': 'mit', 'duplicated_from': 'multimodalart/dreambooth-training'} + ``` + """ + + def __init__( + self, + *, + title: Optional[str] = None, + sdk: Optional[str] = None, + sdk_version: Optional[str] = None, + python_version: Optional[str] = None, + app_file: Optional[str] = None, + app_port: Optional[int] = None, + license: Optional[str] = None, + duplicated_from: Optional[str] = None, + models: Optional[List[str]] = None, + datasets: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.title = title + self.sdk = sdk + self.sdk_version = sdk_version + self.python_version = python_version + self.app_file = app_file + self.app_port = app_port + self.license = license + self.duplicated_from = duplicated_from + self.models = models + self.datasets = datasets + self.tags = _to_unique_list(tags) + super().__init__(**kwargs) + + +def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]: + """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects. + + A detailed spec of the model index can be found here: + https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 + + Args: + model_index (`List[Dict[str, Any]]`): + A model index data structure, likely coming from a README.md file on the + Hugging Face Hub. + + Returns: + model_name (`str`): + The name of the model as found in the model index. This is used as the + identifier for the model on leaderboards like PapersWithCode. + eval_results (`List[EvalResult]`): + A list of `huggingface_hub.EvalResult` objects containing the metrics + reported in the provided model_index. + + Example: + ```python + >>> from huggingface_hub.repocard_data import model_index_to_eval_results + >>> # Define a minimal model index + >>> model_index = [ + ... { + ... "name": "my-cool-model", + ... "results": [ + ... { + ... "task": { + ... "type": "image-classification" + ... }, + ... "dataset": { + ... "type": "beans", + ... "name": "Beans" + ... }, + ... "metrics": [ + ... { + ... "type": "accuracy", + ... "value": 0.9 + ... } + ... ] + ... } + ... ] + ... } + ... ] + >>> model_name, eval_results = model_index_to_eval_results(model_index) + >>> model_name + 'my-cool-model' + >>> eval_results[0].task_type + 'image-classification' + >>> eval_results[0].metric_type + 'accuracy' + + ``` + """ + + eval_results = [] + for elem in model_index: + name = elem["name"] + results = elem["results"] + for result in results: + task_type = result["task"]["type"] + task_name = result["task"].get("name") + dataset_type = result["dataset"]["type"] + dataset_name = result["dataset"]["name"] + dataset_config = result["dataset"].get("config") + dataset_split = result["dataset"].get("split") + dataset_revision = result["dataset"].get("revision") + dataset_args = result["dataset"].get("args") + source_name = result.get("source", {}).get("name") + source_url = result.get("source", {}).get("url") + + for metric in result["metrics"]: + metric_type = metric["type"] + metric_value = metric["value"] + metric_name = metric.get("name") + metric_args = metric.get("args") + metric_config = metric.get("config") + verified = metric.get("verified") + verify_token = metric.get("verifyToken") + + eval_result = EvalResult( + task_type=task_type, # Required + dataset_type=dataset_type, # Required + dataset_name=dataset_name, # Required + metric_type=metric_type, # Required + metric_value=metric_value, # Required + task_name=task_name, + dataset_config=dataset_config, + dataset_split=dataset_split, + dataset_revision=dataset_revision, + dataset_args=dataset_args, + metric_name=metric_name, + metric_args=metric_args, + metric_config=metric_config, + verified=verified, + verify_token=verify_token, + source_name=source_name, + source_url=source_url, + ) + eval_results.append(eval_result) + return name, eval_results + + +def _remove_none(obj): + """ + Recursively remove `None` values from a dict. Borrowed from: https://stackoverflow.com/a/20558778 + """ + if isinstance(obj, (list, tuple, set)): + return type(obj)(_remove_none(x) for x in obj if x is not None) + elif isinstance(obj, dict): + return type(obj)((_remove_none(k), _remove_none(v)) for k, v in obj.items() if k is not None and v is not None) + else: + return obj + + +def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) -> List[Dict[str, Any]]: + """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a + valid model-index that will be compatible with the format expected by the + Hugging Face Hub. + + Args: + model_name (`str`): + Name of the model (ex. "my-cool-model"). This is used as the identifier + for the model on leaderboards like PapersWithCode. + eval_results (`List[EvalResult]`): + List of `huggingface_hub.EvalResult` objects containing the metrics to be + reported in the model-index. + + Returns: + model_index (`List[Dict[str, Any]]`): The eval_results converted to a model-index. + + Example: + ```python + >>> from huggingface_hub.repocard_data import eval_results_to_model_index, EvalResult + >>> # Define minimal eval_results + >>> eval_results = [ + ... EvalResult( + ... task_type="image-classification", # Required + ... dataset_type="beans", # Required + ... dataset_name="Beans", # Required + ... metric_type="accuracy", # Required + ... metric_value=0.9, # Required + ... ) + ... ] + >>> eval_results_to_model_index("my-cool-model", eval_results) + [{'name': 'my-cool-model', 'results': [{'task': {'type': 'image-classification'}, 'dataset': {'name': 'Beans', 'type': 'beans'}, 'metrics': [{'type': 'accuracy', 'value': 0.9}]}]}] + + ``` + """ + + # Metrics are reported on a unique task-and-dataset basis. + # Here, we make a map of those pairs and the associated EvalResults. + task_and_ds_types_map: Dict[Any, List[EvalResult]] = defaultdict(list) + for eval_result in eval_results: + task_and_ds_types_map[eval_result.unique_identifier].append(eval_result) + + # Use the map from above to generate the model index data. + model_index_data = [] + for results in task_and_ds_types_map.values(): + # All items from `results` share same metadata + sample_result = results[0] + data = { + "task": { + "type": sample_result.task_type, + "name": sample_result.task_name, + }, + "dataset": { + "name": sample_result.dataset_name, + "type": sample_result.dataset_type, + "config": sample_result.dataset_config, + "split": sample_result.dataset_split, + "revision": sample_result.dataset_revision, + "args": sample_result.dataset_args, + }, + "metrics": [ + { + "type": result.metric_type, + "value": result.metric_value, + "name": result.metric_name, + "config": result.metric_config, + "args": result.metric_args, + "verified": result.verified, + "verifyToken": result.verify_token, + } + for result in results + ], + } + if sample_result.source_url is not None: + source = { + "url": sample_result.source_url, + } + if sample_result.source_name is not None: + source["name"] = sample_result.source_name + data["source"] = source + model_index_data.append(data) + + # TODO - Check if there cases where this list is longer than one? + # Finally, the model index itself is list of dicts. + model_index = [ + { + "name": model_name, + "results": model_index_data, + } + ] + return _remove_none(model_index) + + +def _to_unique_list(tags: Optional[List[str]]) -> Optional[List[str]]: + if tags is None: + return tags + unique_tags = [] # make tags unique + keep order explicitly + for tag in tags: + if tag not in unique_tags: + unique_tags.append(tag) + return unique_tags diff --git a/phivenv/Lib/site-packages/huggingface_hub/repository.py b/phivenv/Lib/site-packages/huggingface_hub/repository.py new file mode 100644 index 0000000000000000000000000000000000000000..d4a904f458b8c2cc03a786832dd3f97005be9e56 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/repository.py @@ -0,0 +1,1477 @@ +import atexit +import os +import re +import subprocess +import threading +import time +from contextlib import contextmanager +from pathlib import Path +from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union +from urllib.parse import urlparse + +from huggingface_hub import constants +from huggingface_hub.repocard import metadata_load, metadata_save + +from .hf_api import HfApi, repo_type_and_id_from_hf_id +from .lfs import LFS_MULTIPART_UPLOAD_COMMAND +from .utils import ( + SoftTemporaryDirectory, + get_token, + logging, + run_subprocess, + tqdm, + validate_hf_hub_args, +) +from .utils._deprecation import _deprecate_method + + +logger = logging.get_logger(__name__) + + +class CommandInProgress: + """ + Utility to follow commands launched asynchronously. + """ + + def __init__( + self, + title: str, + is_done_method: Callable, + status_method: Callable, + process: subprocess.Popen, + post_method: Optional[Callable] = None, + ): + self.title = title + self._is_done = is_done_method + self._status = status_method + self._process = process + self._stderr = "" + self._stdout = "" + self._post_method = post_method + + @property + def is_done(self) -> bool: + """ + Whether the process is done. + """ + result = self._is_done() + + if result and self._post_method is not None: + self._post_method() + self._post_method = None + + return result + + @property + def status(self) -> int: + """ + The exit code/status of the current action. Will return `0` if the + command has completed successfully, and a number between 1 and 255 if + the process errored-out. + + Will return -1 if the command is still ongoing. + """ + return self._status() + + @property + def failed(self) -> bool: + """ + Whether the process errored-out. + """ + return self.status > 0 + + @property + def stderr(self) -> str: + """ + The current output message on the standard error. + """ + if self._process.stderr is not None: + self._stderr += self._process.stderr.read() + return self._stderr + + @property + def stdout(self) -> str: + """ + The current output message on the standard output. + """ + if self._process.stdout is not None: + self._stdout += self._process.stdout.read() + return self._stdout + + def __repr__(self): + status = self.status + + if status == -1: + status = "running" + + return ( + f"[{self.title} command, status code: {status}," + f" {'in progress.' if not self.is_done else 'finished.'} PID:" + f" {self._process.pid}]" + ) + + +def is_git_repo(folder: Union[str, Path]) -> bool: + """ + Check if the folder is the root or part of a git repository + + Args: + folder (`str`): + The folder in which to run the command. + + Returns: + `bool`: `True` if the repository is part of a repository, `False` + otherwise. + """ + folder_exists = os.path.exists(os.path.join(folder, ".git")) + git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return folder_exists and git_branch.returncode == 0 + + +def is_local_clone(folder: Union[str, Path], remote_url: str) -> bool: + """ + Check if the folder is a local clone of the remote_url + + Args: + folder (`str` or `Path`): + The folder in which to run the command. + remote_url (`str`): + The url of a git repository. + + Returns: + `bool`: `True` if the repository is a local clone of the remote + repository specified, `False` otherwise. + """ + if not is_git_repo(folder): + return False + + remotes = run_subprocess("git remote -v", folder).stdout + + # Remove token for the test with remotes. + remote_url = re.sub(r"https://.*@", "https://", remote_url) + remotes = [re.sub(r"https://.*@", "https://", remote) for remote in remotes.split()] + return remote_url in remotes + + +def is_tracked_with_lfs(filename: Union[str, Path]) -> bool: + """ + Check if the file passed is tracked with git-lfs. + + Args: + filename (`str` or `Path`): + The filename to check. + + Returns: + `bool`: `True` if the file passed is tracked with git-lfs, `False` + otherwise. + """ + folder = Path(filename).parent + filename = Path(filename).name + + try: + p = run_subprocess("git check-attr -a".split() + [filename], folder) + attributes = p.stdout.strip() + except subprocess.CalledProcessError as exc: + if not is_git_repo(folder): + return False + else: + raise OSError(exc.stderr) + + if len(attributes) == 0: + return False + + found_lfs_tag = {"diff": False, "merge": False, "filter": False} + + for attribute in attributes.split("\n"): + for tag in found_lfs_tag.keys(): + if tag in attribute and "lfs" in attribute: + found_lfs_tag[tag] = True + + return all(found_lfs_tag.values()) + + +def is_git_ignored(filename: Union[str, Path]) -> bool: + """ + Check if file is git-ignored. Supports nested .gitignore files. + + Args: + filename (`str` or `Path`): + The filename to check. + + Returns: + `bool`: `True` if the file passed is ignored by `git`, `False` + otherwise. + """ + folder = Path(filename).parent + filename = Path(filename).name + + try: + p = run_subprocess("git check-ignore".split() + [filename], folder, check=False) + # Will return exit code 1 if not gitignored + is_ignored = not bool(p.returncode) + except subprocess.CalledProcessError as exc: + raise OSError(exc.stderr) + + return is_ignored + + +def is_binary_file(filename: Union[str, Path]) -> bool: + """ + Check if file is a binary file. + + Args: + filename (`str` or `Path`): + The filename to check. + + Returns: + `bool`: `True` if the file passed is a binary file, `False` otherwise. + """ + try: + with open(filename, "rb") as f: + content = f.read(10 * (1024**2)) # Read a maximum of 10MB + + # Code sample taken from the following stack overflow thread + # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391 + text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) + return bool(content.translate(None, text_chars)) + except UnicodeDecodeError: + return True + + +def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]: + """ + Returns a list of filenames that are to be staged. + + Args: + pattern (`str` or `Path`): + The pattern of filenames to check. Put `.` to get all files. + folder (`str` or `Path`): + The folder in which to run the command. + + Returns: + `List[str]`: List of files that are to be staged. + """ + try: + p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder) + if len(p.stdout.strip()): + files = p.stdout.strip().split("\n") + else: + files = [] + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return files + + +def is_tracked_upstream(folder: Union[str, Path]) -> bool: + """ + Check if the current checked-out branch is tracked upstream. + + Args: + folder (`str` or `Path`): + The folder in which to run the command. + + Returns: + `bool`: `True` if the current checked-out branch is tracked upstream, + `False` otherwise. + """ + try: + run_subprocess("git rev-parse --symbolic-full-name --abbrev-ref @{u}", folder) + return True + except subprocess.CalledProcessError as exc: + if "HEAD" in exc.stderr: + raise OSError("No branch checked out") + + return False + + +def commits_to_push(folder: Union[str, Path], upstream: Optional[str] = None) -> int: + """ + Check the number of commits that would be pushed upstream + + Args: + folder (`str` or `Path`): + The folder in which to run the command. + upstream (`str`, *optional*): + The name of the upstream repository with which the comparison should be + made. + + Returns: + `int`: Number of commits that would be pushed upstream were a `git + push` to proceed. + """ + try: + result = run_subprocess(f"git cherry -v {upstream or ''}", folder) + return len(result.stdout.split("\n")) - 1 + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + +class PbarT(TypedDict): + # Used to store an opened progress bar in `_lfs_log_progress` + bar: tqdm + past_bytes: int + + +@contextmanager +def _lfs_log_progress(): + """ + This is a context manager that will log the Git LFS progress of cleaning, + smudging, pulling and pushing. + """ + + if logger.getEffectiveLevel() >= logging.ERROR: + try: + yield + except Exception: + pass + return + + def output_progress(stopping_event: threading.Event): + """ + To be launched as a separate thread with an event meaning it should stop + the tail. + """ + # Key is tuple(state, filename), value is a dict(tqdm bar and a previous value) + pbars: Dict[Tuple[str, str], PbarT] = {} + + def close_pbars(): + for pbar in pbars.values(): + pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"]) + pbar["bar"].refresh() + pbar["bar"].close() + + def tail_file(filename) -> Iterator[str]: + """ + Creates a generator to be iterated through, which will return each + line one by one. Will stop tailing the file if the stopping_event is + set. + """ + with open(filename, "r") as file: + current_line = "" + while True: + if stopping_event.is_set(): + close_pbars() + break + + line_bit = file.readline() + if line_bit is not None and not len(line_bit.strip()) == 0: + current_line += line_bit + if current_line.endswith("\n"): + yield current_line + current_line = "" + else: + time.sleep(1) + + # If the file isn't created yet, wait for a few seconds before trying again. + # Can be interrupted with the stopping_event. + while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]): + if stopping_event.is_set(): + close_pbars() + return + + time.sleep(2) + + for line in tail_file(os.environ["GIT_LFS_PROGRESS"]): + try: + state, file_progress, byte_progress, filename = line.split() + except ValueError as error: + # Try/except to ease debugging. See https://github.com/huggingface/huggingface_hub/issues/1373. + raise ValueError(f"Cannot unpack LFS progress line:\n{line}") from error + description = f"{state.capitalize()} file {filename}" + + current_bytes, total_bytes = byte_progress.split("/") + current_bytes_int = int(current_bytes) + total_bytes_int = int(total_bytes) + + pbar = pbars.get((state, filename)) + if pbar is None: + # Initialize progress bar + pbars[(state, filename)] = { + "bar": tqdm( + desc=description, + initial=current_bytes_int, + total=total_bytes_int, + unit="B", + unit_scale=True, + unit_divisor=1024, + name="huggingface_hub.lfs_upload", + ), + "past_bytes": int(current_bytes), + } + else: + # Update progress bar + pbar["bar"].update(current_bytes_int - pbar["past_bytes"]) + pbar["past_bytes"] = current_bytes_int + + current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "") + + with SoftTemporaryDirectory() as tmpdir: + os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress") + logger.debug(f"Following progress in {os.environ['GIT_LFS_PROGRESS']}") + + exit_event = threading.Event() + x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True) + x.start() + + try: + yield + finally: + exit_event.set() + x.join() + + os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value + + +class Repository: + """ + Helper class to wrap the git and git-lfs commands. + + The aim is to facilitate interacting with huggingface.co hosted model or + dataset repos, though not a lot here (if any) is actually specific to + huggingface.co. + + + + [`Repository`] is deprecated in favor of the http-based alternatives implemented in + [`HfApi`]. Given its large adoption in legacy code, the complete removal of + [`Repository`] will only happen in release `v1.0`. For more details, please read + https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http. + + + """ + + command_queue: List[CommandInProgress] + + @validate_hf_hub_args + @_deprecate_method( + version="1.0", + message=( + "Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete" + " removal is only planned on next major release.\nFor more details, please read" + " https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http." + ), + ) + def __init__( + self, + local_dir: Union[str, Path], + clone_from: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[bool, str] = True, + git_user: Optional[str] = None, + git_email: Optional[str] = None, + revision: Optional[str] = None, + skip_lfs_files: bool = False, + client: Optional[HfApi] = None, + ): + """ + Instantiate a local clone of a git repo. + + If `clone_from` is set, the repo will be cloned from an existing remote repository. + If the remote repo does not exist, a `EnvironmentError` exception will be thrown. + Please create the remote repo first using [`create_repo`]. + + `Repository` uses the local git credentials by default. If explicitly set, the `token` + or the `git_user`/`git_email` pair will be used instead. + + Args: + local_dir (`str` or `Path`): + path (e.g. `'my_trained_model/'`) to the local directory, where + the `Repository` will be initialized. + clone_from (`str`, *optional*): + Either a repository url or `repo_id`. + Example: + - `"https://huggingface.co/philschmid/playground-tests"` + - `"philschmid/playground-tests"` + repo_type (`str`, *optional*): + To set when cloning a repo from a repo_id. Default is model. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `hf auth login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + git_user (`str`, *optional*): + will override the `git config user.name` for committing and + pushing files to the hub. + git_email (`str`, *optional*): + will override the `git config user.email` for committing and + pushing files to the hub. + revision (`str`, *optional*): + Revision to checkout after initializing the repository. If the + revision doesn't exist, a branch will be created with that + revision name from the default branch's current HEAD. + skip_lfs_files (`bool`, *optional*, defaults to `False`): + whether to skip git-LFS files or not. + client (`HfApi`, *optional*): + Instance of [`HfApi`] to use when calling the HF Hub API. A new + instance will be created if this is left to `None`. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If the remote repository set in `clone_from` does not exist. + """ + if isinstance(local_dir, Path): + local_dir = str(local_dir) + os.makedirs(local_dir, exist_ok=True) + self.local_dir = os.path.join(os.getcwd(), local_dir) + self._repo_type = repo_type + self.command_queue = [] + self.skip_lfs_files = skip_lfs_files + self.client = client if client is not None else HfApi() + + self.check_git_versions() + + if isinstance(token, str): + self.huggingface_token: Optional[str] = token + elif token is False: + self.huggingface_token = None + else: + # if `True` -> explicit use of the cached token + # if `None` -> implicit use of the cached token + self.huggingface_token = get_token() + + if clone_from is not None: + self.clone_from(repo_url=clone_from) + else: + if is_git_repo(self.local_dir): + logger.debug("[Repository] is a valid git repo") + else: + raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.") + + if self.huggingface_token is not None and (git_email is None or git_user is None): + user = self.client.whoami(self.huggingface_token) + + if git_email is None: + git_email = user.get("email") + + if git_user is None: + git_user = user.get("fullname") + + if git_user is not None or git_email is not None: + self.git_config_username_and_email(git_user, git_email) + + self.lfs_enable_largefiles() + self.git_credential_helper_store() + + if revision is not None: + self.git_checkout(revision, create_branch_ok=True) + + # This ensures that all commands exit before exiting the Python runtime. + # This will ensure all pushes register on the hub, even if other errors happen in subsequent operations. + atexit.register(self.wait_for_commands) + + @property + def current_branch(self) -> str: + """ + Returns the current checked out branch. + + Returns: + `str`: Current checked out branch. + """ + try: + result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return result + + def check_git_versions(self): + """ + Checks that `git` and `git-lfs` can be run. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `git` or `git-lfs` are not installed. + """ + try: + git_version = run_subprocess("git --version", self.local_dir).stdout.strip() + except FileNotFoundError: + raise EnvironmentError("Looks like you do not have git installed, please install.") + + try: + lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip() + except FileNotFoundError: + raise EnvironmentError( + "Looks like you do not have git-lfs installed, please install." + " You can install from https://git-lfs.github.com/." + " Then run `git lfs install` (you only have to do this once)." + ) + logger.info(git_version + "\n" + lfs_version) + + @validate_hf_hub_args + def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): + """ + Clone from a remote. If the folder already exists, will try to clone the + repository within it. + + If this folder is a git repository with linked history, will try to + update the repository. + + Args: + repo_url (`str`): + The URL from which to clone the repository + token (`Union[str, bool]`, *optional*): + Whether to use the authentication token. It can be: + - a string which is the token itself + - `False`, which would not use the authentication token + - `True`, which would fetch the authentication token from the + local folder and use it (you should be logged in for this to + work). + - `None`, which would retrieve the value of + `self.huggingface_token`. + + + + Raises the following error: + + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if an organization token (starts with "api_org") is passed. Use must use + your own personal access token (see https://hf.co/settings/tokens). + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if you are trying to clone the repository in a non-empty folder, or if the + `git` operations raise errors. + + + """ + token = ( + token # str -> use it + if isinstance(token, str) + else ( + None # `False` -> explicit no token + if token is False + else self.huggingface_token # `None` or `True` -> use default + ) + ) + if token is not None and token.startswith("api_org"): + raise ValueError( + "You must use your personal access token, not an Organization token" + " (see https://hf.co/settings/tokens)." + ) + + hub_url = self.client.endpoint + if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2): + repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url) + repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name + + if repo_type is not None: + self._repo_type = repo_type + + repo_url = hub_url + "/" + + if self._repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_url += constants.REPO_TYPES_URL_PREFIXES[self._repo_type] + + if token is not None: + # Add token in git url when provided + scheme = urlparse(repo_url).scheme + repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@") + + repo_url += repo_id + + # For error messages, it's cleaner to show the repo url without the token. + clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url) + try: + run_subprocess("git lfs install", self.local_dir) + + # checks if repository is initialized in a empty repository or in one with files + if len(os.listdir(self.local_dir)) == 0: + logger.warning(f"Cloning {clean_repo_url} into local empty directory.") + + with _lfs_log_progress(): + env = os.environ.copy() + + if self.skip_lfs_files: + env.update({"GIT_LFS_SKIP_SMUDGE": "1"}) + + run_subprocess( + # 'git lfs clone' is deprecated (will display a warning in the terminal) + # but we still use it as it provides a nicer UX when downloading large + # files (shows progress). + f"{'git clone' if self.skip_lfs_files else 'git lfs clone'} {repo_url} .", + self.local_dir, + env=env, + ) + else: + # Check if the folder is the root of a git repository + if not is_git_repo(self.local_dir): + raise EnvironmentError( + "Tried to clone a repository in a non-empty folder that isn't" + f" a git repository ('{self.local_dir}'). If you really want to" + f" do this, do it manually:\n cd {self.local_dir} && git init" + " && git remote add origin && git pull origin main\n or clone" + " repo to a new folder and move your existing files there" + " afterwards." + ) + + if is_local_clone(self.local_dir, repo_url): + logger.warning( + f"{self.local_dir} is already a clone of {clean_repo_url}." + " Make sure you pull the latest changes with" + " `repo.git_pull()`." + ) + else: + output = run_subprocess("git remote get-url origin", self.local_dir, check=False) + + error_msg = ( + f"Tried to clone {clean_repo_url} in an unrelated git" + " repository.\nIf you believe this is an error, please add" + f" a remote with the following URL: {clean_repo_url}." + ) + if output.returncode == 0: + clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout) + error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}" + raise EnvironmentError(error_msg) + + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None): + """ + Sets git username and email (only in the current repo). + + Args: + git_user (`str`, *optional*): + The username to register through `git`. + git_email (`str`, *optional*): + The email to register through `git`. + """ + try: + if git_user is not None: + run_subprocess("git config user.name".split() + [git_user], self.local_dir) + + if git_email is not None: + run_subprocess(f"git config user.email {git_email}".split(), self.local_dir) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_credential_helper_store(self): + """ + Sets the git credential helper to `store` + """ + try: + run_subprocess("git config credential.helper store", self.local_dir) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_head_hash(self) -> str: + """ + Get commit sha on top of HEAD. + + Returns: + `str`: The current checked out commit SHA. + """ + try: + p = run_subprocess("git rev-parse HEAD", self.local_dir) + return p.stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_remote_url(self) -> str: + """ + Get URL to origin remote. + + Returns: + `str`: The URL of the `origin` remote. + """ + try: + p = run_subprocess("git config --get remote.origin.url", self.local_dir) + url = p.stdout.strip() + # Strip basic auth info. + return re.sub(r"https://.*@", "https://", url) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_head_commit_url(self) -> str: + """ + Get URL to last commit on HEAD. We assume it's been pushed, and the url + scheme is the same one as for GitHub or HuggingFace. + + Returns: + `str`: The URL to the current checked-out commit. + """ + sha = self.git_head_hash() + url = self.git_remote_url() + if url.endswith("/"): + url = url[:-1] + return f"{url}/commit/{sha}" + + def list_deleted_files(self) -> List[str]: + """ + Returns a list of the files that are deleted in the working directory or + index. + + Returns: + `List[str]`: A list of files that have been deleted in the working + directory or index. + """ + try: + git_status = run_subprocess("git status -s", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if len(git_status) == 0: + return [] + + # Receives a status like the following + # D .gitignore + # D new_file.json + # AD new_file1.json + # ?? new_file2.json + # ?? new_file4.json + + # Strip each line of whitespaces + modified_files_statuses = [status.strip() for status in git_status.split("\n")] + + # Only keep files that are deleted using the D prefix + deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]] + + # Remove the D prefix and strip to keep only the relevant filename + deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses] + + return deleted_files + + def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False): + """ + Tell git-lfs to track files according to a pattern. + + Setting the `filename` argument to `True` will treat the arguments as + literal filenames, not as patterns. Any special glob characters in the + filename will be escaped when writing to the `.gitattributes` file. + + Args: + patterns (`Union[str, List[str]]`): + The pattern, or list of patterns, to track with git-lfs. + filename (`bool`, *optional*, defaults to `False`): + Whether to use the patterns as literal filenames. + """ + if isinstance(patterns, str): + patterns = [patterns] + try: + for pattern in patterns: + run_subprocess( + f"git lfs track {'--filename' if filename else ''} {pattern}", + self.local_dir, + ) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def lfs_untrack(self, patterns: Union[str, List[str]]): + """ + Tell git-lfs to untrack those files. + + Args: + patterns (`Union[str, List[str]]`): + The pattern, or list of patterns, to untrack with git-lfs. + """ + if isinstance(patterns, str): + patterns = [patterns] + try: + for pattern in patterns: + run_subprocess("git lfs untrack".split() + [pattern], self.local_dir) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def lfs_enable_largefiles(self): + """ + HF-specific. This enables upload support of files >5GB. + """ + try: + lfs_config = "git config lfs.customtransfer.multipart" + run_subprocess(f"{lfs_config}.path hf", self.local_dir) + run_subprocess( + f"{lfs_config}.args {LFS_MULTIPART_UPLOAD_COMMAND}", + self.local_dir, + ) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def auto_track_binary_files(self, pattern: str = ".") -> List[str]: + """ + Automatically track binary files with git-lfs. + + Args: + pattern (`str`, *optional*, defaults to "."): + The pattern with which to track files that are binary. + + Returns: + `List[str]`: List of filenames that are now tracked due to being + binary files + """ + files_to_be_tracked_with_lfs = [] + + deleted_files = self.list_deleted_files() + + for filename in files_to_be_staged(pattern, folder=self.local_dir): + if filename in deleted_files: + continue + + path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) + + if not (is_tracked_with_lfs(path_to_file) or is_git_ignored(path_to_file)): + size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) + + if size_in_mb >= 10: + logger.warning( + "Parsing a large file to check if binary or not. Tracking large" + " files using `repository.auto_track_large_files` is" + " recommended so as to not load the full file in memory." + ) + + is_binary = is_binary_file(path_to_file) + + if is_binary: + self.lfs_track(filename) + files_to_be_tracked_with_lfs.append(filename) + + # Cleanup the .gitattributes if files were deleted + self.lfs_untrack(deleted_files) + + return files_to_be_tracked_with_lfs + + def auto_track_large_files(self, pattern: str = ".") -> List[str]: + """ + Automatically track large files (files that weigh more than 10MBs) with + git-lfs. + + Args: + pattern (`str`, *optional*, defaults to "."): + The pattern with which to track files that are above 10MBs. + + Returns: + `List[str]`: List of filenames that are now tracked due to their + size. + """ + files_to_be_tracked_with_lfs = [] + + deleted_files = self.list_deleted_files() + + for filename in files_to_be_staged(pattern, folder=self.local_dir): + if filename in deleted_files: + continue + + path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) + size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) + + if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file): + self.lfs_track(filename) + files_to_be_tracked_with_lfs.append(filename) + + # Cleanup the .gitattributes if files were deleted + self.lfs_untrack(deleted_files) + + return files_to_be_tracked_with_lfs + + def lfs_prune(self, recent=False): + """ + git lfs prune + + Args: + recent (`bool`, *optional*, defaults to `False`): + Whether to prune files even if they were referenced by recent + commits. See the following + [link](https://github.com/git-lfs/git-lfs/blob/f3d43f0428a84fc4f1e5405b76b5a73ec2437e65/docs/man/git-lfs-prune.1.ronn#recent-files) + for more information. + """ + try: + with _lfs_log_progress(): + result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir) + logger.info(result.stdout) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_pull(self, rebase: bool = False, lfs: bool = False): + """ + git pull + + Args: + rebase (`bool`, *optional*, defaults to `False`): + Whether to rebase the current branch on top of the upstream + branch after fetching. + lfs (`bool`, *optional*, defaults to `False`): + Whether to fetch the LFS files too. This option only changes the + behavior when a repository was cloned without fetching the LFS + files; calling `repo.git_pull(lfs=True)` will then fetch the LFS + file from the remote repository. + """ + command = "git pull" if not lfs else "git lfs pull" + if rebase: + command += " --rebase" + try: + with _lfs_log_progress(): + result = run_subprocess(command, self.local_dir) + logger.info(result.stdout) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_add(self, pattern: str = ".", auto_lfs_track: bool = False): + """ + git add + + Setting the `auto_lfs_track` parameter to `True` will automatically + track files that are larger than 10MB with `git-lfs`. + + Args: + pattern (`str`, *optional*, defaults to "."): + The pattern with which to add files to staging. + auto_lfs_track (`bool`, *optional*, defaults to `False`): + Whether to automatically track large and binary files with + git-lfs. Any file over 10MB in size, or in binary format, will + be automatically tracked. + """ + if auto_lfs_track: + # Track files according to their size (>=10MB) + tracked_files = self.auto_track_large_files(pattern) + + # Read the remaining files and track them if they're binary + tracked_files.extend(self.auto_track_binary_files(pattern)) + + if tracked_files: + logger.warning( + f"Adding files tracked by Git LFS: {tracked_files}. This may take a" + " bit of time if the files are large." + ) + + try: + result = run_subprocess("git add -v".split() + [pattern], self.local_dir) + logger.info(f"Adding to index:\n{result.stdout}\n") + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_commit(self, commit_message: str = "commit files to HF hub"): + """ + git commit + + Args: + commit_message (`str`, *optional*, defaults to "commit files to HF hub"): + The message attributed to the commit. + """ + try: + result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir) + logger.info(f"Committed:\n{result.stdout}\n") + except subprocess.CalledProcessError as exc: + if len(exc.stderr) > 0: + raise EnvironmentError(exc.stderr) + else: + raise EnvironmentError(exc.stdout) + + def git_push( + self, + upstream: Optional[str] = None, + blocking: bool = True, + auto_lfs_prune: bool = False, + ) -> Union[str, Tuple[str, CommandInProgress]]: + """ + git push + + If used without setting `blocking`, will return url to commit on remote + repo. If used with `blocking=True`, will return a tuple containing the + url to commit and the command object to follow for information about the + process. + + Args: + upstream (`str`, *optional*): + Upstream to which this should push. If not specified, will push + to the lastly defined upstream or to the default one (`origin + main`). + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the push has + finished. Setting this to `False` will return an + `CommandInProgress` object which has an `is_done` property. This + property will be set to `True` when the push is finished. + auto_lfs_prune (`bool`, *optional*, defaults to `False`): + Whether to automatically prune files once they have been pushed + to the remote. + """ + command = "git push" + + if upstream: + command += f" --set-upstream {upstream}" + + number_of_commits = commits_to_push(self.local_dir, upstream) + + if number_of_commits > 1: + logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.") + if blocking: + logger.warning("The progress bars may be unreliable.") + + try: + with _lfs_log_progress(): + process = subprocess.Popen( + command.split(), + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + encoding="utf-8", + cwd=self.local_dir, + ) + + if blocking: + stdout, stderr = process.communicate() + return_code = process.poll() + process.kill() + + if len(stderr): + logger.warning(stderr) + + if return_code: + raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr) + + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if not blocking: + + def status_method(): + status = process.poll() + if status is None: + return -1 + else: + return status + + command_in_progress = CommandInProgress( + "push", + is_done_method=lambda: process.poll() is not None, + status_method=status_method, + process=process, + post_method=self.lfs_prune if auto_lfs_prune else None, + ) + + self.command_queue.append(command_in_progress) + + return self.git_head_commit_url(), command_in_progress + + if auto_lfs_prune: + self.lfs_prune() + + return self.git_head_commit_url() + + def git_checkout(self, revision: str, create_branch_ok: bool = False): + """ + git checkout a given revision + + Specifying `create_branch_ok` to `True` will create the branch to the + given revision if that revision doesn't exist. + + Args: + revision (`str`): + The revision to checkout. + create_branch_ok (`str`, *optional*, defaults to `False`): + Whether creating a branch named with the `revision` passed at + the current checked-out reference if `revision` isn't an + existing revision is allowed. + """ + try: + result = run_subprocess(f"git checkout {revision}", self.local_dir) + logger.warning(f"Checked out {revision} from {self.current_branch}.") + logger.warning(result.stdout) + except subprocess.CalledProcessError as exc: + if not create_branch_ok: + raise EnvironmentError(exc.stderr) + else: + try: + result = run_subprocess(f"git checkout -b {revision}", self.local_dir) + logger.warning( + f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`." + ) + logger.warning(result.stdout) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool: + """ + Check if a tag exists or not. + + Args: + tag_name (`str`): + The name of the tag to check. + remote (`str`, *optional*): + Whether to check if the tag exists on a remote. This parameter + should be the identifier of the remote. + + Returns: + `bool`: Whether the tag exists. + """ + if remote: + try: + result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return len(result) != 0 + else: + try: + git_tags = run_subprocess("git tag", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + git_tags = git_tags.split("\n") + return tag_name in git_tags + + def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool: + """ + Delete a tag, both local and remote, if it exists + + Args: + tag_name (`str`): + The tag name to delete. + remote (`str`, *optional*): + The remote on which to delete the tag. + + Returns: + `bool`: `True` if deleted, `False` if the tag didn't exist. + If remote is not passed, will just be updated locally + """ + delete_locally = True + delete_remotely = True + + if not self.tag_exists(tag_name): + delete_locally = False + + if not self.tag_exists(tag_name, remote=remote): + delete_remotely = False + + if delete_locally: + try: + run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if remote and delete_remotely: + try: + run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return True + + def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None): + """ + Add a tag at the current head and push it + + If remote is None, will just be updated locally + + If no message is provided, the tag will be lightweight. if a message is + provided, the tag will be annotated. + + Args: + tag_name (`str`): + The name of the tag to be added. + message (`str`, *optional*): + The message that accompanies the tag. The tag will turn into an + annotated tag if a message is passed. + remote (`str`, *optional*): + The remote on which to add the tag. + """ + if message: + tag_args = ["git", "tag", "-a", tag_name, "-m", message] + else: + tag_args = ["git", "tag", tag_name] + + try: + run_subprocess(tag_args, self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if remote: + try: + run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def is_repo_clean(self) -> bool: + """ + Return whether or not the git status is clean or not + + Returns: + `bool`: `True` if the git status is clean, `False` otherwise. + """ + try: + git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return len(git_status) == 0 + + def push_to_hub( + self, + commit_message: str = "commit files to HF hub", + blocking: bool = True, + clean_ok: bool = True, + auto_lfs_prune: bool = False, + ) -> Union[None, str, Tuple[str, CommandInProgress]]: + """ + Helper to add, commit, and push files to remote repository on the + HuggingFace Hub. Will automatically track large files (>10MB). + + Args: + commit_message (`str`): + Message to use for the commit. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has + finished. + clean_ok (`bool`, *optional*, defaults to `True`): + If True, this function will return None if the repo is + untouched. Default behavior is to fail because the git command + fails. + auto_lfs_prune (`bool`, *optional*, defaults to `False`): + Whether to automatically prune files once they have been pushed + to the remote. + """ + if clean_ok and self.is_repo_clean(): + logger.info("Repo currently clean. Ignoring push_to_hub") + return None + self.git_add(auto_lfs_track=True) + self.git_commit(commit_message) + return self.git_push( + upstream=f"origin {self.current_branch}", + blocking=blocking, + auto_lfs_prune=auto_lfs_prune, + ) + + @contextmanager + def commit( + self, + commit_message: str, + branch: Optional[str] = None, + track_large_files: bool = True, + blocking: bool = True, + auto_lfs_prune: bool = False, + ): + """ + Context manager utility to handle committing to a repository. This + automatically tracks large files (>10Mb) with git-lfs. Set the + `track_large_files` argument to `False` if you wish to ignore that + behavior. + + Args: + commit_message (`str`): + Message to use for the commit. + branch (`str`, *optional*): + The branch on which the commit will appear. This branch will be + checked-out before any operation. + track_large_files (`bool`, *optional*, defaults to `True`): + Whether to automatically track large files or not. Will do so by + default. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has + finished. + auto_lfs_prune (`bool`, defaults to `True`): + Whether to automatically prune files once they have been pushed + to the remote. + + Examples: + + ```python + >>> with Repository( + ... "text-files", + ... clone_from="/text-files", + ... token=True, + >>> ).commit("My first file :)"): + ... with open("file.txt", "w+") as f: + ... f.write(json.dumps({"hey": 8})) + + >>> import torch + + >>> model = torch.nn.Transformer() + >>> with Repository( + ... "torch-model", + ... clone_from="/torch-model", + ... token=True, + >>> ).commit("My cool model :)"): + ... torch.save(model.state_dict(), "model.pt") + ``` + + """ + + files_to_stage = files_to_be_staged(".", folder=self.local_dir) + + if len(files_to_stage): + files_in_msg = str(files_to_stage[:5])[:-1] + ", ...]" if len(files_to_stage) > 5 else str(files_to_stage) + logger.error( + "There exists some updated files in the local repository that are not" + f" committed: {files_in_msg}. This may lead to errors if checking out" + " a branch. These files and their modifications will be added to the" + " current commit." + ) + + if branch is not None: + self.git_checkout(branch, create_branch_ok=True) + + if is_tracked_upstream(self.local_dir): + logger.warning("Pulling changes ...") + self.git_pull(rebase=True) + else: + logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'") + + current_working_directory = os.getcwd() + os.chdir(os.path.join(current_working_directory, self.local_dir)) + + try: + yield self + finally: + self.git_add(auto_lfs_track=track_large_files) + + try: + self.git_commit(commit_message) + except OSError as e: + # If no changes are detected, there is nothing to commit. + if "nothing to commit" not in str(e): + raise e + + try: + self.git_push( + upstream=f"origin {self.current_branch}", + blocking=blocking, + auto_lfs_prune=auto_lfs_prune, + ) + except OSError as e: + # If no changes are detected, there is nothing to commit. + if "could not read Username" in str(e): + raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e + else: + raise e + + os.chdir(current_working_directory) + + def repocard_metadata_load(self) -> Optional[Dict]: + filepath = os.path.join(self.local_dir, constants.REPOCARD_NAME) + if os.path.isfile(filepath): + return metadata_load(filepath) + return None + + def repocard_metadata_save(self, data: Dict) -> None: + return metadata_save(os.path.join(self.local_dir, constants.REPOCARD_NAME), data) + + @property + def commands_failed(self): + """ + Returns the asynchronous commands that failed. + """ + return [c for c in self.command_queue if c.status > 0] + + @property + def commands_in_progress(self): + """ + Returns the asynchronous commands that are currently in progress. + """ + return [c for c in self.command_queue if not c.is_done] + + def wait_for_commands(self): + """ + Blocking method: blocks all subsequent execution until all commands have + been processed. + """ + index = 0 + for command_failed in self.commands_failed: + logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.") + logger.error(command_failed.stderr) + + while self.commands_in_progress: + if index % 10 == 0: + logger.warning( + f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}." + ) + + index += 1 + + time.sleep(1)