Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__main__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/api.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/cd.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/constant.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/legacy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/md.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/models.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/version.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/cli/__init__.py +8 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/cli/__main__.py +321 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__main__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/__init__.py +1434 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_commit_api.py +758 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py +353 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py +402 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_local_folder.py +432 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_login.py +520 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py +307 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_space_api.py +160 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py +194 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_upload_large_folder.py +621 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py +137 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py +386 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/community.py +355 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/constants.py +229 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/errors.py +329 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/fastai_utils.py +425 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/file_download.py +1621 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/hf_api.py +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/hf_file_system.py +1140 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/hub_mixin.py +836 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_client.py +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_common.py +446 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +115 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/fill_mask.py +48 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_text.py +102 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/object_detection.py +59 -0
.gitattributes
CHANGED
|
@@ -122,3 +122,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 122 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
|
| 123 |
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 124 |
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 122 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
|
| 123 |
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 124 |
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__main__.cpython-311.pyc
ADDED
|
Binary file (394 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/api.cpython-311.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/cd.cpython-311.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/constant.cpython-311.pyc
ADDED
|
Binary file (43.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/legacy.cpython-311.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/md.cpython-311.pyc
ADDED
|
Binary file (27.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/version.cpython-311.pyc
ADDED
|
Binary file (400 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from .__main__ import cli_detect, query_yes_no
|
| 4 |
+
|
| 5 |
+
__all__ = (
|
| 6 |
+
"cli_detect",
|
| 7 |
+
"query_yes_no",
|
| 8 |
+
)
|
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__main__.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
from json import dumps
|
| 6 |
+
from os.path import abspath, basename, dirname, join, realpath
|
| 7 |
+
from platform import python_version
|
| 8 |
+
from unicodedata import unidata_version
|
| 9 |
+
|
| 10 |
+
import charset_normalizer.md as md_module
|
| 11 |
+
from charset_normalizer import from_fp
|
| 12 |
+
from charset_normalizer.models import CliDetectionResult
|
| 13 |
+
from charset_normalizer.version import __version__
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def query_yes_no(question: str, default: str = "yes") -> bool:
|
| 17 |
+
"""Ask a yes/no question via input() and return their answer.
|
| 18 |
+
|
| 19 |
+
"question" is a string that is presented to the user.
|
| 20 |
+
"default" is the presumed answer if the user just hits <Enter>.
|
| 21 |
+
It must be "yes" (the default), "no" or None (meaning
|
| 22 |
+
an answer is required of the user).
|
| 23 |
+
|
| 24 |
+
The "answer" return value is True for "yes" or False for "no".
|
| 25 |
+
|
| 26 |
+
Credit goes to (c) https://stackoverflow.com/questions/3041986/apt-command-line-interface-like-yes-no-input
|
| 27 |
+
"""
|
| 28 |
+
valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
|
| 29 |
+
if default is None:
|
| 30 |
+
prompt = " [y/n] "
|
| 31 |
+
elif default == "yes":
|
| 32 |
+
prompt = " [Y/n] "
|
| 33 |
+
elif default == "no":
|
| 34 |
+
prompt = " [y/N] "
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError("invalid default answer: '%s'" % default)
|
| 37 |
+
|
| 38 |
+
while True:
|
| 39 |
+
sys.stdout.write(question + prompt)
|
| 40 |
+
choice = input().lower()
|
| 41 |
+
if default is not None and choice == "":
|
| 42 |
+
return valid[default]
|
| 43 |
+
elif choice in valid:
|
| 44 |
+
return valid[choice]
|
| 45 |
+
else:
|
| 46 |
+
sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def cli_detect(argv: list[str] | None = None) -> int:
|
| 50 |
+
"""
|
| 51 |
+
CLI assistant using ARGV and ArgumentParser
|
| 52 |
+
:param argv:
|
| 53 |
+
:return: 0 if everything is fine, anything else equal trouble
|
| 54 |
+
"""
|
| 55 |
+
parser = argparse.ArgumentParser(
|
| 56 |
+
description="The Real First Universal Charset Detector. "
|
| 57 |
+
"Discover originating encoding used on text file. "
|
| 58 |
+
"Normalize text to unicode."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"files", type=argparse.FileType("rb"), nargs="+", help="File(s) to be analysed"
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"-v",
|
| 66 |
+
"--verbose",
|
| 67 |
+
action="store_true",
|
| 68 |
+
default=False,
|
| 69 |
+
dest="verbose",
|
| 70 |
+
help="Display complementary information about file if any. "
|
| 71 |
+
"Stdout will contain logs about the detection process.",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"-a",
|
| 75 |
+
"--with-alternative",
|
| 76 |
+
action="store_true",
|
| 77 |
+
default=False,
|
| 78 |
+
dest="alternatives",
|
| 79 |
+
help="Output complementary possibilities if any. Top-level JSON WILL be a list.",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"-n",
|
| 83 |
+
"--normalize",
|
| 84 |
+
action="store_true",
|
| 85 |
+
default=False,
|
| 86 |
+
dest="normalize",
|
| 87 |
+
help="Permit to normalize input file. If not set, program does not write anything.",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"-m",
|
| 91 |
+
"--minimal",
|
| 92 |
+
action="store_true",
|
| 93 |
+
default=False,
|
| 94 |
+
dest="minimal",
|
| 95 |
+
help="Only output the charset detected to STDOUT. Disabling JSON output.",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"-r",
|
| 99 |
+
"--replace",
|
| 100 |
+
action="store_true",
|
| 101 |
+
default=False,
|
| 102 |
+
dest="replace",
|
| 103 |
+
help="Replace file when trying to normalize it instead of creating a new one.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"-f",
|
| 107 |
+
"--force",
|
| 108 |
+
action="store_true",
|
| 109 |
+
default=False,
|
| 110 |
+
dest="force",
|
| 111 |
+
help="Replace file without asking if you are sure, use this flag with caution.",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"-i",
|
| 115 |
+
"--no-preemptive",
|
| 116 |
+
action="store_true",
|
| 117 |
+
default=False,
|
| 118 |
+
dest="no_preemptive",
|
| 119 |
+
help="Disable looking at a charset declaration to hint the detector.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"-t",
|
| 123 |
+
"--threshold",
|
| 124 |
+
action="store",
|
| 125 |
+
default=0.2,
|
| 126 |
+
type=float,
|
| 127 |
+
dest="threshold",
|
| 128 |
+
help="Define a custom maximum amount of noise allowed in decoded content. 0. <= noise <= 1.",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--version",
|
| 132 |
+
action="version",
|
| 133 |
+
version="Charset-Normalizer {} - Python {} - Unicode {} - SpeedUp {}".format(
|
| 134 |
+
__version__,
|
| 135 |
+
python_version(),
|
| 136 |
+
unidata_version,
|
| 137 |
+
"OFF" if md_module.__file__.lower().endswith(".py") else "ON",
|
| 138 |
+
),
|
| 139 |
+
help="Show version information and exit.",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
args = parser.parse_args(argv)
|
| 143 |
+
|
| 144 |
+
if args.replace is True and args.normalize is False:
|
| 145 |
+
if args.files:
|
| 146 |
+
for my_file in args.files:
|
| 147 |
+
my_file.close()
|
| 148 |
+
print("Use --replace in addition of --normalize only.", file=sys.stderr)
|
| 149 |
+
return 1
|
| 150 |
+
|
| 151 |
+
if args.force is True and args.replace is False:
|
| 152 |
+
if args.files:
|
| 153 |
+
for my_file in args.files:
|
| 154 |
+
my_file.close()
|
| 155 |
+
print("Use --force in addition of --replace only.", file=sys.stderr)
|
| 156 |
+
return 1
|
| 157 |
+
|
| 158 |
+
if args.threshold < 0.0 or args.threshold > 1.0:
|
| 159 |
+
if args.files:
|
| 160 |
+
for my_file in args.files:
|
| 161 |
+
my_file.close()
|
| 162 |
+
print("--threshold VALUE should be between 0. AND 1.", file=sys.stderr)
|
| 163 |
+
return 1
|
| 164 |
+
|
| 165 |
+
x_ = []
|
| 166 |
+
|
| 167 |
+
for my_file in args.files:
|
| 168 |
+
matches = from_fp(
|
| 169 |
+
my_file,
|
| 170 |
+
threshold=args.threshold,
|
| 171 |
+
explain=args.verbose,
|
| 172 |
+
preemptive_behaviour=args.no_preemptive is False,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
best_guess = matches.best()
|
| 176 |
+
|
| 177 |
+
if best_guess is None:
|
| 178 |
+
print(
|
| 179 |
+
'Unable to identify originating encoding for "{}". {}'.format(
|
| 180 |
+
my_file.name,
|
| 181 |
+
(
|
| 182 |
+
"Maybe try increasing maximum amount of chaos."
|
| 183 |
+
if args.threshold < 1.0
|
| 184 |
+
else ""
|
| 185 |
+
),
|
| 186 |
+
),
|
| 187 |
+
file=sys.stderr,
|
| 188 |
+
)
|
| 189 |
+
x_.append(
|
| 190 |
+
CliDetectionResult(
|
| 191 |
+
abspath(my_file.name),
|
| 192 |
+
None,
|
| 193 |
+
[],
|
| 194 |
+
[],
|
| 195 |
+
"Unknown",
|
| 196 |
+
[],
|
| 197 |
+
False,
|
| 198 |
+
1.0,
|
| 199 |
+
0.0,
|
| 200 |
+
None,
|
| 201 |
+
True,
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
x_.append(
|
| 206 |
+
CliDetectionResult(
|
| 207 |
+
abspath(my_file.name),
|
| 208 |
+
best_guess.encoding,
|
| 209 |
+
best_guess.encoding_aliases,
|
| 210 |
+
[
|
| 211 |
+
cp
|
| 212 |
+
for cp in best_guess.could_be_from_charset
|
| 213 |
+
if cp != best_guess.encoding
|
| 214 |
+
],
|
| 215 |
+
best_guess.language,
|
| 216 |
+
best_guess.alphabets,
|
| 217 |
+
best_guess.bom,
|
| 218 |
+
best_guess.percent_chaos,
|
| 219 |
+
best_guess.percent_coherence,
|
| 220 |
+
None,
|
| 221 |
+
True,
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if len(matches) > 1 and args.alternatives:
|
| 226 |
+
for el in matches:
|
| 227 |
+
if el != best_guess:
|
| 228 |
+
x_.append(
|
| 229 |
+
CliDetectionResult(
|
| 230 |
+
abspath(my_file.name),
|
| 231 |
+
el.encoding,
|
| 232 |
+
el.encoding_aliases,
|
| 233 |
+
[
|
| 234 |
+
cp
|
| 235 |
+
for cp in el.could_be_from_charset
|
| 236 |
+
if cp != el.encoding
|
| 237 |
+
],
|
| 238 |
+
el.language,
|
| 239 |
+
el.alphabets,
|
| 240 |
+
el.bom,
|
| 241 |
+
el.percent_chaos,
|
| 242 |
+
el.percent_coherence,
|
| 243 |
+
None,
|
| 244 |
+
False,
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if args.normalize is True:
|
| 249 |
+
if best_guess.encoding.startswith("utf") is True:
|
| 250 |
+
print(
|
| 251 |
+
'"{}" file does not need to be normalized, as it already came from unicode.'.format(
|
| 252 |
+
my_file.name
|
| 253 |
+
),
|
| 254 |
+
file=sys.stderr,
|
| 255 |
+
)
|
| 256 |
+
if my_file.closed is False:
|
| 257 |
+
my_file.close()
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
dir_path = dirname(realpath(my_file.name))
|
| 261 |
+
file_name = basename(realpath(my_file.name))
|
| 262 |
+
|
| 263 |
+
o_: list[str] = file_name.split(".")
|
| 264 |
+
|
| 265 |
+
if args.replace is False:
|
| 266 |
+
o_.insert(-1, best_guess.encoding)
|
| 267 |
+
if my_file.closed is False:
|
| 268 |
+
my_file.close()
|
| 269 |
+
elif (
|
| 270 |
+
args.force is False
|
| 271 |
+
and query_yes_no(
|
| 272 |
+
'Are you sure to normalize "{}" by replacing it ?'.format(
|
| 273 |
+
my_file.name
|
| 274 |
+
),
|
| 275 |
+
"no",
|
| 276 |
+
)
|
| 277 |
+
is False
|
| 278 |
+
):
|
| 279 |
+
if my_file.closed is False:
|
| 280 |
+
my_file.close()
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
x_[0].unicode_path = join(dir_path, ".".join(o_))
|
| 285 |
+
|
| 286 |
+
with open(x_[0].unicode_path, "wb") as fp:
|
| 287 |
+
fp.write(best_guess.output())
|
| 288 |
+
except OSError as e:
|
| 289 |
+
print(str(e), file=sys.stderr)
|
| 290 |
+
if my_file.closed is False:
|
| 291 |
+
my_file.close()
|
| 292 |
+
return 2
|
| 293 |
+
|
| 294 |
+
if my_file.closed is False:
|
| 295 |
+
my_file.close()
|
| 296 |
+
|
| 297 |
+
if args.minimal is False:
|
| 298 |
+
print(
|
| 299 |
+
dumps(
|
| 300 |
+
[el.__dict__ for el in x_] if len(x_) > 1 else x_[0].__dict__,
|
| 301 |
+
ensure_ascii=True,
|
| 302 |
+
indent=4,
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
for my_file in args.files:
|
| 307 |
+
print(
|
| 308 |
+
", ".join(
|
| 309 |
+
[
|
| 310 |
+
el.encoding or "undefined"
|
| 311 |
+
for el in x_
|
| 312 |
+
if el.path == abspath(my_file.name)
|
| 313 |
+
]
|
| 314 |
+
)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return 0
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
cli_detect()
|
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (366 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__main__.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/__init__.py
ADDED
|
@@ -0,0 +1,1434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# ***********
|
| 16 |
+
# `huggingface_hub` init has 2 modes:
|
| 17 |
+
# - Normal usage:
|
| 18 |
+
# If imported to use it, all modules and functions are lazy-loaded. This means
|
| 19 |
+
# they exist at top level in module but are imported only the first time they are
|
| 20 |
+
# used. This way, `from huggingface_hub import something` will import `something`
|
| 21 |
+
# quickly without the hassle of importing all the features from `huggingface_hub`.
|
| 22 |
+
# - Static check:
|
| 23 |
+
# If statically analyzed, all modules and functions are loaded normally. This way
|
| 24 |
+
# static typing check works properly as well as autocomplete in text editors and
|
| 25 |
+
# IDEs.
|
| 26 |
+
#
|
| 27 |
+
# The static model imports are done inside the `if TYPE_CHECKING:` statement at
|
| 28 |
+
# the bottom of this file. Since module/functions imports are duplicated, it is
|
| 29 |
+
# mandatory to make sure to add them twice when adding one. This is checked in the
|
| 30 |
+
# `make quality` command.
|
| 31 |
+
#
|
| 32 |
+
# To update the static imports, please run the following command and commit the changes.
|
| 33 |
+
# ```
|
| 34 |
+
# # Use script
|
| 35 |
+
# python utils/check_static_imports.py --update-file
|
| 36 |
+
#
|
| 37 |
+
# # Or run style on codebase
|
| 38 |
+
# make style
|
| 39 |
+
# ```
|
| 40 |
+
#
|
| 41 |
+
# ***********
|
| 42 |
+
# Lazy loader vendored from https://github.com/scientific-python/lazy_loader
|
| 43 |
+
import importlib
|
| 44 |
+
import os
|
| 45 |
+
import sys
|
| 46 |
+
from typing import TYPE_CHECKING
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
__version__ = "0.28.1"
|
| 50 |
+
|
| 51 |
+
# Alphabetical order of definitions is ensured in tests
|
| 52 |
+
# WARNING: any comment added in this dictionary definition will be lost when
|
| 53 |
+
# re-generating the file !
|
| 54 |
+
_SUBMOD_ATTRS = {
|
| 55 |
+
"_commit_scheduler": [
|
| 56 |
+
"CommitScheduler",
|
| 57 |
+
],
|
| 58 |
+
"_inference_endpoints": [
|
| 59 |
+
"InferenceEndpoint",
|
| 60 |
+
"InferenceEndpointError",
|
| 61 |
+
"InferenceEndpointStatus",
|
| 62 |
+
"InferenceEndpointTimeoutError",
|
| 63 |
+
"InferenceEndpointType",
|
| 64 |
+
],
|
| 65 |
+
"_login": [
|
| 66 |
+
"auth_list",
|
| 67 |
+
"auth_switch",
|
| 68 |
+
"interpreter_login",
|
| 69 |
+
"login",
|
| 70 |
+
"logout",
|
| 71 |
+
"notebook_login",
|
| 72 |
+
],
|
| 73 |
+
"_snapshot_download": [
|
| 74 |
+
"snapshot_download",
|
| 75 |
+
],
|
| 76 |
+
"_space_api": [
|
| 77 |
+
"SpaceHardware",
|
| 78 |
+
"SpaceRuntime",
|
| 79 |
+
"SpaceStage",
|
| 80 |
+
"SpaceStorage",
|
| 81 |
+
"SpaceVariable",
|
| 82 |
+
],
|
| 83 |
+
"_tensorboard_logger": [
|
| 84 |
+
"HFSummaryWriter",
|
| 85 |
+
],
|
| 86 |
+
"_webhooks_payload": [
|
| 87 |
+
"WebhookPayload",
|
| 88 |
+
"WebhookPayloadComment",
|
| 89 |
+
"WebhookPayloadDiscussion",
|
| 90 |
+
"WebhookPayloadDiscussionChanges",
|
| 91 |
+
"WebhookPayloadEvent",
|
| 92 |
+
"WebhookPayloadMovedTo",
|
| 93 |
+
"WebhookPayloadRepo",
|
| 94 |
+
"WebhookPayloadUrl",
|
| 95 |
+
"WebhookPayloadWebhook",
|
| 96 |
+
],
|
| 97 |
+
"_webhooks_server": [
|
| 98 |
+
"WebhooksServer",
|
| 99 |
+
"webhook_endpoint",
|
| 100 |
+
],
|
| 101 |
+
"community": [
|
| 102 |
+
"Discussion",
|
| 103 |
+
"DiscussionComment",
|
| 104 |
+
"DiscussionCommit",
|
| 105 |
+
"DiscussionEvent",
|
| 106 |
+
"DiscussionStatusChange",
|
| 107 |
+
"DiscussionTitleChange",
|
| 108 |
+
"DiscussionWithDetails",
|
| 109 |
+
],
|
| 110 |
+
"constants": [
|
| 111 |
+
"CONFIG_NAME",
|
| 112 |
+
"FLAX_WEIGHTS_NAME",
|
| 113 |
+
"HUGGINGFACE_CO_URL_HOME",
|
| 114 |
+
"HUGGINGFACE_CO_URL_TEMPLATE",
|
| 115 |
+
"PYTORCH_WEIGHTS_NAME",
|
| 116 |
+
"REPO_TYPE_DATASET",
|
| 117 |
+
"REPO_TYPE_MODEL",
|
| 118 |
+
"REPO_TYPE_SPACE",
|
| 119 |
+
"TF2_WEIGHTS_NAME",
|
| 120 |
+
"TF_WEIGHTS_NAME",
|
| 121 |
+
],
|
| 122 |
+
"fastai_utils": [
|
| 123 |
+
"_save_pretrained_fastai",
|
| 124 |
+
"from_pretrained_fastai",
|
| 125 |
+
"push_to_hub_fastai",
|
| 126 |
+
],
|
| 127 |
+
"file_download": [
|
| 128 |
+
"HfFileMetadata",
|
| 129 |
+
"_CACHED_NO_EXIST",
|
| 130 |
+
"get_hf_file_metadata",
|
| 131 |
+
"hf_hub_download",
|
| 132 |
+
"hf_hub_url",
|
| 133 |
+
"try_to_load_from_cache",
|
| 134 |
+
],
|
| 135 |
+
"hf_api": [
|
| 136 |
+
"Collection",
|
| 137 |
+
"CollectionItem",
|
| 138 |
+
"CommitInfo",
|
| 139 |
+
"CommitOperation",
|
| 140 |
+
"CommitOperationAdd",
|
| 141 |
+
"CommitOperationCopy",
|
| 142 |
+
"CommitOperationDelete",
|
| 143 |
+
"DatasetInfo",
|
| 144 |
+
"GitCommitInfo",
|
| 145 |
+
"GitRefInfo",
|
| 146 |
+
"GitRefs",
|
| 147 |
+
"HfApi",
|
| 148 |
+
"ModelInfo",
|
| 149 |
+
"RepoUrl",
|
| 150 |
+
"SpaceInfo",
|
| 151 |
+
"User",
|
| 152 |
+
"UserLikes",
|
| 153 |
+
"WebhookInfo",
|
| 154 |
+
"WebhookWatchedItem",
|
| 155 |
+
"accept_access_request",
|
| 156 |
+
"add_collection_item",
|
| 157 |
+
"add_space_secret",
|
| 158 |
+
"add_space_variable",
|
| 159 |
+
"auth_check",
|
| 160 |
+
"cancel_access_request",
|
| 161 |
+
"change_discussion_status",
|
| 162 |
+
"comment_discussion",
|
| 163 |
+
"create_branch",
|
| 164 |
+
"create_collection",
|
| 165 |
+
"create_commit",
|
| 166 |
+
"create_discussion",
|
| 167 |
+
"create_inference_endpoint",
|
| 168 |
+
"create_pull_request",
|
| 169 |
+
"create_repo",
|
| 170 |
+
"create_tag",
|
| 171 |
+
"create_webhook",
|
| 172 |
+
"dataset_info",
|
| 173 |
+
"delete_branch",
|
| 174 |
+
"delete_collection",
|
| 175 |
+
"delete_collection_item",
|
| 176 |
+
"delete_file",
|
| 177 |
+
"delete_folder",
|
| 178 |
+
"delete_inference_endpoint",
|
| 179 |
+
"delete_repo",
|
| 180 |
+
"delete_space_secret",
|
| 181 |
+
"delete_space_storage",
|
| 182 |
+
"delete_space_variable",
|
| 183 |
+
"delete_tag",
|
| 184 |
+
"delete_webhook",
|
| 185 |
+
"disable_webhook",
|
| 186 |
+
"duplicate_space",
|
| 187 |
+
"edit_discussion_comment",
|
| 188 |
+
"enable_webhook",
|
| 189 |
+
"file_exists",
|
| 190 |
+
"get_collection",
|
| 191 |
+
"get_dataset_tags",
|
| 192 |
+
"get_discussion_details",
|
| 193 |
+
"get_full_repo_name",
|
| 194 |
+
"get_inference_endpoint",
|
| 195 |
+
"get_model_tags",
|
| 196 |
+
"get_paths_info",
|
| 197 |
+
"get_repo_discussions",
|
| 198 |
+
"get_safetensors_metadata",
|
| 199 |
+
"get_space_runtime",
|
| 200 |
+
"get_space_variables",
|
| 201 |
+
"get_token_permission",
|
| 202 |
+
"get_user_overview",
|
| 203 |
+
"get_webhook",
|
| 204 |
+
"grant_access",
|
| 205 |
+
"list_accepted_access_requests",
|
| 206 |
+
"list_collections",
|
| 207 |
+
"list_datasets",
|
| 208 |
+
"list_inference_endpoints",
|
| 209 |
+
"list_liked_repos",
|
| 210 |
+
"list_models",
|
| 211 |
+
"list_organization_members",
|
| 212 |
+
"list_papers",
|
| 213 |
+
"list_pending_access_requests",
|
| 214 |
+
"list_rejected_access_requests",
|
| 215 |
+
"list_repo_commits",
|
| 216 |
+
"list_repo_files",
|
| 217 |
+
"list_repo_likers",
|
| 218 |
+
"list_repo_refs",
|
| 219 |
+
"list_repo_tree",
|
| 220 |
+
"list_spaces",
|
| 221 |
+
"list_user_followers",
|
| 222 |
+
"list_user_following",
|
| 223 |
+
"list_webhooks",
|
| 224 |
+
"merge_pull_request",
|
| 225 |
+
"model_info",
|
| 226 |
+
"move_repo",
|
| 227 |
+
"paper_info",
|
| 228 |
+
"parse_safetensors_file_metadata",
|
| 229 |
+
"pause_inference_endpoint",
|
| 230 |
+
"pause_space",
|
| 231 |
+
"preupload_lfs_files",
|
| 232 |
+
"reject_access_request",
|
| 233 |
+
"rename_discussion",
|
| 234 |
+
"repo_exists",
|
| 235 |
+
"repo_info",
|
| 236 |
+
"repo_type_and_id_from_hf_id",
|
| 237 |
+
"request_space_hardware",
|
| 238 |
+
"request_space_storage",
|
| 239 |
+
"restart_space",
|
| 240 |
+
"resume_inference_endpoint",
|
| 241 |
+
"revision_exists",
|
| 242 |
+
"run_as_future",
|
| 243 |
+
"scale_to_zero_inference_endpoint",
|
| 244 |
+
"set_space_sleep_time",
|
| 245 |
+
"space_info",
|
| 246 |
+
"super_squash_history",
|
| 247 |
+
"unlike",
|
| 248 |
+
"update_collection_item",
|
| 249 |
+
"update_collection_metadata",
|
| 250 |
+
"update_inference_endpoint",
|
| 251 |
+
"update_repo_settings",
|
| 252 |
+
"update_repo_visibility",
|
| 253 |
+
"update_webhook",
|
| 254 |
+
"upload_file",
|
| 255 |
+
"upload_folder",
|
| 256 |
+
"upload_large_folder",
|
| 257 |
+
"whoami",
|
| 258 |
+
],
|
| 259 |
+
"hf_file_system": [
|
| 260 |
+
"HfFileSystem",
|
| 261 |
+
"HfFileSystemFile",
|
| 262 |
+
"HfFileSystemResolvedPath",
|
| 263 |
+
"HfFileSystemStreamFile",
|
| 264 |
+
],
|
| 265 |
+
"hub_mixin": [
|
| 266 |
+
"ModelHubMixin",
|
| 267 |
+
"PyTorchModelHubMixin",
|
| 268 |
+
],
|
| 269 |
+
"inference._client": [
|
| 270 |
+
"InferenceClient",
|
| 271 |
+
"InferenceTimeoutError",
|
| 272 |
+
],
|
| 273 |
+
"inference._generated._async_client": [
|
| 274 |
+
"AsyncInferenceClient",
|
| 275 |
+
],
|
| 276 |
+
"inference._generated.types": [
|
| 277 |
+
"AudioClassificationInput",
|
| 278 |
+
"AudioClassificationOutputElement",
|
| 279 |
+
"AudioClassificationOutputTransform",
|
| 280 |
+
"AudioClassificationParameters",
|
| 281 |
+
"AudioToAudioInput",
|
| 282 |
+
"AudioToAudioOutputElement",
|
| 283 |
+
"AutomaticSpeechRecognitionEarlyStoppingEnum",
|
| 284 |
+
"AutomaticSpeechRecognitionGenerationParameters",
|
| 285 |
+
"AutomaticSpeechRecognitionInput",
|
| 286 |
+
"AutomaticSpeechRecognitionOutput",
|
| 287 |
+
"AutomaticSpeechRecognitionOutputChunk",
|
| 288 |
+
"AutomaticSpeechRecognitionParameters",
|
| 289 |
+
"ChatCompletionInput",
|
| 290 |
+
"ChatCompletionInputFunctionDefinition",
|
| 291 |
+
"ChatCompletionInputFunctionName",
|
| 292 |
+
"ChatCompletionInputGrammarType",
|
| 293 |
+
"ChatCompletionInputGrammarTypeType",
|
| 294 |
+
"ChatCompletionInputMessage",
|
| 295 |
+
"ChatCompletionInputMessageChunk",
|
| 296 |
+
"ChatCompletionInputMessageChunkType",
|
| 297 |
+
"ChatCompletionInputStreamOptions",
|
| 298 |
+
"ChatCompletionInputTool",
|
| 299 |
+
"ChatCompletionInputToolChoiceClass",
|
| 300 |
+
"ChatCompletionInputToolChoiceEnum",
|
| 301 |
+
"ChatCompletionInputURL",
|
| 302 |
+
"ChatCompletionOutput",
|
| 303 |
+
"ChatCompletionOutputComplete",
|
| 304 |
+
"ChatCompletionOutputFunctionDefinition",
|
| 305 |
+
"ChatCompletionOutputLogprob",
|
| 306 |
+
"ChatCompletionOutputLogprobs",
|
| 307 |
+
"ChatCompletionOutputMessage",
|
| 308 |
+
"ChatCompletionOutputToolCall",
|
| 309 |
+
"ChatCompletionOutputTopLogprob",
|
| 310 |
+
"ChatCompletionOutputUsage",
|
| 311 |
+
"ChatCompletionStreamOutput",
|
| 312 |
+
"ChatCompletionStreamOutputChoice",
|
| 313 |
+
"ChatCompletionStreamOutputDelta",
|
| 314 |
+
"ChatCompletionStreamOutputDeltaToolCall",
|
| 315 |
+
"ChatCompletionStreamOutputFunction",
|
| 316 |
+
"ChatCompletionStreamOutputLogprob",
|
| 317 |
+
"ChatCompletionStreamOutputLogprobs",
|
| 318 |
+
"ChatCompletionStreamOutputTopLogprob",
|
| 319 |
+
"ChatCompletionStreamOutputUsage",
|
| 320 |
+
"DepthEstimationInput",
|
| 321 |
+
"DepthEstimationOutput",
|
| 322 |
+
"DocumentQuestionAnsweringInput",
|
| 323 |
+
"DocumentQuestionAnsweringInputData",
|
| 324 |
+
"DocumentQuestionAnsweringOutputElement",
|
| 325 |
+
"DocumentQuestionAnsweringParameters",
|
| 326 |
+
"FeatureExtractionInput",
|
| 327 |
+
"FeatureExtractionInputTruncationDirection",
|
| 328 |
+
"FillMaskInput",
|
| 329 |
+
"FillMaskOutputElement",
|
| 330 |
+
"FillMaskParameters",
|
| 331 |
+
"ImageClassificationInput",
|
| 332 |
+
"ImageClassificationOutputElement",
|
| 333 |
+
"ImageClassificationOutputTransform",
|
| 334 |
+
"ImageClassificationParameters",
|
| 335 |
+
"ImageSegmentationInput",
|
| 336 |
+
"ImageSegmentationOutputElement",
|
| 337 |
+
"ImageSegmentationParameters",
|
| 338 |
+
"ImageSegmentationSubtask",
|
| 339 |
+
"ImageToImageInput",
|
| 340 |
+
"ImageToImageOutput",
|
| 341 |
+
"ImageToImageParameters",
|
| 342 |
+
"ImageToImageTargetSize",
|
| 343 |
+
"ImageToTextEarlyStoppingEnum",
|
| 344 |
+
"ImageToTextGenerationParameters",
|
| 345 |
+
"ImageToTextInput",
|
| 346 |
+
"ImageToTextOutput",
|
| 347 |
+
"ImageToTextParameters",
|
| 348 |
+
"ObjectDetectionBoundingBox",
|
| 349 |
+
"ObjectDetectionInput",
|
| 350 |
+
"ObjectDetectionOutputElement",
|
| 351 |
+
"ObjectDetectionParameters",
|
| 352 |
+
"Padding",
|
| 353 |
+
"QuestionAnsweringInput",
|
| 354 |
+
"QuestionAnsweringInputData",
|
| 355 |
+
"QuestionAnsweringOutputElement",
|
| 356 |
+
"QuestionAnsweringParameters",
|
| 357 |
+
"SentenceSimilarityInput",
|
| 358 |
+
"SentenceSimilarityInputData",
|
| 359 |
+
"SummarizationInput",
|
| 360 |
+
"SummarizationOutput",
|
| 361 |
+
"SummarizationParameters",
|
| 362 |
+
"SummarizationTruncationStrategy",
|
| 363 |
+
"TableQuestionAnsweringInput",
|
| 364 |
+
"TableQuestionAnsweringInputData",
|
| 365 |
+
"TableQuestionAnsweringOutputElement",
|
| 366 |
+
"TableQuestionAnsweringParameters",
|
| 367 |
+
"Text2TextGenerationInput",
|
| 368 |
+
"Text2TextGenerationOutput",
|
| 369 |
+
"Text2TextGenerationParameters",
|
| 370 |
+
"Text2TextGenerationTruncationStrategy",
|
| 371 |
+
"TextClassificationInput",
|
| 372 |
+
"TextClassificationOutputElement",
|
| 373 |
+
"TextClassificationOutputTransform",
|
| 374 |
+
"TextClassificationParameters",
|
| 375 |
+
"TextGenerationInput",
|
| 376 |
+
"TextGenerationInputGenerateParameters",
|
| 377 |
+
"TextGenerationInputGrammarType",
|
| 378 |
+
"TextGenerationOutput",
|
| 379 |
+
"TextGenerationOutputBestOfSequence",
|
| 380 |
+
"TextGenerationOutputDetails",
|
| 381 |
+
"TextGenerationOutputFinishReason",
|
| 382 |
+
"TextGenerationOutputPrefillToken",
|
| 383 |
+
"TextGenerationOutputToken",
|
| 384 |
+
"TextGenerationStreamOutput",
|
| 385 |
+
"TextGenerationStreamOutputStreamDetails",
|
| 386 |
+
"TextGenerationStreamOutputToken",
|
| 387 |
+
"TextToAudioEarlyStoppingEnum",
|
| 388 |
+
"TextToAudioGenerationParameters",
|
| 389 |
+
"TextToAudioInput",
|
| 390 |
+
"TextToAudioOutput",
|
| 391 |
+
"TextToAudioParameters",
|
| 392 |
+
"TextToImageInput",
|
| 393 |
+
"TextToImageOutput",
|
| 394 |
+
"TextToImageParameters",
|
| 395 |
+
"TextToImageTargetSize",
|
| 396 |
+
"TextToSpeechEarlyStoppingEnum",
|
| 397 |
+
"TextToSpeechGenerationParameters",
|
| 398 |
+
"TextToSpeechInput",
|
| 399 |
+
"TextToSpeechOutput",
|
| 400 |
+
"TextToSpeechParameters",
|
| 401 |
+
"TextToVideoInput",
|
| 402 |
+
"TextToVideoOutput",
|
| 403 |
+
"TextToVideoParameters",
|
| 404 |
+
"TokenClassificationAggregationStrategy",
|
| 405 |
+
"TokenClassificationInput",
|
| 406 |
+
"TokenClassificationOutputElement",
|
| 407 |
+
"TokenClassificationParameters",
|
| 408 |
+
"TranslationInput",
|
| 409 |
+
"TranslationOutput",
|
| 410 |
+
"TranslationParameters",
|
| 411 |
+
"TranslationTruncationStrategy",
|
| 412 |
+
"TypeEnum",
|
| 413 |
+
"VideoClassificationInput",
|
| 414 |
+
"VideoClassificationOutputElement",
|
| 415 |
+
"VideoClassificationOutputTransform",
|
| 416 |
+
"VideoClassificationParameters",
|
| 417 |
+
"VisualQuestionAnsweringInput",
|
| 418 |
+
"VisualQuestionAnsweringInputData",
|
| 419 |
+
"VisualQuestionAnsweringOutputElement",
|
| 420 |
+
"VisualQuestionAnsweringParameters",
|
| 421 |
+
"ZeroShotClassificationInput",
|
| 422 |
+
"ZeroShotClassificationOutputElement",
|
| 423 |
+
"ZeroShotClassificationParameters",
|
| 424 |
+
"ZeroShotImageClassificationInput",
|
| 425 |
+
"ZeroShotImageClassificationOutputElement",
|
| 426 |
+
"ZeroShotImageClassificationParameters",
|
| 427 |
+
"ZeroShotObjectDetectionBoundingBox",
|
| 428 |
+
"ZeroShotObjectDetectionInput",
|
| 429 |
+
"ZeroShotObjectDetectionOutputElement",
|
| 430 |
+
"ZeroShotObjectDetectionParameters",
|
| 431 |
+
],
|
| 432 |
+
"inference_api": [
|
| 433 |
+
"InferenceApi",
|
| 434 |
+
],
|
| 435 |
+
"keras_mixin": [
|
| 436 |
+
"KerasModelHubMixin",
|
| 437 |
+
"from_pretrained_keras",
|
| 438 |
+
"push_to_hub_keras",
|
| 439 |
+
"save_pretrained_keras",
|
| 440 |
+
],
|
| 441 |
+
"repocard": [
|
| 442 |
+
"DatasetCard",
|
| 443 |
+
"ModelCard",
|
| 444 |
+
"RepoCard",
|
| 445 |
+
"SpaceCard",
|
| 446 |
+
"metadata_eval_result",
|
| 447 |
+
"metadata_load",
|
| 448 |
+
"metadata_save",
|
| 449 |
+
"metadata_update",
|
| 450 |
+
],
|
| 451 |
+
"repocard_data": [
|
| 452 |
+
"CardData",
|
| 453 |
+
"DatasetCardData",
|
| 454 |
+
"EvalResult",
|
| 455 |
+
"ModelCardData",
|
| 456 |
+
"SpaceCardData",
|
| 457 |
+
],
|
| 458 |
+
"repository": [
|
| 459 |
+
"Repository",
|
| 460 |
+
],
|
| 461 |
+
"serialization": [
|
| 462 |
+
"StateDictSplit",
|
| 463 |
+
"get_tf_storage_size",
|
| 464 |
+
"get_torch_storage_id",
|
| 465 |
+
"get_torch_storage_size",
|
| 466 |
+
"load_state_dict_from_file",
|
| 467 |
+
"load_torch_model",
|
| 468 |
+
"save_torch_model",
|
| 469 |
+
"save_torch_state_dict",
|
| 470 |
+
"split_state_dict_into_shards_factory",
|
| 471 |
+
"split_tf_state_dict_into_shards",
|
| 472 |
+
"split_torch_state_dict_into_shards",
|
| 473 |
+
],
|
| 474 |
+
"serialization._dduf": [
|
| 475 |
+
"DDUFEntry",
|
| 476 |
+
"export_entries_as_dduf",
|
| 477 |
+
"export_folder_as_dduf",
|
| 478 |
+
"read_dduf_file",
|
| 479 |
+
],
|
| 480 |
+
"utils": [
|
| 481 |
+
"CacheNotFound",
|
| 482 |
+
"CachedFileInfo",
|
| 483 |
+
"CachedRepoInfo",
|
| 484 |
+
"CachedRevisionInfo",
|
| 485 |
+
"CorruptedCacheException",
|
| 486 |
+
"DeleteCacheStrategy",
|
| 487 |
+
"HFCacheInfo",
|
| 488 |
+
"HfFolder",
|
| 489 |
+
"cached_assets_path",
|
| 490 |
+
"configure_http_backend",
|
| 491 |
+
"dump_environment_info",
|
| 492 |
+
"get_session",
|
| 493 |
+
"get_token",
|
| 494 |
+
"logging",
|
| 495 |
+
"scan_cache_dir",
|
| 496 |
+
],
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
# WARNING: __all__ is generated automatically, Any manual edit will be lost when re-generating this file !
|
| 500 |
+
#
|
| 501 |
+
# To update the static imports, please run the following command and commit the changes.
|
| 502 |
+
# ```
|
| 503 |
+
# # Use script
|
| 504 |
+
# python utils/check_all_variable.py --update
|
| 505 |
+
#
|
| 506 |
+
# # Or run style on codebase
|
| 507 |
+
# make style
|
| 508 |
+
# ```
|
| 509 |
+
|
| 510 |
+
__all__ = [
|
| 511 |
+
"AsyncInferenceClient",
|
| 512 |
+
"AudioClassificationInput",
|
| 513 |
+
"AudioClassificationOutputElement",
|
| 514 |
+
"AudioClassificationOutputTransform",
|
| 515 |
+
"AudioClassificationParameters",
|
| 516 |
+
"AudioToAudioInput",
|
| 517 |
+
"AudioToAudioOutputElement",
|
| 518 |
+
"AutomaticSpeechRecognitionEarlyStoppingEnum",
|
| 519 |
+
"AutomaticSpeechRecognitionGenerationParameters",
|
| 520 |
+
"AutomaticSpeechRecognitionInput",
|
| 521 |
+
"AutomaticSpeechRecognitionOutput",
|
| 522 |
+
"AutomaticSpeechRecognitionOutputChunk",
|
| 523 |
+
"AutomaticSpeechRecognitionParameters",
|
| 524 |
+
"CONFIG_NAME",
|
| 525 |
+
"CacheNotFound",
|
| 526 |
+
"CachedFileInfo",
|
| 527 |
+
"CachedRepoInfo",
|
| 528 |
+
"CachedRevisionInfo",
|
| 529 |
+
"CardData",
|
| 530 |
+
"ChatCompletionInput",
|
| 531 |
+
"ChatCompletionInputFunctionDefinition",
|
| 532 |
+
"ChatCompletionInputFunctionName",
|
| 533 |
+
"ChatCompletionInputGrammarType",
|
| 534 |
+
"ChatCompletionInputGrammarTypeType",
|
| 535 |
+
"ChatCompletionInputMessage",
|
| 536 |
+
"ChatCompletionInputMessageChunk",
|
| 537 |
+
"ChatCompletionInputMessageChunkType",
|
| 538 |
+
"ChatCompletionInputStreamOptions",
|
| 539 |
+
"ChatCompletionInputTool",
|
| 540 |
+
"ChatCompletionInputToolChoiceClass",
|
| 541 |
+
"ChatCompletionInputToolChoiceEnum",
|
| 542 |
+
"ChatCompletionInputURL",
|
| 543 |
+
"ChatCompletionOutput",
|
| 544 |
+
"ChatCompletionOutputComplete",
|
| 545 |
+
"ChatCompletionOutputFunctionDefinition",
|
| 546 |
+
"ChatCompletionOutputLogprob",
|
| 547 |
+
"ChatCompletionOutputLogprobs",
|
| 548 |
+
"ChatCompletionOutputMessage",
|
| 549 |
+
"ChatCompletionOutputToolCall",
|
| 550 |
+
"ChatCompletionOutputTopLogprob",
|
| 551 |
+
"ChatCompletionOutputUsage",
|
| 552 |
+
"ChatCompletionStreamOutput",
|
| 553 |
+
"ChatCompletionStreamOutputChoice",
|
| 554 |
+
"ChatCompletionStreamOutputDelta",
|
| 555 |
+
"ChatCompletionStreamOutputDeltaToolCall",
|
| 556 |
+
"ChatCompletionStreamOutputFunction",
|
| 557 |
+
"ChatCompletionStreamOutputLogprob",
|
| 558 |
+
"ChatCompletionStreamOutputLogprobs",
|
| 559 |
+
"ChatCompletionStreamOutputTopLogprob",
|
| 560 |
+
"ChatCompletionStreamOutputUsage",
|
| 561 |
+
"Collection",
|
| 562 |
+
"CollectionItem",
|
| 563 |
+
"CommitInfo",
|
| 564 |
+
"CommitOperation",
|
| 565 |
+
"CommitOperationAdd",
|
| 566 |
+
"CommitOperationCopy",
|
| 567 |
+
"CommitOperationDelete",
|
| 568 |
+
"CommitScheduler",
|
| 569 |
+
"CorruptedCacheException",
|
| 570 |
+
"DDUFEntry",
|
| 571 |
+
"DatasetCard",
|
| 572 |
+
"DatasetCardData",
|
| 573 |
+
"DatasetInfo",
|
| 574 |
+
"DeleteCacheStrategy",
|
| 575 |
+
"DepthEstimationInput",
|
| 576 |
+
"DepthEstimationOutput",
|
| 577 |
+
"Discussion",
|
| 578 |
+
"DiscussionComment",
|
| 579 |
+
"DiscussionCommit",
|
| 580 |
+
"DiscussionEvent",
|
| 581 |
+
"DiscussionStatusChange",
|
| 582 |
+
"DiscussionTitleChange",
|
| 583 |
+
"DiscussionWithDetails",
|
| 584 |
+
"DocumentQuestionAnsweringInput",
|
| 585 |
+
"DocumentQuestionAnsweringInputData",
|
| 586 |
+
"DocumentQuestionAnsweringOutputElement",
|
| 587 |
+
"DocumentQuestionAnsweringParameters",
|
| 588 |
+
"EvalResult",
|
| 589 |
+
"FLAX_WEIGHTS_NAME",
|
| 590 |
+
"FeatureExtractionInput",
|
| 591 |
+
"FeatureExtractionInputTruncationDirection",
|
| 592 |
+
"FillMaskInput",
|
| 593 |
+
"FillMaskOutputElement",
|
| 594 |
+
"FillMaskParameters",
|
| 595 |
+
"GitCommitInfo",
|
| 596 |
+
"GitRefInfo",
|
| 597 |
+
"GitRefs",
|
| 598 |
+
"HFCacheInfo",
|
| 599 |
+
"HFSummaryWriter",
|
| 600 |
+
"HUGGINGFACE_CO_URL_HOME",
|
| 601 |
+
"HUGGINGFACE_CO_URL_TEMPLATE",
|
| 602 |
+
"HfApi",
|
| 603 |
+
"HfFileMetadata",
|
| 604 |
+
"HfFileSystem",
|
| 605 |
+
"HfFileSystemFile",
|
| 606 |
+
"HfFileSystemResolvedPath",
|
| 607 |
+
"HfFileSystemStreamFile",
|
| 608 |
+
"HfFolder",
|
| 609 |
+
"ImageClassificationInput",
|
| 610 |
+
"ImageClassificationOutputElement",
|
| 611 |
+
"ImageClassificationOutputTransform",
|
| 612 |
+
"ImageClassificationParameters",
|
| 613 |
+
"ImageSegmentationInput",
|
| 614 |
+
"ImageSegmentationOutputElement",
|
| 615 |
+
"ImageSegmentationParameters",
|
| 616 |
+
"ImageSegmentationSubtask",
|
| 617 |
+
"ImageToImageInput",
|
| 618 |
+
"ImageToImageOutput",
|
| 619 |
+
"ImageToImageParameters",
|
| 620 |
+
"ImageToImageTargetSize",
|
| 621 |
+
"ImageToTextEarlyStoppingEnum",
|
| 622 |
+
"ImageToTextGenerationParameters",
|
| 623 |
+
"ImageToTextInput",
|
| 624 |
+
"ImageToTextOutput",
|
| 625 |
+
"ImageToTextParameters",
|
| 626 |
+
"InferenceApi",
|
| 627 |
+
"InferenceClient",
|
| 628 |
+
"InferenceEndpoint",
|
| 629 |
+
"InferenceEndpointError",
|
| 630 |
+
"InferenceEndpointStatus",
|
| 631 |
+
"InferenceEndpointTimeoutError",
|
| 632 |
+
"InferenceEndpointType",
|
| 633 |
+
"InferenceTimeoutError",
|
| 634 |
+
"KerasModelHubMixin",
|
| 635 |
+
"ModelCard",
|
| 636 |
+
"ModelCardData",
|
| 637 |
+
"ModelHubMixin",
|
| 638 |
+
"ModelInfo",
|
| 639 |
+
"ObjectDetectionBoundingBox",
|
| 640 |
+
"ObjectDetectionInput",
|
| 641 |
+
"ObjectDetectionOutputElement",
|
| 642 |
+
"ObjectDetectionParameters",
|
| 643 |
+
"PYTORCH_WEIGHTS_NAME",
|
| 644 |
+
"Padding",
|
| 645 |
+
"PyTorchModelHubMixin",
|
| 646 |
+
"QuestionAnsweringInput",
|
| 647 |
+
"QuestionAnsweringInputData",
|
| 648 |
+
"QuestionAnsweringOutputElement",
|
| 649 |
+
"QuestionAnsweringParameters",
|
| 650 |
+
"REPO_TYPE_DATASET",
|
| 651 |
+
"REPO_TYPE_MODEL",
|
| 652 |
+
"REPO_TYPE_SPACE",
|
| 653 |
+
"RepoCard",
|
| 654 |
+
"RepoUrl",
|
| 655 |
+
"Repository",
|
| 656 |
+
"SentenceSimilarityInput",
|
| 657 |
+
"SentenceSimilarityInputData",
|
| 658 |
+
"SpaceCard",
|
| 659 |
+
"SpaceCardData",
|
| 660 |
+
"SpaceHardware",
|
| 661 |
+
"SpaceInfo",
|
| 662 |
+
"SpaceRuntime",
|
| 663 |
+
"SpaceStage",
|
| 664 |
+
"SpaceStorage",
|
| 665 |
+
"SpaceVariable",
|
| 666 |
+
"StateDictSplit",
|
| 667 |
+
"SummarizationInput",
|
| 668 |
+
"SummarizationOutput",
|
| 669 |
+
"SummarizationParameters",
|
| 670 |
+
"SummarizationTruncationStrategy",
|
| 671 |
+
"TF2_WEIGHTS_NAME",
|
| 672 |
+
"TF_WEIGHTS_NAME",
|
| 673 |
+
"TableQuestionAnsweringInput",
|
| 674 |
+
"TableQuestionAnsweringInputData",
|
| 675 |
+
"TableQuestionAnsweringOutputElement",
|
| 676 |
+
"TableQuestionAnsweringParameters",
|
| 677 |
+
"Text2TextGenerationInput",
|
| 678 |
+
"Text2TextGenerationOutput",
|
| 679 |
+
"Text2TextGenerationParameters",
|
| 680 |
+
"Text2TextGenerationTruncationStrategy",
|
| 681 |
+
"TextClassificationInput",
|
| 682 |
+
"TextClassificationOutputElement",
|
| 683 |
+
"TextClassificationOutputTransform",
|
| 684 |
+
"TextClassificationParameters",
|
| 685 |
+
"TextGenerationInput",
|
| 686 |
+
"TextGenerationInputGenerateParameters",
|
| 687 |
+
"TextGenerationInputGrammarType",
|
| 688 |
+
"TextGenerationOutput",
|
| 689 |
+
"TextGenerationOutputBestOfSequence",
|
| 690 |
+
"TextGenerationOutputDetails",
|
| 691 |
+
"TextGenerationOutputFinishReason",
|
| 692 |
+
"TextGenerationOutputPrefillToken",
|
| 693 |
+
"TextGenerationOutputToken",
|
| 694 |
+
"TextGenerationStreamOutput",
|
| 695 |
+
"TextGenerationStreamOutputStreamDetails",
|
| 696 |
+
"TextGenerationStreamOutputToken",
|
| 697 |
+
"TextToAudioEarlyStoppingEnum",
|
| 698 |
+
"TextToAudioGenerationParameters",
|
| 699 |
+
"TextToAudioInput",
|
| 700 |
+
"TextToAudioOutput",
|
| 701 |
+
"TextToAudioParameters",
|
| 702 |
+
"TextToImageInput",
|
| 703 |
+
"TextToImageOutput",
|
| 704 |
+
"TextToImageParameters",
|
| 705 |
+
"TextToImageTargetSize",
|
| 706 |
+
"TextToSpeechEarlyStoppingEnum",
|
| 707 |
+
"TextToSpeechGenerationParameters",
|
| 708 |
+
"TextToSpeechInput",
|
| 709 |
+
"TextToSpeechOutput",
|
| 710 |
+
"TextToSpeechParameters",
|
| 711 |
+
"TextToVideoInput",
|
| 712 |
+
"TextToVideoOutput",
|
| 713 |
+
"TextToVideoParameters",
|
| 714 |
+
"TokenClassificationAggregationStrategy",
|
| 715 |
+
"TokenClassificationInput",
|
| 716 |
+
"TokenClassificationOutputElement",
|
| 717 |
+
"TokenClassificationParameters",
|
| 718 |
+
"TranslationInput",
|
| 719 |
+
"TranslationOutput",
|
| 720 |
+
"TranslationParameters",
|
| 721 |
+
"TranslationTruncationStrategy",
|
| 722 |
+
"TypeEnum",
|
| 723 |
+
"User",
|
| 724 |
+
"UserLikes",
|
| 725 |
+
"VideoClassificationInput",
|
| 726 |
+
"VideoClassificationOutputElement",
|
| 727 |
+
"VideoClassificationOutputTransform",
|
| 728 |
+
"VideoClassificationParameters",
|
| 729 |
+
"VisualQuestionAnsweringInput",
|
| 730 |
+
"VisualQuestionAnsweringInputData",
|
| 731 |
+
"VisualQuestionAnsweringOutputElement",
|
| 732 |
+
"VisualQuestionAnsweringParameters",
|
| 733 |
+
"WebhookInfo",
|
| 734 |
+
"WebhookPayload",
|
| 735 |
+
"WebhookPayloadComment",
|
| 736 |
+
"WebhookPayloadDiscussion",
|
| 737 |
+
"WebhookPayloadDiscussionChanges",
|
| 738 |
+
"WebhookPayloadEvent",
|
| 739 |
+
"WebhookPayloadMovedTo",
|
| 740 |
+
"WebhookPayloadRepo",
|
| 741 |
+
"WebhookPayloadUrl",
|
| 742 |
+
"WebhookPayloadWebhook",
|
| 743 |
+
"WebhookWatchedItem",
|
| 744 |
+
"WebhooksServer",
|
| 745 |
+
"ZeroShotClassificationInput",
|
| 746 |
+
"ZeroShotClassificationOutputElement",
|
| 747 |
+
"ZeroShotClassificationParameters",
|
| 748 |
+
"ZeroShotImageClassificationInput",
|
| 749 |
+
"ZeroShotImageClassificationOutputElement",
|
| 750 |
+
"ZeroShotImageClassificationParameters",
|
| 751 |
+
"ZeroShotObjectDetectionBoundingBox",
|
| 752 |
+
"ZeroShotObjectDetectionInput",
|
| 753 |
+
"ZeroShotObjectDetectionOutputElement",
|
| 754 |
+
"ZeroShotObjectDetectionParameters",
|
| 755 |
+
"_CACHED_NO_EXIST",
|
| 756 |
+
"_save_pretrained_fastai",
|
| 757 |
+
"accept_access_request",
|
| 758 |
+
"add_collection_item",
|
| 759 |
+
"add_space_secret",
|
| 760 |
+
"add_space_variable",
|
| 761 |
+
"auth_check",
|
| 762 |
+
"auth_list",
|
| 763 |
+
"auth_switch",
|
| 764 |
+
"cached_assets_path",
|
| 765 |
+
"cancel_access_request",
|
| 766 |
+
"change_discussion_status",
|
| 767 |
+
"comment_discussion",
|
| 768 |
+
"configure_http_backend",
|
| 769 |
+
"create_branch",
|
| 770 |
+
"create_collection",
|
| 771 |
+
"create_commit",
|
| 772 |
+
"create_discussion",
|
| 773 |
+
"create_inference_endpoint",
|
| 774 |
+
"create_pull_request",
|
| 775 |
+
"create_repo",
|
| 776 |
+
"create_tag",
|
| 777 |
+
"create_webhook",
|
| 778 |
+
"dataset_info",
|
| 779 |
+
"delete_branch",
|
| 780 |
+
"delete_collection",
|
| 781 |
+
"delete_collection_item",
|
| 782 |
+
"delete_file",
|
| 783 |
+
"delete_folder",
|
| 784 |
+
"delete_inference_endpoint",
|
| 785 |
+
"delete_repo",
|
| 786 |
+
"delete_space_secret",
|
| 787 |
+
"delete_space_storage",
|
| 788 |
+
"delete_space_variable",
|
| 789 |
+
"delete_tag",
|
| 790 |
+
"delete_webhook",
|
| 791 |
+
"disable_webhook",
|
| 792 |
+
"dump_environment_info",
|
| 793 |
+
"duplicate_space",
|
| 794 |
+
"edit_discussion_comment",
|
| 795 |
+
"enable_webhook",
|
| 796 |
+
"export_entries_as_dduf",
|
| 797 |
+
"export_folder_as_dduf",
|
| 798 |
+
"file_exists",
|
| 799 |
+
"from_pretrained_fastai",
|
| 800 |
+
"from_pretrained_keras",
|
| 801 |
+
"get_collection",
|
| 802 |
+
"get_dataset_tags",
|
| 803 |
+
"get_discussion_details",
|
| 804 |
+
"get_full_repo_name",
|
| 805 |
+
"get_hf_file_metadata",
|
| 806 |
+
"get_inference_endpoint",
|
| 807 |
+
"get_model_tags",
|
| 808 |
+
"get_paths_info",
|
| 809 |
+
"get_repo_discussions",
|
| 810 |
+
"get_safetensors_metadata",
|
| 811 |
+
"get_session",
|
| 812 |
+
"get_space_runtime",
|
| 813 |
+
"get_space_variables",
|
| 814 |
+
"get_tf_storage_size",
|
| 815 |
+
"get_token",
|
| 816 |
+
"get_token_permission",
|
| 817 |
+
"get_torch_storage_id",
|
| 818 |
+
"get_torch_storage_size",
|
| 819 |
+
"get_user_overview",
|
| 820 |
+
"get_webhook",
|
| 821 |
+
"grant_access",
|
| 822 |
+
"hf_hub_download",
|
| 823 |
+
"hf_hub_url",
|
| 824 |
+
"interpreter_login",
|
| 825 |
+
"list_accepted_access_requests",
|
| 826 |
+
"list_collections",
|
| 827 |
+
"list_datasets",
|
| 828 |
+
"list_inference_endpoints",
|
| 829 |
+
"list_liked_repos",
|
| 830 |
+
"list_models",
|
| 831 |
+
"list_organization_members",
|
| 832 |
+
"list_papers",
|
| 833 |
+
"list_pending_access_requests",
|
| 834 |
+
"list_rejected_access_requests",
|
| 835 |
+
"list_repo_commits",
|
| 836 |
+
"list_repo_files",
|
| 837 |
+
"list_repo_likers",
|
| 838 |
+
"list_repo_refs",
|
| 839 |
+
"list_repo_tree",
|
| 840 |
+
"list_spaces",
|
| 841 |
+
"list_user_followers",
|
| 842 |
+
"list_user_following",
|
| 843 |
+
"list_webhooks",
|
| 844 |
+
"load_state_dict_from_file",
|
| 845 |
+
"load_torch_model",
|
| 846 |
+
"logging",
|
| 847 |
+
"login",
|
| 848 |
+
"logout",
|
| 849 |
+
"merge_pull_request",
|
| 850 |
+
"metadata_eval_result",
|
| 851 |
+
"metadata_load",
|
| 852 |
+
"metadata_save",
|
| 853 |
+
"metadata_update",
|
| 854 |
+
"model_info",
|
| 855 |
+
"move_repo",
|
| 856 |
+
"notebook_login",
|
| 857 |
+
"paper_info",
|
| 858 |
+
"parse_safetensors_file_metadata",
|
| 859 |
+
"pause_inference_endpoint",
|
| 860 |
+
"pause_space",
|
| 861 |
+
"preupload_lfs_files",
|
| 862 |
+
"push_to_hub_fastai",
|
| 863 |
+
"push_to_hub_keras",
|
| 864 |
+
"read_dduf_file",
|
| 865 |
+
"reject_access_request",
|
| 866 |
+
"rename_discussion",
|
| 867 |
+
"repo_exists",
|
| 868 |
+
"repo_info",
|
| 869 |
+
"repo_type_and_id_from_hf_id",
|
| 870 |
+
"request_space_hardware",
|
| 871 |
+
"request_space_storage",
|
| 872 |
+
"restart_space",
|
| 873 |
+
"resume_inference_endpoint",
|
| 874 |
+
"revision_exists",
|
| 875 |
+
"run_as_future",
|
| 876 |
+
"save_pretrained_keras",
|
| 877 |
+
"save_torch_model",
|
| 878 |
+
"save_torch_state_dict",
|
| 879 |
+
"scale_to_zero_inference_endpoint",
|
| 880 |
+
"scan_cache_dir",
|
| 881 |
+
"set_space_sleep_time",
|
| 882 |
+
"snapshot_download",
|
| 883 |
+
"space_info",
|
| 884 |
+
"split_state_dict_into_shards_factory",
|
| 885 |
+
"split_tf_state_dict_into_shards",
|
| 886 |
+
"split_torch_state_dict_into_shards",
|
| 887 |
+
"super_squash_history",
|
| 888 |
+
"try_to_load_from_cache",
|
| 889 |
+
"unlike",
|
| 890 |
+
"update_collection_item",
|
| 891 |
+
"update_collection_metadata",
|
| 892 |
+
"update_inference_endpoint",
|
| 893 |
+
"update_repo_settings",
|
| 894 |
+
"update_repo_visibility",
|
| 895 |
+
"update_webhook",
|
| 896 |
+
"upload_file",
|
| 897 |
+
"upload_folder",
|
| 898 |
+
"upload_large_folder",
|
| 899 |
+
"webhook_endpoint",
|
| 900 |
+
"whoami",
|
| 901 |
+
]
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
def _attach(package_name, submodules=None, submod_attrs=None):
|
| 905 |
+
"""Attach lazily loaded submodules, functions, or other attributes.
|
| 906 |
+
|
| 907 |
+
Typically, modules import submodules and attributes as follows:
|
| 908 |
+
|
| 909 |
+
```py
|
| 910 |
+
import mysubmodule
|
| 911 |
+
import anothersubmodule
|
| 912 |
+
|
| 913 |
+
from .foo import someattr
|
| 914 |
+
```
|
| 915 |
+
|
| 916 |
+
The idea is to replace a package's `__getattr__`, `__dir__`, such that all imports
|
| 917 |
+
work exactly the way they would with normal imports, except that the import occurs
|
| 918 |
+
upon first use.
|
| 919 |
+
|
| 920 |
+
The typical way to call this function, replacing the above imports, is:
|
| 921 |
+
|
| 922 |
+
```python
|
| 923 |
+
__getattr__, __dir__ = lazy.attach(
|
| 924 |
+
__name__,
|
| 925 |
+
['mysubmodule', 'anothersubmodule'],
|
| 926 |
+
{'foo': ['someattr']}
|
| 927 |
+
)
|
| 928 |
+
```
|
| 929 |
+
This functionality requires Python 3.7 or higher.
|
| 930 |
+
|
| 931 |
+
Args:
|
| 932 |
+
package_name (`str`):
|
| 933 |
+
Typically use `__name__`.
|
| 934 |
+
submodules (`set`):
|
| 935 |
+
List of submodules to attach.
|
| 936 |
+
submod_attrs (`dict`):
|
| 937 |
+
Dictionary of submodule -> list of attributes / functions.
|
| 938 |
+
These attributes are imported as they are used.
|
| 939 |
+
|
| 940 |
+
Returns:
|
| 941 |
+
__getattr__, __dir__, __all__
|
| 942 |
+
|
| 943 |
+
"""
|
| 944 |
+
if submod_attrs is None:
|
| 945 |
+
submod_attrs = {}
|
| 946 |
+
|
| 947 |
+
if submodules is None:
|
| 948 |
+
submodules = set()
|
| 949 |
+
else:
|
| 950 |
+
submodules = set(submodules)
|
| 951 |
+
|
| 952 |
+
attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs}
|
| 953 |
+
|
| 954 |
+
def __getattr__(name):
|
| 955 |
+
if name in submodules:
|
| 956 |
+
try:
|
| 957 |
+
return importlib.import_module(f"{package_name}.{name}")
|
| 958 |
+
except Exception as e:
|
| 959 |
+
print(f"Error importing {package_name}.{name}: {e}")
|
| 960 |
+
raise
|
| 961 |
+
elif name in attr_to_modules:
|
| 962 |
+
submod_path = f"{package_name}.{attr_to_modules[name]}"
|
| 963 |
+
try:
|
| 964 |
+
submod = importlib.import_module(submod_path)
|
| 965 |
+
except Exception as e:
|
| 966 |
+
print(f"Error importing {submod_path}: {e}")
|
| 967 |
+
raise
|
| 968 |
+
attr = getattr(submod, name)
|
| 969 |
+
|
| 970 |
+
# If the attribute lives in a file (module) with the same
|
| 971 |
+
# name as the attribute, ensure that the attribute and *not*
|
| 972 |
+
# the module is accessible on the package.
|
| 973 |
+
if name == attr_to_modules[name]:
|
| 974 |
+
pkg = sys.modules[package_name]
|
| 975 |
+
pkg.__dict__[name] = attr
|
| 976 |
+
|
| 977 |
+
return attr
|
| 978 |
+
else:
|
| 979 |
+
raise AttributeError(f"No {package_name} attribute {name}")
|
| 980 |
+
|
| 981 |
+
def __dir__():
|
| 982 |
+
return __all__
|
| 983 |
+
|
| 984 |
+
return __getattr__, __dir__
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
__getattr__, __dir__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS)
|
| 988 |
+
|
| 989 |
+
if os.environ.get("EAGER_IMPORT", ""):
|
| 990 |
+
for attr in __all__:
|
| 991 |
+
__getattr__(attr)
|
| 992 |
+
|
| 993 |
+
# WARNING: any content below this statement is generated automatically. Any manual edit
|
| 994 |
+
# will be lost when re-generating this file !
|
| 995 |
+
#
|
| 996 |
+
# To update the static imports, please run the following command and commit the changes.
|
| 997 |
+
# ```
|
| 998 |
+
# # Use script
|
| 999 |
+
# python utils/check_static_imports.py --update
|
| 1000 |
+
#
|
| 1001 |
+
# # Or run style on codebase
|
| 1002 |
+
# make style
|
| 1003 |
+
# ```
|
| 1004 |
+
if TYPE_CHECKING: # pragma: no cover
|
| 1005 |
+
from ._commit_scheduler import CommitScheduler # noqa: F401
|
| 1006 |
+
from ._inference_endpoints import (
|
| 1007 |
+
InferenceEndpoint, # noqa: F401
|
| 1008 |
+
InferenceEndpointError, # noqa: F401
|
| 1009 |
+
InferenceEndpointStatus, # noqa: F401
|
| 1010 |
+
InferenceEndpointTimeoutError, # noqa: F401
|
| 1011 |
+
InferenceEndpointType, # noqa: F401
|
| 1012 |
+
)
|
| 1013 |
+
from ._login import (
|
| 1014 |
+
auth_list, # noqa: F401
|
| 1015 |
+
auth_switch, # noqa: F401
|
| 1016 |
+
interpreter_login, # noqa: F401
|
| 1017 |
+
login, # noqa: F401
|
| 1018 |
+
logout, # noqa: F401
|
| 1019 |
+
notebook_login, # noqa: F401
|
| 1020 |
+
)
|
| 1021 |
+
from ._snapshot_download import snapshot_download # noqa: F401
|
| 1022 |
+
from ._space_api import (
|
| 1023 |
+
SpaceHardware, # noqa: F401
|
| 1024 |
+
SpaceRuntime, # noqa: F401
|
| 1025 |
+
SpaceStage, # noqa: F401
|
| 1026 |
+
SpaceStorage, # noqa: F401
|
| 1027 |
+
SpaceVariable, # noqa: F401
|
| 1028 |
+
)
|
| 1029 |
+
from ._tensorboard_logger import HFSummaryWriter # noqa: F401
|
| 1030 |
+
from ._webhooks_payload import (
|
| 1031 |
+
WebhookPayload, # noqa: F401
|
| 1032 |
+
WebhookPayloadComment, # noqa: F401
|
| 1033 |
+
WebhookPayloadDiscussion, # noqa: F401
|
| 1034 |
+
WebhookPayloadDiscussionChanges, # noqa: F401
|
| 1035 |
+
WebhookPayloadEvent, # noqa: F401
|
| 1036 |
+
WebhookPayloadMovedTo, # noqa: F401
|
| 1037 |
+
WebhookPayloadRepo, # noqa: F401
|
| 1038 |
+
WebhookPayloadUrl, # noqa: F401
|
| 1039 |
+
WebhookPayloadWebhook, # noqa: F401
|
| 1040 |
+
)
|
| 1041 |
+
from ._webhooks_server import (
|
| 1042 |
+
WebhooksServer, # noqa: F401
|
| 1043 |
+
webhook_endpoint, # noqa: F401
|
| 1044 |
+
)
|
| 1045 |
+
from .community import (
|
| 1046 |
+
Discussion, # noqa: F401
|
| 1047 |
+
DiscussionComment, # noqa: F401
|
| 1048 |
+
DiscussionCommit, # noqa: F401
|
| 1049 |
+
DiscussionEvent, # noqa: F401
|
| 1050 |
+
DiscussionStatusChange, # noqa: F401
|
| 1051 |
+
DiscussionTitleChange, # noqa: F401
|
| 1052 |
+
DiscussionWithDetails, # noqa: F401
|
| 1053 |
+
)
|
| 1054 |
+
from .constants import (
|
| 1055 |
+
CONFIG_NAME, # noqa: F401
|
| 1056 |
+
FLAX_WEIGHTS_NAME, # noqa: F401
|
| 1057 |
+
HUGGINGFACE_CO_URL_HOME, # noqa: F401
|
| 1058 |
+
HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401
|
| 1059 |
+
PYTORCH_WEIGHTS_NAME, # noqa: F401
|
| 1060 |
+
REPO_TYPE_DATASET, # noqa: F401
|
| 1061 |
+
REPO_TYPE_MODEL, # noqa: F401
|
| 1062 |
+
REPO_TYPE_SPACE, # noqa: F401
|
| 1063 |
+
TF2_WEIGHTS_NAME, # noqa: F401
|
| 1064 |
+
TF_WEIGHTS_NAME, # noqa: F401
|
| 1065 |
+
)
|
| 1066 |
+
from .fastai_utils import (
|
| 1067 |
+
_save_pretrained_fastai, # noqa: F401
|
| 1068 |
+
from_pretrained_fastai, # noqa: F401
|
| 1069 |
+
push_to_hub_fastai, # noqa: F401
|
| 1070 |
+
)
|
| 1071 |
+
from .file_download import (
|
| 1072 |
+
_CACHED_NO_EXIST, # noqa: F401
|
| 1073 |
+
HfFileMetadata, # noqa: F401
|
| 1074 |
+
get_hf_file_metadata, # noqa: F401
|
| 1075 |
+
hf_hub_download, # noqa: F401
|
| 1076 |
+
hf_hub_url, # noqa: F401
|
| 1077 |
+
try_to_load_from_cache, # noqa: F401
|
| 1078 |
+
)
|
| 1079 |
+
from .hf_api import (
|
| 1080 |
+
Collection, # noqa: F401
|
| 1081 |
+
CollectionItem, # noqa: F401
|
| 1082 |
+
CommitInfo, # noqa: F401
|
| 1083 |
+
CommitOperation, # noqa: F401
|
| 1084 |
+
CommitOperationAdd, # noqa: F401
|
| 1085 |
+
CommitOperationCopy, # noqa: F401
|
| 1086 |
+
CommitOperationDelete, # noqa: F401
|
| 1087 |
+
DatasetInfo, # noqa: F401
|
| 1088 |
+
GitCommitInfo, # noqa: F401
|
| 1089 |
+
GitRefInfo, # noqa: F401
|
| 1090 |
+
GitRefs, # noqa: F401
|
| 1091 |
+
HfApi, # noqa: F401
|
| 1092 |
+
ModelInfo, # noqa: F401
|
| 1093 |
+
RepoUrl, # noqa: F401
|
| 1094 |
+
SpaceInfo, # noqa: F401
|
| 1095 |
+
User, # noqa: F401
|
| 1096 |
+
UserLikes, # noqa: F401
|
| 1097 |
+
WebhookInfo, # noqa: F401
|
| 1098 |
+
WebhookWatchedItem, # noqa: F401
|
| 1099 |
+
accept_access_request, # noqa: F401
|
| 1100 |
+
add_collection_item, # noqa: F401
|
| 1101 |
+
add_space_secret, # noqa: F401
|
| 1102 |
+
add_space_variable, # noqa: F401
|
| 1103 |
+
auth_check, # noqa: F401
|
| 1104 |
+
cancel_access_request, # noqa: F401
|
| 1105 |
+
change_discussion_status, # noqa: F401
|
| 1106 |
+
comment_discussion, # noqa: F401
|
| 1107 |
+
create_branch, # noqa: F401
|
| 1108 |
+
create_collection, # noqa: F401
|
| 1109 |
+
create_commit, # noqa: F401
|
| 1110 |
+
create_discussion, # noqa: F401
|
| 1111 |
+
create_inference_endpoint, # noqa: F401
|
| 1112 |
+
create_pull_request, # noqa: F401
|
| 1113 |
+
create_repo, # noqa: F401
|
| 1114 |
+
create_tag, # noqa: F401
|
| 1115 |
+
create_webhook, # noqa: F401
|
| 1116 |
+
dataset_info, # noqa: F401
|
| 1117 |
+
delete_branch, # noqa: F401
|
| 1118 |
+
delete_collection, # noqa: F401
|
| 1119 |
+
delete_collection_item, # noqa: F401
|
| 1120 |
+
delete_file, # noqa: F401
|
| 1121 |
+
delete_folder, # noqa: F401
|
| 1122 |
+
delete_inference_endpoint, # noqa: F401
|
| 1123 |
+
delete_repo, # noqa: F401
|
| 1124 |
+
delete_space_secret, # noqa: F401
|
| 1125 |
+
delete_space_storage, # noqa: F401
|
| 1126 |
+
delete_space_variable, # noqa: F401
|
| 1127 |
+
delete_tag, # noqa: F401
|
| 1128 |
+
delete_webhook, # noqa: F401
|
| 1129 |
+
disable_webhook, # noqa: F401
|
| 1130 |
+
duplicate_space, # noqa: F401
|
| 1131 |
+
edit_discussion_comment, # noqa: F401
|
| 1132 |
+
enable_webhook, # noqa: F401
|
| 1133 |
+
file_exists, # noqa: F401
|
| 1134 |
+
get_collection, # noqa: F401
|
| 1135 |
+
get_dataset_tags, # noqa: F401
|
| 1136 |
+
get_discussion_details, # noqa: F401
|
| 1137 |
+
get_full_repo_name, # noqa: F401
|
| 1138 |
+
get_inference_endpoint, # noqa: F401
|
| 1139 |
+
get_model_tags, # noqa: F401
|
| 1140 |
+
get_paths_info, # noqa: F401
|
| 1141 |
+
get_repo_discussions, # noqa: F401
|
| 1142 |
+
get_safetensors_metadata, # noqa: F401
|
| 1143 |
+
get_space_runtime, # noqa: F401
|
| 1144 |
+
get_space_variables, # noqa: F401
|
| 1145 |
+
get_token_permission, # noqa: F401
|
| 1146 |
+
get_user_overview, # noqa: F401
|
| 1147 |
+
get_webhook, # noqa: F401
|
| 1148 |
+
grant_access, # noqa: F401
|
| 1149 |
+
list_accepted_access_requests, # noqa: F401
|
| 1150 |
+
list_collections, # noqa: F401
|
| 1151 |
+
list_datasets, # noqa: F401
|
| 1152 |
+
list_inference_endpoints, # noqa: F401
|
| 1153 |
+
list_liked_repos, # noqa: F401
|
| 1154 |
+
list_models, # noqa: F401
|
| 1155 |
+
list_organization_members, # noqa: F401
|
| 1156 |
+
list_papers, # noqa: F401
|
| 1157 |
+
list_pending_access_requests, # noqa: F401
|
| 1158 |
+
list_rejected_access_requests, # noqa: F401
|
| 1159 |
+
list_repo_commits, # noqa: F401
|
| 1160 |
+
list_repo_files, # noqa: F401
|
| 1161 |
+
list_repo_likers, # noqa: F401
|
| 1162 |
+
list_repo_refs, # noqa: F401
|
| 1163 |
+
list_repo_tree, # noqa: F401
|
| 1164 |
+
list_spaces, # noqa: F401
|
| 1165 |
+
list_user_followers, # noqa: F401
|
| 1166 |
+
list_user_following, # noqa: F401
|
| 1167 |
+
list_webhooks, # noqa: F401
|
| 1168 |
+
merge_pull_request, # noqa: F401
|
| 1169 |
+
model_info, # noqa: F401
|
| 1170 |
+
move_repo, # noqa: F401
|
| 1171 |
+
paper_info, # noqa: F401
|
| 1172 |
+
parse_safetensors_file_metadata, # noqa: F401
|
| 1173 |
+
pause_inference_endpoint, # noqa: F401
|
| 1174 |
+
pause_space, # noqa: F401
|
| 1175 |
+
preupload_lfs_files, # noqa: F401
|
| 1176 |
+
reject_access_request, # noqa: F401
|
| 1177 |
+
rename_discussion, # noqa: F401
|
| 1178 |
+
repo_exists, # noqa: F401
|
| 1179 |
+
repo_info, # noqa: F401
|
| 1180 |
+
repo_type_and_id_from_hf_id, # noqa: F401
|
| 1181 |
+
request_space_hardware, # noqa: F401
|
| 1182 |
+
request_space_storage, # noqa: F401
|
| 1183 |
+
restart_space, # noqa: F401
|
| 1184 |
+
resume_inference_endpoint, # noqa: F401
|
| 1185 |
+
revision_exists, # noqa: F401
|
| 1186 |
+
run_as_future, # noqa: F401
|
| 1187 |
+
scale_to_zero_inference_endpoint, # noqa: F401
|
| 1188 |
+
set_space_sleep_time, # noqa: F401
|
| 1189 |
+
space_info, # noqa: F401
|
| 1190 |
+
super_squash_history, # noqa: F401
|
| 1191 |
+
unlike, # noqa: F401
|
| 1192 |
+
update_collection_item, # noqa: F401
|
| 1193 |
+
update_collection_metadata, # noqa: F401
|
| 1194 |
+
update_inference_endpoint, # noqa: F401
|
| 1195 |
+
update_repo_settings, # noqa: F401
|
| 1196 |
+
update_repo_visibility, # noqa: F401
|
| 1197 |
+
update_webhook, # noqa: F401
|
| 1198 |
+
upload_file, # noqa: F401
|
| 1199 |
+
upload_folder, # noqa: F401
|
| 1200 |
+
upload_large_folder, # noqa: F401
|
| 1201 |
+
whoami, # noqa: F401
|
| 1202 |
+
)
|
| 1203 |
+
from .hf_file_system import (
|
| 1204 |
+
HfFileSystem, # noqa: F401
|
| 1205 |
+
HfFileSystemFile, # noqa: F401
|
| 1206 |
+
HfFileSystemResolvedPath, # noqa: F401
|
| 1207 |
+
HfFileSystemStreamFile, # noqa: F401
|
| 1208 |
+
)
|
| 1209 |
+
from .hub_mixin import (
|
| 1210 |
+
ModelHubMixin, # noqa: F401
|
| 1211 |
+
PyTorchModelHubMixin, # noqa: F401
|
| 1212 |
+
)
|
| 1213 |
+
from .inference._client import (
|
| 1214 |
+
InferenceClient, # noqa: F401
|
| 1215 |
+
InferenceTimeoutError, # noqa: F401
|
| 1216 |
+
)
|
| 1217 |
+
from .inference._generated._async_client import AsyncInferenceClient # noqa: F401
|
| 1218 |
+
from .inference._generated.types import (
|
| 1219 |
+
AudioClassificationInput, # noqa: F401
|
| 1220 |
+
AudioClassificationOutputElement, # noqa: F401
|
| 1221 |
+
AudioClassificationOutputTransform, # noqa: F401
|
| 1222 |
+
AudioClassificationParameters, # noqa: F401
|
| 1223 |
+
AudioToAudioInput, # noqa: F401
|
| 1224 |
+
AudioToAudioOutputElement, # noqa: F401
|
| 1225 |
+
AutomaticSpeechRecognitionEarlyStoppingEnum, # noqa: F401
|
| 1226 |
+
AutomaticSpeechRecognitionGenerationParameters, # noqa: F401
|
| 1227 |
+
AutomaticSpeechRecognitionInput, # noqa: F401
|
| 1228 |
+
AutomaticSpeechRecognitionOutput, # noqa: F401
|
| 1229 |
+
AutomaticSpeechRecognitionOutputChunk, # noqa: F401
|
| 1230 |
+
AutomaticSpeechRecognitionParameters, # noqa: F401
|
| 1231 |
+
ChatCompletionInput, # noqa: F401
|
| 1232 |
+
ChatCompletionInputFunctionDefinition, # noqa: F401
|
| 1233 |
+
ChatCompletionInputFunctionName, # noqa: F401
|
| 1234 |
+
ChatCompletionInputGrammarType, # noqa: F401
|
| 1235 |
+
ChatCompletionInputGrammarTypeType, # noqa: F401
|
| 1236 |
+
ChatCompletionInputMessage, # noqa: F401
|
| 1237 |
+
ChatCompletionInputMessageChunk, # noqa: F401
|
| 1238 |
+
ChatCompletionInputMessageChunkType, # noqa: F401
|
| 1239 |
+
ChatCompletionInputStreamOptions, # noqa: F401
|
| 1240 |
+
ChatCompletionInputTool, # noqa: F401
|
| 1241 |
+
ChatCompletionInputToolChoiceClass, # noqa: F401
|
| 1242 |
+
ChatCompletionInputToolChoiceEnum, # noqa: F401
|
| 1243 |
+
ChatCompletionInputURL, # noqa: F401
|
| 1244 |
+
ChatCompletionOutput, # noqa: F401
|
| 1245 |
+
ChatCompletionOutputComplete, # noqa: F401
|
| 1246 |
+
ChatCompletionOutputFunctionDefinition, # noqa: F401
|
| 1247 |
+
ChatCompletionOutputLogprob, # noqa: F401
|
| 1248 |
+
ChatCompletionOutputLogprobs, # noqa: F401
|
| 1249 |
+
ChatCompletionOutputMessage, # noqa: F401
|
| 1250 |
+
ChatCompletionOutputToolCall, # noqa: F401
|
| 1251 |
+
ChatCompletionOutputTopLogprob, # noqa: F401
|
| 1252 |
+
ChatCompletionOutputUsage, # noqa: F401
|
| 1253 |
+
ChatCompletionStreamOutput, # noqa: F401
|
| 1254 |
+
ChatCompletionStreamOutputChoice, # noqa: F401
|
| 1255 |
+
ChatCompletionStreamOutputDelta, # noqa: F401
|
| 1256 |
+
ChatCompletionStreamOutputDeltaToolCall, # noqa: F401
|
| 1257 |
+
ChatCompletionStreamOutputFunction, # noqa: F401
|
| 1258 |
+
ChatCompletionStreamOutputLogprob, # noqa: F401
|
| 1259 |
+
ChatCompletionStreamOutputLogprobs, # noqa: F401
|
| 1260 |
+
ChatCompletionStreamOutputTopLogprob, # noqa: F401
|
| 1261 |
+
ChatCompletionStreamOutputUsage, # noqa: F401
|
| 1262 |
+
DepthEstimationInput, # noqa: F401
|
| 1263 |
+
DepthEstimationOutput, # noqa: F401
|
| 1264 |
+
DocumentQuestionAnsweringInput, # noqa: F401
|
| 1265 |
+
DocumentQuestionAnsweringInputData, # noqa: F401
|
| 1266 |
+
DocumentQuestionAnsweringOutputElement, # noqa: F401
|
| 1267 |
+
DocumentQuestionAnsweringParameters, # noqa: F401
|
| 1268 |
+
FeatureExtractionInput, # noqa: F401
|
| 1269 |
+
FeatureExtractionInputTruncationDirection, # noqa: F401
|
| 1270 |
+
FillMaskInput, # noqa: F401
|
| 1271 |
+
FillMaskOutputElement, # noqa: F401
|
| 1272 |
+
FillMaskParameters, # noqa: F401
|
| 1273 |
+
ImageClassificationInput, # noqa: F401
|
| 1274 |
+
ImageClassificationOutputElement, # noqa: F401
|
| 1275 |
+
ImageClassificationOutputTransform, # noqa: F401
|
| 1276 |
+
ImageClassificationParameters, # noqa: F401
|
| 1277 |
+
ImageSegmentationInput, # noqa: F401
|
| 1278 |
+
ImageSegmentationOutputElement, # noqa: F401
|
| 1279 |
+
ImageSegmentationParameters, # noqa: F401
|
| 1280 |
+
ImageSegmentationSubtask, # noqa: F401
|
| 1281 |
+
ImageToImageInput, # noqa: F401
|
| 1282 |
+
ImageToImageOutput, # noqa: F401
|
| 1283 |
+
ImageToImageParameters, # noqa: F401
|
| 1284 |
+
ImageToImageTargetSize, # noqa: F401
|
| 1285 |
+
ImageToTextEarlyStoppingEnum, # noqa: F401
|
| 1286 |
+
ImageToTextGenerationParameters, # noqa: F401
|
| 1287 |
+
ImageToTextInput, # noqa: F401
|
| 1288 |
+
ImageToTextOutput, # noqa: F401
|
| 1289 |
+
ImageToTextParameters, # noqa: F401
|
| 1290 |
+
ObjectDetectionBoundingBox, # noqa: F401
|
| 1291 |
+
ObjectDetectionInput, # noqa: F401
|
| 1292 |
+
ObjectDetectionOutputElement, # noqa: F401
|
| 1293 |
+
ObjectDetectionParameters, # noqa: F401
|
| 1294 |
+
Padding, # noqa: F401
|
| 1295 |
+
QuestionAnsweringInput, # noqa: F401
|
| 1296 |
+
QuestionAnsweringInputData, # noqa: F401
|
| 1297 |
+
QuestionAnsweringOutputElement, # noqa: F401
|
| 1298 |
+
QuestionAnsweringParameters, # noqa: F401
|
| 1299 |
+
SentenceSimilarityInput, # noqa: F401
|
| 1300 |
+
SentenceSimilarityInputData, # noqa: F401
|
| 1301 |
+
SummarizationInput, # noqa: F401
|
| 1302 |
+
SummarizationOutput, # noqa: F401
|
| 1303 |
+
SummarizationParameters, # noqa: F401
|
| 1304 |
+
SummarizationTruncationStrategy, # noqa: F401
|
| 1305 |
+
TableQuestionAnsweringInput, # noqa: F401
|
| 1306 |
+
TableQuestionAnsweringInputData, # noqa: F401
|
| 1307 |
+
TableQuestionAnsweringOutputElement, # noqa: F401
|
| 1308 |
+
TableQuestionAnsweringParameters, # noqa: F401
|
| 1309 |
+
Text2TextGenerationInput, # noqa: F401
|
| 1310 |
+
Text2TextGenerationOutput, # noqa: F401
|
| 1311 |
+
Text2TextGenerationParameters, # noqa: F401
|
| 1312 |
+
Text2TextGenerationTruncationStrategy, # noqa: F401
|
| 1313 |
+
TextClassificationInput, # noqa: F401
|
| 1314 |
+
TextClassificationOutputElement, # noqa: F401
|
| 1315 |
+
TextClassificationOutputTransform, # noqa: F401
|
| 1316 |
+
TextClassificationParameters, # noqa: F401
|
| 1317 |
+
TextGenerationInput, # noqa: F401
|
| 1318 |
+
TextGenerationInputGenerateParameters, # noqa: F401
|
| 1319 |
+
TextGenerationInputGrammarType, # noqa: F401
|
| 1320 |
+
TextGenerationOutput, # noqa: F401
|
| 1321 |
+
TextGenerationOutputBestOfSequence, # noqa: F401
|
| 1322 |
+
TextGenerationOutputDetails, # noqa: F401
|
| 1323 |
+
TextGenerationOutputFinishReason, # noqa: F401
|
| 1324 |
+
TextGenerationOutputPrefillToken, # noqa: F401
|
| 1325 |
+
TextGenerationOutputToken, # noqa: F401
|
| 1326 |
+
TextGenerationStreamOutput, # noqa: F401
|
| 1327 |
+
TextGenerationStreamOutputStreamDetails, # noqa: F401
|
| 1328 |
+
TextGenerationStreamOutputToken, # noqa: F401
|
| 1329 |
+
TextToAudioEarlyStoppingEnum, # noqa: F401
|
| 1330 |
+
TextToAudioGenerationParameters, # noqa: F401
|
| 1331 |
+
TextToAudioInput, # noqa: F401
|
| 1332 |
+
TextToAudioOutput, # noqa: F401
|
| 1333 |
+
TextToAudioParameters, # noqa: F401
|
| 1334 |
+
TextToImageInput, # noqa: F401
|
| 1335 |
+
TextToImageOutput, # noqa: F401
|
| 1336 |
+
TextToImageParameters, # noqa: F401
|
| 1337 |
+
TextToImageTargetSize, # noqa: F401
|
| 1338 |
+
TextToSpeechEarlyStoppingEnum, # noqa: F401
|
| 1339 |
+
TextToSpeechGenerationParameters, # noqa: F401
|
| 1340 |
+
TextToSpeechInput, # noqa: F401
|
| 1341 |
+
TextToSpeechOutput, # noqa: F401
|
| 1342 |
+
TextToSpeechParameters, # noqa: F401
|
| 1343 |
+
TextToVideoInput, # noqa: F401
|
| 1344 |
+
TextToVideoOutput, # noqa: F401
|
| 1345 |
+
TextToVideoParameters, # noqa: F401
|
| 1346 |
+
TokenClassificationAggregationStrategy, # noqa: F401
|
| 1347 |
+
TokenClassificationInput, # noqa: F401
|
| 1348 |
+
TokenClassificationOutputElement, # noqa: F401
|
| 1349 |
+
TokenClassificationParameters, # noqa: F401
|
| 1350 |
+
TranslationInput, # noqa: F401
|
| 1351 |
+
TranslationOutput, # noqa: F401
|
| 1352 |
+
TranslationParameters, # noqa: F401
|
| 1353 |
+
TranslationTruncationStrategy, # noqa: F401
|
| 1354 |
+
TypeEnum, # noqa: F401
|
| 1355 |
+
VideoClassificationInput, # noqa: F401
|
| 1356 |
+
VideoClassificationOutputElement, # noqa: F401
|
| 1357 |
+
VideoClassificationOutputTransform, # noqa: F401
|
| 1358 |
+
VideoClassificationParameters, # noqa: F401
|
| 1359 |
+
VisualQuestionAnsweringInput, # noqa: F401
|
| 1360 |
+
VisualQuestionAnsweringInputData, # noqa: F401
|
| 1361 |
+
VisualQuestionAnsweringOutputElement, # noqa: F401
|
| 1362 |
+
VisualQuestionAnsweringParameters, # noqa: F401
|
| 1363 |
+
ZeroShotClassificationInput, # noqa: F401
|
| 1364 |
+
ZeroShotClassificationOutputElement, # noqa: F401
|
| 1365 |
+
ZeroShotClassificationParameters, # noqa: F401
|
| 1366 |
+
ZeroShotImageClassificationInput, # noqa: F401
|
| 1367 |
+
ZeroShotImageClassificationOutputElement, # noqa: F401
|
| 1368 |
+
ZeroShotImageClassificationParameters, # noqa: F401
|
| 1369 |
+
ZeroShotObjectDetectionBoundingBox, # noqa: F401
|
| 1370 |
+
ZeroShotObjectDetectionInput, # noqa: F401
|
| 1371 |
+
ZeroShotObjectDetectionOutputElement, # noqa: F401
|
| 1372 |
+
ZeroShotObjectDetectionParameters, # noqa: F401
|
| 1373 |
+
)
|
| 1374 |
+
from .inference_api import InferenceApi # noqa: F401
|
| 1375 |
+
from .keras_mixin import (
|
| 1376 |
+
KerasModelHubMixin, # noqa: F401
|
| 1377 |
+
from_pretrained_keras, # noqa: F401
|
| 1378 |
+
push_to_hub_keras, # noqa: F401
|
| 1379 |
+
save_pretrained_keras, # noqa: F401
|
| 1380 |
+
)
|
| 1381 |
+
from .repocard import (
|
| 1382 |
+
DatasetCard, # noqa: F401
|
| 1383 |
+
ModelCard, # noqa: F401
|
| 1384 |
+
RepoCard, # noqa: F401
|
| 1385 |
+
SpaceCard, # noqa: F401
|
| 1386 |
+
metadata_eval_result, # noqa: F401
|
| 1387 |
+
metadata_load, # noqa: F401
|
| 1388 |
+
metadata_save, # noqa: F401
|
| 1389 |
+
metadata_update, # noqa: F401
|
| 1390 |
+
)
|
| 1391 |
+
from .repocard_data import (
|
| 1392 |
+
CardData, # noqa: F401
|
| 1393 |
+
DatasetCardData, # noqa: F401
|
| 1394 |
+
EvalResult, # noqa: F401
|
| 1395 |
+
ModelCardData, # noqa: F401
|
| 1396 |
+
SpaceCardData, # noqa: F401
|
| 1397 |
+
)
|
| 1398 |
+
from .repository import Repository # noqa: F401
|
| 1399 |
+
from .serialization import (
|
| 1400 |
+
StateDictSplit, # noqa: F401
|
| 1401 |
+
get_tf_storage_size, # noqa: F401
|
| 1402 |
+
get_torch_storage_id, # noqa: F401
|
| 1403 |
+
get_torch_storage_size, # noqa: F401
|
| 1404 |
+
load_state_dict_from_file, # noqa: F401
|
| 1405 |
+
load_torch_model, # noqa: F401
|
| 1406 |
+
save_torch_model, # noqa: F401
|
| 1407 |
+
save_torch_state_dict, # noqa: F401
|
| 1408 |
+
split_state_dict_into_shards_factory, # noqa: F401
|
| 1409 |
+
split_tf_state_dict_into_shards, # noqa: F401
|
| 1410 |
+
split_torch_state_dict_into_shards, # noqa: F401
|
| 1411 |
+
)
|
| 1412 |
+
from .serialization._dduf import (
|
| 1413 |
+
DDUFEntry, # noqa: F401
|
| 1414 |
+
export_entries_as_dduf, # noqa: F401
|
| 1415 |
+
export_folder_as_dduf, # noqa: F401
|
| 1416 |
+
read_dduf_file, # noqa: F401
|
| 1417 |
+
)
|
| 1418 |
+
from .utils import (
|
| 1419 |
+
CachedFileInfo, # noqa: F401
|
| 1420 |
+
CachedRepoInfo, # noqa: F401
|
| 1421 |
+
CachedRevisionInfo, # noqa: F401
|
| 1422 |
+
CacheNotFound, # noqa: F401
|
| 1423 |
+
CorruptedCacheException, # noqa: F401
|
| 1424 |
+
DeleteCacheStrategy, # noqa: F401
|
| 1425 |
+
HFCacheInfo, # noqa: F401
|
| 1426 |
+
HfFolder, # noqa: F401
|
| 1427 |
+
cached_assets_path, # noqa: F401
|
| 1428 |
+
configure_http_backend, # noqa: F401
|
| 1429 |
+
dump_environment_info, # noqa: F401
|
| 1430 |
+
get_session, # noqa: F401
|
| 1431 |
+
get_token, # noqa: F401
|
| 1432 |
+
logging, # noqa: F401
|
| 1433 |
+
scan_cache_dir, # noqa: F401
|
| 1434 |
+
)
|
.venv/lib/python3.11/site-packages/huggingface_hub/_commit_api.py
ADDED
|
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Type definitions and utilities for the `create_commit` API
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import io
|
| 7 |
+
import os
|
| 8 |
+
import warnings
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from contextlib import contextmanager
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from itertools import groupby
|
| 13 |
+
from pathlib import Path, PurePosixPath
|
| 14 |
+
from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
from tqdm.contrib.concurrent import thread_map
|
| 17 |
+
|
| 18 |
+
from . import constants
|
| 19 |
+
from .errors import EntryNotFoundError
|
| 20 |
+
from .file_download import hf_hub_url
|
| 21 |
+
from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info
|
| 22 |
+
from .utils import (
|
| 23 |
+
FORBIDDEN_FOLDERS,
|
| 24 |
+
chunk_iterable,
|
| 25 |
+
get_session,
|
| 26 |
+
hf_raise_for_status,
|
| 27 |
+
logging,
|
| 28 |
+
sha,
|
| 29 |
+
tqdm_stream_file,
|
| 30 |
+
validate_hf_hub_args,
|
| 31 |
+
)
|
| 32 |
+
from .utils import tqdm as hf_tqdm
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if TYPE_CHECKING:
|
| 36 |
+
from .hf_api import RepoFile
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
UploadMode = Literal["lfs", "regular"]
|
| 43 |
+
|
| 44 |
+
# Max is 1,000 per request on the Hub for HfApi.get_paths_info
|
| 45 |
+
# Otherwise we get:
|
| 46 |
+
# HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters
|
| 47 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1503
|
| 48 |
+
FETCH_LFS_BATCH_SIZE = 500
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class CommitOperationDelete:
|
| 53 |
+
"""
|
| 54 |
+
Data structure holding necessary info to delete a file or a folder from a repository
|
| 55 |
+
on the Hub.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
path_in_repo (`str`):
|
| 59 |
+
Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
|
| 60 |
+
for a file or `"checkpoints/1fec34a/"` for a folder.
|
| 61 |
+
is_folder (`bool` or `Literal["auto"]`, *optional*)
|
| 62 |
+
Whether the Delete Operation applies to a folder or not. If "auto", the path
|
| 63 |
+
type (file or folder) is guessed automatically by looking if path ends with
|
| 64 |
+
a "/" (folder) or not (file). To explicitly set the path type, you can set
|
| 65 |
+
`is_folder=True` or `is_folder=False`.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
path_in_repo: str
|
| 69 |
+
is_folder: Union[bool, Literal["auto"]] = "auto"
|
| 70 |
+
|
| 71 |
+
def __post_init__(self):
|
| 72 |
+
self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
|
| 73 |
+
|
| 74 |
+
if self.is_folder == "auto":
|
| 75 |
+
self.is_folder = self.path_in_repo.endswith("/")
|
| 76 |
+
if not isinstance(self.is_folder, bool):
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class CommitOperationCopy:
|
| 84 |
+
"""
|
| 85 |
+
Data structure holding necessary info to copy a file in a repository on the Hub.
|
| 86 |
+
|
| 87 |
+
Limitations:
|
| 88 |
+
- Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it
|
| 89 |
+
- Cross-repository copies are not supported.
|
| 90 |
+
|
| 91 |
+
Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
src_path_in_repo (`str`):
|
| 95 |
+
Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`.
|
| 96 |
+
path_in_repo (`str`):
|
| 97 |
+
Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`.
|
| 98 |
+
src_revision (`str`, *optional*):
|
| 99 |
+
The git revision of the file to be copied. Can be any valid git revision.
|
| 100 |
+
Default to the target commit revision.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
src_path_in_repo: str
|
| 104 |
+
path_in_repo: str
|
| 105 |
+
src_revision: Optional[str] = None
|
| 106 |
+
# set to the OID of the file to be copied if it has already been uploaded
|
| 107 |
+
# useful to determine if a commit will be empty or not.
|
| 108 |
+
_src_oid: Optional[str] = None
|
| 109 |
+
# set to the OID of the file to copy to if it has already been uploaded
|
| 110 |
+
# useful to determine if a commit will be empty or not.
|
| 111 |
+
_dest_oid: Optional[str] = None
|
| 112 |
+
|
| 113 |
+
def __post_init__(self):
|
| 114 |
+
self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo)
|
| 115 |
+
self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class CommitOperationAdd:
|
| 120 |
+
"""
|
| 121 |
+
Data structure holding necessary info to upload a file to a repository on the Hub.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
path_in_repo (`str`):
|
| 125 |
+
Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
|
| 126 |
+
path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`):
|
| 127 |
+
Either:
|
| 128 |
+
- a path to a local file (as `str` or `pathlib.Path`) to upload
|
| 129 |
+
- a buffer of bytes (`bytes`) holding the content of the file to upload
|
| 130 |
+
- a "file object" (subclass of `io.BufferedIOBase`), typically obtained
|
| 131 |
+
with `open(path, "rb")`. It must support `seek()` and `tell()` methods.
|
| 132 |
+
|
| 133 |
+
Raises:
|
| 134 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 135 |
+
If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`.
|
| 136 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 137 |
+
If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file.
|
| 138 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 139 |
+
If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both
|
| 140 |
+
`seek()` and `tell()`.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
path_in_repo: str
|
| 144 |
+
path_or_fileobj: Union[str, Path, bytes, BinaryIO]
|
| 145 |
+
upload_info: UploadInfo = field(init=False, repr=False)
|
| 146 |
+
|
| 147 |
+
# Internal attributes
|
| 148 |
+
|
| 149 |
+
# set to "lfs" or "regular" once known
|
| 150 |
+
_upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None)
|
| 151 |
+
|
| 152 |
+
# set to True if .gitignore rules prevent the file from being uploaded as LFS
|
| 153 |
+
# (server-side check)
|
| 154 |
+
_should_ignore: Optional[bool] = field(init=False, repr=False, default=None)
|
| 155 |
+
|
| 156 |
+
# set to the remote OID of the file if it has already been uploaded
|
| 157 |
+
# useful to determine if a commit will be empty or not
|
| 158 |
+
_remote_oid: Optional[str] = field(init=False, repr=False, default=None)
|
| 159 |
+
|
| 160 |
+
# set to True once the file has been uploaded as LFS
|
| 161 |
+
_is_uploaded: bool = field(init=False, repr=False, default=False)
|
| 162 |
+
|
| 163 |
+
# set to True once the file has been committed
|
| 164 |
+
_is_committed: bool = field(init=False, repr=False, default=False)
|
| 165 |
+
|
| 166 |
+
def __post_init__(self) -> None:
|
| 167 |
+
"""Validates `path_or_fileobj` and compute `upload_info`."""
|
| 168 |
+
self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
|
| 169 |
+
|
| 170 |
+
# Validate `path_or_fileobj` value
|
| 171 |
+
if isinstance(self.path_or_fileobj, Path):
|
| 172 |
+
self.path_or_fileobj = str(self.path_or_fileobj)
|
| 173 |
+
if isinstance(self.path_or_fileobj, str):
|
| 174 |
+
path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj))
|
| 175 |
+
if not os.path.isfile(path_or_fileobj):
|
| 176 |
+
raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system")
|
| 177 |
+
elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
|
| 178 |
+
# ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"path_or_fileobj must be either an instance of str, bytes or"
|
| 181 |
+
" io.BufferedIOBase. If you passed a file-like object, make sure it is"
|
| 182 |
+
" in binary mode."
|
| 183 |
+
)
|
| 184 |
+
if isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
| 185 |
+
try:
|
| 186 |
+
self.path_or_fileobj.tell()
|
| 187 |
+
self.path_or_fileobj.seek(0, os.SEEK_CUR)
|
| 188 |
+
except (OSError, AttributeError) as exc:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
"path_or_fileobj is a file-like object but does not implement seek() and tell()"
|
| 191 |
+
) from exc
|
| 192 |
+
|
| 193 |
+
# Compute "upload_info" attribute
|
| 194 |
+
if isinstance(self.path_or_fileobj, str):
|
| 195 |
+
self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
|
| 196 |
+
elif isinstance(self.path_or_fileobj, bytes):
|
| 197 |
+
self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
|
| 198 |
+
else:
|
| 199 |
+
self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
|
| 200 |
+
|
| 201 |
+
@contextmanager
|
| 202 |
+
def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]:
|
| 203 |
+
"""
|
| 204 |
+
A context manager that yields a file-like object allowing to read the underlying
|
| 205 |
+
data behind `path_or_fileobj`.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
with_tqdm (`bool`, *optional*, defaults to `False`):
|
| 209 |
+
If True, iterating over the file object will display a progress bar. Only
|
| 210 |
+
works if the file-like object is a path to a file. Pure bytes and buffers
|
| 211 |
+
are not supported.
|
| 212 |
+
|
| 213 |
+
Example:
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
>>> operation = CommitOperationAdd(
|
| 217 |
+
... path_in_repo="remote/dir/weights.h5",
|
| 218 |
+
... path_or_fileobj="./local/weights.h5",
|
| 219 |
+
... )
|
| 220 |
+
CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5')
|
| 221 |
+
|
| 222 |
+
>>> with operation.as_file() as file:
|
| 223 |
+
... content = file.read()
|
| 224 |
+
|
| 225 |
+
>>> with operation.as_file(with_tqdm=True) as file:
|
| 226 |
+
... while True:
|
| 227 |
+
... data = file.read(1024)
|
| 228 |
+
... if not data:
|
| 229 |
+
... break
|
| 230 |
+
config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
|
| 231 |
+
|
| 232 |
+
>>> with operation.as_file(with_tqdm=True) as file:
|
| 233 |
+
... requests.put(..., data=file)
|
| 234 |
+
config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
|
| 235 |
+
```
|
| 236 |
+
"""
|
| 237 |
+
if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path):
|
| 238 |
+
if with_tqdm:
|
| 239 |
+
with tqdm_stream_file(self.path_or_fileobj) as file:
|
| 240 |
+
yield file
|
| 241 |
+
else:
|
| 242 |
+
with open(self.path_or_fileobj, "rb") as file:
|
| 243 |
+
yield file
|
| 244 |
+
elif isinstance(self.path_or_fileobj, bytes):
|
| 245 |
+
yield io.BytesIO(self.path_or_fileobj)
|
| 246 |
+
elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
| 247 |
+
prev_pos = self.path_or_fileobj.tell()
|
| 248 |
+
yield self.path_or_fileobj
|
| 249 |
+
self.path_or_fileobj.seek(prev_pos, io.SEEK_SET)
|
| 250 |
+
|
| 251 |
+
def b64content(self) -> bytes:
|
| 252 |
+
"""
|
| 253 |
+
The base64-encoded content of `path_or_fileobj`
|
| 254 |
+
|
| 255 |
+
Returns: `bytes`
|
| 256 |
+
"""
|
| 257 |
+
with self.as_file() as file:
|
| 258 |
+
return base64.b64encode(file.read())
|
| 259 |
+
|
| 260 |
+
@property
|
| 261 |
+
def _local_oid(self) -> Optional[str]:
|
| 262 |
+
"""Return the OID of the local file.
|
| 263 |
+
|
| 264 |
+
This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one.
|
| 265 |
+
If the file did not change, we won't upload it again to prevent empty commits.
|
| 266 |
+
|
| 267 |
+
For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref).
|
| 268 |
+
For regular files, the OID corresponds to the SHA1 of the file content.
|
| 269 |
+
Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the
|
| 270 |
+
pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes
|
| 271 |
+
and more convenient client-side.
|
| 272 |
+
"""
|
| 273 |
+
if self._upload_mode is None:
|
| 274 |
+
return None
|
| 275 |
+
elif self._upload_mode == "lfs":
|
| 276 |
+
return self.upload_info.sha256.hex()
|
| 277 |
+
else:
|
| 278 |
+
# Regular file => compute sha1
|
| 279 |
+
# => no need to read by chunk since the file is guaranteed to be <=5MB.
|
| 280 |
+
with self.as_file() as file:
|
| 281 |
+
return sha.git_hash(file.read())
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _validate_path_in_repo(path_in_repo: str) -> str:
|
| 285 |
+
# Validate `path_in_repo` value to prevent a server-side issue
|
| 286 |
+
if path_in_repo.startswith("/"):
|
| 287 |
+
path_in_repo = path_in_repo[1:]
|
| 288 |
+
if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"):
|
| 289 |
+
raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'")
|
| 290 |
+
if path_in_repo.startswith("./"):
|
| 291 |
+
path_in_repo = path_in_repo[2:]
|
| 292 |
+
for forbidden in FORBIDDEN_FOLDERS:
|
| 293 |
+
if any(part == forbidden for part in path_in_repo.split("/")):
|
| 294 |
+
raise ValueError(
|
| 295 |
+
f"Invalid `path_in_repo` in CommitOperation: cannot update files under a '{forbidden}/' folder (path:"
|
| 296 |
+
f" '{path_in_repo}')."
|
| 297 |
+
)
|
| 298 |
+
return path_in_repo
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete]
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None:
|
| 305 |
+
"""
|
| 306 |
+
Warn user when a list of operations is expected to overwrite itself in a single
|
| 307 |
+
commit.
|
| 308 |
+
|
| 309 |
+
Rules:
|
| 310 |
+
- If a filepath is updated by multiple `CommitOperationAdd` operations, a warning
|
| 311 |
+
message is triggered.
|
| 312 |
+
- If a filepath is updated at least once by a `CommitOperationAdd` and then deleted
|
| 313 |
+
by a `CommitOperationDelete`, a warning is triggered.
|
| 314 |
+
- If a `CommitOperationDelete` deletes a filepath that is then updated by a
|
| 315 |
+
`CommitOperationAdd`, no warning is triggered. This is usually useless (no need to
|
| 316 |
+
delete before upload) but can happen if a user deletes an entire folder and then
|
| 317 |
+
add new files to it.
|
| 318 |
+
"""
|
| 319 |
+
nb_additions_per_path: Dict[str, int] = defaultdict(int)
|
| 320 |
+
for operation in operations:
|
| 321 |
+
path_in_repo = operation.path_in_repo
|
| 322 |
+
if isinstance(operation, CommitOperationAdd):
|
| 323 |
+
if nb_additions_per_path[path_in_repo] > 0:
|
| 324 |
+
warnings.warn(
|
| 325 |
+
"About to update multiple times the same file in the same commit:"
|
| 326 |
+
f" '{path_in_repo}'. This can cause undesired inconsistencies in"
|
| 327 |
+
" your repo."
|
| 328 |
+
)
|
| 329 |
+
nb_additions_per_path[path_in_repo] += 1
|
| 330 |
+
for parent in PurePosixPath(path_in_repo).parents:
|
| 331 |
+
# Also keep track of number of updated files per folder
|
| 332 |
+
# => warns if deleting a folder overwrite some contained files
|
| 333 |
+
nb_additions_per_path[str(parent)] += 1
|
| 334 |
+
if isinstance(operation, CommitOperationDelete):
|
| 335 |
+
if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0:
|
| 336 |
+
if operation.is_folder:
|
| 337 |
+
warnings.warn(
|
| 338 |
+
"About to delete a folder containing files that have just been"
|
| 339 |
+
f" updated within the same commit: '{path_in_repo}'. This can"
|
| 340 |
+
" cause undesired inconsistencies in your repo."
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
warnings.warn(
|
| 344 |
+
"About to delete a file that have just been updated within the"
|
| 345 |
+
f" same commit: '{path_in_repo}'. This can cause undesired"
|
| 346 |
+
" inconsistencies in your repo."
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
@validate_hf_hub_args
|
| 351 |
+
def _upload_lfs_files(
|
| 352 |
+
*,
|
| 353 |
+
additions: List[CommitOperationAdd],
|
| 354 |
+
repo_type: str,
|
| 355 |
+
repo_id: str,
|
| 356 |
+
headers: Dict[str, str],
|
| 357 |
+
endpoint: Optional[str] = None,
|
| 358 |
+
num_threads: int = 5,
|
| 359 |
+
revision: Optional[str] = None,
|
| 360 |
+
):
|
| 361 |
+
"""
|
| 362 |
+
Uploads the content of `additions` to the Hub using the large file storage protocol.
|
| 363 |
+
|
| 364 |
+
Relevant external documentation:
|
| 365 |
+
- LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
additions (`List` of `CommitOperationAdd`):
|
| 369 |
+
The files to be uploaded
|
| 370 |
+
repo_type (`str`):
|
| 371 |
+
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
| 372 |
+
repo_id (`str`):
|
| 373 |
+
A namespace (user or an organization) and a repo name separated
|
| 374 |
+
by a `/`.
|
| 375 |
+
headers (`Dict[str, str]`):
|
| 376 |
+
Headers to use for the request, including authorization headers and user agent.
|
| 377 |
+
num_threads (`int`, *optional*):
|
| 378 |
+
The number of concurrent threads to use when uploading. Defaults to 5.
|
| 379 |
+
revision (`str`, *optional*):
|
| 380 |
+
The git revision to upload to.
|
| 381 |
+
|
| 382 |
+
Raises:
|
| 383 |
+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
| 384 |
+
If an upload failed for any reason
|
| 385 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 386 |
+
If the server returns malformed responses
|
| 387 |
+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
| 388 |
+
If the LFS batch endpoint returned an HTTP error.
|
| 389 |
+
"""
|
| 390 |
+
# Step 1: retrieve upload instructions from the LFS batch endpoint.
|
| 391 |
+
# Upload instructions are retrieved by chunk of 256 files to avoid reaching
|
| 392 |
+
# the payload limit.
|
| 393 |
+
batch_actions: List[Dict] = []
|
| 394 |
+
for chunk in chunk_iterable(additions, chunk_size=256):
|
| 395 |
+
batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(
|
| 396 |
+
upload_infos=[op.upload_info for op in chunk],
|
| 397 |
+
repo_id=repo_id,
|
| 398 |
+
repo_type=repo_type,
|
| 399 |
+
revision=revision,
|
| 400 |
+
endpoint=endpoint,
|
| 401 |
+
headers=headers,
|
| 402 |
+
token=None, # already passed in 'headers'
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# If at least 1 error, we do not retrieve information for other chunks
|
| 406 |
+
if batch_errors_chunk:
|
| 407 |
+
message = "\n".join(
|
| 408 |
+
[
|
| 409 |
+
f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}"
|
| 410 |
+
for err in batch_errors_chunk
|
| 411 |
+
]
|
| 412 |
+
)
|
| 413 |
+
raise ValueError(f"LFS batch endpoint returned errors:\n{message}")
|
| 414 |
+
|
| 415 |
+
batch_actions += batch_actions_chunk
|
| 416 |
+
oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions}
|
| 417 |
+
|
| 418 |
+
# Step 2: ignore files that have already been uploaded
|
| 419 |
+
filtered_actions = []
|
| 420 |
+
for action in batch_actions:
|
| 421 |
+
if action.get("actions") is None:
|
| 422 |
+
logger.debug(
|
| 423 |
+
f"Content of file {oid2addop[action['oid']].path_in_repo} is already"
|
| 424 |
+
" present upstream - skipping upload."
|
| 425 |
+
)
|
| 426 |
+
else:
|
| 427 |
+
filtered_actions.append(action)
|
| 428 |
+
|
| 429 |
+
if len(filtered_actions) == 0:
|
| 430 |
+
logger.debug("No LFS files to upload.")
|
| 431 |
+
return
|
| 432 |
+
|
| 433 |
+
# Step 3: upload files concurrently according to these instructions
|
| 434 |
+
def _wrapped_lfs_upload(batch_action) -> None:
|
| 435 |
+
try:
|
| 436 |
+
operation = oid2addop[batch_action["oid"]]
|
| 437 |
+
lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers, endpoint=endpoint)
|
| 438 |
+
except Exception as exc:
|
| 439 |
+
raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc
|
| 440 |
+
|
| 441 |
+
if constants.HF_HUB_ENABLE_HF_TRANSFER:
|
| 442 |
+
logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.")
|
| 443 |
+
for action in hf_tqdm(filtered_actions, name="huggingface_hub.lfs_upload"):
|
| 444 |
+
_wrapped_lfs_upload(action)
|
| 445 |
+
elif len(filtered_actions) == 1:
|
| 446 |
+
logger.debug("Uploading 1 LFS file to the Hub")
|
| 447 |
+
_wrapped_lfs_upload(filtered_actions[0])
|
| 448 |
+
else:
|
| 449 |
+
logger.debug(
|
| 450 |
+
f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently"
|
| 451 |
+
)
|
| 452 |
+
thread_map(
|
| 453 |
+
_wrapped_lfs_upload,
|
| 454 |
+
filtered_actions,
|
| 455 |
+
desc=f"Upload {len(filtered_actions)} LFS files",
|
| 456 |
+
max_workers=num_threads,
|
| 457 |
+
tqdm_class=hf_tqdm,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def _validate_preupload_info(preupload_info: dict):
|
| 462 |
+
files = preupload_info.get("files")
|
| 463 |
+
if not isinstance(files, list):
|
| 464 |
+
raise ValueError("preupload_info is improperly formatted")
|
| 465 |
+
for file_info in files:
|
| 466 |
+
if not (
|
| 467 |
+
isinstance(file_info, dict)
|
| 468 |
+
and isinstance(file_info.get("path"), str)
|
| 469 |
+
and isinstance(file_info.get("uploadMode"), str)
|
| 470 |
+
and (file_info["uploadMode"] in ("lfs", "regular"))
|
| 471 |
+
):
|
| 472 |
+
raise ValueError("preupload_info is improperly formatted:")
|
| 473 |
+
return preupload_info
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
@validate_hf_hub_args
|
| 477 |
+
def _fetch_upload_modes(
|
| 478 |
+
additions: Iterable[CommitOperationAdd],
|
| 479 |
+
repo_type: str,
|
| 480 |
+
repo_id: str,
|
| 481 |
+
headers: Dict[str, str],
|
| 482 |
+
revision: str,
|
| 483 |
+
endpoint: Optional[str] = None,
|
| 484 |
+
create_pr: bool = False,
|
| 485 |
+
gitignore_content: Optional[str] = None,
|
| 486 |
+
) -> None:
|
| 487 |
+
"""
|
| 488 |
+
Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob
|
| 489 |
+
or as git LFS blob. Input `additions` are mutated in-place with the upload mode.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
additions (`Iterable` of :class:`CommitOperationAdd`):
|
| 493 |
+
Iterable of :class:`CommitOperationAdd` describing the files to
|
| 494 |
+
upload to the Hub.
|
| 495 |
+
repo_type (`str`):
|
| 496 |
+
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
| 497 |
+
repo_id (`str`):
|
| 498 |
+
A namespace (user or an organization) and a repo name separated
|
| 499 |
+
by a `/`.
|
| 500 |
+
headers (`Dict[str, str]`):
|
| 501 |
+
Headers to use for the request, including authorization headers and user agent.
|
| 502 |
+
revision (`str`):
|
| 503 |
+
The git revision to upload the files to. Can be any valid git revision.
|
| 504 |
+
gitignore_content (`str`, *optional*):
|
| 505 |
+
The content of the `.gitignore` file to know which files should be ignored. The order of priority
|
| 506 |
+
is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present
|
| 507 |
+
in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub
|
| 508 |
+
(if any).
|
| 509 |
+
Raises:
|
| 510 |
+
[`~utils.HfHubHTTPError`]
|
| 511 |
+
If the Hub API returned an error.
|
| 512 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 513 |
+
If the Hub API response is improperly formatted.
|
| 514 |
+
"""
|
| 515 |
+
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
|
| 516 |
+
|
| 517 |
+
# Fetch upload mode (LFS or regular) chunk by chunk.
|
| 518 |
+
upload_modes: Dict[str, UploadMode] = {}
|
| 519 |
+
should_ignore_info: Dict[str, bool] = {}
|
| 520 |
+
oid_info: Dict[str, Optional[str]] = {}
|
| 521 |
+
|
| 522 |
+
for chunk in chunk_iterable(additions, 256):
|
| 523 |
+
payload: Dict = {
|
| 524 |
+
"files": [
|
| 525 |
+
{
|
| 526 |
+
"path": op.path_in_repo,
|
| 527 |
+
"sample": base64.b64encode(op.upload_info.sample).decode("ascii"),
|
| 528 |
+
"size": op.upload_info.size,
|
| 529 |
+
}
|
| 530 |
+
for op in chunk
|
| 531 |
+
]
|
| 532 |
+
}
|
| 533 |
+
if gitignore_content is not None:
|
| 534 |
+
payload["gitIgnore"] = gitignore_content
|
| 535 |
+
|
| 536 |
+
resp = get_session().post(
|
| 537 |
+
f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
|
| 538 |
+
json=payload,
|
| 539 |
+
headers=headers,
|
| 540 |
+
params={"create_pr": "1"} if create_pr else None,
|
| 541 |
+
)
|
| 542 |
+
hf_raise_for_status(resp)
|
| 543 |
+
preupload_info = _validate_preupload_info(resp.json())
|
| 544 |
+
upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]})
|
| 545 |
+
should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]})
|
| 546 |
+
oid_info.update(**{file["path"]: file.get("oid") for file in preupload_info["files"]})
|
| 547 |
+
|
| 548 |
+
# Set upload mode for each addition operation
|
| 549 |
+
for addition in additions:
|
| 550 |
+
addition._upload_mode = upload_modes[addition.path_in_repo]
|
| 551 |
+
addition._should_ignore = should_ignore_info[addition.path_in_repo]
|
| 552 |
+
addition._remote_oid = oid_info[addition.path_in_repo]
|
| 553 |
+
|
| 554 |
+
# Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented)
|
| 555 |
+
# => empty files are uploaded as "regular" to still allow users to commit them.
|
| 556 |
+
for addition in additions:
|
| 557 |
+
if addition.upload_info.size == 0:
|
| 558 |
+
addition._upload_mode = "regular"
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
@validate_hf_hub_args
|
| 562 |
+
def _fetch_files_to_copy(
|
| 563 |
+
copies: Iterable[CommitOperationCopy],
|
| 564 |
+
repo_type: str,
|
| 565 |
+
repo_id: str,
|
| 566 |
+
headers: Dict[str, str],
|
| 567 |
+
revision: str,
|
| 568 |
+
endpoint: Optional[str] = None,
|
| 569 |
+
) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]:
|
| 570 |
+
"""
|
| 571 |
+
Fetch information about the files to copy.
|
| 572 |
+
|
| 573 |
+
For LFS files, we only need their metadata (file size and sha256) while for regular files
|
| 574 |
+
we need to download the raw content from the Hub.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
copies (`Iterable` of :class:`CommitOperationCopy`):
|
| 578 |
+
Iterable of :class:`CommitOperationCopy` describing the files to
|
| 579 |
+
copy on the Hub.
|
| 580 |
+
repo_type (`str`):
|
| 581 |
+
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
| 582 |
+
repo_id (`str`):
|
| 583 |
+
A namespace (user or an organization) and a repo name separated
|
| 584 |
+
by a `/`.
|
| 585 |
+
headers (`Dict[str, str]`):
|
| 586 |
+
Headers to use for the request, including authorization headers and user agent.
|
| 587 |
+
revision (`str`):
|
| 588 |
+
The git revision to upload the files to. Can be any valid git revision.
|
| 589 |
+
|
| 590 |
+
Returns: `Dict[Tuple[str, Optional[str]], Union[RepoFile, bytes]]]`
|
| 591 |
+
Key is the file path and revision of the file to copy.
|
| 592 |
+
Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files).
|
| 593 |
+
|
| 594 |
+
Raises:
|
| 595 |
+
[`~utils.HfHubHTTPError`]
|
| 596 |
+
If the Hub API returned an error.
|
| 597 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 598 |
+
If the Hub API response is improperly formatted.
|
| 599 |
+
"""
|
| 600 |
+
from .hf_api import HfApi, RepoFolder
|
| 601 |
+
|
| 602 |
+
hf_api = HfApi(endpoint=endpoint, headers=headers)
|
| 603 |
+
files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {}
|
| 604 |
+
# Store (path, revision) -> oid mapping
|
| 605 |
+
oid_info: Dict[Tuple[str, Optional[str]], Optional[str]] = {}
|
| 606 |
+
# 1. Fetch OIDs for destination paths in batches.
|
| 607 |
+
dest_paths = [op.path_in_repo for op in copies]
|
| 608 |
+
for offset in range(0, len(dest_paths), FETCH_LFS_BATCH_SIZE):
|
| 609 |
+
dest_repo_files = hf_api.get_paths_info(
|
| 610 |
+
repo_id=repo_id,
|
| 611 |
+
paths=dest_paths[offset : offset + FETCH_LFS_BATCH_SIZE],
|
| 612 |
+
revision=revision,
|
| 613 |
+
repo_type=repo_type,
|
| 614 |
+
)
|
| 615 |
+
for file in dest_repo_files:
|
| 616 |
+
if not isinstance(file, RepoFolder):
|
| 617 |
+
oid_info[(file.path, revision)] = file.blob_id
|
| 618 |
+
|
| 619 |
+
# 2. Group by source revision and fetch source file info in batches.
|
| 620 |
+
for src_revision, operations in groupby(copies, key=lambda op: op.src_revision):
|
| 621 |
+
operations = list(operations) # type: ignore
|
| 622 |
+
src_paths = [op.src_path_in_repo for op in operations]
|
| 623 |
+
for offset in range(0, len(src_paths), FETCH_LFS_BATCH_SIZE):
|
| 624 |
+
src_repo_files = hf_api.get_paths_info(
|
| 625 |
+
repo_id=repo_id,
|
| 626 |
+
paths=src_paths[offset : offset + FETCH_LFS_BATCH_SIZE],
|
| 627 |
+
revision=src_revision or revision,
|
| 628 |
+
repo_type=repo_type,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
for src_repo_file in src_repo_files:
|
| 632 |
+
if isinstance(src_repo_file, RepoFolder):
|
| 633 |
+
raise NotImplementedError("Copying a folder is not implemented.")
|
| 634 |
+
oid_info[(src_repo_file.path, src_revision)] = src_repo_file.blob_id
|
| 635 |
+
# If it's an LFS file, store the RepoFile object. Otherwise, download raw bytes.
|
| 636 |
+
if src_repo_file.lfs:
|
| 637 |
+
files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file
|
| 638 |
+
else:
|
| 639 |
+
# TODO: (optimization) download regular files to copy concurrently
|
| 640 |
+
url = hf_hub_url(
|
| 641 |
+
endpoint=endpoint,
|
| 642 |
+
repo_type=repo_type,
|
| 643 |
+
repo_id=repo_id,
|
| 644 |
+
revision=src_revision or revision,
|
| 645 |
+
filename=src_repo_file.path,
|
| 646 |
+
)
|
| 647 |
+
response = get_session().get(url, headers=headers)
|
| 648 |
+
hf_raise_for_status(response)
|
| 649 |
+
files_to_copy[(src_repo_file.path, src_revision)] = response.content
|
| 650 |
+
# 3. Ensure all operations found a corresponding file in the Hub
|
| 651 |
+
# and track src/dest OIDs for each operation.
|
| 652 |
+
for operation in operations:
|
| 653 |
+
if (operation.src_path_in_repo, src_revision) not in files_to_copy:
|
| 654 |
+
raise EntryNotFoundError(
|
| 655 |
+
f"Cannot copy {operation.src_path_in_repo} at revision "
|
| 656 |
+
f"{src_revision or revision}: file is missing on repo."
|
| 657 |
+
)
|
| 658 |
+
operation._src_oid = oid_info.get((operation.src_path_in_repo, operation.src_revision))
|
| 659 |
+
operation._dest_oid = oid_info.get((operation.path_in_repo, revision))
|
| 660 |
+
return files_to_copy
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def _prepare_commit_payload(
|
| 664 |
+
operations: Iterable[CommitOperation],
|
| 665 |
+
files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]],
|
| 666 |
+
commit_message: str,
|
| 667 |
+
commit_description: Optional[str] = None,
|
| 668 |
+
parent_commit: Optional[str] = None,
|
| 669 |
+
) -> Iterable[Dict[str, Any]]:
|
| 670 |
+
"""
|
| 671 |
+
Builds the payload to POST to the `/commit` API of the Hub.
|
| 672 |
+
|
| 673 |
+
Payload is returned as an iterator so that it can be streamed as a ndjson in the
|
| 674 |
+
POST request.
|
| 675 |
+
|
| 676 |
+
For more information, see:
|
| 677 |
+
- https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073
|
| 678 |
+
- http://ndjson.org/
|
| 679 |
+
"""
|
| 680 |
+
commit_description = commit_description if commit_description is not None else ""
|
| 681 |
+
|
| 682 |
+
# 1. Send a header item with the commit metadata
|
| 683 |
+
header_value = {"summary": commit_message, "description": commit_description}
|
| 684 |
+
if parent_commit is not None:
|
| 685 |
+
header_value["parentCommit"] = parent_commit
|
| 686 |
+
yield {"key": "header", "value": header_value}
|
| 687 |
+
|
| 688 |
+
nb_ignored_files = 0
|
| 689 |
+
|
| 690 |
+
# 2. Send operations, one per line
|
| 691 |
+
for operation in operations:
|
| 692 |
+
# Skip ignored files
|
| 693 |
+
if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
|
| 694 |
+
logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
|
| 695 |
+
nb_ignored_files += 1
|
| 696 |
+
continue
|
| 697 |
+
|
| 698 |
+
# 2.a. Case adding a regular file
|
| 699 |
+
if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular":
|
| 700 |
+
yield {
|
| 701 |
+
"key": "file",
|
| 702 |
+
"value": {
|
| 703 |
+
"content": operation.b64content().decode(),
|
| 704 |
+
"path": operation.path_in_repo,
|
| 705 |
+
"encoding": "base64",
|
| 706 |
+
},
|
| 707 |
+
}
|
| 708 |
+
# 2.b. Case adding an LFS file
|
| 709 |
+
elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs":
|
| 710 |
+
yield {
|
| 711 |
+
"key": "lfsFile",
|
| 712 |
+
"value": {
|
| 713 |
+
"path": operation.path_in_repo,
|
| 714 |
+
"algo": "sha256",
|
| 715 |
+
"oid": operation.upload_info.sha256.hex(),
|
| 716 |
+
"size": operation.upload_info.size,
|
| 717 |
+
},
|
| 718 |
+
}
|
| 719 |
+
# 2.c. Case deleting a file or folder
|
| 720 |
+
elif isinstance(operation, CommitOperationDelete):
|
| 721 |
+
yield {
|
| 722 |
+
"key": "deletedFolder" if operation.is_folder else "deletedFile",
|
| 723 |
+
"value": {"path": operation.path_in_repo},
|
| 724 |
+
}
|
| 725 |
+
# 2.d. Case copying a file or folder
|
| 726 |
+
elif isinstance(operation, CommitOperationCopy):
|
| 727 |
+
file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)]
|
| 728 |
+
if isinstance(file_to_copy, bytes):
|
| 729 |
+
yield {
|
| 730 |
+
"key": "file",
|
| 731 |
+
"value": {
|
| 732 |
+
"content": base64.b64encode(file_to_copy).decode(),
|
| 733 |
+
"path": operation.path_in_repo,
|
| 734 |
+
"encoding": "base64",
|
| 735 |
+
},
|
| 736 |
+
}
|
| 737 |
+
elif file_to_copy.lfs:
|
| 738 |
+
yield {
|
| 739 |
+
"key": "lfsFile",
|
| 740 |
+
"value": {
|
| 741 |
+
"path": operation.path_in_repo,
|
| 742 |
+
"algo": "sha256",
|
| 743 |
+
"oid": file_to_copy.lfs.sha256,
|
| 744 |
+
},
|
| 745 |
+
}
|
| 746 |
+
else:
|
| 747 |
+
raise ValueError(
|
| 748 |
+
"Malformed files_to_copy (should be raw file content as bytes or RepoFile objects with LFS info."
|
| 749 |
+
)
|
| 750 |
+
# 2.e. Never expected to happen
|
| 751 |
+
else:
|
| 752 |
+
raise ValueError(
|
| 753 |
+
f"Unknown operation to commit. Operation: {operation}. Upload mode:"
|
| 754 |
+
f" {getattr(operation, '_upload_mode', None)}"
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
if nb_ignored_files > 0:
|
| 758 |
+
logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).")
|
.venv/lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from concurrent.futures import Future
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from io import SEEK_END, SEEK_SET, BytesIO
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from threading import Lock, Thread
|
| 10 |
+
from typing import Dict, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
from .hf_api import DEFAULT_IGNORE_PATTERNS, CommitInfo, CommitOperationAdd, HfApi
|
| 13 |
+
from .utils import filter_repo_objects
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class _FileToUpload:
|
| 21 |
+
"""Temporary dataclass to store info about files to upload. Not meant to be used directly."""
|
| 22 |
+
|
| 23 |
+
local_path: Path
|
| 24 |
+
path_in_repo: str
|
| 25 |
+
size_limit: int
|
| 26 |
+
last_modified: float
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CommitScheduler:
|
| 30 |
+
"""
|
| 31 |
+
Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
|
| 32 |
+
|
| 33 |
+
The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
|
| 34 |
+
properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
|
| 35 |
+
with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
|
| 36 |
+
to learn more about how to use it.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
repo_id (`str`):
|
| 40 |
+
The id of the repo to commit to.
|
| 41 |
+
folder_path (`str` or `Path`):
|
| 42 |
+
Path to the local folder to upload regularly.
|
| 43 |
+
every (`int` or `float`, *optional*):
|
| 44 |
+
The number of minutes between each commit. Defaults to 5 minutes.
|
| 45 |
+
path_in_repo (`str`, *optional*):
|
| 46 |
+
Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
|
| 47 |
+
of the repository.
|
| 48 |
+
repo_type (`str`, *optional*):
|
| 49 |
+
The type of the repo to commit to. Defaults to `model`.
|
| 50 |
+
revision (`str`, *optional*):
|
| 51 |
+
The revision of the repo to commit to. Defaults to `main`.
|
| 52 |
+
private (`bool`, *optional*):
|
| 53 |
+
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
|
| 54 |
+
token (`str`, *optional*):
|
| 55 |
+
The token to use to commit to the repo. Defaults to the token saved on the machine.
|
| 56 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
| 57 |
+
If provided, only files matching at least one pattern are uploaded.
|
| 58 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
| 59 |
+
If provided, files matching any of the patterns are not uploaded.
|
| 60 |
+
squash_history (`bool`, *optional*):
|
| 61 |
+
Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
|
| 62 |
+
useful to avoid degraded performances on the repo when it grows too large.
|
| 63 |
+
hf_api (`HfApi`, *optional*):
|
| 64 |
+
The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
|
| 65 |
+
|
| 66 |
+
Example:
|
| 67 |
+
```py
|
| 68 |
+
>>> from pathlib import Path
|
| 69 |
+
>>> from huggingface_hub import CommitScheduler
|
| 70 |
+
|
| 71 |
+
# Scheduler uploads every 10 minutes
|
| 72 |
+
>>> csv_path = Path("watched_folder/data.csv")
|
| 73 |
+
>>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
|
| 74 |
+
|
| 75 |
+
>>> with csv_path.open("a") as f:
|
| 76 |
+
... f.write("first line")
|
| 77 |
+
|
| 78 |
+
# Some time later (...)
|
| 79 |
+
>>> with csv_path.open("a") as f:
|
| 80 |
+
... f.write("second line")
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Example using a context manager:
|
| 84 |
+
```py
|
| 85 |
+
>>> from pathlib import Path
|
| 86 |
+
>>> from huggingface_hub import CommitScheduler
|
| 87 |
+
|
| 88 |
+
>>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
|
| 89 |
+
... csv_path = Path("watched_folder/data.csv")
|
| 90 |
+
... with csv_path.open("a") as f:
|
| 91 |
+
... f.write("first line")
|
| 92 |
+
... (...)
|
| 93 |
+
... with csv_path.open("a") as f:
|
| 94 |
+
... f.write("second line")
|
| 95 |
+
|
| 96 |
+
# Scheduler is now stopped and last commit have been triggered
|
| 97 |
+
```
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
*,
|
| 103 |
+
repo_id: str,
|
| 104 |
+
folder_path: Union[str, Path],
|
| 105 |
+
every: Union[int, float] = 5,
|
| 106 |
+
path_in_repo: Optional[str] = None,
|
| 107 |
+
repo_type: Optional[str] = None,
|
| 108 |
+
revision: Optional[str] = None,
|
| 109 |
+
private: Optional[bool] = None,
|
| 110 |
+
token: Optional[str] = None,
|
| 111 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
| 112 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
| 113 |
+
squash_history: bool = False,
|
| 114 |
+
hf_api: Optional["HfApi"] = None,
|
| 115 |
+
) -> None:
|
| 116 |
+
self.api = hf_api or HfApi(token=token)
|
| 117 |
+
|
| 118 |
+
# Folder
|
| 119 |
+
self.folder_path = Path(folder_path).expanduser().resolve()
|
| 120 |
+
self.path_in_repo = path_in_repo or ""
|
| 121 |
+
self.allow_patterns = allow_patterns
|
| 122 |
+
|
| 123 |
+
if ignore_patterns is None:
|
| 124 |
+
ignore_patterns = []
|
| 125 |
+
elif isinstance(ignore_patterns, str):
|
| 126 |
+
ignore_patterns = [ignore_patterns]
|
| 127 |
+
self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
|
| 128 |
+
|
| 129 |
+
if self.folder_path.is_file():
|
| 130 |
+
raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.")
|
| 131 |
+
self.folder_path.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
|
| 133 |
+
# Repository
|
| 134 |
+
repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True)
|
| 135 |
+
self.repo_id = repo_url.repo_id
|
| 136 |
+
self.repo_type = repo_type
|
| 137 |
+
self.revision = revision
|
| 138 |
+
self.token = token
|
| 139 |
+
|
| 140 |
+
# Keep track of already uploaded files
|
| 141 |
+
self.last_uploaded: Dict[Path, float] = {} # key is local path, value is timestamp
|
| 142 |
+
|
| 143 |
+
# Scheduler
|
| 144 |
+
if not every > 0:
|
| 145 |
+
raise ValueError(f"'every' must be a positive integer, not '{every}'.")
|
| 146 |
+
self.lock = Lock()
|
| 147 |
+
self.every = every
|
| 148 |
+
self.squash_history = squash_history
|
| 149 |
+
|
| 150 |
+
logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.")
|
| 151 |
+
self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
|
| 152 |
+
self._scheduler_thread.start()
|
| 153 |
+
atexit.register(self._push_to_hub)
|
| 154 |
+
|
| 155 |
+
self.__stopped = False
|
| 156 |
+
|
| 157 |
+
def stop(self) -> None:
|
| 158 |
+
"""Stop the scheduler.
|
| 159 |
+
|
| 160 |
+
A stopped scheduler cannot be restarted. Mostly for tests purposes.
|
| 161 |
+
"""
|
| 162 |
+
self.__stopped = True
|
| 163 |
+
|
| 164 |
+
def __enter__(self) -> "CommitScheduler":
|
| 165 |
+
return self
|
| 166 |
+
|
| 167 |
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
| 168 |
+
# Upload last changes before exiting
|
| 169 |
+
self.trigger().result()
|
| 170 |
+
self.stop()
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
def _run_scheduler(self) -> None:
|
| 174 |
+
"""Dumb thread waiting between each scheduled push to Hub."""
|
| 175 |
+
while True:
|
| 176 |
+
self.last_future = self.trigger()
|
| 177 |
+
time.sleep(self.every * 60)
|
| 178 |
+
if self.__stopped:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
def trigger(self) -> Future:
|
| 182 |
+
"""Trigger a `push_to_hub` and return a future.
|
| 183 |
+
|
| 184 |
+
This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
|
| 185 |
+
immediately, without waiting for the next scheduled commit.
|
| 186 |
+
"""
|
| 187 |
+
return self.api.run_as_future(self._push_to_hub)
|
| 188 |
+
|
| 189 |
+
def _push_to_hub(self) -> Optional[CommitInfo]:
|
| 190 |
+
if self.__stopped: # If stopped, already scheduled commits are ignored
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
logger.info("(Background) scheduled commit triggered.")
|
| 194 |
+
try:
|
| 195 |
+
value = self.push_to_hub()
|
| 196 |
+
if self.squash_history:
|
| 197 |
+
logger.info("(Background) squashing repo history.")
|
| 198 |
+
self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision)
|
| 199 |
+
return value
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced
|
| 202 |
+
raise
|
| 203 |
+
|
| 204 |
+
def push_to_hub(self) -> Optional[CommitInfo]:
|
| 205 |
+
"""
|
| 206 |
+
Push folder to the Hub and return the commit info.
|
| 207 |
+
|
| 208 |
+
<Tip warning={true}>
|
| 209 |
+
|
| 210 |
+
This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
|
| 211 |
+
queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
|
| 212 |
+
issues.
|
| 213 |
+
|
| 214 |
+
</Tip>
|
| 215 |
+
|
| 216 |
+
The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
|
| 217 |
+
uploads only changed files. If no changes are found, the method returns without committing anything. If you want
|
| 218 |
+
to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
|
| 219 |
+
for example to compress data together in a single file before committing. For more details and examples, check
|
| 220 |
+
out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
|
| 221 |
+
"""
|
| 222 |
+
# Check files to upload (with lock)
|
| 223 |
+
with self.lock:
|
| 224 |
+
logger.debug("Listing files to upload for scheduled commit.")
|
| 225 |
+
|
| 226 |
+
# List files from folder (taken from `_prepare_upload_folder_additions`)
|
| 227 |
+
relpath_to_abspath = {
|
| 228 |
+
path.relative_to(self.folder_path).as_posix(): path
|
| 229 |
+
for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic
|
| 230 |
+
if path.is_file()
|
| 231 |
+
}
|
| 232 |
+
prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
|
| 233 |
+
|
| 234 |
+
# Filter with pattern + filter out unchanged files + retrieve current file size
|
| 235 |
+
files_to_upload: List[_FileToUpload] = []
|
| 236 |
+
for relpath in filter_repo_objects(
|
| 237 |
+
relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns
|
| 238 |
+
):
|
| 239 |
+
local_path = relpath_to_abspath[relpath]
|
| 240 |
+
stat = local_path.stat()
|
| 241 |
+
if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime:
|
| 242 |
+
files_to_upload.append(
|
| 243 |
+
_FileToUpload(
|
| 244 |
+
local_path=local_path,
|
| 245 |
+
path_in_repo=prefix + relpath,
|
| 246 |
+
size_limit=stat.st_size,
|
| 247 |
+
last_modified=stat.st_mtime,
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Return if nothing to upload
|
| 252 |
+
if len(files_to_upload) == 0:
|
| 253 |
+
logger.debug("Dropping schedule commit: no changed file to upload.")
|
| 254 |
+
return None
|
| 255 |
+
|
| 256 |
+
# Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
|
| 257 |
+
logger.debug("Removing unchanged files since previous scheduled commit.")
|
| 258 |
+
add_operations = [
|
| 259 |
+
CommitOperationAdd(
|
| 260 |
+
# Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
|
| 261 |
+
path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit),
|
| 262 |
+
path_in_repo=file_to_upload.path_in_repo,
|
| 263 |
+
)
|
| 264 |
+
for file_to_upload in files_to_upload
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
# Upload files (append mode expected - no need for lock)
|
| 268 |
+
logger.debug("Uploading files for scheduled commit.")
|
| 269 |
+
commit_info = self.api.create_commit(
|
| 270 |
+
repo_id=self.repo_id,
|
| 271 |
+
repo_type=self.repo_type,
|
| 272 |
+
operations=add_operations,
|
| 273 |
+
commit_message="Scheduled Commit",
|
| 274 |
+
revision=self.revision,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Successful commit: keep track of the latest "last_modified" for each file
|
| 278 |
+
for file in files_to_upload:
|
| 279 |
+
self.last_uploaded[file.local_path] = file.last_modified
|
| 280 |
+
return commit_info
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class PartialFileIO(BytesIO):
|
| 284 |
+
"""A file-like object that reads only the first part of a file.
|
| 285 |
+
|
| 286 |
+
Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
|
| 287 |
+
file is uploaded (i.e. the part that was available when the filesystem was first scanned).
|
| 288 |
+
|
| 289 |
+
In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
|
| 290 |
+
disturbance for the user. The object is passed to `CommitOperationAdd`.
|
| 291 |
+
|
| 292 |
+
Only supports `read`, `tell` and `seek` methods.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
file_path (`str` or `Path`):
|
| 296 |
+
Path to the file to read.
|
| 297 |
+
size_limit (`int`):
|
| 298 |
+
The maximum number of bytes to read from the file. If the file is larger than this, only the first part
|
| 299 |
+
will be read (and uploaded).
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
|
| 303 |
+
self._file_path = Path(file_path)
|
| 304 |
+
self._file = self._file_path.open("rb")
|
| 305 |
+
self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
|
| 306 |
+
|
| 307 |
+
def __del__(self) -> None:
|
| 308 |
+
self._file.close()
|
| 309 |
+
return super().__del__()
|
| 310 |
+
|
| 311 |
+
def __repr__(self) -> str:
|
| 312 |
+
return f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
|
| 313 |
+
|
| 314 |
+
def __len__(self) -> int:
|
| 315 |
+
return self._size_limit
|
| 316 |
+
|
| 317 |
+
def __getattribute__(self, name: str):
|
| 318 |
+
if name.startswith("_") or name in ("read", "tell", "seek"): # only 3 public methods supported
|
| 319 |
+
return super().__getattribute__(name)
|
| 320 |
+
raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
|
| 321 |
+
|
| 322 |
+
def tell(self) -> int:
|
| 323 |
+
"""Return the current file position."""
|
| 324 |
+
return self._file.tell()
|
| 325 |
+
|
| 326 |
+
def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
|
| 327 |
+
"""Change the stream position to the given offset.
|
| 328 |
+
|
| 329 |
+
Behavior is the same as a regular file, except that the position is capped to the size limit.
|
| 330 |
+
"""
|
| 331 |
+
if __whence == SEEK_END:
|
| 332 |
+
# SEEK_END => set from the truncated end
|
| 333 |
+
__offset = len(self) + __offset
|
| 334 |
+
__whence = SEEK_SET
|
| 335 |
+
|
| 336 |
+
pos = self._file.seek(__offset, __whence)
|
| 337 |
+
if pos > self._size_limit:
|
| 338 |
+
return self._file.seek(self._size_limit)
|
| 339 |
+
return pos
|
| 340 |
+
|
| 341 |
+
def read(self, __size: Optional[int] = -1) -> bytes:
|
| 342 |
+
"""Read at most `__size` bytes from the file.
|
| 343 |
+
|
| 344 |
+
Behavior is the same as a regular file, except that it is capped to the size limit.
|
| 345 |
+
"""
|
| 346 |
+
current = self._file.tell()
|
| 347 |
+
if __size is None or __size < 0:
|
| 348 |
+
# Read until file limit
|
| 349 |
+
truncated_size = self._size_limit - current
|
| 350 |
+
else:
|
| 351 |
+
# Read until file limit or __size
|
| 352 |
+
truncated_size = min(__size, self._size_limit - current)
|
| 353 |
+
return self._file.read(truncated_size)
|
.venv/lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import TYPE_CHECKING, Dict, Optional, Union
|
| 6 |
+
|
| 7 |
+
from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError
|
| 8 |
+
|
| 9 |
+
from .inference._client import InferenceClient
|
| 10 |
+
from .inference._generated._async_client import AsyncInferenceClient
|
| 11 |
+
from .utils import get_session, logging, parse_datetime
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from .hf_api import HfApi
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InferenceEndpointStatus(str, Enum):
|
| 22 |
+
PENDING = "pending"
|
| 23 |
+
INITIALIZING = "initializing"
|
| 24 |
+
UPDATING = "updating"
|
| 25 |
+
UPDATE_FAILED = "updateFailed"
|
| 26 |
+
RUNNING = "running"
|
| 27 |
+
PAUSED = "paused"
|
| 28 |
+
FAILED = "failed"
|
| 29 |
+
SCALED_TO_ZERO = "scaledToZero"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class InferenceEndpointType(str, Enum):
|
| 33 |
+
PUBlIC = "public"
|
| 34 |
+
PROTECTED = "protected"
|
| 35 |
+
PRIVATE = "private"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class InferenceEndpoint:
|
| 40 |
+
"""
|
| 41 |
+
Contains information about a deployed Inference Endpoint.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
name (`str`):
|
| 45 |
+
The unique name of the Inference Endpoint.
|
| 46 |
+
namespace (`str`):
|
| 47 |
+
The namespace where the Inference Endpoint is located.
|
| 48 |
+
repository (`str`):
|
| 49 |
+
The name of the model repository deployed on this Inference Endpoint.
|
| 50 |
+
status ([`InferenceEndpointStatus`]):
|
| 51 |
+
The current status of the Inference Endpoint.
|
| 52 |
+
url (`str`, *optional*):
|
| 53 |
+
The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL.
|
| 54 |
+
framework (`str`):
|
| 55 |
+
The machine learning framework used for the model.
|
| 56 |
+
revision (`str`):
|
| 57 |
+
The specific model revision deployed on the Inference Endpoint.
|
| 58 |
+
task (`str`):
|
| 59 |
+
The task associated with the deployed model.
|
| 60 |
+
created_at (`datetime.datetime`):
|
| 61 |
+
The timestamp when the Inference Endpoint was created.
|
| 62 |
+
updated_at (`datetime.datetime`):
|
| 63 |
+
The timestamp of the last update of the Inference Endpoint.
|
| 64 |
+
type ([`InferenceEndpointType`]):
|
| 65 |
+
The type of the Inference Endpoint (public, protected, private).
|
| 66 |
+
raw (`Dict`):
|
| 67 |
+
The raw dictionary data returned from the API.
|
| 68 |
+
token (`str` or `bool`, *optional*):
|
| 69 |
+
Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the
|
| 70 |
+
locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server.
|
| 71 |
+
|
| 72 |
+
Example:
|
| 73 |
+
```python
|
| 74 |
+
>>> from huggingface_hub import get_inference_endpoint
|
| 75 |
+
>>> endpoint = get_inference_endpoint("my-text-to-image")
|
| 76 |
+
>>> endpoint
|
| 77 |
+
InferenceEndpoint(name='my-text-to-image', ...)
|
| 78 |
+
|
| 79 |
+
# Get status
|
| 80 |
+
>>> endpoint.status
|
| 81 |
+
'running'
|
| 82 |
+
>>> endpoint.url
|
| 83 |
+
'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
|
| 84 |
+
|
| 85 |
+
# Run inference
|
| 86 |
+
>>> endpoint.client.text_to_image(...)
|
| 87 |
+
|
| 88 |
+
# Pause endpoint to save $$$
|
| 89 |
+
>>> endpoint.pause()
|
| 90 |
+
|
| 91 |
+
# ...
|
| 92 |
+
# Resume and wait for deployment
|
| 93 |
+
>>> endpoint.resume()
|
| 94 |
+
>>> endpoint.wait()
|
| 95 |
+
>>> endpoint.client.text_to_image(...)
|
| 96 |
+
```
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
# Field in __repr__
|
| 100 |
+
name: str = field(init=False)
|
| 101 |
+
namespace: str
|
| 102 |
+
repository: str = field(init=False)
|
| 103 |
+
status: InferenceEndpointStatus = field(init=False)
|
| 104 |
+
url: Optional[str] = field(init=False)
|
| 105 |
+
|
| 106 |
+
# Other fields
|
| 107 |
+
framework: str = field(repr=False, init=False)
|
| 108 |
+
revision: str = field(repr=False, init=False)
|
| 109 |
+
task: str = field(repr=False, init=False)
|
| 110 |
+
created_at: datetime = field(repr=False, init=False)
|
| 111 |
+
updated_at: datetime = field(repr=False, init=False)
|
| 112 |
+
type: InferenceEndpointType = field(repr=False, init=False)
|
| 113 |
+
|
| 114 |
+
# Raw dict from the API
|
| 115 |
+
raw: Dict = field(repr=False)
|
| 116 |
+
|
| 117 |
+
# Internal fields
|
| 118 |
+
_token: Union[str, bool, None] = field(repr=False, compare=False)
|
| 119 |
+
_api: "HfApi" = field(repr=False, compare=False)
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def from_raw(
|
| 123 |
+
cls, raw: Dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None
|
| 124 |
+
) -> "InferenceEndpoint":
|
| 125 |
+
"""Initialize object from raw dictionary."""
|
| 126 |
+
if api is None:
|
| 127 |
+
from .hf_api import HfApi
|
| 128 |
+
|
| 129 |
+
api = HfApi()
|
| 130 |
+
if token is None:
|
| 131 |
+
token = api.token
|
| 132 |
+
|
| 133 |
+
# All other fields are populated in __post_init__
|
| 134 |
+
return cls(raw=raw, namespace=namespace, _token=token, _api=api)
|
| 135 |
+
|
| 136 |
+
def __post_init__(self) -> None:
|
| 137 |
+
"""Populate fields from raw dictionary."""
|
| 138 |
+
self._populate_from_raw()
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def client(self) -> InferenceClient:
|
| 142 |
+
"""Returns a client to make predictions on this Inference Endpoint.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
[`InferenceClient`]: an inference client pointing to the deployed endpoint.
|
| 146 |
+
|
| 147 |
+
Raises:
|
| 148 |
+
[`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
|
| 149 |
+
"""
|
| 150 |
+
if self.url is None:
|
| 151 |
+
raise InferenceEndpointError(
|
| 152 |
+
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
|
| 153 |
+
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
|
| 154 |
+
)
|
| 155 |
+
return InferenceClient(
|
| 156 |
+
model=self.url,
|
| 157 |
+
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def async_client(self) -> AsyncInferenceClient:
|
| 162 |
+
"""Returns a client to make predictions on this Inference Endpoint.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
[`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint.
|
| 166 |
+
|
| 167 |
+
Raises:
|
| 168 |
+
[`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
|
| 169 |
+
"""
|
| 170 |
+
if self.url is None:
|
| 171 |
+
raise InferenceEndpointError(
|
| 172 |
+
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
|
| 173 |
+
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
|
| 174 |
+
)
|
| 175 |
+
return AsyncInferenceClient(
|
| 176 |
+
model=self.url,
|
| 177 |
+
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint":
|
| 181 |
+
"""Wait for the Inference Endpoint to be deployed.
|
| 182 |
+
|
| 183 |
+
Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout`
|
| 184 |
+
seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest
|
| 185 |
+
data.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
timeout (`int`, *optional*):
|
| 189 |
+
The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait
|
| 190 |
+
indefinitely.
|
| 191 |
+
refresh_every (`int`, *optional*):
|
| 192 |
+
The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
| 196 |
+
|
| 197 |
+
Raises:
|
| 198 |
+
[`InferenceEndpointError`]
|
| 199 |
+
If the Inference Endpoint ended up in a failed state.
|
| 200 |
+
[`InferenceEndpointTimeoutError`]
|
| 201 |
+
If the Inference Endpoint is not deployed after `timeout` seconds.
|
| 202 |
+
"""
|
| 203 |
+
if timeout is not None and timeout < 0:
|
| 204 |
+
raise ValueError("`timeout` cannot be negative.")
|
| 205 |
+
if refresh_every <= 0:
|
| 206 |
+
raise ValueError("`refresh_every` must be positive.")
|
| 207 |
+
|
| 208 |
+
start = time.time()
|
| 209 |
+
while True:
|
| 210 |
+
if self.url is not None:
|
| 211 |
+
# Means the URL is provisioned => check if the endpoint is reachable
|
| 212 |
+
response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token))
|
| 213 |
+
if response.status_code == 200:
|
| 214 |
+
logger.info("Inference Endpoint is ready to be used.")
|
| 215 |
+
return self
|
| 216 |
+
if self.status == InferenceEndpointStatus.FAILED:
|
| 217 |
+
raise InferenceEndpointError(
|
| 218 |
+
f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information."
|
| 219 |
+
)
|
| 220 |
+
if timeout is not None:
|
| 221 |
+
if time.time() - start > timeout:
|
| 222 |
+
raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")
|
| 223 |
+
logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...")
|
| 224 |
+
time.sleep(refresh_every)
|
| 225 |
+
self.fetch()
|
| 226 |
+
|
| 227 |
+
def fetch(self) -> "InferenceEndpoint":
|
| 228 |
+
"""Fetch latest information about the Inference Endpoint.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
| 232 |
+
"""
|
| 233 |
+
obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
|
| 234 |
+
self.raw = obj.raw
|
| 235 |
+
self._populate_from_raw()
|
| 236 |
+
return self
|
| 237 |
+
|
| 238 |
+
def update(
|
| 239 |
+
self,
|
| 240 |
+
*,
|
| 241 |
+
# Compute update
|
| 242 |
+
accelerator: Optional[str] = None,
|
| 243 |
+
instance_size: Optional[str] = None,
|
| 244 |
+
instance_type: Optional[str] = None,
|
| 245 |
+
min_replica: Optional[int] = None,
|
| 246 |
+
max_replica: Optional[int] = None,
|
| 247 |
+
scale_to_zero_timeout: Optional[int] = None,
|
| 248 |
+
# Model update
|
| 249 |
+
repository: Optional[str] = None,
|
| 250 |
+
framework: Optional[str] = None,
|
| 251 |
+
revision: Optional[str] = None,
|
| 252 |
+
task: Optional[str] = None,
|
| 253 |
+
custom_image: Optional[Dict] = None,
|
| 254 |
+
secrets: Optional[Dict[str, str]] = None,
|
| 255 |
+
) -> "InferenceEndpoint":
|
| 256 |
+
"""Update the Inference Endpoint.
|
| 257 |
+
|
| 258 |
+
This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
|
| 259 |
+
optional but at least one must be provided.
|
| 260 |
+
|
| 261 |
+
This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the
|
| 262 |
+
latest data from the server.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
accelerator (`str`, *optional*):
|
| 266 |
+
The hardware accelerator to be used for inference (e.g. `"cpu"`).
|
| 267 |
+
instance_size (`str`, *optional*):
|
| 268 |
+
The size or type of the instance to be used for hosting the model (e.g. `"x4"`).
|
| 269 |
+
instance_type (`str`, *optional*):
|
| 270 |
+
The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`).
|
| 271 |
+
min_replica (`int`, *optional*):
|
| 272 |
+
The minimum number of replicas (instances) to keep running for the Inference Endpoint.
|
| 273 |
+
max_replica (`int`, *optional*):
|
| 274 |
+
The maximum number of replicas (instances) to scale to for the Inference Endpoint.
|
| 275 |
+
scale_to_zero_timeout (`int`, *optional*):
|
| 276 |
+
The duration in minutes before an inactive endpoint is scaled to zero.
|
| 277 |
+
|
| 278 |
+
repository (`str`, *optional*):
|
| 279 |
+
The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
|
| 280 |
+
framework (`str`, *optional*):
|
| 281 |
+
The machine learning framework used for the model (e.g. `"custom"`).
|
| 282 |
+
revision (`str`, *optional*):
|
| 283 |
+
The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
|
| 284 |
+
task (`str`, *optional*):
|
| 285 |
+
The task on which to deploy the model (e.g. `"text-classification"`).
|
| 286 |
+
custom_image (`Dict`, *optional*):
|
| 287 |
+
A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an
|
| 288 |
+
Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples).
|
| 289 |
+
secrets (`Dict[str, str]`, *optional*):
|
| 290 |
+
Secret values to inject in the container environment.
|
| 291 |
+
Returns:
|
| 292 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
| 293 |
+
"""
|
| 294 |
+
# Make API call
|
| 295 |
+
obj = self._api.update_inference_endpoint(
|
| 296 |
+
name=self.name,
|
| 297 |
+
namespace=self.namespace,
|
| 298 |
+
accelerator=accelerator,
|
| 299 |
+
instance_size=instance_size,
|
| 300 |
+
instance_type=instance_type,
|
| 301 |
+
min_replica=min_replica,
|
| 302 |
+
max_replica=max_replica,
|
| 303 |
+
scale_to_zero_timeout=scale_to_zero_timeout,
|
| 304 |
+
repository=repository,
|
| 305 |
+
framework=framework,
|
| 306 |
+
revision=revision,
|
| 307 |
+
task=task,
|
| 308 |
+
custom_image=custom_image,
|
| 309 |
+
secrets=secrets,
|
| 310 |
+
token=self._token, # type: ignore [arg-type]
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Mutate current object
|
| 314 |
+
self.raw = obj.raw
|
| 315 |
+
self._populate_from_raw()
|
| 316 |
+
return self
|
| 317 |
+
|
| 318 |
+
def pause(self) -> "InferenceEndpoint":
|
| 319 |
+
"""Pause the Inference Endpoint.
|
| 320 |
+
|
| 321 |
+
A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`].
|
| 322 |
+
This is different than scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which
|
| 323 |
+
would be automatically restarted when a request is made to it.
|
| 324 |
+
|
| 325 |
+
This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the
|
| 326 |
+
latest data from the server.
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
| 330 |
+
"""
|
| 331 |
+
obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
|
| 332 |
+
self.raw = obj.raw
|
| 333 |
+
self._populate_from_raw()
|
| 334 |
+
return self
|
| 335 |
+
|
| 336 |
+
def resume(self, running_ok: bool = True) -> "InferenceEndpoint":
|
| 337 |
+
"""Resume the Inference Endpoint.
|
| 338 |
+
|
| 339 |
+
This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the
|
| 340 |
+
latest data from the server.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
running_ok (`bool`, *optional*):
|
| 344 |
+
If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to
|
| 345 |
+
`True`.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
| 349 |
+
"""
|
| 350 |
+
obj = self._api.resume_inference_endpoint(
|
| 351 |
+
name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token
|
| 352 |
+
) # type: ignore [arg-type]
|
| 353 |
+
self.raw = obj.raw
|
| 354 |
+
self._populate_from_raw()
|
| 355 |
+
return self
|
| 356 |
+
|
| 357 |
+
def scale_to_zero(self) -> "InferenceEndpoint":
|
| 358 |
+
"""Scale Inference Endpoint to zero.
|
| 359 |
+
|
| 360 |
+
An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a
|
| 361 |
+
cold start delay. This is different than pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which
|
| 362 |
+
would require a manual resume with [`InferenceEndpoint.resume`].
|
| 363 |
+
|
| 364 |
+
This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the
|
| 365 |
+
latest data from the server.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
| 369 |
+
"""
|
| 370 |
+
obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
|
| 371 |
+
self.raw = obj.raw
|
| 372 |
+
self._populate_from_raw()
|
| 373 |
+
return self
|
| 374 |
+
|
| 375 |
+
def delete(self) -> None:
|
| 376 |
+
"""Delete the Inference Endpoint.
|
| 377 |
+
|
| 378 |
+
This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
|
| 379 |
+
to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`].
|
| 380 |
+
|
| 381 |
+
This is an alias for [`HfApi.delete_inference_endpoint`].
|
| 382 |
+
"""
|
| 383 |
+
self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
|
| 384 |
+
|
| 385 |
+
def _populate_from_raw(self) -> None:
|
| 386 |
+
"""Populate fields from raw dictionary.
|
| 387 |
+
|
| 388 |
+
Called in __post_init__ + each time the Inference Endpoint is updated.
|
| 389 |
+
"""
|
| 390 |
+
# Repr fields
|
| 391 |
+
self.name = self.raw["name"]
|
| 392 |
+
self.repository = self.raw["model"]["repository"]
|
| 393 |
+
self.status = self.raw["status"]["state"]
|
| 394 |
+
self.url = self.raw["status"].get("url")
|
| 395 |
+
|
| 396 |
+
# Other fields
|
| 397 |
+
self.framework = self.raw["model"]["framework"]
|
| 398 |
+
self.revision = self.raw["model"]["revision"]
|
| 399 |
+
self.task = self.raw["model"]["task"]
|
| 400 |
+
self.created_at = parse_datetime(self.raw["status"]["createdAt"])
|
| 401 |
+
self.updated_at = parse_datetime(self.raw["status"]["updatedAt"])
|
| 402 |
+
self.type = self.raw["type"]
|
.venv/lib/python3.11/site-packages/huggingface_hub/_local_folder.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024-present, the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Contains utilities to handle the `../.cache/huggingface` folder in local directories.
|
| 16 |
+
|
| 17 |
+
First discussed in https://github.com/huggingface/huggingface_hub/issues/1738 to store
|
| 18 |
+
download metadata when downloading files from the hub to a local directory (without
|
| 19 |
+
using the cache).
|
| 20 |
+
|
| 21 |
+
./.cache/huggingface folder structure:
|
| 22 |
+
[4.0K] data
|
| 23 |
+
├── [4.0K] .cache
|
| 24 |
+
│ └── [4.0K] huggingface
|
| 25 |
+
│ └── [4.0K] download
|
| 26 |
+
│ ├── [ 16] file.parquet.metadata
|
| 27 |
+
│ ├── [ 16] file.txt.metadata
|
| 28 |
+
│ └── [4.0K] folder
|
| 29 |
+
│ └── [ 16] file.parquet.metadata
|
| 30 |
+
│
|
| 31 |
+
├── [6.5G] file.parquet
|
| 32 |
+
├── [1.5K] file.txt
|
| 33 |
+
└── [4.0K] folder
|
| 34 |
+
└── [ 16] file.parquet
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Download metadata file structure:
|
| 38 |
+
```
|
| 39 |
+
# file.txt.metadata
|
| 40 |
+
11c5a3d5811f50298f278a704980280950aedb10
|
| 41 |
+
a16a55fda99d2f2e7b69cce5cf93ff4ad3049930
|
| 42 |
+
1712656091.123
|
| 43 |
+
|
| 44 |
+
# file.parquet.metadata
|
| 45 |
+
11c5a3d5811f50298f278a704980280950aedb10
|
| 46 |
+
7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421
|
| 47 |
+
1712656091.123
|
| 48 |
+
}
|
| 49 |
+
```
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import base64
|
| 53 |
+
import hashlib
|
| 54 |
+
import logging
|
| 55 |
+
import os
|
| 56 |
+
import time
|
| 57 |
+
from dataclasses import dataclass
|
| 58 |
+
from pathlib import Path
|
| 59 |
+
from typing import Optional
|
| 60 |
+
|
| 61 |
+
from .utils import WeakFileLock
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
logger = logging.getLogger(__name__)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class LocalDownloadFilePaths:
|
| 69 |
+
"""
|
| 70 |
+
Paths to the files related to a download process in a local dir.
|
| 71 |
+
|
| 72 |
+
Returned by [`get_local_download_paths`].
|
| 73 |
+
|
| 74 |
+
Attributes:
|
| 75 |
+
file_path (`Path`):
|
| 76 |
+
Path where the file will be saved.
|
| 77 |
+
lock_path (`Path`):
|
| 78 |
+
Path to the lock file used to ensure atomicity when reading/writing metadata.
|
| 79 |
+
metadata_path (`Path`):
|
| 80 |
+
Path to the metadata file.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
file_path: Path
|
| 84 |
+
lock_path: Path
|
| 85 |
+
metadata_path: Path
|
| 86 |
+
|
| 87 |
+
def incomplete_path(self, etag: str) -> Path:
|
| 88 |
+
"""Return the path where a file will be temporarily downloaded before being moved to `file_path`."""
|
| 89 |
+
return self.metadata_path.parent / f"{_short_hash(self.metadata_path.name)}.{etag}.incomplete"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass(frozen=True)
|
| 93 |
+
class LocalUploadFilePaths:
|
| 94 |
+
"""
|
| 95 |
+
Paths to the files related to an upload process in a local dir.
|
| 96 |
+
|
| 97 |
+
Returned by [`get_local_upload_paths`].
|
| 98 |
+
|
| 99 |
+
Attributes:
|
| 100 |
+
path_in_repo (`str`):
|
| 101 |
+
Path of the file in the repo.
|
| 102 |
+
file_path (`Path`):
|
| 103 |
+
Path where the file will be saved.
|
| 104 |
+
lock_path (`Path`):
|
| 105 |
+
Path to the lock file used to ensure atomicity when reading/writing metadata.
|
| 106 |
+
metadata_path (`Path`):
|
| 107 |
+
Path to the metadata file.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
path_in_repo: str
|
| 111 |
+
file_path: Path
|
| 112 |
+
lock_path: Path
|
| 113 |
+
metadata_path: Path
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass
|
| 117 |
+
class LocalDownloadFileMetadata:
|
| 118 |
+
"""
|
| 119 |
+
Metadata about a file in the local directory related to a download process.
|
| 120 |
+
|
| 121 |
+
Attributes:
|
| 122 |
+
filename (`str`):
|
| 123 |
+
Path of the file in the repo.
|
| 124 |
+
commit_hash (`str`):
|
| 125 |
+
Commit hash of the file in the repo.
|
| 126 |
+
etag (`str`):
|
| 127 |
+
ETag of the file in the repo. Used to check if the file has changed.
|
| 128 |
+
For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash.
|
| 129 |
+
timestamp (`int`):
|
| 130 |
+
Unix timestamp of when the metadata was saved i.e. when the metadata was accurate.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
filename: str
|
| 134 |
+
commit_hash: str
|
| 135 |
+
etag: str
|
| 136 |
+
timestamp: float
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass
|
| 140 |
+
class LocalUploadFileMetadata:
|
| 141 |
+
"""
|
| 142 |
+
Metadata about a file in the local directory related to an upload process.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
size: int
|
| 146 |
+
|
| 147 |
+
# Default values correspond to "we don't know yet"
|
| 148 |
+
timestamp: Optional[float] = None
|
| 149 |
+
should_ignore: Optional[bool] = None
|
| 150 |
+
sha256: Optional[str] = None
|
| 151 |
+
upload_mode: Optional[str] = None
|
| 152 |
+
is_uploaded: bool = False
|
| 153 |
+
is_committed: bool = False
|
| 154 |
+
|
| 155 |
+
def save(self, paths: LocalUploadFilePaths) -> None:
|
| 156 |
+
"""Save the metadata to disk."""
|
| 157 |
+
with WeakFileLock(paths.lock_path):
|
| 158 |
+
with paths.metadata_path.open("w") as f:
|
| 159 |
+
new_timestamp = time.time()
|
| 160 |
+
f.write(str(new_timestamp) + "\n")
|
| 161 |
+
|
| 162 |
+
f.write(str(self.size)) # never None
|
| 163 |
+
f.write("\n")
|
| 164 |
+
|
| 165 |
+
if self.should_ignore is not None:
|
| 166 |
+
f.write(str(int(self.should_ignore)))
|
| 167 |
+
f.write("\n")
|
| 168 |
+
|
| 169 |
+
if self.sha256 is not None:
|
| 170 |
+
f.write(self.sha256)
|
| 171 |
+
f.write("\n")
|
| 172 |
+
|
| 173 |
+
if self.upload_mode is not None:
|
| 174 |
+
f.write(self.upload_mode)
|
| 175 |
+
f.write("\n")
|
| 176 |
+
|
| 177 |
+
f.write(str(int(self.is_uploaded)) + "\n")
|
| 178 |
+
f.write(str(int(self.is_committed)) + "\n")
|
| 179 |
+
|
| 180 |
+
self.timestamp = new_timestamp
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_local_download_paths(local_dir: Path, filename: str) -> LocalDownloadFilePaths:
|
| 184 |
+
"""Compute paths to the files related to a download process.
|
| 185 |
+
|
| 186 |
+
Folders containing the paths are all guaranteed to exist.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
local_dir (`Path`):
|
| 190 |
+
Path to the local directory in which files are downloaded.
|
| 191 |
+
filename (`str`):
|
| 192 |
+
Path of the file in the repo.
|
| 193 |
+
|
| 194 |
+
Return:
|
| 195 |
+
[`LocalDownloadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path, incomplete_path).
|
| 196 |
+
"""
|
| 197 |
+
# filename is the path in the Hub repository (separated by '/')
|
| 198 |
+
# make sure to have a cross platform transcription
|
| 199 |
+
sanitized_filename = os.path.join(*filename.split("/"))
|
| 200 |
+
if os.name == "nt":
|
| 201 |
+
if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository"
|
| 204 |
+
" owner to rename this file."
|
| 205 |
+
)
|
| 206 |
+
file_path = local_dir / sanitized_filename
|
| 207 |
+
metadata_path = _huggingface_dir(local_dir) / "download" / f"{sanitized_filename}.metadata"
|
| 208 |
+
lock_path = metadata_path.with_suffix(".lock")
|
| 209 |
+
|
| 210 |
+
# Some Windows versions do not allow for paths longer than 255 characters.
|
| 211 |
+
# In this case, we must specify it as an extended path by using the "\\?\" prefix
|
| 212 |
+
if os.name == "nt":
|
| 213 |
+
if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255:
|
| 214 |
+
file_path = Path("\\\\?\\" + os.path.abspath(file_path))
|
| 215 |
+
lock_path = Path("\\\\?\\" + os.path.abspath(lock_path))
|
| 216 |
+
metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path))
|
| 217 |
+
|
| 218 |
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 219 |
+
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
| 220 |
+
return LocalDownloadFilePaths(file_path=file_path, lock_path=lock_path, metadata_path=metadata_path)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePaths:
|
| 224 |
+
"""Compute paths to the files related to an upload process.
|
| 225 |
+
|
| 226 |
+
Folders containing the paths are all guaranteed to exist.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
local_dir (`Path`):
|
| 230 |
+
Path to the local directory that is uploaded.
|
| 231 |
+
filename (`str`):
|
| 232 |
+
Path of the file in the repo.
|
| 233 |
+
|
| 234 |
+
Return:
|
| 235 |
+
[`LocalUploadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path).
|
| 236 |
+
"""
|
| 237 |
+
# filename is the path in the Hub repository (separated by '/')
|
| 238 |
+
# make sure to have a cross platform transcription
|
| 239 |
+
sanitized_filename = os.path.join(*filename.split("/"))
|
| 240 |
+
if os.name == "nt":
|
| 241 |
+
if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename:
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository"
|
| 244 |
+
" owner to rename this file."
|
| 245 |
+
)
|
| 246 |
+
file_path = local_dir / sanitized_filename
|
| 247 |
+
metadata_path = _huggingface_dir(local_dir) / "upload" / f"{sanitized_filename}.metadata"
|
| 248 |
+
lock_path = metadata_path.with_suffix(".lock")
|
| 249 |
+
|
| 250 |
+
# Some Windows versions do not allow for paths longer than 255 characters.
|
| 251 |
+
# In this case, we must specify it as an extended path by using the "\\?\" prefix
|
| 252 |
+
if os.name == "nt":
|
| 253 |
+
if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255:
|
| 254 |
+
file_path = Path("\\\\?\\" + os.path.abspath(file_path))
|
| 255 |
+
lock_path = Path("\\\\?\\" + os.path.abspath(lock_path))
|
| 256 |
+
metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path))
|
| 257 |
+
|
| 258 |
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 259 |
+
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
| 260 |
+
return LocalUploadFilePaths(
|
| 261 |
+
path_in_repo=filename, file_path=file_path, lock_path=lock_path, metadata_path=metadata_path
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def read_download_metadata(local_dir: Path, filename: str) -> Optional[LocalDownloadFileMetadata]:
|
| 266 |
+
"""Read metadata about a file in the local directory related to a download process.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
local_dir (`Path`):
|
| 270 |
+
Path to the local directory in which files are downloaded.
|
| 271 |
+
filename (`str`):
|
| 272 |
+
Path of the file in the repo.
|
| 273 |
+
|
| 274 |
+
Return:
|
| 275 |
+
`[LocalDownloadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise.
|
| 276 |
+
"""
|
| 277 |
+
paths = get_local_download_paths(local_dir, filename)
|
| 278 |
+
with WeakFileLock(paths.lock_path):
|
| 279 |
+
if paths.metadata_path.exists():
|
| 280 |
+
try:
|
| 281 |
+
with paths.metadata_path.open() as f:
|
| 282 |
+
commit_hash = f.readline().strip()
|
| 283 |
+
etag = f.readline().strip()
|
| 284 |
+
timestamp = float(f.readline().strip())
|
| 285 |
+
metadata = LocalDownloadFileMetadata(
|
| 286 |
+
filename=filename,
|
| 287 |
+
commit_hash=commit_hash,
|
| 288 |
+
etag=etag,
|
| 289 |
+
timestamp=timestamp,
|
| 290 |
+
)
|
| 291 |
+
except Exception as e:
|
| 292 |
+
# remove the metadata file if it is corrupted / not the right format
|
| 293 |
+
logger.warning(
|
| 294 |
+
f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue."
|
| 295 |
+
)
|
| 296 |
+
try:
|
| 297 |
+
paths.metadata_path.unlink()
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
# check if the file exists and hasn't been modified since the metadata was saved
|
| 303 |
+
stat = paths.file_path.stat()
|
| 304 |
+
if (
|
| 305 |
+
stat.st_mtime - 1 <= metadata.timestamp
|
| 306 |
+
): # allow 1s difference as stat.st_mtime might not be precise
|
| 307 |
+
return metadata
|
| 308 |
+
logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.")
|
| 309 |
+
except FileNotFoundError:
|
| 310 |
+
# file does not exist => metadata is outdated
|
| 311 |
+
return None
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetadata:
|
| 316 |
+
"""Read metadata about a file in the local directory related to an upload process.
|
| 317 |
+
|
| 318 |
+
TODO: factorize logic with `read_download_metadata`.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
local_dir (`Path`):
|
| 322 |
+
Path to the local directory in which files are downloaded.
|
| 323 |
+
filename (`str`):
|
| 324 |
+
Path of the file in the repo.
|
| 325 |
+
|
| 326 |
+
Return:
|
| 327 |
+
`[LocalUploadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise.
|
| 328 |
+
"""
|
| 329 |
+
paths = get_local_upload_paths(local_dir, filename)
|
| 330 |
+
with WeakFileLock(paths.lock_path):
|
| 331 |
+
if paths.metadata_path.exists():
|
| 332 |
+
try:
|
| 333 |
+
with paths.metadata_path.open() as f:
|
| 334 |
+
timestamp = float(f.readline().strip())
|
| 335 |
+
|
| 336 |
+
size = int(f.readline().strip()) # never None
|
| 337 |
+
|
| 338 |
+
_should_ignore = f.readline().strip()
|
| 339 |
+
should_ignore = None if _should_ignore == "" else bool(int(_should_ignore))
|
| 340 |
+
|
| 341 |
+
_sha256 = f.readline().strip()
|
| 342 |
+
sha256 = None if _sha256 == "" else _sha256
|
| 343 |
+
|
| 344 |
+
_upload_mode = f.readline().strip()
|
| 345 |
+
upload_mode = None if _upload_mode == "" else _upload_mode
|
| 346 |
+
if upload_mode not in (None, "regular", "lfs"):
|
| 347 |
+
raise ValueError(f"Invalid upload mode in metadata {paths.path_in_repo}: {upload_mode}")
|
| 348 |
+
|
| 349 |
+
is_uploaded = bool(int(f.readline().strip()))
|
| 350 |
+
is_committed = bool(int(f.readline().strip()))
|
| 351 |
+
|
| 352 |
+
metadata = LocalUploadFileMetadata(
|
| 353 |
+
timestamp=timestamp,
|
| 354 |
+
size=size,
|
| 355 |
+
should_ignore=should_ignore,
|
| 356 |
+
sha256=sha256,
|
| 357 |
+
upload_mode=upload_mode,
|
| 358 |
+
is_uploaded=is_uploaded,
|
| 359 |
+
is_committed=is_committed,
|
| 360 |
+
)
|
| 361 |
+
except Exception as e:
|
| 362 |
+
# remove the metadata file if it is corrupted / not the right format
|
| 363 |
+
logger.warning(
|
| 364 |
+
f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue."
|
| 365 |
+
)
|
| 366 |
+
try:
|
| 367 |
+
paths.metadata_path.unlink()
|
| 368 |
+
except Exception as e:
|
| 369 |
+
logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
|
| 370 |
+
|
| 371 |
+
# TODO: can we do better?
|
| 372 |
+
if (
|
| 373 |
+
metadata.timestamp is not None
|
| 374 |
+
and metadata.is_uploaded # file was uploaded
|
| 375 |
+
and not metadata.is_committed # but not committed
|
| 376 |
+
and time.time() - metadata.timestamp > 20 * 3600 # and it's been more than 20 hours
|
| 377 |
+
): # => we consider it as garbage-collected by S3
|
| 378 |
+
metadata.is_uploaded = False
|
| 379 |
+
|
| 380 |
+
# check if the file exists and hasn't been modified since the metadata was saved
|
| 381 |
+
try:
|
| 382 |
+
if metadata.timestamp is not None and paths.file_path.stat().st_mtime <= metadata.timestamp:
|
| 383 |
+
return metadata
|
| 384 |
+
logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.")
|
| 385 |
+
except FileNotFoundError:
|
| 386 |
+
# file does not exist => metadata is outdated
|
| 387 |
+
pass
|
| 388 |
+
|
| 389 |
+
# empty metadata => we don't know anything expect its size
|
| 390 |
+
return LocalUploadFileMetadata(size=paths.file_path.stat().st_size)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def write_download_metadata(local_dir: Path, filename: str, commit_hash: str, etag: str) -> None:
|
| 394 |
+
"""Write metadata about a file in the local directory related to a download process.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
local_dir (`Path`):
|
| 398 |
+
Path to the local directory in which files are downloaded.
|
| 399 |
+
"""
|
| 400 |
+
paths = get_local_download_paths(local_dir, filename)
|
| 401 |
+
with WeakFileLock(paths.lock_path):
|
| 402 |
+
with paths.metadata_path.open("w") as f:
|
| 403 |
+
f.write(f"{commit_hash}\n{etag}\n{time.time()}\n")
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _huggingface_dir(local_dir: Path) -> Path:
|
| 407 |
+
"""Return the path to the `.cache/huggingface` directory in a local directory."""
|
| 408 |
+
# Wrap in lru_cache to avoid overwriting the .gitignore file if called multiple times
|
| 409 |
+
path = local_dir / ".cache" / "huggingface"
|
| 410 |
+
path.mkdir(exist_ok=True, parents=True)
|
| 411 |
+
|
| 412 |
+
# Create a .gitignore file in the .cache/huggingface directory if it doesn't exist
|
| 413 |
+
# Should be thread-safe enough like this.
|
| 414 |
+
gitignore = path / ".gitignore"
|
| 415 |
+
gitignore_lock = path / ".gitignore.lock"
|
| 416 |
+
if not gitignore.exists():
|
| 417 |
+
try:
|
| 418 |
+
with WeakFileLock(gitignore_lock, timeout=0.1):
|
| 419 |
+
gitignore.write_text("*")
|
| 420 |
+
except IndexError:
|
| 421 |
+
pass
|
| 422 |
+
except OSError: # TimeoutError, FileNotFoundError, PermissionError, etc.
|
| 423 |
+
pass
|
| 424 |
+
try:
|
| 425 |
+
gitignore_lock.unlink()
|
| 426 |
+
except OSError:
|
| 427 |
+
pass
|
| 428 |
+
return path
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def _short_hash(filename: str) -> str:
|
| 432 |
+
return base64.urlsafe_b64encode(hashlib.sha1(filename.encode()).digest()).decode()
|
.venv/lib/python3.11/site-packages/huggingface_hub/_login.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Contains methods to log in to the Hub."""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import subprocess
|
| 18 |
+
from getpass import getpass
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
from . import constants
|
| 23 |
+
from .commands._cli_utils import ANSI
|
| 24 |
+
from .utils import (
|
| 25 |
+
capture_output,
|
| 26 |
+
get_token,
|
| 27 |
+
is_google_colab,
|
| 28 |
+
is_notebook,
|
| 29 |
+
list_credential_helpers,
|
| 30 |
+
logging,
|
| 31 |
+
run_subprocess,
|
| 32 |
+
set_git_credential,
|
| 33 |
+
unset_git_credential,
|
| 34 |
+
)
|
| 35 |
+
from .utils._auth import (
|
| 36 |
+
_get_token_by_name,
|
| 37 |
+
_get_token_from_environment,
|
| 38 |
+
_get_token_from_file,
|
| 39 |
+
_get_token_from_google_colab,
|
| 40 |
+
_save_stored_tokens,
|
| 41 |
+
_save_token,
|
| 42 |
+
get_stored_tokens,
|
| 43 |
+
)
|
| 44 |
+
from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
_HF_LOGO_ASCII = """
|
| 50 |
+
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
| 51 |
+
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
| 52 |
+
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
| 53 |
+
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
| 54 |
+
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@_deprecate_arguments(
|
| 59 |
+
version="1.0",
|
| 60 |
+
deprecated_args="write_permission",
|
| 61 |
+
custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.",
|
| 62 |
+
)
|
| 63 |
+
@_deprecate_positional_args(version="1.0")
|
| 64 |
+
def login(
|
| 65 |
+
token: Optional[str] = None,
|
| 66 |
+
*,
|
| 67 |
+
add_to_git_credential: bool = False,
|
| 68 |
+
new_session: bool = True,
|
| 69 |
+
write_permission: bool = False,
|
| 70 |
+
) -> None:
|
| 71 |
+
"""Login the machine to access the Hub.
|
| 72 |
+
|
| 73 |
+
The `token` is persisted in cache and set as a git credential. Once done, the machine
|
| 74 |
+
is logged in and the access token will be available across all `huggingface_hub`
|
| 75 |
+
components. If `token` is not provided, it will be prompted to the user either with
|
| 76 |
+
a widget (in a notebook) or via the terminal.
|
| 77 |
+
|
| 78 |
+
To log in from outside of a script, one can also use `huggingface-cli login` which is
|
| 79 |
+
a cli command that wraps [`login`].
|
| 80 |
+
|
| 81 |
+
<Tip>
|
| 82 |
+
|
| 83 |
+
[`login`] is a drop-in replacement method for [`notebook_login`] as it wraps and
|
| 84 |
+
extends its capabilities.
|
| 85 |
+
|
| 86 |
+
</Tip>
|
| 87 |
+
|
| 88 |
+
<Tip>
|
| 89 |
+
|
| 90 |
+
When the token is not passed, [`login`] will automatically detect if the script runs
|
| 91 |
+
in a notebook or not. However, this detection might not be accurate due to the
|
| 92 |
+
variety of notebooks that exists nowadays. If that is the case, you can always force
|
| 93 |
+
the UI by using [`notebook_login`] or [`interpreter_login`].
|
| 94 |
+
|
| 95 |
+
</Tip>
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
token (`str`, *optional*):
|
| 99 |
+
User access token to generate from https://huggingface.co/settings/token.
|
| 100 |
+
add_to_git_credential (`bool`, defaults to `False`):
|
| 101 |
+
If `True`, token will be set as git credential. If no git credential helper
|
| 102 |
+
is configured, a warning will be displayed to the user. If `token` is `None`,
|
| 103 |
+
the value of `add_to_git_credential` is ignored and will be prompted again
|
| 104 |
+
to the end user.
|
| 105 |
+
new_session (`bool`, defaults to `True`):
|
| 106 |
+
If `True`, will request a token even if one is already saved on the machine.
|
| 107 |
+
write_permission (`bool`):
|
| 108 |
+
Ignored and deprecated argument.
|
| 109 |
+
Raises:
|
| 110 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 111 |
+
If an organization token is passed. Only personal account tokens are valid
|
| 112 |
+
to log in.
|
| 113 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 114 |
+
If token is invalid.
|
| 115 |
+
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
| 116 |
+
If running in a notebook but `ipywidgets` is not installed.
|
| 117 |
+
"""
|
| 118 |
+
if token is not None:
|
| 119 |
+
if not add_to_git_credential:
|
| 120 |
+
logger.info(
|
| 121 |
+
"The token has not been saved to the git credentials helper. Pass "
|
| 122 |
+
"`add_to_git_credential=True` in this function directly or "
|
| 123 |
+
"`--add-to-git-credential` if using via `huggingface-cli` if "
|
| 124 |
+
"you want to set the git credential as well."
|
| 125 |
+
)
|
| 126 |
+
_login(token, add_to_git_credential=add_to_git_credential)
|
| 127 |
+
elif is_notebook():
|
| 128 |
+
notebook_login(new_session=new_session)
|
| 129 |
+
else:
|
| 130 |
+
interpreter_login(new_session=new_session)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def logout(token_name: Optional[str] = None) -> None:
|
| 134 |
+
"""Logout the machine from the Hub.
|
| 135 |
+
|
| 136 |
+
Token is deleted from the machine and removed from git credential.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
token_name (`str`, *optional*):
|
| 140 |
+
Name of the access token to logout from. If `None`, will logout from all saved access tokens.
|
| 141 |
+
Raises:
|
| 142 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError):
|
| 143 |
+
If the access token name is not found.
|
| 144 |
+
"""
|
| 145 |
+
if get_token() is None and not get_stored_tokens(): # No active token and no saved access tokens
|
| 146 |
+
logger.warning("Not logged in!")
|
| 147 |
+
return
|
| 148 |
+
if not token_name:
|
| 149 |
+
# Delete all saved access tokens and token
|
| 150 |
+
for file_path in (constants.HF_TOKEN_PATH, constants.HF_STORED_TOKENS_PATH):
|
| 151 |
+
try:
|
| 152 |
+
Path(file_path).unlink()
|
| 153 |
+
except FileNotFoundError:
|
| 154 |
+
pass
|
| 155 |
+
logger.info("Successfully logged out from all access tokens.")
|
| 156 |
+
else:
|
| 157 |
+
_logout_from_token(token_name)
|
| 158 |
+
logger.info(f"Successfully logged out from access token: {token_name}.")
|
| 159 |
+
|
| 160 |
+
unset_git_credential()
|
| 161 |
+
|
| 162 |
+
# Check if still logged in
|
| 163 |
+
if _get_token_from_google_colab() is not None:
|
| 164 |
+
raise EnvironmentError(
|
| 165 |
+
"You are automatically logged in using a Google Colab secret.\n"
|
| 166 |
+
"To log out, you must unset the `HF_TOKEN` secret in your Colab settings."
|
| 167 |
+
)
|
| 168 |
+
if _get_token_from_environment() is not None:
|
| 169 |
+
raise EnvironmentError(
|
| 170 |
+
"Token has been deleted from your machine but you are still logged in.\n"
|
| 171 |
+
"To log out, you must clear out both `HF_TOKEN` and `HUGGING_FACE_HUB_TOKEN` environment variables."
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def auth_switch(token_name: str, add_to_git_credential: bool = False) -> None:
|
| 176 |
+
"""Switch to a different access token.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
token_name (`str`):
|
| 180 |
+
Name of the access token to switch to.
|
| 181 |
+
add_to_git_credential (`bool`, defaults to `False`):
|
| 182 |
+
If `True`, token will be set as git credential. If no git credential helper
|
| 183 |
+
is configured, a warning will be displayed to the user. If `token` is `None`,
|
| 184 |
+
the value of `add_to_git_credential` is ignored and will be prompted again
|
| 185 |
+
to the end user.
|
| 186 |
+
|
| 187 |
+
Raises:
|
| 188 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError):
|
| 189 |
+
If the access token name is not found.
|
| 190 |
+
"""
|
| 191 |
+
token = _get_token_by_name(token_name)
|
| 192 |
+
if not token:
|
| 193 |
+
raise ValueError(f"Access token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}")
|
| 194 |
+
# Write token to HF_TOKEN_PATH
|
| 195 |
+
_set_active_token(token_name, add_to_git_credential)
|
| 196 |
+
logger.info(f"The current active token is: {token_name}")
|
| 197 |
+
token_from_environment = _get_token_from_environment()
|
| 198 |
+
if token_from_environment is not None and token_from_environment != token:
|
| 199 |
+
logger.warning(
|
| 200 |
+
"The environment variable `HF_TOKEN` is set and will override the access token you've just switched to."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def auth_list() -> None:
|
| 205 |
+
"""List all stored access tokens."""
|
| 206 |
+
tokens = get_stored_tokens()
|
| 207 |
+
|
| 208 |
+
if not tokens:
|
| 209 |
+
logger.info("No access tokens found.")
|
| 210 |
+
return
|
| 211 |
+
# Find current token
|
| 212 |
+
current_token = get_token()
|
| 213 |
+
current_token_name = None
|
| 214 |
+
for token_name in tokens:
|
| 215 |
+
if tokens.get(token_name) == current_token:
|
| 216 |
+
current_token_name = token_name
|
| 217 |
+
# Print header
|
| 218 |
+
max_offset = max(len("token"), max(len(token) for token in tokens)) + 2
|
| 219 |
+
print(f" {{:<{max_offset}}}| {{:<15}}".format("name", "token"))
|
| 220 |
+
print("-" * (max_offset + 2) + "|" + "-" * 15)
|
| 221 |
+
|
| 222 |
+
# Print saved access tokens
|
| 223 |
+
for token_name in tokens:
|
| 224 |
+
token = tokens.get(token_name, "<not set>")
|
| 225 |
+
masked_token = f"{token[:3]}****{token[-4:]}" if token != "<not set>" else token
|
| 226 |
+
is_current = "*" if token == current_token else " "
|
| 227 |
+
|
| 228 |
+
print(f"{is_current} {{:<{max_offset}}}| {{:<15}}".format(token_name, masked_token))
|
| 229 |
+
|
| 230 |
+
if _get_token_from_environment():
|
| 231 |
+
logger.warning(
|
| 232 |
+
"\nNote: Environment variable `HF_TOKEN` is set and is the current active token independently from the stored tokens listed above."
|
| 233 |
+
)
|
| 234 |
+
elif current_token_name is None:
|
| 235 |
+
logger.warning(
|
| 236 |
+
"\nNote: No active token is set and no environment variable `HF_TOKEN` is found. Use `huggingface-cli login` to log in."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
###
|
| 241 |
+
# Interpreter-based login (text)
|
| 242 |
+
###
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@_deprecate_arguments(
|
| 246 |
+
version="1.0",
|
| 247 |
+
deprecated_args="write_permission",
|
| 248 |
+
custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.",
|
| 249 |
+
)
|
| 250 |
+
@_deprecate_positional_args(version="1.0")
|
| 251 |
+
def interpreter_login(*, new_session: bool = True, write_permission: bool = False) -> None:
|
| 252 |
+
"""
|
| 253 |
+
Displays a prompt to log in to the HF website and store the token.
|
| 254 |
+
|
| 255 |
+
This is equivalent to [`login`] without passing a token when not run in a notebook.
|
| 256 |
+
[`interpreter_login`] is useful if you want to force the use of the terminal prompt
|
| 257 |
+
instead of a notebook widget.
|
| 258 |
+
|
| 259 |
+
For more details, see [`login`].
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
new_session (`bool`, defaults to `True`):
|
| 263 |
+
If `True`, will request a token even if one is already saved on the machine.
|
| 264 |
+
write_permission (`bool`):
|
| 265 |
+
Ignored and deprecated argument.
|
| 266 |
+
"""
|
| 267 |
+
if not new_session and get_token() is not None:
|
| 268 |
+
logger.info("User is already logged in.")
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
from .commands.delete_cache import _ask_for_confirmation_no_tui
|
| 272 |
+
|
| 273 |
+
print(_HF_LOGO_ASCII)
|
| 274 |
+
if get_token() is not None:
|
| 275 |
+
logger.info(
|
| 276 |
+
" A token is already saved on your machine. Run `huggingface-cli"
|
| 277 |
+
" whoami` to get more information or `huggingface-cli logout` if you want"
|
| 278 |
+
" to log out."
|
| 279 |
+
)
|
| 280 |
+
logger.info(" Setting a new token will erase the existing one.")
|
| 281 |
+
|
| 282 |
+
logger.info(
|
| 283 |
+
" To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens ."
|
| 284 |
+
)
|
| 285 |
+
if os.name == "nt":
|
| 286 |
+
logger.info("Token can be pasted using 'Right-Click'.")
|
| 287 |
+
token = getpass("Enter your token (input will not be visible): ")
|
| 288 |
+
add_to_git_credential = _ask_for_confirmation_no_tui("Add token as git credential?")
|
| 289 |
+
|
| 290 |
+
_login(token=token, add_to_git_credential=add_to_git_credential)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
###
|
| 294 |
+
# Notebook-based login (widget)
|
| 295 |
+
###
|
| 296 |
+
|
| 297 |
+
NOTEBOOK_LOGIN_PASSWORD_HTML = """<center> <img
|
| 298 |
+
src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
|
| 299 |
+
alt='Hugging Face'> <br> Immediately click login after typing your password or
|
| 300 |
+
it might be stored in plain text in this notebook file. </center>"""
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
NOTEBOOK_LOGIN_TOKEN_HTML_START = """<center> <img
|
| 304 |
+
src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
|
| 305 |
+
alt='Hugging Face'> <br> Copy a token from <a
|
| 306 |
+
href="https://huggingface.co/settings/tokens" target="_blank">your Hugging Face
|
| 307 |
+
tokens page</a> and paste it below. <br> Immediately click login after copying
|
| 308 |
+
your token or it might be stored in plain text in this notebook file. </center>"""
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
NOTEBOOK_LOGIN_TOKEN_HTML_END = """
|
| 312 |
+
<b>Pro Tip:</b> If you don't already have one, you can create a dedicated
|
| 313 |
+
'notebooks' token with 'write' access, that you can then easily reuse for all
|
| 314 |
+
notebooks. </center>"""
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@_deprecate_arguments(
|
| 318 |
+
version="1.0",
|
| 319 |
+
deprecated_args="write_permission",
|
| 320 |
+
custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.",
|
| 321 |
+
)
|
| 322 |
+
@_deprecate_positional_args(version="1.0")
|
| 323 |
+
def notebook_login(*, new_session: bool = True, write_permission: bool = False) -> None:
|
| 324 |
+
"""
|
| 325 |
+
Displays a widget to log in to the HF website and store the token.
|
| 326 |
+
|
| 327 |
+
This is equivalent to [`login`] without passing a token when run in a notebook.
|
| 328 |
+
[`notebook_login`] is useful if you want to force the use of the notebook widget
|
| 329 |
+
instead of a prompt in the terminal.
|
| 330 |
+
|
| 331 |
+
For more details, see [`login`].
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
new_session (`bool`, defaults to `True`):
|
| 335 |
+
If `True`, will request a token even if one is already saved on the machine.
|
| 336 |
+
write_permission (`bool`):
|
| 337 |
+
Ignored and deprecated argument.
|
| 338 |
+
"""
|
| 339 |
+
try:
|
| 340 |
+
import ipywidgets.widgets as widgets # type: ignore
|
| 341 |
+
from IPython.display import display # type: ignore
|
| 342 |
+
except ImportError:
|
| 343 |
+
raise ImportError(
|
| 344 |
+
"The `notebook_login` function can only be used in a notebook (Jupyter or"
|
| 345 |
+
" Colab) and you need the `ipywidgets` module: `pip install ipywidgets`."
|
| 346 |
+
)
|
| 347 |
+
if not new_session and get_token() is not None:
|
| 348 |
+
logger.info("User is already logged in.")
|
| 349 |
+
return
|
| 350 |
+
|
| 351 |
+
box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%")
|
| 352 |
+
|
| 353 |
+
token_widget = widgets.Password(description="Token:")
|
| 354 |
+
git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?")
|
| 355 |
+
token_finish_button = widgets.Button(description="Login")
|
| 356 |
+
|
| 357 |
+
login_token_widget = widgets.VBox(
|
| 358 |
+
[
|
| 359 |
+
widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_START),
|
| 360 |
+
token_widget,
|
| 361 |
+
git_checkbox_widget,
|
| 362 |
+
token_finish_button,
|
| 363 |
+
widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_END),
|
| 364 |
+
],
|
| 365 |
+
layout=box_layout,
|
| 366 |
+
)
|
| 367 |
+
display(login_token_widget)
|
| 368 |
+
|
| 369 |
+
# On click events
|
| 370 |
+
def login_token_event(t):
|
| 371 |
+
"""Event handler for the login button."""
|
| 372 |
+
token = token_widget.value
|
| 373 |
+
add_to_git_credential = git_checkbox_widget.value
|
| 374 |
+
# Erase token and clear value to make sure it's not saved in the notebook.
|
| 375 |
+
token_widget.value = ""
|
| 376 |
+
# Hide inputs
|
| 377 |
+
login_token_widget.children = [widgets.Label("Connecting...")]
|
| 378 |
+
try:
|
| 379 |
+
with capture_output() as captured:
|
| 380 |
+
_login(token, add_to_git_credential=add_to_git_credential)
|
| 381 |
+
message = captured.getvalue()
|
| 382 |
+
except Exception as error:
|
| 383 |
+
message = str(error)
|
| 384 |
+
# Print result (success message or error)
|
| 385 |
+
login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()]
|
| 386 |
+
|
| 387 |
+
token_finish_button.on_click(login_token_event)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
###
|
| 391 |
+
# Login private helpers
|
| 392 |
+
###
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _login(
|
| 396 |
+
token: str,
|
| 397 |
+
add_to_git_credential: bool,
|
| 398 |
+
) -> None:
|
| 399 |
+
from .hf_api import whoami # avoid circular import
|
| 400 |
+
|
| 401 |
+
if token.startswith("api_org"):
|
| 402 |
+
raise ValueError("You must use your personal account token, not an organization token.")
|
| 403 |
+
|
| 404 |
+
token_info = whoami(token)
|
| 405 |
+
permission = token_info["auth"]["accessToken"]["role"]
|
| 406 |
+
logger.info(f"Token is valid (permission: {permission}).")
|
| 407 |
+
|
| 408 |
+
token_name = token_info["auth"]["accessToken"]["displayName"]
|
| 409 |
+
# Store token locally
|
| 410 |
+
_save_token(token=token, token_name=token_name)
|
| 411 |
+
# Set active token
|
| 412 |
+
_set_active_token(token_name=token_name, add_to_git_credential=add_to_git_credential)
|
| 413 |
+
logger.info("Login successful.")
|
| 414 |
+
if _get_token_from_environment():
|
| 415 |
+
logger.warning(
|
| 416 |
+
"Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured."
|
| 417 |
+
)
|
| 418 |
+
else:
|
| 419 |
+
logger.info(f"The current active token is: `{token_name}`")
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def _logout_from_token(token_name: str) -> None:
|
| 423 |
+
"""Logout from a specific access token.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
token_name (`str`):
|
| 427 |
+
The name of the access token to logout from.
|
| 428 |
+
Raises:
|
| 429 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError):
|
| 430 |
+
If the access token name is not found.
|
| 431 |
+
"""
|
| 432 |
+
stored_tokens = get_stored_tokens()
|
| 433 |
+
# If there is no access tokens saved or the access token name is not found, do nothing
|
| 434 |
+
if not stored_tokens or token_name not in stored_tokens:
|
| 435 |
+
return
|
| 436 |
+
|
| 437 |
+
token = stored_tokens.pop(token_name)
|
| 438 |
+
_save_stored_tokens(stored_tokens)
|
| 439 |
+
|
| 440 |
+
if token == _get_token_from_file():
|
| 441 |
+
logger.warning(f"Active token '{token_name}' has been deleted.")
|
| 442 |
+
Path(constants.HF_TOKEN_PATH).unlink(missing_ok=True)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _set_active_token(
|
| 446 |
+
token_name: str,
|
| 447 |
+
add_to_git_credential: bool,
|
| 448 |
+
) -> None:
|
| 449 |
+
"""Set the active access token.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
token_name (`str`):
|
| 453 |
+
The name of the token to set as active.
|
| 454 |
+
"""
|
| 455 |
+
token = _get_token_by_name(token_name)
|
| 456 |
+
if not token:
|
| 457 |
+
raise ValueError(f"Token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}")
|
| 458 |
+
if add_to_git_credential:
|
| 459 |
+
if _is_git_credential_helper_configured():
|
| 460 |
+
set_git_credential(token)
|
| 461 |
+
logger.info(
|
| 462 |
+
"Your token has been saved in your configured git credential helpers"
|
| 463 |
+
+ f" ({','.join(list_credential_helpers())})."
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
logger.warning("Token has not been saved to git credential helper.")
|
| 467 |
+
# Write token to HF_TOKEN_PATH
|
| 468 |
+
path = Path(constants.HF_TOKEN_PATH)
|
| 469 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 470 |
+
path.write_text(token)
|
| 471 |
+
logger.info(f"Your token has been saved to {constants.HF_TOKEN_PATH}")
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def _is_git_credential_helper_configured() -> bool:
|
| 475 |
+
"""Check if a git credential helper is configured.
|
| 476 |
+
|
| 477 |
+
Warns user if not the case (except for Google Colab where "store" is set by default
|
| 478 |
+
by `huggingface_hub`).
|
| 479 |
+
"""
|
| 480 |
+
helpers = list_credential_helpers()
|
| 481 |
+
if len(helpers) > 0:
|
| 482 |
+
return True # Do not warn: at least 1 helper is set
|
| 483 |
+
|
| 484 |
+
# Only in Google Colab to avoid the warning message
|
| 485 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1043#issuecomment-1247010710
|
| 486 |
+
if is_google_colab():
|
| 487 |
+
_set_store_as_git_credential_helper_globally()
|
| 488 |
+
return True # Do not warn: "store" is used by default in Google Colab
|
| 489 |
+
|
| 490 |
+
# Otherwise, warn user
|
| 491 |
+
print(
|
| 492 |
+
ANSI.red(
|
| 493 |
+
"Cannot authenticate through git-credential as no helper is defined on your"
|
| 494 |
+
" machine.\nYou might have to re-authenticate when pushing to the Hugging"
|
| 495 |
+
" Face Hub.\nRun the following command in your terminal in case you want to"
|
| 496 |
+
" set the 'store' credential helper as default.\n\ngit config --global"
|
| 497 |
+
" credential.helper store\n\nRead"
|
| 498 |
+
" https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more"
|
| 499 |
+
" details."
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
return False
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def _set_store_as_git_credential_helper_globally() -> None:
|
| 506 |
+
"""Set globally the credential.helper to `store`.
|
| 507 |
+
|
| 508 |
+
To be used only in Google Colab as we assume the user doesn't care about the git
|
| 509 |
+
credential config. It is the only particular case where we don't want to display the
|
| 510 |
+
warning message in [`notebook_login()`].
|
| 511 |
+
|
| 512 |
+
Related:
|
| 513 |
+
- https://github.com/huggingface/huggingface_hub/issues/1043
|
| 514 |
+
- https://github.com/huggingface/huggingface_hub/issues/1051
|
| 515 |
+
- https://git-scm.com/docs/git-credential-store
|
| 516 |
+
"""
|
| 517 |
+
try:
|
| 518 |
+
run_subprocess("git config --global credential.helper store")
|
| 519 |
+
except subprocess.CalledProcessError as exc:
|
| 520 |
+
raise EnvironmentError(exc.stderr)
|
.venv/lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List, Literal, Optional, Union
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
from tqdm.auto import tqdm as base_tqdm
|
| 7 |
+
from tqdm.contrib.concurrent import thread_map
|
| 8 |
+
|
| 9 |
+
from . import constants
|
| 10 |
+
from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
| 11 |
+
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
|
| 12 |
+
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
|
| 13 |
+
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
|
| 14 |
+
from .utils import tqdm as hf_tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@validate_hf_hub_args
|
| 21 |
+
def snapshot_download(
|
| 22 |
+
repo_id: str,
|
| 23 |
+
*,
|
| 24 |
+
repo_type: Optional[str] = None,
|
| 25 |
+
revision: Optional[str] = None,
|
| 26 |
+
cache_dir: Union[str, Path, None] = None,
|
| 27 |
+
local_dir: Union[str, Path, None] = None,
|
| 28 |
+
library_name: Optional[str] = None,
|
| 29 |
+
library_version: Optional[str] = None,
|
| 30 |
+
user_agent: Optional[Union[Dict, str]] = None,
|
| 31 |
+
proxies: Optional[Dict] = None,
|
| 32 |
+
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
|
| 33 |
+
force_download: bool = False,
|
| 34 |
+
token: Optional[Union[bool, str]] = None,
|
| 35 |
+
local_files_only: bool = False,
|
| 36 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
| 37 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
| 38 |
+
max_workers: int = 8,
|
| 39 |
+
tqdm_class: Optional[base_tqdm] = None,
|
| 40 |
+
headers: Optional[Dict[str, str]] = None,
|
| 41 |
+
endpoint: Optional[str] = None,
|
| 42 |
+
# Deprecated args
|
| 43 |
+
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
| 44 |
+
resume_download: Optional[bool] = None,
|
| 45 |
+
) -> str:
|
| 46 |
+
"""Download repo files.
|
| 47 |
+
|
| 48 |
+
Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
|
| 49 |
+
a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
|
| 50 |
+
to keep their actual filename relative to that folder. You can also filter which files to download using
|
| 51 |
+
`allow_patterns` and `ignore_patterns`.
|
| 52 |
+
|
| 53 |
+
If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
|
| 54 |
+
option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
|
| 55 |
+
to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
|
| 56 |
+
cache-system, it's optimized for regularly pulling the latest version of a repository.
|
| 57 |
+
|
| 58 |
+
An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
|
| 59 |
+
configured. It is also not possible to filter which files to download when cloning a repository using git.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
repo_id (`str`):
|
| 63 |
+
A user or an organization name and a repo name separated by a `/`.
|
| 64 |
+
repo_type (`str`, *optional*):
|
| 65 |
+
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
|
| 66 |
+
`None` or `"model"` if downloading from a model. Default is `None`.
|
| 67 |
+
revision (`str`, *optional*):
|
| 68 |
+
An optional Git revision id which can be a branch name, a tag, or a
|
| 69 |
+
commit hash.
|
| 70 |
+
cache_dir (`str`, `Path`, *optional*):
|
| 71 |
+
Path to the folder where cached files are stored.
|
| 72 |
+
local_dir (`str` or `Path`, *optional*):
|
| 73 |
+
If provided, the downloaded files will be placed under this directory.
|
| 74 |
+
library_name (`str`, *optional*):
|
| 75 |
+
The name of the library to which the object corresponds.
|
| 76 |
+
library_version (`str`, *optional*):
|
| 77 |
+
The version of the library.
|
| 78 |
+
user_agent (`str`, `dict`, *optional*):
|
| 79 |
+
The user-agent info in the form of a dictionary or a string.
|
| 80 |
+
proxies (`dict`, *optional*):
|
| 81 |
+
Dictionary mapping protocol to the URL of the proxy passed to
|
| 82 |
+
`requests.request`.
|
| 83 |
+
etag_timeout (`float`, *optional*, defaults to `10`):
|
| 84 |
+
When fetching ETag, how many seconds to wait for the server to send
|
| 85 |
+
data before giving up which is passed to `requests.request`.
|
| 86 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 87 |
+
Whether the file should be downloaded even if it already exists in the local cache.
|
| 88 |
+
token (`str`, `bool`, *optional*):
|
| 89 |
+
A token to be used for the download.
|
| 90 |
+
- If `True`, the token is read from the HuggingFace config
|
| 91 |
+
folder.
|
| 92 |
+
- If a string, it's used as the authentication token.
|
| 93 |
+
headers (`dict`, *optional*):
|
| 94 |
+
Additional headers to include in the request. Those headers take precedence over the others.
|
| 95 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 96 |
+
If `True`, avoid downloading the file and return the path to the
|
| 97 |
+
local cached file if it exists.
|
| 98 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
| 99 |
+
If provided, only files matching at least one pattern are downloaded.
|
| 100 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
| 101 |
+
If provided, files matching any of the patterns are not downloaded.
|
| 102 |
+
max_workers (`int`, *optional*):
|
| 103 |
+
Number of concurrent threads to download files (1 thread = 1 file download).
|
| 104 |
+
Defaults to 8.
|
| 105 |
+
tqdm_class (`tqdm`, *optional*):
|
| 106 |
+
If provided, overwrites the default behavior for the progress bar. Passed
|
| 107 |
+
argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
|
| 108 |
+
Note that the `tqdm_class` is not passed to each individual download.
|
| 109 |
+
Defaults to the custom HF progress bar that can be disabled by setting
|
| 110 |
+
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
`str`: folder path of the repo snapshot.
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
[`~utils.RepositoryNotFoundError`]
|
| 117 |
+
If the repository to download from cannot be found. This may be because it doesn't exist,
|
| 118 |
+
or because it is set to `private` and you do not have access.
|
| 119 |
+
[`~utils.RevisionNotFoundError`]
|
| 120 |
+
If the revision to download from cannot be found.
|
| 121 |
+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
| 122 |
+
If `token=True` and the token cannot be found.
|
| 123 |
+
[`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
| 124 |
+
ETag cannot be determined.
|
| 125 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 126 |
+
if some parameter value is invalid.
|
| 127 |
+
"""
|
| 128 |
+
if cache_dir is None:
|
| 129 |
+
cache_dir = constants.HF_HUB_CACHE
|
| 130 |
+
if revision is None:
|
| 131 |
+
revision = constants.DEFAULT_REVISION
|
| 132 |
+
if isinstance(cache_dir, Path):
|
| 133 |
+
cache_dir = str(cache_dir)
|
| 134 |
+
|
| 135 |
+
if repo_type is None:
|
| 136 |
+
repo_type = "model"
|
| 137 |
+
if repo_type not in constants.REPO_TYPES:
|
| 138 |
+
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
|
| 139 |
+
|
| 140 |
+
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
|
| 141 |
+
|
| 142 |
+
repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
|
| 143 |
+
api_call_error: Optional[Exception] = None
|
| 144 |
+
if not local_files_only:
|
| 145 |
+
# try/except logic to handle different errors => taken from `hf_hub_download`
|
| 146 |
+
try:
|
| 147 |
+
# if we have internet connection we want to list files to download
|
| 148 |
+
api = HfApi(
|
| 149 |
+
library_name=library_name,
|
| 150 |
+
library_version=library_version,
|
| 151 |
+
user_agent=user_agent,
|
| 152 |
+
endpoint=endpoint,
|
| 153 |
+
headers=headers,
|
| 154 |
+
)
|
| 155 |
+
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
|
| 156 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
| 157 |
+
# Actually raise for those subclasses of ConnectionError
|
| 158 |
+
raise
|
| 159 |
+
except (
|
| 160 |
+
requests.exceptions.ConnectionError,
|
| 161 |
+
requests.exceptions.Timeout,
|
| 162 |
+
OfflineModeIsEnabled,
|
| 163 |
+
) as error:
|
| 164 |
+
# Internet connection is down
|
| 165 |
+
# => will try to use local files only
|
| 166 |
+
api_call_error = error
|
| 167 |
+
pass
|
| 168 |
+
except RevisionNotFoundError:
|
| 169 |
+
# The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
|
| 170 |
+
raise
|
| 171 |
+
except requests.HTTPError as error:
|
| 172 |
+
# Multiple reasons for an http error:
|
| 173 |
+
# - Repository is private and invalid/missing token sent
|
| 174 |
+
# - Repository is gated and invalid/missing token sent
|
| 175 |
+
# - Hub is down (error 500 or 504)
|
| 176 |
+
# => let's switch to 'local_files_only=True' to check if the files are already cached.
|
| 177 |
+
# (if it's not the case, the error will be re-raised)
|
| 178 |
+
api_call_error = error
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
# At this stage, if `repo_info` is None it means either:
|
| 182 |
+
# - internet connection is down
|
| 183 |
+
# - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
|
| 184 |
+
# - repo is private/gated and invalid/missing token sent
|
| 185 |
+
# - Hub is down
|
| 186 |
+
# => let's look if we can find the appropriate folder in the cache:
|
| 187 |
+
# - if the specified revision is a commit hash, look inside "snapshots".
|
| 188 |
+
# - f the specified revision is a branch or tag, look inside "refs".
|
| 189 |
+
# => if local_dir is not None, we will return the path to the local folder if it exists.
|
| 190 |
+
if repo_info is None:
|
| 191 |
+
# Try to get which commit hash corresponds to the specified revision
|
| 192 |
+
commit_hash = None
|
| 193 |
+
if REGEX_COMMIT_HASH.match(revision):
|
| 194 |
+
commit_hash = revision
|
| 195 |
+
else:
|
| 196 |
+
ref_path = os.path.join(storage_folder, "refs", revision)
|
| 197 |
+
if os.path.exists(ref_path):
|
| 198 |
+
# retrieve commit_hash from refs file
|
| 199 |
+
with open(ref_path) as f:
|
| 200 |
+
commit_hash = f.read()
|
| 201 |
+
|
| 202 |
+
# Try to locate snapshot folder for this commit hash
|
| 203 |
+
if commit_hash is not None:
|
| 204 |
+
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
| 205 |
+
if os.path.exists(snapshot_folder):
|
| 206 |
+
# Snapshot folder exists => let's return it
|
| 207 |
+
# (but we can't check if all the files are actually there)
|
| 208 |
+
return snapshot_folder
|
| 209 |
+
# If local_dir is not None, return it if it exists and is not empty
|
| 210 |
+
if local_dir is not None:
|
| 211 |
+
local_dir = Path(local_dir)
|
| 212 |
+
if local_dir.is_dir() and any(local_dir.iterdir()):
|
| 213 |
+
logger.warning(
|
| 214 |
+
f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
|
| 215 |
+
)
|
| 216 |
+
return str(local_dir.resolve())
|
| 217 |
+
# If we couldn't find the appropriate folder on disk, raise an error.
|
| 218 |
+
if local_files_only:
|
| 219 |
+
raise LocalEntryNotFoundError(
|
| 220 |
+
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
|
| 221 |
+
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
|
| 222 |
+
"'local_files_only=False' as input."
|
| 223 |
+
)
|
| 224 |
+
elif isinstance(api_call_error, OfflineModeIsEnabled):
|
| 225 |
+
raise LocalEntryNotFoundError(
|
| 226 |
+
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
|
| 227 |
+
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
|
| 228 |
+
"'HF_HUB_OFFLINE=0' as environment variable."
|
| 229 |
+
) from api_call_error
|
| 230 |
+
elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
|
| 231 |
+
# Repo not found => let's raise the actual error
|
| 232 |
+
raise api_call_error
|
| 233 |
+
else:
|
| 234 |
+
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
|
| 235 |
+
raise LocalEntryNotFoundError(
|
| 236 |
+
"An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
|
| 237 |
+
" snapshot folder for the specified revision on the local disk. Please check your internet connection"
|
| 238 |
+
" and try again."
|
| 239 |
+
) from api_call_error
|
| 240 |
+
|
| 241 |
+
# At this stage, internet connection is up and running
|
| 242 |
+
# => let's download the files!
|
| 243 |
+
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
|
| 244 |
+
assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
|
| 245 |
+
filtered_repo_files = list(
|
| 246 |
+
filter_repo_objects(
|
| 247 |
+
items=[f.rfilename for f in repo_info.siblings],
|
| 248 |
+
allow_patterns=allow_patterns,
|
| 249 |
+
ignore_patterns=ignore_patterns,
|
| 250 |
+
)
|
| 251 |
+
)
|
| 252 |
+
commit_hash = repo_info.sha
|
| 253 |
+
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
| 254 |
+
# if passed revision is not identical to commit_hash
|
| 255 |
+
# then revision has to be a branch name or tag name.
|
| 256 |
+
# In that case store a ref.
|
| 257 |
+
if revision != commit_hash:
|
| 258 |
+
ref_path = os.path.join(storage_folder, "refs", revision)
|
| 259 |
+
try:
|
| 260 |
+
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
|
| 261 |
+
with open(ref_path, "w") as f:
|
| 262 |
+
f.write(commit_hash)
|
| 263 |
+
except OSError as e:
|
| 264 |
+
logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
|
| 265 |
+
|
| 266 |
+
# we pass the commit_hash to hf_hub_download
|
| 267 |
+
# so no network call happens if we already
|
| 268 |
+
# have the file locally.
|
| 269 |
+
def _inner_hf_hub_download(repo_file: str):
|
| 270 |
+
return hf_hub_download(
|
| 271 |
+
repo_id,
|
| 272 |
+
filename=repo_file,
|
| 273 |
+
repo_type=repo_type,
|
| 274 |
+
revision=commit_hash,
|
| 275 |
+
endpoint=endpoint,
|
| 276 |
+
cache_dir=cache_dir,
|
| 277 |
+
local_dir=local_dir,
|
| 278 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
| 279 |
+
library_name=library_name,
|
| 280 |
+
library_version=library_version,
|
| 281 |
+
user_agent=user_agent,
|
| 282 |
+
proxies=proxies,
|
| 283 |
+
etag_timeout=etag_timeout,
|
| 284 |
+
resume_download=resume_download,
|
| 285 |
+
force_download=force_download,
|
| 286 |
+
token=token,
|
| 287 |
+
headers=headers,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if constants.HF_HUB_ENABLE_HF_TRANSFER:
|
| 291 |
+
# when using hf_transfer we don't want extra parallelism
|
| 292 |
+
# from the one hf_transfer provides
|
| 293 |
+
for file in filtered_repo_files:
|
| 294 |
+
_inner_hf_hub_download(file)
|
| 295 |
+
else:
|
| 296 |
+
thread_map(
|
| 297 |
+
_inner_hf_hub_download,
|
| 298 |
+
filtered_repo_files,
|
| 299 |
+
desc=f"Fetching {len(filtered_repo_files)} files",
|
| 300 |
+
max_workers=max_workers,
|
| 301 |
+
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
|
| 302 |
+
tqdm_class=tqdm_class or hf_tqdm,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
if local_dir is not None:
|
| 306 |
+
return str(os.path.realpath(local_dir))
|
| 307 |
+
return snapshot_folder
|
.venv/lib/python3.11/site-packages/huggingface_hub/_space_api.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2019-present, the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from enum import Enum
|
| 18 |
+
from typing import Dict, Optional
|
| 19 |
+
|
| 20 |
+
from huggingface_hub.utils import parse_datetime
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SpaceStage(str, Enum):
|
| 24 |
+
"""
|
| 25 |
+
Enumeration of possible stage of a Space on the Hub.
|
| 26 |
+
|
| 27 |
+
Value can be compared to a string:
|
| 28 |
+
```py
|
| 29 |
+
assert SpaceStage.BUILDING == "BUILDING"
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L61 (private url).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# Copied from moon-landing > server > repo_types > SpaceInfo.ts (private repo)
|
| 36 |
+
NO_APP_FILE = "NO_APP_FILE"
|
| 37 |
+
CONFIG_ERROR = "CONFIG_ERROR"
|
| 38 |
+
BUILDING = "BUILDING"
|
| 39 |
+
BUILD_ERROR = "BUILD_ERROR"
|
| 40 |
+
RUNNING = "RUNNING"
|
| 41 |
+
RUNNING_BUILDING = "RUNNING_BUILDING"
|
| 42 |
+
RUNTIME_ERROR = "RUNTIME_ERROR"
|
| 43 |
+
DELETING = "DELETING"
|
| 44 |
+
STOPPED = "STOPPED"
|
| 45 |
+
PAUSED = "PAUSED"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SpaceHardware(str, Enum):
|
| 49 |
+
"""
|
| 50 |
+
Enumeration of hardwares available to run your Space on the Hub.
|
| 51 |
+
|
| 52 |
+
Value can be compared to a string:
|
| 53 |
+
```py
|
| 54 |
+
assert SpaceHardware.CPU_BASIC == "cpu-basic"
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L73 (private url).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
CPU_BASIC = "cpu-basic"
|
| 61 |
+
CPU_UPGRADE = "cpu-upgrade"
|
| 62 |
+
T4_SMALL = "t4-small"
|
| 63 |
+
T4_MEDIUM = "t4-medium"
|
| 64 |
+
L4X1 = "l4x1"
|
| 65 |
+
L4X4 = "l4x4"
|
| 66 |
+
ZERO_A10G = "zero-a10g"
|
| 67 |
+
A10G_SMALL = "a10g-small"
|
| 68 |
+
A10G_LARGE = "a10g-large"
|
| 69 |
+
A10G_LARGEX2 = "a10g-largex2"
|
| 70 |
+
A10G_LARGEX4 = "a10g-largex4"
|
| 71 |
+
A100_LARGE = "a100-large"
|
| 72 |
+
V5E_1X1 = "v5e-1x1"
|
| 73 |
+
V5E_2X2 = "v5e-2x2"
|
| 74 |
+
V5E_2X4 = "v5e-2x4"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SpaceStorage(str, Enum):
|
| 78 |
+
"""
|
| 79 |
+
Enumeration of persistent storage available for your Space on the Hub.
|
| 80 |
+
|
| 81 |
+
Value can be compared to a string:
|
| 82 |
+
```py
|
| 83 |
+
assert SpaceStorage.SMALL == "small"
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts#L24 (private url).
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
SMALL = "small"
|
| 90 |
+
MEDIUM = "medium"
|
| 91 |
+
LARGE = "large"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class SpaceRuntime:
|
| 96 |
+
"""
|
| 97 |
+
Contains information about the current runtime of a Space.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
stage (`str`):
|
| 101 |
+
Current stage of the space. Example: RUNNING.
|
| 102 |
+
hardware (`str` or `None`):
|
| 103 |
+
Current hardware of the space. Example: "cpu-basic". Can be `None` if Space
|
| 104 |
+
is `BUILDING` for the first time.
|
| 105 |
+
requested_hardware (`str` or `None`):
|
| 106 |
+
Requested hardware. Can be different than `hardware` especially if the request
|
| 107 |
+
has just been made. Example: "t4-medium". Can be `None` if no hardware has
|
| 108 |
+
been requested yet.
|
| 109 |
+
sleep_time (`int` or `None`):
|
| 110 |
+
Number of seconds the Space will be kept alive after the last request. By default (if value is `None`), the
|
| 111 |
+
Space will never go to sleep if it's running on an upgraded hardware, while it will go to sleep after 48
|
| 112 |
+
hours on a free 'cpu-basic' hardware. For more details, see https://huggingface.co/docs/hub/spaces-gpus#sleep-time.
|
| 113 |
+
raw (`dict`):
|
| 114 |
+
Raw response from the server. Contains more information about the Space
|
| 115 |
+
runtime like number of replicas, number of cpu, memory size,...
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
stage: SpaceStage
|
| 119 |
+
hardware: Optional[SpaceHardware]
|
| 120 |
+
requested_hardware: Optional[SpaceHardware]
|
| 121 |
+
sleep_time: Optional[int]
|
| 122 |
+
storage: Optional[SpaceStorage]
|
| 123 |
+
raw: Dict
|
| 124 |
+
|
| 125 |
+
def __init__(self, data: Dict) -> None:
|
| 126 |
+
self.stage = data["stage"]
|
| 127 |
+
self.hardware = data.get("hardware", {}).get("current")
|
| 128 |
+
self.requested_hardware = data.get("hardware", {}).get("requested")
|
| 129 |
+
self.sleep_time = data.get("gcTimeout")
|
| 130 |
+
self.storage = data.get("storage")
|
| 131 |
+
self.raw = data
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class SpaceVariable:
|
| 136 |
+
"""
|
| 137 |
+
Contains information about the current variables of a Space.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
key (`str`):
|
| 141 |
+
Variable key. Example: `"MODEL_REPO_ID"`
|
| 142 |
+
value (`str`):
|
| 143 |
+
Variable value. Example: `"the_model_repo_id"`.
|
| 144 |
+
description (`str` or None):
|
| 145 |
+
Description of the variable. Example: `"Model Repo ID of the implemented model"`.
|
| 146 |
+
updatedAt (`datetime` or None):
|
| 147 |
+
datetime of the last update of the variable (if the variable has been updated at least once).
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
key: str
|
| 151 |
+
value: str
|
| 152 |
+
description: Optional[str]
|
| 153 |
+
updated_at: Optional[datetime]
|
| 154 |
+
|
| 155 |
+
def __init__(self, key: str, values: Dict) -> None:
|
| 156 |
+
self.key = key
|
| 157 |
+
self.value = values["value"]
|
| 158 |
+
self.description = values.get("description")
|
| 159 |
+
updated_at = values.get("updatedAt")
|
| 160 |
+
self.updated_at = parse_datetime(updated_at) if updated_at is not None else None
|
.venv/lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Contains a logger to push training logs to the Hub, using Tensorboard."""
|
| 15 |
+
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
from ._commit_scheduler import CommitScheduler
|
| 20 |
+
from .errors import EntryNotFoundError
|
| 21 |
+
from .repocard import ModelCard
|
| 22 |
+
from .utils import experimental
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Depending on user's setup, SummaryWriter can come either from 'tensorboardX'
|
| 26 |
+
# or from 'torch.utils.tensorboard'. Both are compatible so let's try to load
|
| 27 |
+
# from either of them.
|
| 28 |
+
try:
|
| 29 |
+
from tensorboardX import SummaryWriter
|
| 30 |
+
|
| 31 |
+
is_summary_writer_available = True
|
| 32 |
+
|
| 33 |
+
except ImportError:
|
| 34 |
+
try:
|
| 35 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 36 |
+
|
| 37 |
+
is_summary_writer_available = False
|
| 38 |
+
except ImportError:
|
| 39 |
+
# Dummy class to avoid failing at import. Will raise on instance creation.
|
| 40 |
+
SummaryWriter = object
|
| 41 |
+
is_summary_writer_available = False
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING:
|
| 44 |
+
from tensorboardX import SummaryWriter
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class HFSummaryWriter(SummaryWriter):
|
| 48 |
+
"""
|
| 49 |
+
Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub.
|
| 50 |
+
|
| 51 |
+
Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate
|
| 52 |
+
thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection
|
| 53 |
+
issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
|
| 54 |
+
minutes (default to every 5 minutes).
|
| 55 |
+
|
| 56 |
+
<Tip warning={true}>
|
| 57 |
+
|
| 58 |
+
`HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
|
| 59 |
+
|
| 60 |
+
</Tip>
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
repo_id (`str`):
|
| 64 |
+
The id of the repo to which the logs will be pushed.
|
| 65 |
+
logdir (`str`, *optional*):
|
| 66 |
+
The directory where the logs will be written. If not specified, a local directory will be created by the
|
| 67 |
+
underlying `SummaryWriter` object.
|
| 68 |
+
commit_every (`int` or `float`, *optional*):
|
| 69 |
+
The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes.
|
| 70 |
+
squash_history (`bool`, *optional*):
|
| 71 |
+
Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
|
| 72 |
+
useful to avoid degraded performances on the repo when it grows too large.
|
| 73 |
+
repo_type (`str`, *optional*):
|
| 74 |
+
The type of the repo to which the logs will be pushed. Defaults to "model".
|
| 75 |
+
repo_revision (`str`, *optional*):
|
| 76 |
+
The revision of the repo to which the logs will be pushed. Defaults to "main".
|
| 77 |
+
repo_private (`bool`, *optional*):
|
| 78 |
+
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
|
| 79 |
+
path_in_repo (`str`, *optional*):
|
| 80 |
+
The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/".
|
| 81 |
+
repo_allow_patterns (`List[str]` or `str`, *optional*):
|
| 82 |
+
A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the
|
| 83 |
+
[upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
|
| 84 |
+
repo_ignore_patterns (`List[str]` or `str`, *optional*):
|
| 85 |
+
A list of patterns to exclude in the upload. Check out the
|
| 86 |
+
[upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
|
| 87 |
+
token (`str`, *optional*):
|
| 88 |
+
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
|
| 89 |
+
details
|
| 90 |
+
kwargs:
|
| 91 |
+
Additional keyword arguments passed to `SummaryWriter`.
|
| 92 |
+
|
| 93 |
+
Examples:
|
| 94 |
+
```diff
|
| 95 |
+
# Taken from https://pytorch.org/docs/stable/tensorboard.html
|
| 96 |
+
- from torch.utils.tensorboard import SummaryWriter
|
| 97 |
+
+ from huggingface_hub import HFSummaryWriter
|
| 98 |
+
|
| 99 |
+
import numpy as np
|
| 100 |
+
|
| 101 |
+
- writer = SummaryWriter()
|
| 102 |
+
+ writer = HFSummaryWriter(repo_id="username/my-trained-model")
|
| 103 |
+
|
| 104 |
+
for n_iter in range(100):
|
| 105 |
+
writer.add_scalar('Loss/train', np.random.random(), n_iter)
|
| 106 |
+
writer.add_scalar('Loss/test', np.random.random(), n_iter)
|
| 107 |
+
writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
|
| 108 |
+
writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
```py
|
| 112 |
+
>>> from huggingface_hub import HFSummaryWriter
|
| 113 |
+
|
| 114 |
+
# Logs are automatically pushed every 15 minutes (5 by default) + when exiting the context manager
|
| 115 |
+
>>> with HFSummaryWriter(repo_id="test_hf_logger", commit_every=15) as logger:
|
| 116 |
+
... logger.add_scalar("a", 1)
|
| 117 |
+
... logger.add_scalar("b", 2)
|
| 118 |
+
```
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
@experimental
|
| 122 |
+
def __new__(cls, *args, **kwargs) -> "HFSummaryWriter":
|
| 123 |
+
if not is_summary_writer_available:
|
| 124 |
+
raise ImportError(
|
| 125 |
+
"You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade"
|
| 126 |
+
" tensorboardX` first."
|
| 127 |
+
)
|
| 128 |
+
return super().__new__(cls)
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
repo_id: str,
|
| 133 |
+
*,
|
| 134 |
+
logdir: Optional[str] = None,
|
| 135 |
+
commit_every: Union[int, float] = 5,
|
| 136 |
+
squash_history: bool = False,
|
| 137 |
+
repo_type: Optional[str] = None,
|
| 138 |
+
repo_revision: Optional[str] = None,
|
| 139 |
+
repo_private: Optional[bool] = None,
|
| 140 |
+
path_in_repo: Optional[str] = "tensorboard",
|
| 141 |
+
repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*",
|
| 142 |
+
repo_ignore_patterns: Optional[Union[List[str], str]] = None,
|
| 143 |
+
token: Optional[str] = None,
|
| 144 |
+
**kwargs,
|
| 145 |
+
):
|
| 146 |
+
# Initialize SummaryWriter
|
| 147 |
+
super().__init__(logdir=logdir, **kwargs)
|
| 148 |
+
|
| 149 |
+
# Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it.
|
| 150 |
+
if not isinstance(self.logdir, str):
|
| 151 |
+
raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.")
|
| 152 |
+
|
| 153 |
+
# Append logdir name to `path_in_repo`
|
| 154 |
+
if path_in_repo is None or path_in_repo == "":
|
| 155 |
+
path_in_repo = Path(self.logdir).name
|
| 156 |
+
else:
|
| 157 |
+
path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name
|
| 158 |
+
|
| 159 |
+
# Initialize scheduler
|
| 160 |
+
self.scheduler = CommitScheduler(
|
| 161 |
+
folder_path=self.logdir,
|
| 162 |
+
path_in_repo=path_in_repo,
|
| 163 |
+
repo_id=repo_id,
|
| 164 |
+
repo_type=repo_type,
|
| 165 |
+
revision=repo_revision,
|
| 166 |
+
private=repo_private,
|
| 167 |
+
token=token,
|
| 168 |
+
allow_patterns=repo_allow_patterns,
|
| 169 |
+
ignore_patterns=repo_ignore_patterns,
|
| 170 |
+
every=commit_every,
|
| 171 |
+
squash_history=squash_history,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Exposing some high-level info at root level
|
| 175 |
+
self.repo_id = self.scheduler.repo_id
|
| 176 |
+
self.repo_type = self.scheduler.repo_type
|
| 177 |
+
self.repo_revision = self.scheduler.revision
|
| 178 |
+
|
| 179 |
+
# Add `hf-summary-writer` tag to the model card metadata
|
| 180 |
+
try:
|
| 181 |
+
card = ModelCard.load(repo_id_or_path=self.repo_id, repo_type=self.repo_type)
|
| 182 |
+
except EntryNotFoundError:
|
| 183 |
+
card = ModelCard("")
|
| 184 |
+
tags = card.data.get("tags", [])
|
| 185 |
+
if "hf-summary-writer" not in tags:
|
| 186 |
+
tags.append("hf-summary-writer")
|
| 187 |
+
card.data["tags"] = tags
|
| 188 |
+
card.push_to_hub(repo_id=self.repo_id, repo_type=self.repo_type)
|
| 189 |
+
|
| 190 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 191 |
+
"""Push to hub in a non-blocking way when exiting the logger's context manager."""
|
| 192 |
+
super().__exit__(exc_type, exc_val, exc_tb)
|
| 193 |
+
future = self.scheduler.trigger()
|
| 194 |
+
future.result()
|
.venv/lib/python3.11/site-packages/huggingface_hub/_upload_large_folder.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024-present, the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import enum
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import queue
|
| 19 |
+
import shutil
|
| 20 |
+
import sys
|
| 21 |
+
import threading
|
| 22 |
+
import time
|
| 23 |
+
import traceback
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from threading import Lock
|
| 27 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 28 |
+
|
| 29 |
+
from . import constants
|
| 30 |
+
from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes
|
| 31 |
+
from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata
|
| 32 |
+
from .constants import DEFAULT_REVISION, REPO_TYPES
|
| 33 |
+
from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm
|
| 34 |
+
from .utils._cache_manager import _format_size
|
| 35 |
+
from .utils.sha import sha_fileobj
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if TYPE_CHECKING:
|
| 39 |
+
from .hf_api import HfApi
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
WAITING_TIME_IF_NO_TASKS = 10 # seconds
|
| 44 |
+
MAX_NB_REGULAR_FILES_PER_COMMIT = 75
|
| 45 |
+
MAX_NB_LFS_FILES_PER_COMMIT = 150
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def upload_large_folder_internal(
|
| 49 |
+
api: "HfApi",
|
| 50 |
+
repo_id: str,
|
| 51 |
+
folder_path: Union[str, Path],
|
| 52 |
+
*,
|
| 53 |
+
repo_type: str, # Repo type is required!
|
| 54 |
+
revision: Optional[str] = None,
|
| 55 |
+
private: Optional[bool] = None,
|
| 56 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
| 57 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
| 58 |
+
num_workers: Optional[int] = None,
|
| 59 |
+
print_report: bool = True,
|
| 60 |
+
print_report_every: int = 60,
|
| 61 |
+
):
|
| 62 |
+
"""Upload a large folder to the Hub in the most resilient way possible.
|
| 63 |
+
|
| 64 |
+
See [`HfApi.upload_large_folder`] for the full documentation.
|
| 65 |
+
"""
|
| 66 |
+
# 1. Check args and setup
|
| 67 |
+
if repo_type is None:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"For large uploads, `repo_type` is explicitly required. Please set it to `model`, `dataset` or `space`."
|
| 70 |
+
" If you are using the CLI, pass it as `--repo-type=model`."
|
| 71 |
+
)
|
| 72 |
+
if repo_type not in REPO_TYPES:
|
| 73 |
+
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
|
| 74 |
+
if revision is None:
|
| 75 |
+
revision = DEFAULT_REVISION
|
| 76 |
+
|
| 77 |
+
folder_path = Path(folder_path).expanduser().resolve()
|
| 78 |
+
if not folder_path.is_dir():
|
| 79 |
+
raise ValueError(f"Provided path: '{folder_path}' is not a directory")
|
| 80 |
+
|
| 81 |
+
if ignore_patterns is None:
|
| 82 |
+
ignore_patterns = []
|
| 83 |
+
elif isinstance(ignore_patterns, str):
|
| 84 |
+
ignore_patterns = [ignore_patterns]
|
| 85 |
+
ignore_patterns += DEFAULT_IGNORE_PATTERNS
|
| 86 |
+
|
| 87 |
+
if num_workers is None:
|
| 88 |
+
nb_cores = os.cpu_count() or 1
|
| 89 |
+
num_workers = max(nb_cores - 2, 2) # Use all but 2 cores, or at least 2 cores
|
| 90 |
+
|
| 91 |
+
# 2. Create repo if missing
|
| 92 |
+
repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)
|
| 93 |
+
logger.info(f"Repo created: {repo_url}")
|
| 94 |
+
repo_id = repo_url.repo_id
|
| 95 |
+
|
| 96 |
+
# 3. List files to upload
|
| 97 |
+
filtered_paths_list = filter_repo_objects(
|
| 98 |
+
(path.relative_to(folder_path).as_posix() for path in folder_path.glob("**/*") if path.is_file()),
|
| 99 |
+
allow_patterns=allow_patterns,
|
| 100 |
+
ignore_patterns=ignore_patterns,
|
| 101 |
+
)
|
| 102 |
+
paths_list = [get_local_upload_paths(folder_path, relpath) for relpath in filtered_paths_list]
|
| 103 |
+
logger.info(f"Found {len(paths_list)} candidate files to upload")
|
| 104 |
+
|
| 105 |
+
# Read metadata for each file
|
| 106 |
+
items = [
|
| 107 |
+
(paths, read_upload_metadata(folder_path, paths.path_in_repo))
|
| 108 |
+
for paths in tqdm(paths_list, desc="Recovering from metadata files")
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# 4. Start workers
|
| 112 |
+
status = LargeUploadStatus(items)
|
| 113 |
+
threads = [
|
| 114 |
+
threading.Thread(
|
| 115 |
+
target=_worker_job,
|
| 116 |
+
kwargs={
|
| 117 |
+
"status": status,
|
| 118 |
+
"api": api,
|
| 119 |
+
"repo_id": repo_id,
|
| 120 |
+
"repo_type": repo_type,
|
| 121 |
+
"revision": revision,
|
| 122 |
+
},
|
| 123 |
+
)
|
| 124 |
+
for _ in range(num_workers)
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
for thread in threads:
|
| 128 |
+
thread.start()
|
| 129 |
+
|
| 130 |
+
# 5. Print regular reports
|
| 131 |
+
if print_report:
|
| 132 |
+
print("\n\n" + status.current_report())
|
| 133 |
+
last_report_ts = time.time()
|
| 134 |
+
while True:
|
| 135 |
+
time.sleep(1)
|
| 136 |
+
if time.time() - last_report_ts >= print_report_every:
|
| 137 |
+
if print_report:
|
| 138 |
+
_print_overwrite(status.current_report())
|
| 139 |
+
last_report_ts = time.time()
|
| 140 |
+
if status.is_done():
|
| 141 |
+
logging.info("Is done: exiting main loop")
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
for thread in threads:
|
| 145 |
+
thread.join()
|
| 146 |
+
|
| 147 |
+
logger.info(status.current_report())
|
| 148 |
+
logging.info("Upload is complete!")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
####################
|
| 152 |
+
# Logic to manage workers and synchronize tasks
|
| 153 |
+
####################
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class WorkerJob(enum.Enum):
|
| 157 |
+
SHA256 = enum.auto()
|
| 158 |
+
GET_UPLOAD_MODE = enum.auto()
|
| 159 |
+
PREUPLOAD_LFS = enum.auto()
|
| 160 |
+
COMMIT = enum.auto()
|
| 161 |
+
WAIT = enum.auto() # if no tasks are available but we don't want to exit
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
JOB_ITEM_T = Tuple[LocalUploadFilePaths, LocalUploadFileMetadata]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class LargeUploadStatus:
|
| 168 |
+
"""Contains information, queues and tasks for a large upload process."""
|
| 169 |
+
|
| 170 |
+
def __init__(self, items: List[JOB_ITEM_T]):
|
| 171 |
+
self.items = items
|
| 172 |
+
self.queue_sha256: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
|
| 173 |
+
self.queue_get_upload_mode: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
|
| 174 |
+
self.queue_preupload_lfs: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
|
| 175 |
+
self.queue_commit: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
|
| 176 |
+
self.lock = Lock()
|
| 177 |
+
|
| 178 |
+
self.nb_workers_sha256: int = 0
|
| 179 |
+
self.nb_workers_get_upload_mode: int = 0
|
| 180 |
+
self.nb_workers_preupload_lfs: int = 0
|
| 181 |
+
self.nb_workers_commit: int = 0
|
| 182 |
+
self.nb_workers_waiting: int = 0
|
| 183 |
+
self.last_commit_attempt: Optional[float] = None
|
| 184 |
+
|
| 185 |
+
self._started_at = datetime.now()
|
| 186 |
+
|
| 187 |
+
# Setup queues
|
| 188 |
+
for item in self.items:
|
| 189 |
+
paths, metadata = item
|
| 190 |
+
if metadata.sha256 is None:
|
| 191 |
+
self.queue_sha256.put(item)
|
| 192 |
+
elif metadata.upload_mode is None:
|
| 193 |
+
self.queue_get_upload_mode.put(item)
|
| 194 |
+
elif metadata.upload_mode == "lfs" and not metadata.is_uploaded:
|
| 195 |
+
self.queue_preupload_lfs.put(item)
|
| 196 |
+
elif not metadata.is_committed:
|
| 197 |
+
self.queue_commit.put(item)
|
| 198 |
+
else:
|
| 199 |
+
logger.debug(f"Skipping file {paths.path_in_repo} (already uploaded and committed)")
|
| 200 |
+
|
| 201 |
+
def current_report(self) -> str:
|
| 202 |
+
"""Generate a report of the current status of the large upload."""
|
| 203 |
+
nb_hashed = 0
|
| 204 |
+
size_hashed = 0
|
| 205 |
+
nb_preuploaded = 0
|
| 206 |
+
nb_lfs = 0
|
| 207 |
+
nb_lfs_unsure = 0
|
| 208 |
+
size_preuploaded = 0
|
| 209 |
+
nb_committed = 0
|
| 210 |
+
size_committed = 0
|
| 211 |
+
total_size = 0
|
| 212 |
+
ignored_files = 0
|
| 213 |
+
total_files = 0
|
| 214 |
+
|
| 215 |
+
with self.lock:
|
| 216 |
+
for _, metadata in self.items:
|
| 217 |
+
if metadata.should_ignore:
|
| 218 |
+
ignored_files += 1
|
| 219 |
+
continue
|
| 220 |
+
total_size += metadata.size
|
| 221 |
+
total_files += 1
|
| 222 |
+
if metadata.sha256 is not None:
|
| 223 |
+
nb_hashed += 1
|
| 224 |
+
size_hashed += metadata.size
|
| 225 |
+
if metadata.upload_mode == "lfs":
|
| 226 |
+
nb_lfs += 1
|
| 227 |
+
if metadata.upload_mode is None:
|
| 228 |
+
nb_lfs_unsure += 1
|
| 229 |
+
if metadata.is_uploaded:
|
| 230 |
+
nb_preuploaded += 1
|
| 231 |
+
size_preuploaded += metadata.size
|
| 232 |
+
if metadata.is_committed:
|
| 233 |
+
nb_committed += 1
|
| 234 |
+
size_committed += metadata.size
|
| 235 |
+
total_size_str = _format_size(total_size)
|
| 236 |
+
|
| 237 |
+
now = datetime.now()
|
| 238 |
+
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
|
| 239 |
+
elapsed = now - self._started_at
|
| 240 |
+
elapsed_str = str(elapsed).split(".")[0] # remove milliseconds
|
| 241 |
+
|
| 242 |
+
message = "\n" + "-" * 10
|
| 243 |
+
message += f" {now_str} ({elapsed_str}) "
|
| 244 |
+
message += "-" * 10 + "\n"
|
| 245 |
+
|
| 246 |
+
message += "Files: "
|
| 247 |
+
message += f"hashed {nb_hashed}/{total_files} ({_format_size(size_hashed)}/{total_size_str}) | "
|
| 248 |
+
message += f"pre-uploaded: {nb_preuploaded}/{nb_lfs} ({_format_size(size_preuploaded)}/{total_size_str})"
|
| 249 |
+
if nb_lfs_unsure > 0:
|
| 250 |
+
message += f" (+{nb_lfs_unsure} unsure)"
|
| 251 |
+
message += f" | committed: {nb_committed}/{total_files} ({_format_size(size_committed)}/{total_size_str})"
|
| 252 |
+
message += f" | ignored: {ignored_files}\n"
|
| 253 |
+
|
| 254 |
+
message += "Workers: "
|
| 255 |
+
message += f"hashing: {self.nb_workers_sha256} | "
|
| 256 |
+
message += f"get upload mode: {self.nb_workers_get_upload_mode} | "
|
| 257 |
+
message += f"pre-uploading: {self.nb_workers_preupload_lfs} | "
|
| 258 |
+
message += f"committing: {self.nb_workers_commit} | "
|
| 259 |
+
message += f"waiting: {self.nb_workers_waiting}\n"
|
| 260 |
+
message += "-" * 51
|
| 261 |
+
|
| 262 |
+
return message
|
| 263 |
+
|
| 264 |
+
def is_done(self) -> bool:
|
| 265 |
+
with self.lock:
|
| 266 |
+
return all(metadata.is_committed or metadata.should_ignore for _, metadata in self.items)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _worker_job(
|
| 270 |
+
status: LargeUploadStatus,
|
| 271 |
+
api: "HfApi",
|
| 272 |
+
repo_id: str,
|
| 273 |
+
repo_type: str,
|
| 274 |
+
revision: str,
|
| 275 |
+
):
|
| 276 |
+
"""
|
| 277 |
+
Main process for a worker. The worker will perform tasks based on the priority list until all files are uploaded
|
| 278 |
+
and committed. If no tasks are available, the worker will wait for 10 seconds before checking again.
|
| 279 |
+
|
| 280 |
+
If a task fails for any reason, the item(s) are put back in the queue for another worker to pick up.
|
| 281 |
+
|
| 282 |
+
Read `upload_large_folder` docstring for more information on how tasks are prioritized.
|
| 283 |
+
"""
|
| 284 |
+
while True:
|
| 285 |
+
next_job: Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]] = None
|
| 286 |
+
|
| 287 |
+
# Determine next task
|
| 288 |
+
next_job = _determine_next_job(status)
|
| 289 |
+
if next_job is None:
|
| 290 |
+
return
|
| 291 |
+
job, items = next_job
|
| 292 |
+
|
| 293 |
+
# Perform task
|
| 294 |
+
if job == WorkerJob.SHA256:
|
| 295 |
+
item = items[0] # single item
|
| 296 |
+
try:
|
| 297 |
+
_compute_sha256(item)
|
| 298 |
+
status.queue_get_upload_mode.put(item)
|
| 299 |
+
except KeyboardInterrupt:
|
| 300 |
+
raise
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.error(f"Failed to compute sha256: {e}")
|
| 303 |
+
traceback.format_exc()
|
| 304 |
+
status.queue_sha256.put(item)
|
| 305 |
+
|
| 306 |
+
with status.lock:
|
| 307 |
+
status.nb_workers_sha256 -= 1
|
| 308 |
+
|
| 309 |
+
elif job == WorkerJob.GET_UPLOAD_MODE:
|
| 310 |
+
try:
|
| 311 |
+
_get_upload_mode(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision)
|
| 312 |
+
except KeyboardInterrupt:
|
| 313 |
+
raise
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.error(f"Failed to get upload mode: {e}")
|
| 316 |
+
traceback.format_exc()
|
| 317 |
+
|
| 318 |
+
# Items are either:
|
| 319 |
+
# - dropped (if should_ignore)
|
| 320 |
+
# - put in LFS queue (if LFS)
|
| 321 |
+
# - put in commit queue (if regular)
|
| 322 |
+
# - or put back (if error occurred).
|
| 323 |
+
for item in items:
|
| 324 |
+
_, metadata = item
|
| 325 |
+
if metadata.should_ignore:
|
| 326 |
+
continue
|
| 327 |
+
if metadata.upload_mode == "lfs":
|
| 328 |
+
status.queue_preupload_lfs.put(item)
|
| 329 |
+
elif metadata.upload_mode == "regular":
|
| 330 |
+
status.queue_commit.put(item)
|
| 331 |
+
else:
|
| 332 |
+
status.queue_get_upload_mode.put(item)
|
| 333 |
+
|
| 334 |
+
with status.lock:
|
| 335 |
+
status.nb_workers_get_upload_mode -= 1
|
| 336 |
+
|
| 337 |
+
elif job == WorkerJob.PREUPLOAD_LFS:
|
| 338 |
+
item = items[0] # single item
|
| 339 |
+
try:
|
| 340 |
+
_preupload_lfs(item, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision)
|
| 341 |
+
status.queue_commit.put(item)
|
| 342 |
+
except KeyboardInterrupt:
|
| 343 |
+
raise
|
| 344 |
+
except Exception as e:
|
| 345 |
+
logger.error(f"Failed to preupload LFS: {e}")
|
| 346 |
+
traceback.format_exc()
|
| 347 |
+
status.queue_preupload_lfs.put(item)
|
| 348 |
+
|
| 349 |
+
with status.lock:
|
| 350 |
+
status.nb_workers_preupload_lfs -= 1
|
| 351 |
+
|
| 352 |
+
elif job == WorkerJob.COMMIT:
|
| 353 |
+
try:
|
| 354 |
+
_commit(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision)
|
| 355 |
+
except KeyboardInterrupt:
|
| 356 |
+
raise
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.error(f"Failed to commit: {e}")
|
| 359 |
+
traceback.format_exc()
|
| 360 |
+
for item in items:
|
| 361 |
+
status.queue_commit.put(item)
|
| 362 |
+
with status.lock:
|
| 363 |
+
status.last_commit_attempt = time.time()
|
| 364 |
+
status.nb_workers_commit -= 1
|
| 365 |
+
|
| 366 |
+
elif job == WorkerJob.WAIT:
|
| 367 |
+
time.sleep(WAITING_TIME_IF_NO_TASKS)
|
| 368 |
+
with status.lock:
|
| 369 |
+
status.nb_workers_waiting -= 1
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]]:
|
| 373 |
+
with status.lock:
|
| 374 |
+
# 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file)
|
| 375 |
+
if (
|
| 376 |
+
status.nb_workers_commit == 0
|
| 377 |
+
and status.queue_commit.qsize() > 0
|
| 378 |
+
and status.last_commit_attempt is not None
|
| 379 |
+
and time.time() - status.last_commit_attempt > 5 * 60
|
| 380 |
+
):
|
| 381 |
+
status.nb_workers_commit += 1
|
| 382 |
+
logger.debug("Job: commit (more than 5 minutes since last commit attempt)")
|
| 383 |
+
return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
|
| 384 |
+
|
| 385 |
+
# 2. Commit if at least 100 files are ready to commit
|
| 386 |
+
elif status.nb_workers_commit == 0 and status.queue_commit.qsize() >= 150:
|
| 387 |
+
status.nb_workers_commit += 1
|
| 388 |
+
logger.debug("Job: commit (>100 files ready)")
|
| 389 |
+
return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
|
| 390 |
+
|
| 391 |
+
# 3. Get upload mode if at least 10 files
|
| 392 |
+
elif status.queue_get_upload_mode.qsize() >= 10:
|
| 393 |
+
status.nb_workers_get_upload_mode += 1
|
| 394 |
+
logger.debug("Job: get upload mode (>10 files ready)")
|
| 395 |
+
return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, 50))
|
| 396 |
+
|
| 397 |
+
# 4. Preupload LFS file if at least 1 file and no worker is preuploading LFS
|
| 398 |
+
elif status.queue_preupload_lfs.qsize() > 0 and status.nb_workers_preupload_lfs == 0:
|
| 399 |
+
status.nb_workers_preupload_lfs += 1
|
| 400 |
+
logger.debug("Job: preupload LFS (no other worker preuploading LFS)")
|
| 401 |
+
return (WorkerJob.PREUPLOAD_LFS, _get_one(status.queue_preupload_lfs))
|
| 402 |
+
|
| 403 |
+
# 5. Compute sha256 if at least 1 file and no worker is computing sha256
|
| 404 |
+
elif status.queue_sha256.qsize() > 0 and status.nb_workers_sha256 == 0:
|
| 405 |
+
status.nb_workers_sha256 += 1
|
| 406 |
+
logger.debug("Job: sha256 (no other worker computing sha256)")
|
| 407 |
+
return (WorkerJob.SHA256, _get_one(status.queue_sha256))
|
| 408 |
+
|
| 409 |
+
# 6. Get upload mode if at least 1 file and no worker is getting upload mode
|
| 410 |
+
elif status.queue_get_upload_mode.qsize() > 0 and status.nb_workers_get_upload_mode == 0:
|
| 411 |
+
status.nb_workers_get_upload_mode += 1
|
| 412 |
+
logger.debug("Job: get upload mode (no other worker getting upload mode)")
|
| 413 |
+
return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, 50))
|
| 414 |
+
|
| 415 |
+
# 7. Preupload LFS file if at least 1 file
|
| 416 |
+
# Skip if hf_transfer is enabled and there is already a worker preuploading LFS
|
| 417 |
+
elif status.queue_preupload_lfs.qsize() > 0 and (
|
| 418 |
+
status.nb_workers_preupload_lfs == 0 or not constants.HF_HUB_ENABLE_HF_TRANSFER
|
| 419 |
+
):
|
| 420 |
+
status.nb_workers_preupload_lfs += 1
|
| 421 |
+
logger.debug("Job: preupload LFS")
|
| 422 |
+
return (WorkerJob.PREUPLOAD_LFS, _get_one(status.queue_preupload_lfs))
|
| 423 |
+
|
| 424 |
+
# 8. Compute sha256 if at least 1 file
|
| 425 |
+
elif status.queue_sha256.qsize() > 0:
|
| 426 |
+
status.nb_workers_sha256 += 1
|
| 427 |
+
logger.debug("Job: sha256")
|
| 428 |
+
return (WorkerJob.SHA256, _get_one(status.queue_sha256))
|
| 429 |
+
|
| 430 |
+
# 9. Get upload mode if at least 1 file
|
| 431 |
+
elif status.queue_get_upload_mode.qsize() > 0:
|
| 432 |
+
status.nb_workers_get_upload_mode += 1
|
| 433 |
+
logger.debug("Job: get upload mode")
|
| 434 |
+
return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, 50))
|
| 435 |
+
|
| 436 |
+
# 10. Commit if at least 1 file and 1 min since last commit attempt
|
| 437 |
+
elif (
|
| 438 |
+
status.nb_workers_commit == 0
|
| 439 |
+
and status.queue_commit.qsize() > 0
|
| 440 |
+
and status.last_commit_attempt is not None
|
| 441 |
+
and time.time() - status.last_commit_attempt > 1 * 60
|
| 442 |
+
):
|
| 443 |
+
status.nb_workers_commit += 1
|
| 444 |
+
logger.debug("Job: commit (1 min since last commit attempt)")
|
| 445 |
+
return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
|
| 446 |
+
|
| 447 |
+
# 11. Commit if at least 1 file all other queues are empty and all workers are waiting
|
| 448 |
+
# e.g. when it's the last commit
|
| 449 |
+
elif (
|
| 450 |
+
status.nb_workers_commit == 0
|
| 451 |
+
and status.queue_commit.qsize() > 0
|
| 452 |
+
and status.queue_sha256.qsize() == 0
|
| 453 |
+
and status.queue_get_upload_mode.qsize() == 0
|
| 454 |
+
and status.queue_preupload_lfs.qsize() == 0
|
| 455 |
+
and status.nb_workers_sha256 == 0
|
| 456 |
+
and status.nb_workers_get_upload_mode == 0
|
| 457 |
+
and status.nb_workers_preupload_lfs == 0
|
| 458 |
+
):
|
| 459 |
+
status.nb_workers_commit += 1
|
| 460 |
+
logger.debug("Job: commit")
|
| 461 |
+
return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
|
| 462 |
+
|
| 463 |
+
# 12. If all queues are empty, exit
|
| 464 |
+
elif all(metadata.is_committed or metadata.should_ignore for _, metadata in status.items):
|
| 465 |
+
logger.info("All files have been processed! Exiting worker.")
|
| 466 |
+
return None
|
| 467 |
+
|
| 468 |
+
# 13. If no task is available, wait
|
| 469 |
+
else:
|
| 470 |
+
status.nb_workers_waiting += 1
|
| 471 |
+
logger.debug(f"No task available, waiting... ({WAITING_TIME_IF_NO_TASKS}s)")
|
| 472 |
+
return (WorkerJob.WAIT, [])
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
####################
|
| 476 |
+
# Atomic jobs (sha256, get_upload_mode, preupload_lfs, commit)
|
| 477 |
+
####################
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _compute_sha256(item: JOB_ITEM_T) -> None:
|
| 481 |
+
"""Compute sha256 of a file and save it in metadata."""
|
| 482 |
+
paths, metadata = item
|
| 483 |
+
if metadata.sha256 is None:
|
| 484 |
+
with paths.file_path.open("rb") as f:
|
| 485 |
+
metadata.sha256 = sha_fileobj(f).hex()
|
| 486 |
+
metadata.save(paths)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def _get_upload_mode(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
|
| 490 |
+
"""Get upload mode for each file and update metadata.
|
| 491 |
+
|
| 492 |
+
Also receive info if the file should be ignored.
|
| 493 |
+
"""
|
| 494 |
+
additions = [_build_hacky_operation(item) for item in items]
|
| 495 |
+
_fetch_upload_modes(
|
| 496 |
+
additions=additions,
|
| 497 |
+
repo_type=repo_type,
|
| 498 |
+
repo_id=repo_id,
|
| 499 |
+
headers=api._build_hf_headers(),
|
| 500 |
+
revision=revision,
|
| 501 |
+
)
|
| 502 |
+
for item, addition in zip(items, additions):
|
| 503 |
+
paths, metadata = item
|
| 504 |
+
metadata.upload_mode = addition._upload_mode
|
| 505 |
+
metadata.should_ignore = addition._should_ignore
|
| 506 |
+
metadata.save(paths)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _preupload_lfs(item: JOB_ITEM_T, api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
|
| 510 |
+
"""Preupload LFS file and update metadata."""
|
| 511 |
+
paths, metadata = item
|
| 512 |
+
addition = _build_hacky_operation(item)
|
| 513 |
+
api.preupload_lfs_files(
|
| 514 |
+
repo_id=repo_id,
|
| 515 |
+
repo_type=repo_type,
|
| 516 |
+
revision=revision,
|
| 517 |
+
additions=[addition],
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
metadata.is_uploaded = True
|
| 521 |
+
metadata.save(paths)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def _commit(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
|
| 525 |
+
"""Commit files to the repo."""
|
| 526 |
+
additions = [_build_hacky_operation(item) for item in items]
|
| 527 |
+
api.create_commit(
|
| 528 |
+
repo_id=repo_id,
|
| 529 |
+
repo_type=repo_type,
|
| 530 |
+
revision=revision,
|
| 531 |
+
operations=additions,
|
| 532 |
+
commit_message="Add files using upload-large-folder tool",
|
| 533 |
+
)
|
| 534 |
+
for paths, metadata in items:
|
| 535 |
+
metadata.is_committed = True
|
| 536 |
+
metadata.save(paths)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
####################
|
| 540 |
+
# Hacks with CommitOperationAdd to bypass checks/sha256 calculation
|
| 541 |
+
####################
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class HackyCommitOperationAdd(CommitOperationAdd):
|
| 545 |
+
def __post_init__(self) -> None:
|
| 546 |
+
if isinstance(self.path_or_fileobj, Path):
|
| 547 |
+
self.path_or_fileobj = str(self.path_or_fileobj)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def _build_hacky_operation(item: JOB_ITEM_T) -> HackyCommitOperationAdd:
|
| 551 |
+
paths, metadata = item
|
| 552 |
+
operation = HackyCommitOperationAdd(path_in_repo=paths.path_in_repo, path_or_fileobj=paths.file_path)
|
| 553 |
+
with paths.file_path.open("rb") as file:
|
| 554 |
+
sample = file.peek(512)[:512]
|
| 555 |
+
if metadata.sha256 is None:
|
| 556 |
+
raise ValueError("sha256 must have been computed by now!")
|
| 557 |
+
operation.upload_info = UploadInfo(sha256=bytes.fromhex(metadata.sha256), size=metadata.size, sample=sample)
|
| 558 |
+
return operation
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
####################
|
| 562 |
+
# Misc helpers
|
| 563 |
+
####################
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]:
|
| 567 |
+
return [queue.get()]
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> List[JOB_ITEM_T]:
|
| 571 |
+
return [queue.get() for _ in range(min(queue.qsize(), n))]
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def _get_items_to_commit(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]:
|
| 575 |
+
"""Special case for commit job: the number of items to commit depends on the type of files."""
|
| 576 |
+
# Can take at most 50 regular files and/or 100 LFS files in a single commit
|
| 577 |
+
items: List[JOB_ITEM_T] = []
|
| 578 |
+
nb_lfs, nb_regular = 0, 0
|
| 579 |
+
while True:
|
| 580 |
+
# If empty queue => commit everything
|
| 581 |
+
if queue.qsize() == 0:
|
| 582 |
+
return items
|
| 583 |
+
|
| 584 |
+
# If we have enough items => commit them
|
| 585 |
+
if nb_lfs >= MAX_NB_LFS_FILES_PER_COMMIT or nb_regular >= MAX_NB_REGULAR_FILES_PER_COMMIT:
|
| 586 |
+
return items
|
| 587 |
+
|
| 588 |
+
# Else, get a new item and increase counter
|
| 589 |
+
item = queue.get()
|
| 590 |
+
items.append(item)
|
| 591 |
+
_, metadata = item
|
| 592 |
+
if metadata.upload_mode == "lfs":
|
| 593 |
+
nb_lfs += 1
|
| 594 |
+
else:
|
| 595 |
+
nb_regular += 1
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def _print_overwrite(report: str) -> None:
|
| 599 |
+
"""Print a report, overwriting the previous lines.
|
| 600 |
+
|
| 601 |
+
Since tqdm in using `sys.stderr` to (re-)write progress bars, we need to use `sys.stdout`
|
| 602 |
+
to print the report.
|
| 603 |
+
|
| 604 |
+
Note: works well only if no other process is writing to `sys.stdout`!
|
| 605 |
+
"""
|
| 606 |
+
report += "\n"
|
| 607 |
+
# Get terminal width
|
| 608 |
+
terminal_width = shutil.get_terminal_size().columns
|
| 609 |
+
|
| 610 |
+
# Count number of lines that should be cleared
|
| 611 |
+
nb_lines = sum(len(line) // terminal_width + 1 for line in report.splitlines())
|
| 612 |
+
|
| 613 |
+
# Clear previous lines based on the number of lines in the report
|
| 614 |
+
for _ in range(nb_lines):
|
| 615 |
+
sys.stdout.write("\r\033[K") # Clear line
|
| 616 |
+
sys.stdout.write("\033[F") # Move cursor up one line
|
| 617 |
+
|
| 618 |
+
# Print the new report, filling remaining space with whitespace
|
| 619 |
+
sys.stdout.write(report)
|
| 620 |
+
sys.stdout.write(" " * (terminal_width - len(report.splitlines()[-1])))
|
| 621 |
+
sys.stdout.flush()
|
.venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Contains data structures to parse the webhooks payload."""
|
| 16 |
+
|
| 17 |
+
from typing import List, Literal, Optional
|
| 18 |
+
|
| 19 |
+
from .utils import is_pydantic_available
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if is_pydantic_available():
|
| 23 |
+
from pydantic import BaseModel
|
| 24 |
+
else:
|
| 25 |
+
# Define a dummy BaseModel to avoid import errors when pydantic is not installed
|
| 26 |
+
# Import error will be raised when trying to use the class
|
| 27 |
+
|
| 28 |
+
class BaseModel: # type: ignore [no-redef]
|
| 29 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 30 |
+
raise ImportError(
|
| 31 |
+
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
|
| 32 |
+
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they
|
| 37 |
+
# are not in used anymore. To keep in sync when format is updated in
|
| 38 |
+
# https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link).
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
WebhookEvent_T = Literal[
|
| 42 |
+
"create",
|
| 43 |
+
"delete",
|
| 44 |
+
"move",
|
| 45 |
+
"update",
|
| 46 |
+
]
|
| 47 |
+
RepoChangeEvent_T = Literal[
|
| 48 |
+
"add",
|
| 49 |
+
"move",
|
| 50 |
+
"remove",
|
| 51 |
+
"update",
|
| 52 |
+
]
|
| 53 |
+
RepoType_T = Literal[
|
| 54 |
+
"dataset",
|
| 55 |
+
"model",
|
| 56 |
+
"space",
|
| 57 |
+
]
|
| 58 |
+
DiscussionStatus_T = Literal[
|
| 59 |
+
"closed",
|
| 60 |
+
"draft",
|
| 61 |
+
"open",
|
| 62 |
+
"merged",
|
| 63 |
+
]
|
| 64 |
+
SupportedWebhookVersion = Literal[3]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ObjectId(BaseModel):
|
| 68 |
+
id: str
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class WebhookPayloadUrl(BaseModel):
|
| 72 |
+
web: str
|
| 73 |
+
api: Optional[str] = None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class WebhookPayloadMovedTo(BaseModel):
|
| 77 |
+
name: str
|
| 78 |
+
owner: ObjectId
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class WebhookPayloadWebhook(ObjectId):
|
| 82 |
+
version: SupportedWebhookVersion
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class WebhookPayloadEvent(BaseModel):
|
| 86 |
+
action: WebhookEvent_T
|
| 87 |
+
scope: str
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class WebhookPayloadDiscussionChanges(BaseModel):
|
| 91 |
+
base: str
|
| 92 |
+
mergeCommitId: Optional[str] = None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class WebhookPayloadComment(ObjectId):
|
| 96 |
+
author: ObjectId
|
| 97 |
+
hidden: bool
|
| 98 |
+
content: Optional[str] = None
|
| 99 |
+
url: WebhookPayloadUrl
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class WebhookPayloadDiscussion(ObjectId):
|
| 103 |
+
num: int
|
| 104 |
+
author: ObjectId
|
| 105 |
+
url: WebhookPayloadUrl
|
| 106 |
+
title: str
|
| 107 |
+
isPullRequest: bool
|
| 108 |
+
status: DiscussionStatus_T
|
| 109 |
+
changes: Optional[WebhookPayloadDiscussionChanges] = None
|
| 110 |
+
pinned: Optional[bool] = None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class WebhookPayloadRepo(ObjectId):
|
| 114 |
+
owner: ObjectId
|
| 115 |
+
head_sha: Optional[str] = None
|
| 116 |
+
name: str
|
| 117 |
+
private: bool
|
| 118 |
+
subdomain: Optional[str] = None
|
| 119 |
+
tags: Optional[List[str]] = None
|
| 120 |
+
type: Literal["dataset", "model", "space"]
|
| 121 |
+
url: WebhookPayloadUrl
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class WebhookPayloadUpdatedRef(BaseModel):
|
| 125 |
+
ref: str
|
| 126 |
+
oldSha: Optional[str] = None
|
| 127 |
+
newSha: Optional[str] = None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class WebhookPayload(BaseModel):
|
| 131 |
+
event: WebhookPayloadEvent
|
| 132 |
+
repo: WebhookPayloadRepo
|
| 133 |
+
discussion: Optional[WebhookPayloadDiscussion] = None
|
| 134 |
+
comment: Optional[WebhookPayloadComment] = None
|
| 135 |
+
webhook: WebhookPayloadWebhook
|
| 136 |
+
movedTo: Optional[WebhookPayloadMovedTo] = None
|
| 137 |
+
updatedRefs: Optional[List[WebhookPayloadUpdatedRef]] = None
|
.venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
|
| 16 |
+
|
| 17 |
+
import atexit
|
| 18 |
+
import inspect
|
| 19 |
+
import os
|
| 20 |
+
from functools import wraps
|
| 21 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
| 22 |
+
|
| 23 |
+
from .utils import experimental, is_fastapi_available, is_gradio_available
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
import gradio as gr
|
| 28 |
+
from fastapi import Request
|
| 29 |
+
|
| 30 |
+
if is_fastapi_available():
|
| 31 |
+
from fastapi import FastAPI, Request
|
| 32 |
+
from fastapi.responses import JSONResponse
|
| 33 |
+
else:
|
| 34 |
+
# Will fail at runtime if FastAPI is not available
|
| 35 |
+
FastAPI = Request = JSONResponse = None # type: ignore [misc, assignment]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_global_app: Optional["WebhooksServer"] = None
|
| 39 |
+
_is_local = os.environ.get("SPACE_ID") is None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@experimental
|
| 43 |
+
class WebhooksServer:
|
| 44 |
+
"""
|
| 45 |
+
The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
|
| 46 |
+
These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
|
| 47 |
+
the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `launch` method has to be
|
| 48 |
+
called to start the app.
|
| 49 |
+
|
| 50 |
+
It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
|
| 51 |
+
model that contains all the information about the webhook event. The data will be parsed automatically for you.
|
| 52 |
+
|
| 53 |
+
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
|
| 54 |
+
WebhooksServer and deploy it on a Space.
|
| 55 |
+
|
| 56 |
+
<Tip warning={true}>
|
| 57 |
+
|
| 58 |
+
`WebhooksServer` is experimental. Its API is subject to change in the future.
|
| 59 |
+
|
| 60 |
+
</Tip>
|
| 61 |
+
|
| 62 |
+
<Tip warning={true}>
|
| 63 |
+
|
| 64 |
+
You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
|
| 65 |
+
|
| 66 |
+
</Tip>
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
ui (`gradio.Blocks`, optional):
|
| 70 |
+
A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
|
| 71 |
+
about the configured webhooks is created.
|
| 72 |
+
webhook_secret (`str`, optional):
|
| 73 |
+
A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
|
| 74 |
+
you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
|
| 75 |
+
can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
|
| 76 |
+
webhook endpoints are opened without any security.
|
| 77 |
+
|
| 78 |
+
Example:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
import gradio as gr
|
| 82 |
+
from huggingface_hub import WebhooksServer, WebhookPayload
|
| 83 |
+
|
| 84 |
+
with gr.Blocks() as ui:
|
| 85 |
+
...
|
| 86 |
+
|
| 87 |
+
app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
|
| 88 |
+
|
| 89 |
+
@app.add_webhook("/say_hello")
|
| 90 |
+
async def hello(payload: WebhookPayload):
|
| 91 |
+
return {"message": "hello"}
|
| 92 |
+
|
| 93 |
+
app.launch()
|
| 94 |
+
```
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __new__(cls, *args, **kwargs) -> "WebhooksServer":
|
| 98 |
+
if not is_gradio_available():
|
| 99 |
+
raise ImportError(
|
| 100 |
+
"You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`"
|
| 101 |
+
" first."
|
| 102 |
+
)
|
| 103 |
+
if not is_fastapi_available():
|
| 104 |
+
raise ImportError(
|
| 105 |
+
"You must have `fastapi` installed to use `WebhooksServer`. Please run `pip install --upgrade fastapi`"
|
| 106 |
+
" first."
|
| 107 |
+
)
|
| 108 |
+
return super().__new__(cls)
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
ui: Optional["gr.Blocks"] = None,
|
| 113 |
+
webhook_secret: Optional[str] = None,
|
| 114 |
+
) -> None:
|
| 115 |
+
self._ui = ui
|
| 116 |
+
|
| 117 |
+
self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
|
| 118 |
+
self.registered_webhooks: Dict[str, Callable] = {}
|
| 119 |
+
_warn_on_empty_secret(self.webhook_secret)
|
| 120 |
+
|
| 121 |
+
def add_webhook(self, path: Optional[str] = None) -> Callable:
|
| 122 |
+
"""
|
| 123 |
+
Decorator to add a webhook to the [`WebhooksServer`] server.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
path (`str`, optional):
|
| 127 |
+
The URL path to register the webhook function. If not provided, the function name will be used as the
|
| 128 |
+
path. In any case, all webhooks are registered under `/webhooks`.
|
| 129 |
+
|
| 130 |
+
Raises:
|
| 131 |
+
ValueError: If the provided path is already registered as a webhook.
|
| 132 |
+
|
| 133 |
+
Example:
|
| 134 |
+
```python
|
| 135 |
+
from huggingface_hub import WebhooksServer, WebhookPayload
|
| 136 |
+
|
| 137 |
+
app = WebhooksServer()
|
| 138 |
+
|
| 139 |
+
@app.add_webhook
|
| 140 |
+
async def trigger_training(payload: WebhookPayload):
|
| 141 |
+
if payload.repo.type == "dataset" and payload.event.action == "update":
|
| 142 |
+
# Trigger a training job if a dataset is updated
|
| 143 |
+
...
|
| 144 |
+
|
| 145 |
+
app.launch()
|
| 146 |
+
```
|
| 147 |
+
"""
|
| 148 |
+
# Usage: directly as decorator. Example: `@app.add_webhook`
|
| 149 |
+
if callable(path):
|
| 150 |
+
# If path is a function, it means it was used as a decorator without arguments
|
| 151 |
+
return self.add_webhook()(path)
|
| 152 |
+
|
| 153 |
+
# Usage: provide a path. Example: `@app.add_webhook(...)`
|
| 154 |
+
@wraps(FastAPI.post)
|
| 155 |
+
def _inner_post(*args, **kwargs):
|
| 156 |
+
func = args[0]
|
| 157 |
+
abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
|
| 158 |
+
if abs_path in self.registered_webhooks:
|
| 159 |
+
raise ValueError(f"Webhook {abs_path} already exists.")
|
| 160 |
+
self.registered_webhooks[abs_path] = func
|
| 161 |
+
|
| 162 |
+
return _inner_post
|
| 163 |
+
|
| 164 |
+
def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None:
|
| 165 |
+
"""Launch the Gradio app and register webhooks to the underlying FastAPI server.
|
| 166 |
+
|
| 167 |
+
Input parameters are forwarded to Gradio when launching the app.
|
| 168 |
+
"""
|
| 169 |
+
ui = self._ui or self._get_default_ui()
|
| 170 |
+
|
| 171 |
+
# Start Gradio App
|
| 172 |
+
# - as non-blocking so that webhooks can be added afterwards
|
| 173 |
+
# - as shared if launch locally (to debug webhooks)
|
| 174 |
+
launch_kwargs.setdefault("share", _is_local)
|
| 175 |
+
self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs)
|
| 176 |
+
|
| 177 |
+
# Register webhooks to FastAPI app
|
| 178 |
+
for path, func in self.registered_webhooks.items():
|
| 179 |
+
# Add secret check if required
|
| 180 |
+
if self.webhook_secret is not None:
|
| 181 |
+
func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
|
| 182 |
+
|
| 183 |
+
# Add route to FastAPI app
|
| 184 |
+
self.fastapi_app.post(path)(func)
|
| 185 |
+
|
| 186 |
+
# Print instructions and block main thread
|
| 187 |
+
space_host = os.environ.get("SPACE_HOST")
|
| 188 |
+
url = "https://" + space_host if space_host is not None else (ui.share_url or ui.local_url)
|
| 189 |
+
url = url.strip("/")
|
| 190 |
+
message = "\nWebhooks are correctly setup and ready to use:"
|
| 191 |
+
message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
|
| 192 |
+
message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
|
| 193 |
+
print(message)
|
| 194 |
+
|
| 195 |
+
if not prevent_thread_lock:
|
| 196 |
+
ui.block_thread()
|
| 197 |
+
|
| 198 |
+
def _get_default_ui(self) -> "gr.Blocks":
|
| 199 |
+
"""Default UI if not provided (lists webhooks and provides basic instructions)."""
|
| 200 |
+
import gradio as gr
|
| 201 |
+
|
| 202 |
+
with gr.Blocks() as ui:
|
| 203 |
+
gr.Markdown("# This is an app to process 🤗 Webhooks")
|
| 204 |
+
gr.Markdown(
|
| 205 |
+
"Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
|
| 206 |
+
" specific repos or to all repos belonging to particular set of users/organizations (not just your"
|
| 207 |
+
" repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
|
| 208 |
+
" know more about webhooks on the Huggingface Hub."
|
| 209 |
+
)
|
| 210 |
+
gr.Markdown(
|
| 211 |
+
f"{len(self.registered_webhooks)} webhook(s) are registered:"
|
| 212 |
+
+ "\n\n"
|
| 213 |
+
+ "\n ".join(
|
| 214 |
+
f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
|
| 215 |
+
for webhook_path, webhook in self.registered_webhooks.items()
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
gr.Markdown(
|
| 219 |
+
"Go to https://huggingface.co/settings/webhooks to setup your webhooks."
|
| 220 |
+
+ "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
|
| 221 |
+
if _is_local
|
| 222 |
+
else (
|
| 223 |
+
"\nThis app is running on a Space. You can find the corresponding URL in the options menu"
|
| 224 |
+
" (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
return ui
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@experimental
|
| 231 |
+
def webhook_endpoint(path: Optional[str] = None) -> Callable:
|
| 232 |
+
"""Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
|
| 233 |
+
|
| 234 |
+
This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
|
| 235 |
+
you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
|
| 236 |
+
this decorator multiple times.
|
| 237 |
+
|
| 238 |
+
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
|
| 239 |
+
server and deploy it on a Space.
|
| 240 |
+
|
| 241 |
+
<Tip warning={true}>
|
| 242 |
+
|
| 243 |
+
`webhook_endpoint` is experimental. Its API is subject to change in the future.
|
| 244 |
+
|
| 245 |
+
</Tip>
|
| 246 |
+
|
| 247 |
+
<Tip warning={true}>
|
| 248 |
+
|
| 249 |
+
You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
|
| 250 |
+
|
| 251 |
+
</Tip>
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
path (`str`, optional):
|
| 255 |
+
The URL path to register the webhook function. If not provided, the function name will be used as the path.
|
| 256 |
+
In any case, all webhooks are registered under `/webhooks`.
|
| 257 |
+
|
| 258 |
+
Examples:
|
| 259 |
+
The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
|
| 260 |
+
The server will be started automatically at exit (i.e. at the end of the script).
|
| 261 |
+
|
| 262 |
+
```python
|
| 263 |
+
from huggingface_hub import webhook_endpoint, WebhookPayload
|
| 264 |
+
|
| 265 |
+
@webhook_endpoint
|
| 266 |
+
async def trigger_training(payload: WebhookPayload):
|
| 267 |
+
if payload.repo.type == "dataset" and payload.event.action == "update":
|
| 268 |
+
# Trigger a training job if a dataset is updated
|
| 269 |
+
...
|
| 270 |
+
|
| 271 |
+
# Server is automatically started at the end of the script.
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
|
| 275 |
+
are running it in a notebook.
|
| 276 |
+
|
| 277 |
+
```python
|
| 278 |
+
from huggingface_hub import webhook_endpoint, WebhookPayload
|
| 279 |
+
|
| 280 |
+
@webhook_endpoint
|
| 281 |
+
async def trigger_training(payload: WebhookPayload):
|
| 282 |
+
if payload.repo.type == "dataset" and payload.event.action == "update":
|
| 283 |
+
# Trigger a training job if a dataset is updated
|
| 284 |
+
...
|
| 285 |
+
|
| 286 |
+
# Start the server manually
|
| 287 |
+
trigger_training.launch()
|
| 288 |
+
```
|
| 289 |
+
"""
|
| 290 |
+
if callable(path):
|
| 291 |
+
# If path is a function, it means it was used as a decorator without arguments
|
| 292 |
+
return webhook_endpoint()(path)
|
| 293 |
+
|
| 294 |
+
@wraps(WebhooksServer.add_webhook)
|
| 295 |
+
def _inner(func: Callable) -> Callable:
|
| 296 |
+
app = _get_global_app()
|
| 297 |
+
app.add_webhook(path)(func)
|
| 298 |
+
if len(app.registered_webhooks) == 1:
|
| 299 |
+
# Register `app.launch` to run at exit (only once)
|
| 300 |
+
atexit.register(app.launch)
|
| 301 |
+
|
| 302 |
+
@wraps(app.launch)
|
| 303 |
+
def _launch_now():
|
| 304 |
+
# Run the app directly (without waiting atexit)
|
| 305 |
+
atexit.unregister(app.launch)
|
| 306 |
+
app.launch()
|
| 307 |
+
|
| 308 |
+
func.launch = _launch_now # type: ignore
|
| 309 |
+
return func
|
| 310 |
+
|
| 311 |
+
return _inner
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _get_global_app() -> WebhooksServer:
|
| 315 |
+
global _global_app
|
| 316 |
+
if _global_app is None:
|
| 317 |
+
_global_app = WebhooksServer()
|
| 318 |
+
return _global_app
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None:
|
| 322 |
+
if webhook_secret is None:
|
| 323 |
+
print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
|
| 324 |
+
print(
|
| 325 |
+
"To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
|
| 326 |
+
"\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
|
| 327 |
+
)
|
| 328 |
+
print(
|
| 329 |
+
"For more details about webhook secrets, please refer to"
|
| 330 |
+
" https://huggingface.co/docs/hub/webhooks#webhook-secret."
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
print("Webhook secret is correctly defined.")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
|
| 337 |
+
"""Returns the anchor to a given webhook in the docs (experimental)"""
|
| 338 |
+
return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
|
| 342 |
+
"""Wraps a webhook function to check the webhook secret before calling the function.
|
| 343 |
+
|
| 344 |
+
This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
|
| 345 |
+
parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
|
| 346 |
+
object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
|
| 347 |
+
`fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
|
| 348 |
+
Gradio internals (and not by us), we cannot add a middleware.
|
| 349 |
+
|
| 350 |
+
This method is called only when a secret has been defined by the user. If a request is sent without the
|
| 351 |
+
"x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
|
| 352 |
+
the function will return a 403 error (forbidden).
|
| 353 |
+
|
| 354 |
+
Inspired by https://stackoverflow.com/a/33112180.
|
| 355 |
+
"""
|
| 356 |
+
initial_sig = inspect.signature(func)
|
| 357 |
+
|
| 358 |
+
@wraps(func)
|
| 359 |
+
async def _protected_func(request: Request, **kwargs):
|
| 360 |
+
request_secret = request.headers.get("x-webhook-secret")
|
| 361 |
+
if request_secret is None:
|
| 362 |
+
return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
|
| 363 |
+
if request_secret != webhook_secret:
|
| 364 |
+
return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
|
| 365 |
+
|
| 366 |
+
# Inject `request` in kwargs if required
|
| 367 |
+
if "request" in initial_sig.parameters:
|
| 368 |
+
kwargs["request"] = request
|
| 369 |
+
|
| 370 |
+
# Handle both sync and async routes
|
| 371 |
+
if inspect.iscoroutinefunction(func):
|
| 372 |
+
return await func(**kwargs)
|
| 373 |
+
else:
|
| 374 |
+
return func(**kwargs)
|
| 375 |
+
|
| 376 |
+
# Update signature to include request
|
| 377 |
+
if "request" not in initial_sig.parameters:
|
| 378 |
+
_protected_func.__signature__ = initial_sig.replace( # type: ignore
|
| 379 |
+
parameters=(
|
| 380 |
+
inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
|
| 381 |
+
)
|
| 382 |
+
+ tuple(initial_sig.parameters.values())
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Return protected route
|
| 386 |
+
return _protected_func
|
.venv/lib/python3.11/site-packages/huggingface_hub/community.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data structures to interact with Discussions and Pull Requests on the Hub.
|
| 3 |
+
|
| 4 |
+
See [the Discussions and Pull Requests guide](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)
|
| 5 |
+
for more information on Pull Requests, Discussions, and the community tab.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import List, Literal, Optional, Union
|
| 11 |
+
|
| 12 |
+
from . import constants
|
| 13 |
+
from .utils import parse_datetime
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DiscussionStatus = Literal["open", "closed", "merged", "draft"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class Discussion:
|
| 21 |
+
"""
|
| 22 |
+
A Discussion or Pull Request on the Hub.
|
| 23 |
+
|
| 24 |
+
This dataclass is not intended to be instantiated directly.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
title (`str`):
|
| 28 |
+
The title of the Discussion / Pull Request
|
| 29 |
+
status (`str`):
|
| 30 |
+
The status of the Discussion / Pull Request.
|
| 31 |
+
It must be one of:
|
| 32 |
+
* `"open"`
|
| 33 |
+
* `"closed"`
|
| 34 |
+
* `"merged"` (only for Pull Requests )
|
| 35 |
+
* `"draft"` (only for Pull Requests )
|
| 36 |
+
num (`int`):
|
| 37 |
+
The number of the Discussion / Pull Request.
|
| 38 |
+
repo_id (`str`):
|
| 39 |
+
The id (`"{namespace}/{repo_name}"`) of the repo on which
|
| 40 |
+
the Discussion / Pull Request was open.
|
| 41 |
+
repo_type (`str`):
|
| 42 |
+
The type of the repo on which the Discussion / Pull Request was open.
|
| 43 |
+
Possible values are: `"model"`, `"dataset"`, `"space"`.
|
| 44 |
+
author (`str`):
|
| 45 |
+
The username of the Discussion / Pull Request author.
|
| 46 |
+
Can be `"deleted"` if the user has been deleted since.
|
| 47 |
+
is_pull_request (`bool`):
|
| 48 |
+
Whether or not this is a Pull Request.
|
| 49 |
+
created_at (`datetime`):
|
| 50 |
+
The `datetime` of creation of the Discussion / Pull Request.
|
| 51 |
+
endpoint (`str`):
|
| 52 |
+
Endpoint of the Hub. Default is https://huggingface.co.
|
| 53 |
+
git_reference (`str`, *optional*):
|
| 54 |
+
(property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
|
| 55 |
+
url (`str`):
|
| 56 |
+
(property) URL of the discussion on the Hub.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
title: str
|
| 60 |
+
status: DiscussionStatus
|
| 61 |
+
num: int
|
| 62 |
+
repo_id: str
|
| 63 |
+
repo_type: str
|
| 64 |
+
author: str
|
| 65 |
+
is_pull_request: bool
|
| 66 |
+
created_at: datetime
|
| 67 |
+
endpoint: str
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def git_reference(self) -> Optional[str]:
|
| 71 |
+
"""
|
| 72 |
+
If this is a Pull Request , returns the git reference to which changes can be pushed.
|
| 73 |
+
Returns `None` otherwise.
|
| 74 |
+
"""
|
| 75 |
+
if self.is_pull_request:
|
| 76 |
+
return f"refs/pr/{self.num}"
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def url(self) -> str:
|
| 81 |
+
"""Returns the URL of the discussion on the Hub."""
|
| 82 |
+
if self.repo_type is None or self.repo_type == constants.REPO_TYPE_MODEL:
|
| 83 |
+
return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}"
|
| 84 |
+
return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclass
|
| 88 |
+
class DiscussionWithDetails(Discussion):
|
| 89 |
+
"""
|
| 90 |
+
Subclass of [`Discussion`].
|
| 91 |
+
|
| 92 |
+
Attributes:
|
| 93 |
+
title (`str`):
|
| 94 |
+
The title of the Discussion / Pull Request
|
| 95 |
+
status (`str`):
|
| 96 |
+
The status of the Discussion / Pull Request.
|
| 97 |
+
It can be one of:
|
| 98 |
+
* `"open"`
|
| 99 |
+
* `"closed"`
|
| 100 |
+
* `"merged"` (only for Pull Requests )
|
| 101 |
+
* `"draft"` (only for Pull Requests )
|
| 102 |
+
num (`int`):
|
| 103 |
+
The number of the Discussion / Pull Request.
|
| 104 |
+
repo_id (`str`):
|
| 105 |
+
The id (`"{namespace}/{repo_name}"`) of the repo on which
|
| 106 |
+
the Discussion / Pull Request was open.
|
| 107 |
+
repo_type (`str`):
|
| 108 |
+
The type of the repo on which the Discussion / Pull Request was open.
|
| 109 |
+
Possible values are: `"model"`, `"dataset"`, `"space"`.
|
| 110 |
+
author (`str`):
|
| 111 |
+
The username of the Discussion / Pull Request author.
|
| 112 |
+
Can be `"deleted"` if the user has been deleted since.
|
| 113 |
+
is_pull_request (`bool`):
|
| 114 |
+
Whether or not this is a Pull Request.
|
| 115 |
+
created_at (`datetime`):
|
| 116 |
+
The `datetime` of creation of the Discussion / Pull Request.
|
| 117 |
+
events (`list` of [`DiscussionEvent`])
|
| 118 |
+
The list of [`DiscussionEvents`] in this Discussion or Pull Request.
|
| 119 |
+
conflicting_files (`Union[List[str], bool, None]`, *optional*):
|
| 120 |
+
A list of conflicting files if this is a Pull Request.
|
| 121 |
+
`None` if `self.is_pull_request` is `False`.
|
| 122 |
+
`True` if there are conflicting files but the list can't be retrieved.
|
| 123 |
+
target_branch (`str`, *optional*):
|
| 124 |
+
The branch into which changes are to be merged if this is a
|
| 125 |
+
Pull Request . `None` if `self.is_pull_request` is `False`.
|
| 126 |
+
merge_commit_oid (`str`, *optional*):
|
| 127 |
+
If this is a merged Pull Request , this is set to the OID / SHA of
|
| 128 |
+
the merge commit, `None` otherwise.
|
| 129 |
+
diff (`str`, *optional*):
|
| 130 |
+
The git diff if this is a Pull Request , `None` otherwise.
|
| 131 |
+
endpoint (`str`):
|
| 132 |
+
Endpoint of the Hub. Default is https://huggingface.co.
|
| 133 |
+
git_reference (`str`, *optional*):
|
| 134 |
+
(property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
|
| 135 |
+
url (`str`):
|
| 136 |
+
(property) URL of the discussion on the Hub.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
events: List["DiscussionEvent"]
|
| 140 |
+
conflicting_files: Union[List[str], bool, None]
|
| 141 |
+
target_branch: Optional[str]
|
| 142 |
+
merge_commit_oid: Optional[str]
|
| 143 |
+
diff: Optional[str]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@dataclass
|
| 147 |
+
class DiscussionEvent:
|
| 148 |
+
"""
|
| 149 |
+
An event in a Discussion or Pull Request.
|
| 150 |
+
|
| 151 |
+
Use concrete classes:
|
| 152 |
+
* [`DiscussionComment`]
|
| 153 |
+
* [`DiscussionStatusChange`]
|
| 154 |
+
* [`DiscussionCommit`]
|
| 155 |
+
* [`DiscussionTitleChange`]
|
| 156 |
+
|
| 157 |
+
Attributes:
|
| 158 |
+
id (`str`):
|
| 159 |
+
The ID of the event. An hexadecimal string.
|
| 160 |
+
type (`str`):
|
| 161 |
+
The type of the event.
|
| 162 |
+
created_at (`datetime`):
|
| 163 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
| 164 |
+
object holding the creation timestamp for the event.
|
| 165 |
+
author (`str`):
|
| 166 |
+
The username of the Discussion / Pull Request author.
|
| 167 |
+
Can be `"deleted"` if the user has been deleted since.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
id: str
|
| 171 |
+
type: str
|
| 172 |
+
created_at: datetime
|
| 173 |
+
author: str
|
| 174 |
+
|
| 175 |
+
_event: dict
|
| 176 |
+
"""Stores the original event data, in case we need to access it later."""
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@dataclass
|
| 180 |
+
class DiscussionComment(DiscussionEvent):
|
| 181 |
+
"""A comment in a Discussion / Pull Request.
|
| 182 |
+
|
| 183 |
+
Subclass of [`DiscussionEvent`].
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
Attributes:
|
| 187 |
+
id (`str`):
|
| 188 |
+
The ID of the event. An hexadecimal string.
|
| 189 |
+
type (`str`):
|
| 190 |
+
The type of the event.
|
| 191 |
+
created_at (`datetime`):
|
| 192 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
| 193 |
+
object holding the creation timestamp for the event.
|
| 194 |
+
author (`str`):
|
| 195 |
+
The username of the Discussion / Pull Request author.
|
| 196 |
+
Can be `"deleted"` if the user has been deleted since.
|
| 197 |
+
content (`str`):
|
| 198 |
+
The raw markdown content of the comment. Mentions, links and images are not rendered.
|
| 199 |
+
edited (`bool`):
|
| 200 |
+
Whether or not this comment has been edited.
|
| 201 |
+
hidden (`bool`):
|
| 202 |
+
Whether or not this comment has been hidden.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
content: str
|
| 206 |
+
edited: bool
|
| 207 |
+
hidden: bool
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def rendered(self) -> str:
|
| 211 |
+
"""The rendered comment, as a HTML string"""
|
| 212 |
+
return self._event["data"]["latest"]["html"]
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def last_edited_at(self) -> datetime:
|
| 216 |
+
"""The last edit time, as a `datetime` object."""
|
| 217 |
+
return parse_datetime(self._event["data"]["latest"]["updatedAt"])
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def last_edited_by(self) -> str:
|
| 221 |
+
"""The last edit time, as a `datetime` object."""
|
| 222 |
+
return self._event["data"]["latest"].get("author", {}).get("name", "deleted")
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def edit_history(self) -> List[dict]:
|
| 226 |
+
"""The edit history of the comment"""
|
| 227 |
+
return self._event["data"]["history"]
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def number_of_edits(self) -> int:
|
| 231 |
+
return len(self.edit_history)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@dataclass
|
| 235 |
+
class DiscussionStatusChange(DiscussionEvent):
|
| 236 |
+
"""A change of status in a Discussion / Pull Request.
|
| 237 |
+
|
| 238 |
+
Subclass of [`DiscussionEvent`].
|
| 239 |
+
|
| 240 |
+
Attributes:
|
| 241 |
+
id (`str`):
|
| 242 |
+
The ID of the event. An hexadecimal string.
|
| 243 |
+
type (`str`):
|
| 244 |
+
The type of the event.
|
| 245 |
+
created_at (`datetime`):
|
| 246 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
| 247 |
+
object holding the creation timestamp for the event.
|
| 248 |
+
author (`str`):
|
| 249 |
+
The username of the Discussion / Pull Request author.
|
| 250 |
+
Can be `"deleted"` if the user has been deleted since.
|
| 251 |
+
new_status (`str`):
|
| 252 |
+
The status of the Discussion / Pull Request after the change.
|
| 253 |
+
It can be one of:
|
| 254 |
+
* `"open"`
|
| 255 |
+
* `"closed"`
|
| 256 |
+
* `"merged"` (only for Pull Requests )
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
new_status: str
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@dataclass
|
| 263 |
+
class DiscussionCommit(DiscussionEvent):
|
| 264 |
+
"""A commit in a Pull Request.
|
| 265 |
+
|
| 266 |
+
Subclass of [`DiscussionEvent`].
|
| 267 |
+
|
| 268 |
+
Attributes:
|
| 269 |
+
id (`str`):
|
| 270 |
+
The ID of the event. An hexadecimal string.
|
| 271 |
+
type (`str`):
|
| 272 |
+
The type of the event.
|
| 273 |
+
created_at (`datetime`):
|
| 274 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
| 275 |
+
object holding the creation timestamp for the event.
|
| 276 |
+
author (`str`):
|
| 277 |
+
The username of the Discussion / Pull Request author.
|
| 278 |
+
Can be `"deleted"` if the user has been deleted since.
|
| 279 |
+
summary (`str`):
|
| 280 |
+
The summary of the commit.
|
| 281 |
+
oid (`str`):
|
| 282 |
+
The OID / SHA of the commit, as a hexadecimal string.
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
summary: str
|
| 286 |
+
oid: str
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@dataclass
|
| 290 |
+
class DiscussionTitleChange(DiscussionEvent):
|
| 291 |
+
"""A rename event in a Discussion / Pull Request.
|
| 292 |
+
|
| 293 |
+
Subclass of [`DiscussionEvent`].
|
| 294 |
+
|
| 295 |
+
Attributes:
|
| 296 |
+
id (`str`):
|
| 297 |
+
The ID of the event. An hexadecimal string.
|
| 298 |
+
type (`str`):
|
| 299 |
+
The type of the event.
|
| 300 |
+
created_at (`datetime`):
|
| 301 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
| 302 |
+
object holding the creation timestamp for the event.
|
| 303 |
+
author (`str`):
|
| 304 |
+
The username of the Discussion / Pull Request author.
|
| 305 |
+
Can be `"deleted"` if the user has been deleted since.
|
| 306 |
+
old_title (`str`):
|
| 307 |
+
The previous title for the Discussion / Pull Request.
|
| 308 |
+
new_title (`str`):
|
| 309 |
+
The new title.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
old_title: str
|
| 313 |
+
new_title: str
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def deserialize_event(event: dict) -> DiscussionEvent:
|
| 317 |
+
"""Instantiates a [`DiscussionEvent`] from a dict"""
|
| 318 |
+
event_id: str = event["id"]
|
| 319 |
+
event_type: str = event["type"]
|
| 320 |
+
created_at = parse_datetime(event["createdAt"])
|
| 321 |
+
|
| 322 |
+
common_args = dict(
|
| 323 |
+
id=event_id,
|
| 324 |
+
type=event_type,
|
| 325 |
+
created_at=created_at,
|
| 326 |
+
author=event.get("author", {}).get("name", "deleted"),
|
| 327 |
+
_event=event,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if event_type == "comment":
|
| 331 |
+
return DiscussionComment(
|
| 332 |
+
**common_args,
|
| 333 |
+
edited=event["data"]["edited"],
|
| 334 |
+
hidden=event["data"]["hidden"],
|
| 335 |
+
content=event["data"]["latest"]["raw"],
|
| 336 |
+
)
|
| 337 |
+
if event_type == "status-change":
|
| 338 |
+
return DiscussionStatusChange(
|
| 339 |
+
**common_args,
|
| 340 |
+
new_status=event["data"]["status"],
|
| 341 |
+
)
|
| 342 |
+
if event_type == "commit":
|
| 343 |
+
return DiscussionCommit(
|
| 344 |
+
**common_args,
|
| 345 |
+
summary=event["data"]["subject"],
|
| 346 |
+
oid=event["data"]["oid"],
|
| 347 |
+
)
|
| 348 |
+
if event_type == "title-change":
|
| 349 |
+
return DiscussionTitleChange(
|
| 350 |
+
**common_args,
|
| 351 |
+
old_title=event["data"]["from"],
|
| 352 |
+
new_title=event["data"]["to"],
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
return DiscussionEvent(**common_args)
|
.venv/lib/python3.11/site-packages/huggingface_hub/constants.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import typing
|
| 4 |
+
from typing import Literal, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Possible values for env variables
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
| 11 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _is_true(value: Optional[str]) -> bool:
|
| 15 |
+
if value is None:
|
| 16 |
+
return False
|
| 17 |
+
return value.upper() in ENV_VARS_TRUE_VALUES
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _as_int(value: Optional[str]) -> Optional[int]:
|
| 21 |
+
if value is None:
|
| 22 |
+
return None
|
| 23 |
+
return int(value)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Constants for file downloads
|
| 27 |
+
|
| 28 |
+
PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
|
| 29 |
+
TF2_WEIGHTS_NAME = "tf_model.h5"
|
| 30 |
+
TF_WEIGHTS_NAME = "model.ckpt"
|
| 31 |
+
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
| 32 |
+
CONFIG_NAME = "config.json"
|
| 33 |
+
REPOCARD_NAME = "README.md"
|
| 34 |
+
DEFAULT_ETAG_TIMEOUT = 10
|
| 35 |
+
DEFAULT_DOWNLOAD_TIMEOUT = 10
|
| 36 |
+
DEFAULT_REQUEST_TIMEOUT = 10
|
| 37 |
+
DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024
|
| 38 |
+
HF_TRANSFER_CONCURRENCY = 100
|
| 39 |
+
|
| 40 |
+
# Constants for serialization
|
| 41 |
+
|
| 42 |
+
PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead
|
| 43 |
+
SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors"
|
| 44 |
+
TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5"
|
| 45 |
+
|
| 46 |
+
# Constants for safetensors repos
|
| 47 |
+
|
| 48 |
+
SAFETENSORS_SINGLE_FILE = "model.safetensors"
|
| 49 |
+
SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"
|
| 50 |
+
SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000
|
| 51 |
+
|
| 52 |
+
# Timeout of aquiring file lock and logging the attempt
|
| 53 |
+
FILELOCK_LOG_EVERY_SECONDS = 10
|
| 54 |
+
|
| 55 |
+
# Git-related constants
|
| 56 |
+
|
| 57 |
+
DEFAULT_REVISION = "main"
|
| 58 |
+
REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}")
|
| 59 |
+
|
| 60 |
+
HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"
|
| 61 |
+
|
| 62 |
+
_staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING"))
|
| 63 |
+
|
| 64 |
+
_HF_DEFAULT_ENDPOINT = "https://huggingface.co"
|
| 65 |
+
_HF_DEFAULT_STAGING_ENDPOINT = "https://hub-ci.huggingface.co"
|
| 66 |
+
ENDPOINT = os.getenv("HF_ENDPOINT", "").rstrip("/") or (
|
| 67 |
+
_HF_DEFAULT_STAGING_ENDPOINT if _staging_mode else _HF_DEFAULT_ENDPOINT
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
|
| 71 |
+
HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit"
|
| 72 |
+
HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag"
|
| 73 |
+
HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size"
|
| 74 |
+
|
| 75 |
+
INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co")
|
| 76 |
+
|
| 77 |
+
# See https://huggingface.co/docs/inference-endpoints/index
|
| 78 |
+
INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2"
|
| 79 |
+
|
| 80 |
+
# Proxy for third-party providers
|
| 81 |
+
INFERENCE_PROXY_TEMPLATE = ENDPOINT + "/api/inference-proxy/{provider}"
|
| 82 |
+
|
| 83 |
+
REPO_ID_SEPARATOR = "--"
|
| 84 |
+
# ^ this substring is not allowed in repo_ids on hf.co
|
| 85 |
+
# and is the canonical one we use for serialization of repo ids elsewhere.
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
REPO_TYPE_DATASET = "dataset"
|
| 89 |
+
REPO_TYPE_SPACE = "space"
|
| 90 |
+
REPO_TYPE_MODEL = "model"
|
| 91 |
+
REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
|
| 92 |
+
SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"]
|
| 93 |
+
|
| 94 |
+
REPO_TYPES_URL_PREFIXES = {
|
| 95 |
+
REPO_TYPE_DATASET: "datasets/",
|
| 96 |
+
REPO_TYPE_SPACE: "spaces/",
|
| 97 |
+
}
|
| 98 |
+
REPO_TYPES_MAPPING = {
|
| 99 |
+
"datasets": REPO_TYPE_DATASET,
|
| 100 |
+
"spaces": REPO_TYPE_SPACE,
|
| 101 |
+
"models": REPO_TYPE_MODEL,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
DiscussionTypeFilter = Literal["all", "discussion", "pull_request"]
|
| 105 |
+
DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter)
|
| 106 |
+
DiscussionStatusFilter = Literal["all", "open", "closed"]
|
| 107 |
+
DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter)
|
| 108 |
+
|
| 109 |
+
# Webhook subscription types
|
| 110 |
+
WEBHOOK_DOMAIN_T = Literal["repo", "discussions"]
|
| 111 |
+
|
| 112 |
+
# default cache
|
| 113 |
+
default_home = os.path.join(os.path.expanduser("~"), ".cache")
|
| 114 |
+
HF_HOME = os.path.expanduser(
|
| 115 |
+
os.getenv(
|
| 116 |
+
"HF_HOME",
|
| 117 |
+
os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"),
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
hf_cache_home = HF_HOME # for backward compatibility. TODO: remove this in 1.0.0
|
| 121 |
+
|
| 122 |
+
default_cache_path = os.path.join(HF_HOME, "hub")
|
| 123 |
+
default_assets_cache_path = os.path.join(HF_HOME, "assets")
|
| 124 |
+
|
| 125 |
+
# Legacy env variables
|
| 126 |
+
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
|
| 127 |
+
HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path)
|
| 128 |
+
|
| 129 |
+
# New env variables
|
| 130 |
+
HF_HUB_CACHE = os.getenv("HF_HUB_CACHE", HUGGINGFACE_HUB_CACHE)
|
| 131 |
+
HF_ASSETS_CACHE = os.getenv("HF_ASSETS_CACHE", HUGGINGFACE_ASSETS_CACHE)
|
| 132 |
+
|
| 133 |
+
HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE"))
|
| 134 |
+
|
| 135 |
+
# Opt-out from telemetry requests
|
| 136 |
+
HF_HUB_DISABLE_TELEMETRY = (
|
| 137 |
+
_is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable
|
| 138 |
+
or _is_true(os.environ.get("DISABLE_TELEMETRY"))
|
| 139 |
+
or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# In the past, token was stored in a hardcoded location
|
| 143 |
+
# `_OLD_HF_TOKEN_PATH` is deprecated and will be removed "at some point".
|
| 144 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1232
|
| 145 |
+
_OLD_HF_TOKEN_PATH = os.path.expanduser("~/.huggingface/token")
|
| 146 |
+
HF_TOKEN_PATH = os.environ.get("HF_TOKEN_PATH", os.path.join(HF_HOME, "token"))
|
| 147 |
+
HF_STORED_TOKENS_PATH = os.path.join(os.path.dirname(HF_TOKEN_PATH), "stored_tokens")
|
| 148 |
+
|
| 149 |
+
if _staging_mode:
|
| 150 |
+
# In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens
|
| 151 |
+
_staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging")
|
| 152 |
+
HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub")
|
| 153 |
+
_OLD_HF_TOKEN_PATH = os.path.join(_staging_home, "_old_token")
|
| 154 |
+
HF_TOKEN_PATH = os.path.join(_staging_home, "token")
|
| 155 |
+
|
| 156 |
+
# Here, `True` will disable progress bars globally without possibility of enabling it
|
| 157 |
+
# programmatically. `False` will enable them without possibility of disabling them.
|
| 158 |
+
# If environment variable is not set (None), then the user is free to enable/disable
|
| 159 |
+
# them programmatically.
|
| 160 |
+
# TL;DR: env variable has priority over code
|
| 161 |
+
__HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
|
| 162 |
+
HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = (
|
| 163 |
+
_is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Disable warning on machines that do not support symlinks (e.g. Windows non-developer)
|
| 167 |
+
HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING"))
|
| 168 |
+
|
| 169 |
+
# Disable warning when using experimental features
|
| 170 |
+
HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING"))
|
| 171 |
+
|
| 172 |
+
# Disable sending the cached token by default is all HTTP requests to the Hub
|
| 173 |
+
HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"))
|
| 174 |
+
|
| 175 |
+
# Enable fast-download using external dependency "hf_transfer"
|
| 176 |
+
# See:
|
| 177 |
+
# - https://pypi.org/project/hf-transfer/
|
| 178 |
+
# - https://github.com/huggingface/hf_transfer (private)
|
| 179 |
+
HF_HUB_ENABLE_HF_TRANSFER: bool = _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER"))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# UNUSED
|
| 183 |
+
# We don't use symlinks in local dir anymore.
|
| 184 |
+
HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD: int = (
|
| 185 |
+
_as_int(os.environ.get("HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD")) or 5 * 1024 * 1024
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Used to override the etag timeout on a system level
|
| 189 |
+
HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT
|
| 190 |
+
|
| 191 |
+
# Used to override the get request timeout on a system level
|
| 192 |
+
HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT
|
| 193 |
+
|
| 194 |
+
# List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are
|
| 195 |
+
# deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by
|
| 196 |
+
# default. We still keep the full list of supported frameworks in case we want to scan all of them.
|
| 197 |
+
MAIN_INFERENCE_API_FRAMEWORKS = [
|
| 198 |
+
"diffusers",
|
| 199 |
+
"sentence-transformers",
|
| 200 |
+
"text-generation-inference",
|
| 201 |
+
"transformers",
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [
|
| 205 |
+
"adapter-transformers",
|
| 206 |
+
"allennlp",
|
| 207 |
+
"asteroid",
|
| 208 |
+
"bertopic",
|
| 209 |
+
"doctr",
|
| 210 |
+
"espnet",
|
| 211 |
+
"fairseq",
|
| 212 |
+
"fastai",
|
| 213 |
+
"fasttext",
|
| 214 |
+
"flair",
|
| 215 |
+
"k2",
|
| 216 |
+
"keras",
|
| 217 |
+
"mindspore",
|
| 218 |
+
"nemo",
|
| 219 |
+
"open_clip",
|
| 220 |
+
"paddlenlp",
|
| 221 |
+
"peft",
|
| 222 |
+
"pyannote-audio",
|
| 223 |
+
"sklearn",
|
| 224 |
+
"spacy",
|
| 225 |
+
"span-marker",
|
| 226 |
+
"speechbrain",
|
| 227 |
+
"stanza",
|
| 228 |
+
"timm",
|
| 229 |
+
]
|
.venv/lib/python3.11/site-packages/huggingface_hub/errors.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains all custom errors."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
from requests import HTTPError, Response
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# CACHE ERRORS
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CacheNotFound(Exception):
|
| 13 |
+
"""Exception thrown when the Huggingface cache is not found."""
|
| 14 |
+
|
| 15 |
+
cache_dir: Union[str, Path]
|
| 16 |
+
|
| 17 |
+
def __init__(self, msg: str, cache_dir: Union[str, Path], *args, **kwargs):
|
| 18 |
+
super().__init__(msg, *args, **kwargs)
|
| 19 |
+
self.cache_dir = cache_dir
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CorruptedCacheException(Exception):
|
| 23 |
+
"""Exception for any unexpected structure in the Huggingface cache-system."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# HEADERS ERRORS
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LocalTokenNotFoundError(EnvironmentError):
|
| 30 |
+
"""Raised if local token is required but not found."""
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# HTTP ERRORS
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class OfflineModeIsEnabled(ConnectionError):
|
| 37 |
+
"""Raised when a request is made but `HF_HUB_OFFLINE=1` is set as environment variable."""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class HfHubHTTPError(HTTPError):
|
| 41 |
+
"""
|
| 42 |
+
HTTPError to inherit from for any custom HTTP Error raised in HF Hub.
|
| 43 |
+
|
| 44 |
+
Any HTTPError is converted at least into a `HfHubHTTPError`. If some information is
|
| 45 |
+
sent back by the server, it will be added to the error message.
|
| 46 |
+
|
| 47 |
+
Added details:
|
| 48 |
+
- Request id from "X-Request-Id" header if exists. If not, fallback to "X-Amzn-Trace-Id" header if exists.
|
| 49 |
+
- Server error message from the header "X-Error-Message".
|
| 50 |
+
- Server error message if we can found one in the response body.
|
| 51 |
+
|
| 52 |
+
Example:
|
| 53 |
+
```py
|
| 54 |
+
import requests
|
| 55 |
+
from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError
|
| 56 |
+
|
| 57 |
+
response = get_session().post(...)
|
| 58 |
+
try:
|
| 59 |
+
hf_raise_for_status(response)
|
| 60 |
+
except HfHubHTTPError as e:
|
| 61 |
+
print(str(e)) # formatted message
|
| 62 |
+
e.request_id, e.server_message # details returned by server
|
| 63 |
+
|
| 64 |
+
# Complete the error message with additional information once it's raised
|
| 65 |
+
e.append_to_message("\n`create_commit` expects the repository to exist.")
|
| 66 |
+
raise
|
| 67 |
+
```
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, message: str, response: Optional[Response] = None, *, server_message: Optional[str] = None):
|
| 71 |
+
self.request_id = (
|
| 72 |
+
response.headers.get("x-request-id") or response.headers.get("X-Amzn-Trace-Id")
|
| 73 |
+
if response is not None
|
| 74 |
+
else None
|
| 75 |
+
)
|
| 76 |
+
self.server_message = server_message
|
| 77 |
+
|
| 78 |
+
super().__init__(
|
| 79 |
+
message,
|
| 80 |
+
response=response, # type: ignore [arg-type]
|
| 81 |
+
request=response.request if response is not None else None, # type: ignore [arg-type]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def append_to_message(self, additional_message: str) -> None:
|
| 85 |
+
"""Append additional information to the `HfHubHTTPError` initial message."""
|
| 86 |
+
self.args = (self.args[0] + additional_message,) + self.args[1:]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# INFERENCE CLIENT ERRORS
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class InferenceTimeoutError(HTTPError, TimeoutError):
|
| 93 |
+
"""Error raised when a model is unavailable or the request times out."""
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# INFERENCE ENDPOINT ERRORS
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class InferenceEndpointError(Exception):
|
| 100 |
+
"""Generic exception when dealing with Inference Endpoints."""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class InferenceEndpointTimeoutError(InferenceEndpointError, TimeoutError):
|
| 104 |
+
"""Exception for timeouts while waiting for Inference Endpoint."""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# SAFETENSORS ERRORS
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class SafetensorsParsingError(Exception):
|
| 111 |
+
"""Raised when failing to parse a safetensors file metadata.
|
| 112 |
+
|
| 113 |
+
This can be the case if the file is not a safetensors file or does not respect the specification.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class NotASafetensorsRepoError(Exception):
|
| 118 |
+
"""Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a
|
| 119 |
+
`model.safetensors.index.json` file.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# TEXT GENERATION ERRORS
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class TextGenerationError(HTTPError):
|
| 127 |
+
"""Generic error raised if text-generation went wrong."""
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Text Generation Inference Errors
|
| 131 |
+
class ValidationError(TextGenerationError):
|
| 132 |
+
"""Server-side validation error."""
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class GenerationError(TextGenerationError):
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class OverloadedError(TextGenerationError):
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class IncompleteGenerationError(TextGenerationError):
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class UnknownError(TextGenerationError):
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# VALIDATION ERRORS
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class HFValidationError(ValueError):
|
| 155 |
+
"""Generic exception thrown by `huggingface_hub` validators.
|
| 156 |
+
|
| 157 |
+
Inherits from [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError).
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# FILE METADATA ERRORS
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class FileMetadataError(OSError):
|
| 165 |
+
"""Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash).
|
| 166 |
+
|
| 167 |
+
Inherits from `OSError` for backward compatibility.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# REPOSITORY ERRORS
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class RepositoryNotFoundError(HfHubHTTPError):
|
| 175 |
+
"""
|
| 176 |
+
Raised when trying to access a hf.co URL with an invalid repository name, or
|
| 177 |
+
with a private repo name the user does not have access to.
|
| 178 |
+
|
| 179 |
+
Example:
|
| 180 |
+
|
| 181 |
+
```py
|
| 182 |
+
>>> from huggingface_hub import model_info
|
| 183 |
+
>>> model_info("<non_existent_repository>")
|
| 184 |
+
(...)
|
| 185 |
+
huggingface_hub.utils._errors.RepositoryNotFoundError: 401 Client Error. (Request ID: PvMw_VjBMjVdMz53WKIzP)
|
| 186 |
+
|
| 187 |
+
Repository Not Found for url: https://huggingface.co/api/models/%3Cnon_existent_repository%3E.
|
| 188 |
+
Please make sure you specified the correct `repo_id` and `repo_type`.
|
| 189 |
+
If the repo is private, make sure you are authenticated.
|
| 190 |
+
Invalid username or password.
|
| 191 |
+
```
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class GatedRepoError(RepositoryNotFoundError):
|
| 196 |
+
"""
|
| 197 |
+
Raised when trying to access a gated repository for which the user is not on the
|
| 198 |
+
authorized list.
|
| 199 |
+
|
| 200 |
+
Note: derives from `RepositoryNotFoundError` to ensure backward compatibility.
|
| 201 |
+
|
| 202 |
+
Example:
|
| 203 |
+
|
| 204 |
+
```py
|
| 205 |
+
>>> from huggingface_hub import model_info
|
| 206 |
+
>>> model_info("<gated_repository>")
|
| 207 |
+
(...)
|
| 208 |
+
huggingface_hub.utils._errors.GatedRepoError: 403 Client Error. (Request ID: ViT1Bf7O_026LGSQuVqfa)
|
| 209 |
+
|
| 210 |
+
Cannot access gated repo for url https://huggingface.co/api/models/ardent-figment/gated-model.
|
| 211 |
+
Access to model ardent-figment/gated-model is restricted and you are not in the authorized list.
|
| 212 |
+
Visit https://huggingface.co/ardent-figment/gated-model to ask for access.
|
| 213 |
+
```
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class DisabledRepoError(HfHubHTTPError):
|
| 218 |
+
"""
|
| 219 |
+
Raised when trying to access a repository that has been disabled by its author.
|
| 220 |
+
|
| 221 |
+
Example:
|
| 222 |
+
|
| 223 |
+
```py
|
| 224 |
+
>>> from huggingface_hub import dataset_info
|
| 225 |
+
>>> dataset_info("laion/laion-art")
|
| 226 |
+
(...)
|
| 227 |
+
huggingface_hub.utils._errors.DisabledRepoError: 403 Client Error. (Request ID: Root=1-659fc3fa-3031673e0f92c71a2260dbe2;bc6f4dfb-b30a-4862-af0a-5cfe827610d8)
|
| 228 |
+
|
| 229 |
+
Cannot access repository for url https://huggingface.co/api/datasets/laion/laion-art.
|
| 230 |
+
Access to this resource is disabled.
|
| 231 |
+
```
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# REVISION ERROR
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class RevisionNotFoundError(HfHubHTTPError):
|
| 239 |
+
"""
|
| 240 |
+
Raised when trying to access a hf.co URL with a valid repository but an invalid
|
| 241 |
+
revision.
|
| 242 |
+
|
| 243 |
+
Example:
|
| 244 |
+
|
| 245 |
+
```py
|
| 246 |
+
>>> from huggingface_hub import hf_hub_download
|
| 247 |
+
>>> hf_hub_download('bert-base-cased', 'config.json', revision='<non-existent-revision>')
|
| 248 |
+
(...)
|
| 249 |
+
huggingface_hub.utils._errors.RevisionNotFoundError: 404 Client Error. (Request ID: Mwhe_c3Kt650GcdKEFomX)
|
| 250 |
+
|
| 251 |
+
Revision Not Found for url: https://huggingface.co/bert-base-cased/resolve/%3Cnon-existent-revision%3E/config.json.
|
| 252 |
+
```
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ENTRY ERRORS
|
| 257 |
+
class EntryNotFoundError(HfHubHTTPError):
|
| 258 |
+
"""
|
| 259 |
+
Raised when trying to access a hf.co URL with a valid repository and revision
|
| 260 |
+
but an invalid filename.
|
| 261 |
+
|
| 262 |
+
Example:
|
| 263 |
+
|
| 264 |
+
```py
|
| 265 |
+
>>> from huggingface_hub import hf_hub_download
|
| 266 |
+
>>> hf_hub_download('bert-base-cased', '<non-existent-file>')
|
| 267 |
+
(...)
|
| 268 |
+
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: 53pNl6M0MxsnG5Sw8JA6x)
|
| 269 |
+
|
| 270 |
+
Entry Not Found for url: https://huggingface.co/bert-base-cased/resolve/main/%3Cnon-existent-file%3E.
|
| 271 |
+
```
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class LocalEntryNotFoundError(EntryNotFoundError, FileNotFoundError, ValueError):
|
| 276 |
+
"""
|
| 277 |
+
Raised when trying to access a file or snapshot that is not on the disk when network is
|
| 278 |
+
disabled or unavailable (connection issue). The entry may exist on the Hub.
|
| 279 |
+
|
| 280 |
+
Note: `ValueError` type is to ensure backward compatibility.
|
| 281 |
+
Note: `LocalEntryNotFoundError` derives from `HTTPError` because of `EntryNotFoundError`
|
| 282 |
+
even when it is not a network issue.
|
| 283 |
+
|
| 284 |
+
Example:
|
| 285 |
+
|
| 286 |
+
```py
|
| 287 |
+
>>> from huggingface_hub import hf_hub_download
|
| 288 |
+
>>> hf_hub_download('bert-base-cased', '<non-cached-file>', local_files_only=True)
|
| 289 |
+
(...)
|
| 290 |
+
huggingface_hub.utils._errors.LocalEntryNotFoundError: Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co look-ups and downloads online, set 'local_files_only' to False.
|
| 291 |
+
```
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
def __init__(self, message: str):
|
| 295 |
+
super().__init__(message, response=None)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# REQUEST ERROR
|
| 299 |
+
class BadRequestError(HfHubHTTPError, ValueError):
|
| 300 |
+
"""
|
| 301 |
+
Raised by `hf_raise_for_status` when the server returns a HTTP 400 error.
|
| 302 |
+
|
| 303 |
+
Example:
|
| 304 |
+
|
| 305 |
+
```py
|
| 306 |
+
>>> resp = requests.post("hf.co/api/check", ...)
|
| 307 |
+
>>> hf_raise_for_status(resp, endpoint_name="check")
|
| 308 |
+
huggingface_hub.utils._errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX)
|
| 309 |
+
```
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# DDUF file format ERROR
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class DDUFError(Exception):
|
| 317 |
+
"""Base exception for errors related to the DDUF format."""
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class DDUFCorruptedFileError(DDUFError):
|
| 321 |
+
"""Exception thrown when the DDUF file is corrupted."""
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class DDUFExportError(DDUFError):
|
| 325 |
+
"""Base exception for errors during DDUF export."""
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class DDUFInvalidEntryNameError(DDUFExportError):
|
| 329 |
+
"""Exception thrown when the entry name is invalid."""
|
.venv/lib/python3.11/site-packages/huggingface_hub/fastai_utils.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from pickle import DEFAULT_PROTOCOL, PicklingError
|
| 5 |
+
from typing import Any, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
from packaging import version
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import constants, snapshot_download
|
| 10 |
+
from huggingface_hub.hf_api import HfApi
|
| 11 |
+
from huggingface_hub.utils import (
|
| 12 |
+
SoftTemporaryDirectory,
|
| 13 |
+
get_fastai_version,
|
| 14 |
+
get_fastcore_version,
|
| 15 |
+
get_python_version,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from .utils import logging, validate_hf_hub_args
|
| 19 |
+
from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility...
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _check_fastai_fastcore_versions(
|
| 26 |
+
fastai_min_version: str = "2.4",
|
| 27 |
+
fastcore_min_version: str = "1.3.27",
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Checks that the installed fastai and fastcore versions are compatible for pickle serialization.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
fastai_min_version (`str`, *optional*):
|
| 34 |
+
The minimum fastai version supported.
|
| 35 |
+
fastcore_min_version (`str`, *optional*):
|
| 36 |
+
The minimum fastcore version supported.
|
| 37 |
+
|
| 38 |
+
<Tip>
|
| 39 |
+
Raises the following error:
|
| 40 |
+
|
| 41 |
+
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
| 42 |
+
if the fastai or fastcore libraries are not available or are of an invalid version.
|
| 43 |
+
|
| 44 |
+
</Tip>
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
if (get_fastcore_version() or get_fastai_version()) == "N/A":
|
| 48 |
+
raise ImportError(
|
| 49 |
+
f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are"
|
| 50 |
+
f" required. Currently using fastai=={get_fastai_version()} and"
|
| 51 |
+
f" fastcore=={get_fastcore_version()}."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
current_fastai_version = version.Version(get_fastai_version())
|
| 55 |
+
current_fastcore_version = version.Version(get_fastcore_version())
|
| 56 |
+
|
| 57 |
+
if current_fastai_version < version.Version(fastai_min_version):
|
| 58 |
+
raise ImportError(
|
| 59 |
+
"`push_to_hub_fastai` and `from_pretrained_fastai` require a"
|
| 60 |
+
f" fastai>={fastai_min_version} version, but you are using fastai version"
|
| 61 |
+
f" {get_fastai_version()} which is incompatible. Upgrade with `pip install"
|
| 62 |
+
" fastai==2.5.6`."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if current_fastcore_version < version.Version(fastcore_min_version):
|
| 66 |
+
raise ImportError(
|
| 67 |
+
"`push_to_hub_fastai` and `from_pretrained_fastai` require a"
|
| 68 |
+
f" fastcore>={fastcore_min_version} version, but you are using fastcore"
|
| 69 |
+
f" version {get_fastcore_version()} which is incompatible. Upgrade with"
|
| 70 |
+
" `pip install fastcore==1.3.27`."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _check_fastai_fastcore_pyproject_versions(
|
| 75 |
+
storage_folder: str,
|
| 76 |
+
fastai_min_version: str = "2.4",
|
| 77 |
+
fastcore_min_version: str = "1.3.27",
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions
|
| 81 |
+
that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist
|
| 82 |
+
or does not contain versions for fastai and fastcore, then it logs a warning.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
storage_folder (`str`):
|
| 86 |
+
Folder to look for the `pyproject.toml` file.
|
| 87 |
+
fastai_min_version (`str`, *optional*):
|
| 88 |
+
The minimum fastai version supported.
|
| 89 |
+
fastcore_min_version (`str`, *optional*):
|
| 90 |
+
The minimum fastcore version supported.
|
| 91 |
+
|
| 92 |
+
<Tip>
|
| 93 |
+
Raises the following errors:
|
| 94 |
+
|
| 95 |
+
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
| 96 |
+
if the `toml` module is not installed.
|
| 97 |
+
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
| 98 |
+
if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore.
|
| 99 |
+
|
| 100 |
+
</Tip>
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
import toml
|
| 105 |
+
except ModuleNotFoundError:
|
| 106 |
+
raise ImportError(
|
| 107 |
+
"`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module."
|
| 108 |
+
" Install it with `pip install toml`."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages.
|
| 112 |
+
if not os.path.isfile(f"{storage_folder}/pyproject.toml"):
|
| 113 |
+
logger.warning(
|
| 114 |
+
"There is no `pyproject.toml` in the repository that contains the fastai"
|
| 115 |
+
" `Learner`. The `pyproject.toml` would allow us to verify that your fastai"
|
| 116 |
+
" and fastcore versions are compatible with those of the model you want to"
|
| 117 |
+
" load."
|
| 118 |
+
)
|
| 119 |
+
return
|
| 120 |
+
pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml")
|
| 121 |
+
|
| 122 |
+
if "build-system" not in pyproject_toml.keys():
|
| 123 |
+
logger.warning(
|
| 124 |
+
"There is no `build-system` section in the pyproject.toml of the repository"
|
| 125 |
+
" that contains the fastai `Learner`. The `build-system` would allow us to"
|
| 126 |
+
" verify that your fastai and fastcore versions are compatible with those"
|
| 127 |
+
" of the model you want to load."
|
| 128 |
+
)
|
| 129 |
+
return
|
| 130 |
+
build_system_toml = pyproject_toml["build-system"]
|
| 131 |
+
|
| 132 |
+
if "requires" not in build_system_toml.keys():
|
| 133 |
+
logger.warning(
|
| 134 |
+
"There is no `requires` section in the pyproject.toml of the repository"
|
| 135 |
+
" that contains the fastai `Learner`. The `requires` would allow us to"
|
| 136 |
+
" verify that your fastai and fastcore versions are compatible with those"
|
| 137 |
+
" of the model you want to load."
|
| 138 |
+
)
|
| 139 |
+
return
|
| 140 |
+
package_versions = build_system_toml["requires"]
|
| 141 |
+
|
| 142 |
+
# Extracts contains fastai and fastcore versions from `pyproject.toml` if available.
|
| 143 |
+
# If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest.
|
| 144 |
+
fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")]
|
| 145 |
+
if len(fastai_packages) == 0:
|
| 146 |
+
logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.")
|
| 147 |
+
# fastai_version is an empty string if not specified
|
| 148 |
+
else:
|
| 149 |
+
fastai_version = str(fastai_packages[0]).partition("=")[2]
|
| 150 |
+
if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version):
|
| 151 |
+
raise ImportError(
|
| 152 |
+
"`from_pretrained_fastai` requires"
|
| 153 |
+
f" fastai>={fastai_min_version} version but the model to load uses"
|
| 154 |
+
f" {fastai_version} which is incompatible."
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")]
|
| 158 |
+
if len(fastcore_packages) == 0:
|
| 159 |
+
logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.")
|
| 160 |
+
# fastcore_version is an empty string if not specified
|
| 161 |
+
else:
|
| 162 |
+
fastcore_version = str(fastcore_packages[0]).partition("=")[2]
|
| 163 |
+
if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version):
|
| 164 |
+
raise ImportError(
|
| 165 |
+
"`from_pretrained_fastai` requires"
|
| 166 |
+
f" fastcore>={fastcore_min_version} version, but you are using fastcore"
|
| 167 |
+
f" version {fastcore_version} which is incompatible."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
README_TEMPLATE = """---
|
| 172 |
+
tags:
|
| 173 |
+
- fastai
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
# Amazing!
|
| 177 |
+
|
| 178 |
+
🥳 Congratulations on hosting your fastai model on the Hugging Face Hub!
|
| 179 |
+
|
| 180 |
+
# Some next steps
|
| 181 |
+
1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))!
|
| 182 |
+
|
| 183 |
+
2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)).
|
| 184 |
+
|
| 185 |
+
3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)!
|
| 186 |
+
|
| 187 |
+
Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card.
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
---
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Model card
|
| 194 |
+
|
| 195 |
+
## Model description
|
| 196 |
+
More information needed
|
| 197 |
+
|
| 198 |
+
## Intended uses & limitations
|
| 199 |
+
More information needed
|
| 200 |
+
|
| 201 |
+
## Training and evaluation data
|
| 202 |
+
More information needed
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
PYPROJECT_TEMPLATE = f"""[build-system]
|
| 206 |
+
requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"]
|
| 207 |
+
build-backend = "setuptools.build_meta:__legacy__"
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _create_model_card(repo_dir: Path):
|
| 212 |
+
"""
|
| 213 |
+
Creates a model card for the repository.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
repo_dir (`Path`):
|
| 217 |
+
Directory where model card is created.
|
| 218 |
+
"""
|
| 219 |
+
readme_path = repo_dir / "README.md"
|
| 220 |
+
|
| 221 |
+
if not readme_path.exists():
|
| 222 |
+
with readme_path.open("w", encoding="utf-8") as f:
|
| 223 |
+
f.write(README_TEMPLATE)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _create_model_pyproject(repo_dir: Path):
|
| 227 |
+
"""
|
| 228 |
+
Creates a `pyproject.toml` for the repository.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
repo_dir (`Path`):
|
| 232 |
+
Directory where `pyproject.toml` is created.
|
| 233 |
+
"""
|
| 234 |
+
pyproject_path = repo_dir / "pyproject.toml"
|
| 235 |
+
|
| 236 |
+
if not pyproject_path.exists():
|
| 237 |
+
with pyproject_path.open("w", encoding="utf-8") as f:
|
| 238 |
+
f.write(PYPROJECT_TEMPLATE)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _save_pretrained_fastai(
|
| 242 |
+
learner,
|
| 243 |
+
save_directory: Union[str, Path],
|
| 244 |
+
config: Optional[Dict[str, Any]] = None,
|
| 245 |
+
):
|
| 246 |
+
"""
|
| 247 |
+
Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
learner (`Learner`):
|
| 251 |
+
The `fastai.Learner` you'd like to save.
|
| 252 |
+
save_directory (`str` or `Path`):
|
| 253 |
+
Specific directory in which you want to save the fastai learner.
|
| 254 |
+
config (`dict`, *optional*):
|
| 255 |
+
Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'.
|
| 256 |
+
|
| 257 |
+
<Tip>
|
| 258 |
+
|
| 259 |
+
Raises the following error:
|
| 260 |
+
|
| 261 |
+
- [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError)
|
| 262 |
+
if the config file provided is not a dictionary.
|
| 263 |
+
|
| 264 |
+
</Tip>
|
| 265 |
+
"""
|
| 266 |
+
_check_fastai_fastcore_versions()
|
| 267 |
+
|
| 268 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 269 |
+
|
| 270 |
+
# if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE.
|
| 271 |
+
if config is not None:
|
| 272 |
+
if not isinstance(config, dict):
|
| 273 |
+
raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'")
|
| 274 |
+
path = os.path.join(save_directory, constants.CONFIG_NAME)
|
| 275 |
+
with open(path, "w") as f:
|
| 276 |
+
json.dump(config, f)
|
| 277 |
+
|
| 278 |
+
_create_model_card(Path(save_directory))
|
| 279 |
+
_create_model_pyproject(Path(save_directory))
|
| 280 |
+
|
| 281 |
+
# learner.export saves the model in `self.path`.
|
| 282 |
+
learner.path = Path(save_directory)
|
| 283 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 284 |
+
try:
|
| 285 |
+
learner.export(
|
| 286 |
+
fname="model.pkl",
|
| 287 |
+
pickle_protocol=DEFAULT_PROTOCOL,
|
| 288 |
+
)
|
| 289 |
+
except PicklingError:
|
| 290 |
+
raise PicklingError(
|
| 291 |
+
"You are using a lambda function, i.e., an anonymous function. `pickle`"
|
| 292 |
+
" cannot pickle function objects and requires that all functions have"
|
| 293 |
+
" names. One possible solution is to name the function."
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@validate_hf_hub_args
|
| 298 |
+
def from_pretrained_fastai(
|
| 299 |
+
repo_id: str,
|
| 300 |
+
revision: Optional[str] = None,
|
| 301 |
+
):
|
| 302 |
+
"""
|
| 303 |
+
Load pretrained fastai model from the Hub or from a local directory.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
repo_id (`str`):
|
| 307 |
+
The location where the pickled fastai.Learner is. It can be either of the two:
|
| 308 |
+
- Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'.
|
| 309 |
+
You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`.
|
| 310 |
+
Revision is the specific model version to use. Since we use a git-based system for storing models and other
|
| 311 |
+
artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id.
|
| 312 |
+
- Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml
|
| 313 |
+
indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`.
|
| 314 |
+
revision (`str`, *optional*):
|
| 315 |
+
Revision at which the repo's files are downloaded. See documentation of `snapshot_download`.
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
The `fastai.Learner` model in the `repo_id` repo.
|
| 319 |
+
"""
|
| 320 |
+
_check_fastai_fastcore_versions()
|
| 321 |
+
|
| 322 |
+
# Load the `repo_id` repo.
|
| 323 |
+
# `snapshot_download` returns the folder where the model was stored.
|
| 324 |
+
# `cache_dir` will be the default '/root/.cache/huggingface/hub'
|
| 325 |
+
if not os.path.isdir(repo_id):
|
| 326 |
+
storage_folder = snapshot_download(
|
| 327 |
+
repo_id=repo_id,
|
| 328 |
+
revision=revision,
|
| 329 |
+
library_name="fastai",
|
| 330 |
+
library_version=get_fastai_version(),
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
storage_folder = repo_id
|
| 334 |
+
|
| 335 |
+
_check_fastai_fastcore_pyproject_versions(storage_folder)
|
| 336 |
+
|
| 337 |
+
from fastai.learner import load_learner # type: ignore
|
| 338 |
+
|
| 339 |
+
return load_learner(os.path.join(storage_folder, "model.pkl"))
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@validate_hf_hub_args
|
| 343 |
+
def push_to_hub_fastai(
|
| 344 |
+
learner,
|
| 345 |
+
*,
|
| 346 |
+
repo_id: str,
|
| 347 |
+
commit_message: str = "Push FastAI model using huggingface_hub.",
|
| 348 |
+
private: Optional[bool] = None,
|
| 349 |
+
token: Optional[str] = None,
|
| 350 |
+
config: Optional[dict] = None,
|
| 351 |
+
branch: Optional[str] = None,
|
| 352 |
+
create_pr: Optional[bool] = None,
|
| 353 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
| 354 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
| 355 |
+
delete_patterns: Optional[Union[List[str], str]] = None,
|
| 356 |
+
api_endpoint: Optional[str] = None,
|
| 357 |
+
):
|
| 358 |
+
"""
|
| 359 |
+
Upload learner checkpoint files to the Hub.
|
| 360 |
+
|
| 361 |
+
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
| 362 |
+
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
| 363 |
+
details.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
learner (`Learner`):
|
| 367 |
+
The `fastai.Learner' you'd like to push to the Hub.
|
| 368 |
+
repo_id (`str`):
|
| 369 |
+
The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de').
|
| 370 |
+
commit_message (`str`, *optional*):
|
| 371 |
+
Message to commit while pushing. Will default to :obj:`"add model"`.
|
| 372 |
+
private (`bool`, *optional*):
|
| 373 |
+
Whether or not the repository created should be private.
|
| 374 |
+
If `None` (default), will default to been public except if the organization's default is private.
|
| 375 |
+
token (`str`, *optional*):
|
| 376 |
+
The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt.
|
| 377 |
+
config (`dict`, *optional*):
|
| 378 |
+
Configuration object to be saved alongside the model weights.
|
| 379 |
+
branch (`str`, *optional*):
|
| 380 |
+
The git branch on which to push the model. This defaults to
|
| 381 |
+
the default branch as specified in your repository, which
|
| 382 |
+
defaults to `"main"`.
|
| 383 |
+
create_pr (`boolean`, *optional*):
|
| 384 |
+
Whether or not to create a Pull Request from `branch` with that commit.
|
| 385 |
+
Defaults to `False`.
|
| 386 |
+
api_endpoint (`str`, *optional*):
|
| 387 |
+
The API endpoint to use when pushing the model to the hub.
|
| 388 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
| 389 |
+
If provided, only files matching at least one pattern are pushed.
|
| 390 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
| 391 |
+
If provided, files matching any of the patterns are not pushed.
|
| 392 |
+
delete_patterns (`List[str]` or `str`, *optional*):
|
| 393 |
+
If provided, remote files matching any of the patterns will be deleted from the repo.
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
The url of the commit of your model in the given repository.
|
| 397 |
+
|
| 398 |
+
<Tip>
|
| 399 |
+
|
| 400 |
+
Raises the following error:
|
| 401 |
+
|
| 402 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 403 |
+
if the user is not log on to the Hugging Face Hub.
|
| 404 |
+
|
| 405 |
+
</Tip>
|
| 406 |
+
"""
|
| 407 |
+
_check_fastai_fastcore_versions()
|
| 408 |
+
api = HfApi(endpoint=api_endpoint)
|
| 409 |
+
repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id
|
| 410 |
+
|
| 411 |
+
# Push the files to the repo in a single commit
|
| 412 |
+
with SoftTemporaryDirectory() as tmp:
|
| 413 |
+
saved_path = Path(tmp) / repo_id
|
| 414 |
+
_save_pretrained_fastai(learner, saved_path, config=config)
|
| 415 |
+
return api.upload_folder(
|
| 416 |
+
repo_id=repo_id,
|
| 417 |
+
token=token,
|
| 418 |
+
folder_path=saved_path,
|
| 419 |
+
commit_message=commit_message,
|
| 420 |
+
revision=branch,
|
| 421 |
+
create_pr=create_pr,
|
| 422 |
+
allow_patterns=allow_patterns,
|
| 423 |
+
ignore_patterns=ignore_patterns,
|
| 424 |
+
delete_patterns=delete_patterns,
|
| 425 |
+
)
|
.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py
ADDED
|
@@ -0,0 +1,1621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import copy
|
| 3 |
+
import errno
|
| 4 |
+
import inspect
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import stat
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
import warnings
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union
|
| 15 |
+
from urllib.parse import quote, urlparse
|
| 16 |
+
|
| 17 |
+
import requests
|
| 18 |
+
|
| 19 |
+
from . import (
|
| 20 |
+
__version__, # noqa: F401 # for backward compatibility
|
| 21 |
+
constants,
|
| 22 |
+
)
|
| 23 |
+
from ._local_folder import get_local_download_paths, read_download_metadata, write_download_metadata
|
| 24 |
+
from .constants import (
|
| 25 |
+
HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 # for backward compatibility
|
| 26 |
+
HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility
|
| 27 |
+
)
|
| 28 |
+
from .errors import (
|
| 29 |
+
EntryNotFoundError,
|
| 30 |
+
FileMetadataError,
|
| 31 |
+
GatedRepoError,
|
| 32 |
+
LocalEntryNotFoundError,
|
| 33 |
+
RepositoryNotFoundError,
|
| 34 |
+
RevisionNotFoundError,
|
| 35 |
+
)
|
| 36 |
+
from .utils import (
|
| 37 |
+
OfflineModeIsEnabled,
|
| 38 |
+
SoftTemporaryDirectory,
|
| 39 |
+
WeakFileLock,
|
| 40 |
+
build_hf_headers,
|
| 41 |
+
get_fastai_version, # noqa: F401 # for backward compatibility
|
| 42 |
+
get_fastcore_version, # noqa: F401 # for backward compatibility
|
| 43 |
+
get_graphviz_version, # noqa: F401 # for backward compatibility
|
| 44 |
+
get_jinja_version, # noqa: F401 # for backward compatibility
|
| 45 |
+
get_pydot_version, # noqa: F401 # for backward compatibility
|
| 46 |
+
get_session,
|
| 47 |
+
get_tf_version, # noqa: F401 # for backward compatibility
|
| 48 |
+
get_torch_version, # noqa: F401 # for backward compatibility
|
| 49 |
+
hf_raise_for_status,
|
| 50 |
+
is_fastai_available, # noqa: F401 # for backward compatibility
|
| 51 |
+
is_fastcore_available, # noqa: F401 # for backward compatibility
|
| 52 |
+
is_graphviz_available, # noqa: F401 # for backward compatibility
|
| 53 |
+
is_jinja_available, # noqa: F401 # for backward compatibility
|
| 54 |
+
is_pydot_available, # noqa: F401 # for backward compatibility
|
| 55 |
+
is_tf_available, # noqa: F401 # for backward compatibility
|
| 56 |
+
is_torch_available, # noqa: F401 # for backward compatibility
|
| 57 |
+
logging,
|
| 58 |
+
reset_sessions,
|
| 59 |
+
tqdm,
|
| 60 |
+
validate_hf_hub_args,
|
| 61 |
+
)
|
| 62 |
+
from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility
|
| 63 |
+
from .utils._typing import HTTP_METHOD_T
|
| 64 |
+
from .utils.sha import sha_fileobj
|
| 65 |
+
from .utils.tqdm import is_tqdm_disabled
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
logger = logging.get_logger(__name__)
|
| 69 |
+
|
| 70 |
+
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
|
| 71 |
+
_CACHED_NO_EXIST = object()
|
| 72 |
+
_CACHED_NO_EXIST_T = Any
|
| 73 |
+
|
| 74 |
+
# Regex to get filename from a "Content-Disposition" header for CDN-served files
|
| 75 |
+
HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P<filename>.*?)";')
|
| 76 |
+
|
| 77 |
+
# Regex to check if the revision IS directly a commit_hash
|
| 78 |
+
REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
|
| 79 |
+
|
| 80 |
+
# Regex to check if the file etag IS a valid sha256
|
| 81 |
+
REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$")
|
| 82 |
+
|
| 83 |
+
_are_symlinks_supported_in_dir: Dict[str, bool] = {}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
|
| 87 |
+
"""Return whether the symlinks are supported on the machine.
|
| 88 |
+
|
| 89 |
+
Since symlinks support can change depending on the mounted disk, we need to check
|
| 90 |
+
on the precise cache folder. By default, the default HF cache directory is checked.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
cache_dir (`str`, `Path`, *optional*):
|
| 94 |
+
Path to the folder where cached files are stored.
|
| 95 |
+
|
| 96 |
+
Returns: [bool] Whether symlinks are supported in the directory.
|
| 97 |
+
"""
|
| 98 |
+
# Defaults to HF cache
|
| 99 |
+
if cache_dir is None:
|
| 100 |
+
cache_dir = constants.HF_HUB_CACHE
|
| 101 |
+
cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique
|
| 102 |
+
|
| 103 |
+
# Check symlink compatibility only once (per cache directory) at first time use
|
| 104 |
+
if cache_dir not in _are_symlinks_supported_in_dir:
|
| 105 |
+
_are_symlinks_supported_in_dir[cache_dir] = True
|
| 106 |
+
|
| 107 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 108 |
+
with SoftTemporaryDirectory(dir=cache_dir) as tmpdir:
|
| 109 |
+
src_path = Path(tmpdir) / "dummy_file_src"
|
| 110 |
+
src_path.touch()
|
| 111 |
+
dst_path = Path(tmpdir) / "dummy_file_dst"
|
| 112 |
+
|
| 113 |
+
# Relative source path as in `_create_symlink``
|
| 114 |
+
relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path))
|
| 115 |
+
try:
|
| 116 |
+
os.symlink(relative_src, dst_path)
|
| 117 |
+
except OSError:
|
| 118 |
+
# Likely running on Windows
|
| 119 |
+
_are_symlinks_supported_in_dir[cache_dir] = False
|
| 120 |
+
|
| 121 |
+
if not constants.HF_HUB_DISABLE_SYMLINKS_WARNING:
|
| 122 |
+
message = (
|
| 123 |
+
"`huggingface_hub` cache-system uses symlinks by default to"
|
| 124 |
+
" efficiently store duplicated files but your machine does not"
|
| 125 |
+
f" support them in {cache_dir}. Caching files will still work"
|
| 126 |
+
" but in a degraded version that might require more space on"
|
| 127 |
+
" your disk. This warning can be disabled by setting the"
|
| 128 |
+
" `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For"
|
| 129 |
+
" more details, see"
|
| 130 |
+
" https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations."
|
| 131 |
+
)
|
| 132 |
+
if os.name == "nt":
|
| 133 |
+
message += (
|
| 134 |
+
"\nTo support symlinks on Windows, you either need to"
|
| 135 |
+
" activate Developer Mode or to run Python as an"
|
| 136 |
+
" administrator. In order to activate developer mode,"
|
| 137 |
+
" see this article:"
|
| 138 |
+
" https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development"
|
| 139 |
+
)
|
| 140 |
+
warnings.warn(message)
|
| 141 |
+
|
| 142 |
+
return _are_symlinks_supported_in_dir[cache_dir]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass(frozen=True)
|
| 146 |
+
class HfFileMetadata:
|
| 147 |
+
"""Data structure containing information about a file versioned on the Hub.
|
| 148 |
+
|
| 149 |
+
Returned by [`get_hf_file_metadata`] based on a URL.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
commit_hash (`str`, *optional*):
|
| 153 |
+
The commit_hash related to the file.
|
| 154 |
+
etag (`str`, *optional*):
|
| 155 |
+
Etag of the file on the server.
|
| 156 |
+
location (`str`):
|
| 157 |
+
Location where to download the file. Can be a Hub url or not (CDN).
|
| 158 |
+
size (`size`):
|
| 159 |
+
Size of the file. In case of an LFS file, contains the size of the actual
|
| 160 |
+
LFS file, not the pointer.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
commit_hash: Optional[str]
|
| 164 |
+
etag: Optional[str]
|
| 165 |
+
location: str
|
| 166 |
+
size: Optional[int]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@validate_hf_hub_args
|
| 170 |
+
def hf_hub_url(
|
| 171 |
+
repo_id: str,
|
| 172 |
+
filename: str,
|
| 173 |
+
*,
|
| 174 |
+
subfolder: Optional[str] = None,
|
| 175 |
+
repo_type: Optional[str] = None,
|
| 176 |
+
revision: Optional[str] = None,
|
| 177 |
+
endpoint: Optional[str] = None,
|
| 178 |
+
) -> str:
|
| 179 |
+
"""Construct the URL of a file from the given information.
|
| 180 |
+
|
| 181 |
+
The resolved address can either be a huggingface.co-hosted url, or a link to
|
| 182 |
+
Cloudfront (a Content Delivery Network, or CDN) for large files which are
|
| 183 |
+
more than a few MBs.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
repo_id (`str`):
|
| 187 |
+
A namespace (user or an organization) name and a repo name separated
|
| 188 |
+
by a `/`.
|
| 189 |
+
filename (`str`):
|
| 190 |
+
The name of the file in the repo.
|
| 191 |
+
subfolder (`str`, *optional*):
|
| 192 |
+
An optional value corresponding to a folder inside the repo.
|
| 193 |
+
repo_type (`str`, *optional*):
|
| 194 |
+
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
|
| 195 |
+
`None` or `"model"` if downloading from a model. Default is `None`.
|
| 196 |
+
revision (`str`, *optional*):
|
| 197 |
+
An optional Git revision id which can be a branch name, a tag, or a
|
| 198 |
+
commit hash.
|
| 199 |
+
|
| 200 |
+
Example:
|
| 201 |
+
|
| 202 |
+
```python
|
| 203 |
+
>>> from huggingface_hub import hf_hub_url
|
| 204 |
+
|
| 205 |
+
>>> hf_hub_url(
|
| 206 |
+
... repo_id="julien-c/EsperBERTo-small", filename="pytorch_model.bin"
|
| 207 |
+
... )
|
| 208 |
+
'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin'
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
<Tip>
|
| 212 |
+
|
| 213 |
+
Notes:
|
| 214 |
+
|
| 215 |
+
Cloudfront is replicated over the globe so downloads are way faster for
|
| 216 |
+
the end user (and it also lowers our bandwidth costs).
|
| 217 |
+
|
| 218 |
+
Cloudfront aggressively caches files by default (default TTL is 24
|
| 219 |
+
hours), however this is not an issue here because we implement a
|
| 220 |
+
git-based versioning system on huggingface.co, which means that we store
|
| 221 |
+
the files on S3/Cloudfront in a content-addressable way (i.e., the file
|
| 222 |
+
name is its hash). Using content-addressable filenames means cache can't
|
| 223 |
+
ever be stale.
|
| 224 |
+
|
| 225 |
+
In terms of client-side caching from this library, we base our caching
|
| 226 |
+
on the objects' entity tag (`ETag`), which is an identifier of a
|
| 227 |
+
specific version of a resource [1]_. An object's ETag is: its git-sha1
|
| 228 |
+
if stored in git, or its sha256 if stored in git-lfs.
|
| 229 |
+
|
| 230 |
+
</Tip>
|
| 231 |
+
|
| 232 |
+
References:
|
| 233 |
+
|
| 234 |
+
- [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag
|
| 235 |
+
"""
|
| 236 |
+
if subfolder == "":
|
| 237 |
+
subfolder = None
|
| 238 |
+
if subfolder is not None:
|
| 239 |
+
filename = f"{subfolder}/{filename}"
|
| 240 |
+
|
| 241 |
+
if repo_type not in constants.REPO_TYPES:
|
| 242 |
+
raise ValueError("Invalid repo type")
|
| 243 |
+
|
| 244 |
+
if repo_type in constants.REPO_TYPES_URL_PREFIXES:
|
| 245 |
+
repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
|
| 246 |
+
|
| 247 |
+
if revision is None:
|
| 248 |
+
revision = constants.DEFAULT_REVISION
|
| 249 |
+
url = HUGGINGFACE_CO_URL_TEMPLATE.format(
|
| 250 |
+
repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename)
|
| 251 |
+
)
|
| 252 |
+
# Update endpoint if provided
|
| 253 |
+
if endpoint is not None and url.startswith(constants.ENDPOINT):
|
| 254 |
+
url = endpoint + url[len(constants.ENDPOINT) :]
|
| 255 |
+
return url
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _request_wrapper(
|
| 259 |
+
method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params
|
| 260 |
+
) -> requests.Response:
|
| 261 |
+
"""Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when
|
| 262 |
+
`allow_redirection=False`.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
method (`str`):
|
| 266 |
+
HTTP method, such as 'GET' or 'HEAD'.
|
| 267 |
+
url (`str`):
|
| 268 |
+
The URL of the resource to fetch.
|
| 269 |
+
follow_relative_redirects (`bool`, *optional*, defaults to `False`)
|
| 270 |
+
If True, relative redirection (redirection to the same site) will be resolved even when `allow_redirection`
|
| 271 |
+
kwarg is set to False. Useful when we want to follow a redirection to a renamed repository without
|
| 272 |
+
following redirection to a CDN.
|
| 273 |
+
**params (`dict`, *optional*):
|
| 274 |
+
Params to pass to `requests.request`.
|
| 275 |
+
"""
|
| 276 |
+
# Recursively follow relative redirects
|
| 277 |
+
if follow_relative_redirects:
|
| 278 |
+
response = _request_wrapper(
|
| 279 |
+
method=method,
|
| 280 |
+
url=url,
|
| 281 |
+
follow_relative_redirects=False,
|
| 282 |
+
**params,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# If redirection, we redirect only relative paths.
|
| 286 |
+
# This is useful in case of a renamed repository.
|
| 287 |
+
if 300 <= response.status_code <= 399:
|
| 288 |
+
parsed_target = urlparse(response.headers["Location"])
|
| 289 |
+
if parsed_target.netloc == "":
|
| 290 |
+
# This means it is a relative 'location' headers, as allowed by RFC 7231.
|
| 291 |
+
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
| 292 |
+
# We want to follow this relative redirect !
|
| 293 |
+
#
|
| 294 |
+
# Highly inspired by `resolve_redirects` from requests library.
|
| 295 |
+
# See https://github.com/psf/requests/blob/main/requests/sessions.py#L159
|
| 296 |
+
next_url = urlparse(url)._replace(path=parsed_target.path).geturl()
|
| 297 |
+
return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params)
|
| 298 |
+
return response
|
| 299 |
+
|
| 300 |
+
# Perform request and return if status_code is not in the retry list.
|
| 301 |
+
response = get_session().request(method=method, url=url, **params)
|
| 302 |
+
hf_raise_for_status(response)
|
| 303 |
+
return response
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def http_get(
|
| 307 |
+
url: str,
|
| 308 |
+
temp_file: BinaryIO,
|
| 309 |
+
*,
|
| 310 |
+
proxies: Optional[Dict] = None,
|
| 311 |
+
resume_size: float = 0,
|
| 312 |
+
headers: Optional[Dict[str, str]] = None,
|
| 313 |
+
expected_size: Optional[int] = None,
|
| 314 |
+
displayed_filename: Optional[str] = None,
|
| 315 |
+
_nb_retries: int = 5,
|
| 316 |
+
_tqdm_bar: Optional[tqdm] = None,
|
| 317 |
+
) -> None:
|
| 318 |
+
"""
|
| 319 |
+
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
|
| 320 |
+
|
| 321 |
+
If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely a
|
| 322 |
+
transient error (network outage?). We log a warning message and try to resume the download a few times before
|
| 323 |
+
giving up. The method gives up after 5 attempts if no new data has being received from the server.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
url (`str`):
|
| 327 |
+
The URL of the file to download.
|
| 328 |
+
temp_file (`BinaryIO`):
|
| 329 |
+
The file-like object where to save the file.
|
| 330 |
+
proxies (`dict`, *optional*):
|
| 331 |
+
Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
|
| 332 |
+
resume_size (`float`, *optional*):
|
| 333 |
+
The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a
|
| 334 |
+
positive number, the download will resume at the given position.
|
| 335 |
+
headers (`dict`, *optional*):
|
| 336 |
+
Dictionary of HTTP Headers to send with the request.
|
| 337 |
+
expected_size (`int`, *optional*):
|
| 338 |
+
The expected size of the file to download. If set, the download will raise an error if the size of the
|
| 339 |
+
received content is different from the expected one.
|
| 340 |
+
displayed_filename (`str`, *optional*):
|
| 341 |
+
The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If
|
| 342 |
+
not set, the filename is guessed from the URL or the `Content-Disposition` header.
|
| 343 |
+
"""
|
| 344 |
+
if expected_size is not None and resume_size == expected_size:
|
| 345 |
+
# If the file is already fully downloaded, we don't need to download it again.
|
| 346 |
+
return
|
| 347 |
+
|
| 348 |
+
hf_transfer = None
|
| 349 |
+
if constants.HF_HUB_ENABLE_HF_TRANSFER:
|
| 350 |
+
if resume_size != 0:
|
| 351 |
+
warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method")
|
| 352 |
+
elif proxies is not None:
|
| 353 |
+
warnings.warn("'hf_transfer' does not support `proxies`: falling back to regular download method")
|
| 354 |
+
else:
|
| 355 |
+
try:
|
| 356 |
+
import hf_transfer # type: ignore[no-redef]
|
| 357 |
+
except ImportError:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"Fast download using 'hf_transfer' is enabled"
|
| 360 |
+
" (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is not"
|
| 361 |
+
" available in your environment. Try `pip install hf_transfer`."
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
initial_headers = headers
|
| 365 |
+
headers = copy.deepcopy(headers) or {}
|
| 366 |
+
if resume_size > 0:
|
| 367 |
+
headers["Range"] = "bytes=%d-" % (resume_size,)
|
| 368 |
+
|
| 369 |
+
r = _request_wrapper(
|
| 370 |
+
method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT
|
| 371 |
+
)
|
| 372 |
+
hf_raise_for_status(r)
|
| 373 |
+
content_length = r.headers.get("Content-Length")
|
| 374 |
+
|
| 375 |
+
# NOTE: 'total' is the total number of bytes to download, not the number of bytes in the file.
|
| 376 |
+
# If the file is compressed, the number of bytes in the saved file will be higher than 'total'.
|
| 377 |
+
total = resume_size + int(content_length) if content_length is not None else None
|
| 378 |
+
|
| 379 |
+
if displayed_filename is None:
|
| 380 |
+
displayed_filename = url
|
| 381 |
+
content_disposition = r.headers.get("Content-Disposition")
|
| 382 |
+
if content_disposition is not None:
|
| 383 |
+
match = HEADER_FILENAME_PATTERN.search(content_disposition)
|
| 384 |
+
if match is not None:
|
| 385 |
+
# Means file is on CDN
|
| 386 |
+
displayed_filename = match.groupdict()["filename"]
|
| 387 |
+
|
| 388 |
+
# Truncate filename if too long to display
|
| 389 |
+
if len(displayed_filename) > 40:
|
| 390 |
+
displayed_filename = f"(…){displayed_filename[-40:]}"
|
| 391 |
+
|
| 392 |
+
consistency_error_message = (
|
| 393 |
+
f"Consistency check failed: file should be of size {expected_size} but has size"
|
| 394 |
+
f" {{actual_size}} ({displayed_filename}).\nThis is usually due to network issues while downloading the file."
|
| 395 |
+
" Please retry with `force_download=True`."
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Stream file to buffer
|
| 399 |
+
progress_cm: tqdm = (
|
| 400 |
+
tqdm( # type: ignore[assignment]
|
| 401 |
+
unit="B",
|
| 402 |
+
unit_scale=True,
|
| 403 |
+
total=total,
|
| 404 |
+
initial=resume_size,
|
| 405 |
+
desc=displayed_filename,
|
| 406 |
+
disable=is_tqdm_disabled(logger.getEffectiveLevel()),
|
| 407 |
+
name="huggingface_hub.http_get",
|
| 408 |
+
)
|
| 409 |
+
if _tqdm_bar is None
|
| 410 |
+
else contextlib.nullcontext(_tqdm_bar)
|
| 411 |
+
# ^ `contextlib.nullcontext` mimics a context manager that does nothing
|
| 412 |
+
# Makes it easier to use the same code path for both cases but in the later
|
| 413 |
+
# case, the progress bar is not closed when exiting the context manager.
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
with progress_cm as progress:
|
| 417 |
+
if hf_transfer and total is not None and total > 5 * constants.DOWNLOAD_CHUNK_SIZE:
|
| 418 |
+
supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
|
| 419 |
+
if not supports_callback:
|
| 420 |
+
warnings.warn(
|
| 421 |
+
"You are using an outdated version of `hf_transfer`. "
|
| 422 |
+
"Consider upgrading to latest version to enable progress bars "
|
| 423 |
+
"using `pip install -U hf_transfer`."
|
| 424 |
+
)
|
| 425 |
+
try:
|
| 426 |
+
hf_transfer.download(
|
| 427 |
+
url=url,
|
| 428 |
+
filename=temp_file.name,
|
| 429 |
+
max_files=constants.HF_TRANSFER_CONCURRENCY,
|
| 430 |
+
chunk_size=constants.DOWNLOAD_CHUNK_SIZE,
|
| 431 |
+
headers=headers,
|
| 432 |
+
parallel_failures=3,
|
| 433 |
+
max_retries=5,
|
| 434 |
+
**({"callback": progress.update} if supports_callback else {}),
|
| 435 |
+
)
|
| 436 |
+
except Exception as e:
|
| 437 |
+
raise RuntimeError(
|
| 438 |
+
"An error occurred while downloading using `hf_transfer`. Consider"
|
| 439 |
+
" disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
|
| 440 |
+
) from e
|
| 441 |
+
if not supports_callback:
|
| 442 |
+
progress.update(total)
|
| 443 |
+
if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
|
| 444 |
+
raise EnvironmentError(
|
| 445 |
+
consistency_error_message.format(
|
| 446 |
+
actual_size=os.path.getsize(temp_file.name),
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
return
|
| 450 |
+
new_resume_size = resume_size
|
| 451 |
+
try:
|
| 452 |
+
for chunk in r.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE):
|
| 453 |
+
if chunk: # filter out keep-alive new chunks
|
| 454 |
+
progress.update(len(chunk))
|
| 455 |
+
temp_file.write(chunk)
|
| 456 |
+
new_resume_size += len(chunk)
|
| 457 |
+
# Some data has been downloaded from the server so we reset the number of retries.
|
| 458 |
+
_nb_retries = 5
|
| 459 |
+
except (requests.ConnectionError, requests.ReadTimeout) as e:
|
| 460 |
+
# If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
|
| 461 |
+
# a transient error (network outage?). We log a warning message and try to resume the download a few times
|
| 462 |
+
# before giving up. Tre retry mechanism is basic but should be enough in most cases.
|
| 463 |
+
if _nb_retries <= 0:
|
| 464 |
+
logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
|
| 465 |
+
raise
|
| 466 |
+
logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
|
| 467 |
+
time.sleep(1)
|
| 468 |
+
reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
|
| 469 |
+
return http_get(
|
| 470 |
+
url=url,
|
| 471 |
+
temp_file=temp_file,
|
| 472 |
+
proxies=proxies,
|
| 473 |
+
resume_size=new_resume_size,
|
| 474 |
+
headers=initial_headers,
|
| 475 |
+
expected_size=expected_size,
|
| 476 |
+
_nb_retries=_nb_retries - 1,
|
| 477 |
+
_tqdm_bar=_tqdm_bar,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
if expected_size is not None and expected_size != temp_file.tell():
|
| 481 |
+
raise EnvironmentError(
|
| 482 |
+
consistency_error_message.format(
|
| 483 |
+
actual_size=temp_file.tell(),
|
| 484 |
+
)
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _normalize_etag(etag: Optional[str]) -> Optional[str]:
|
| 489 |
+
"""Normalize ETag HTTP header, so it can be used to create nice filepaths.
|
| 490 |
+
|
| 491 |
+
The HTTP spec allows two forms of ETag:
|
| 492 |
+
ETag: W/"<etag_value>"
|
| 493 |
+
ETag: "<etag_value>"
|
| 494 |
+
|
| 495 |
+
For now, we only expect the second form from the server, but we want to be future-proof so we support both. For
|
| 496 |
+
more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428.
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
etag (`str`, *optional*): HTTP header
|
| 500 |
+
|
| 501 |
+
Returns:
|
| 502 |
+
`str` or `None`: string that can be used as a nice directory name.
|
| 503 |
+
Returns `None` if input is None.
|
| 504 |
+
"""
|
| 505 |
+
if etag is None:
|
| 506 |
+
return None
|
| 507 |
+
return etag.lstrip("W/").strip('"')
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None:
|
| 511 |
+
"""Alias method used in `transformers` conversion script."""
|
| 512 |
+
return _create_symlink(src=src, dst=dst, new_blob=new_blob)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
|
| 516 |
+
"""Create a symbolic link named dst pointing to src.
|
| 517 |
+
|
| 518 |
+
By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages:
|
| 519 |
+
- If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will
|
| 520 |
+
not break.
|
| 521 |
+
- Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when
|
| 522 |
+
changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398,
|
| 523 |
+
https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228.
|
| 524 |
+
NOTE: The issue with absolute paths doesn't happen on admin mode.
|
| 525 |
+
When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created.
|
| 526 |
+
This happens when paths are not on the same volume. In that case, we use absolute paths.
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
The result layout looks something like
|
| 530 |
+
└── [ 128] snapshots
|
| 531 |
+
├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
|
| 532 |
+
│ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
|
| 533 |
+
│ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
| 534 |
+
|
| 535 |
+
If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by
|
| 536 |
+
having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file
|
| 537 |
+
(`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing
|
| 538 |
+
cache, the file is duplicated on the disk.
|
| 539 |
+
|
| 540 |
+
In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`.
|
| 541 |
+
The warning message can be disabled with the `DISABLE_SYMLINKS_WARNING` environment variable.
|
| 542 |
+
"""
|
| 543 |
+
try:
|
| 544 |
+
os.remove(dst)
|
| 545 |
+
except OSError:
|
| 546 |
+
pass
|
| 547 |
+
|
| 548 |
+
abs_src = os.path.abspath(os.path.expanduser(src))
|
| 549 |
+
abs_dst = os.path.abspath(os.path.expanduser(dst))
|
| 550 |
+
abs_dst_folder = os.path.dirname(abs_dst)
|
| 551 |
+
|
| 552 |
+
# Use relative_dst in priority
|
| 553 |
+
try:
|
| 554 |
+
relative_src = os.path.relpath(abs_src, abs_dst_folder)
|
| 555 |
+
except ValueError:
|
| 556 |
+
# Raised on Windows if src and dst are not on the same volume. This is the case when creating a symlink to a
|
| 557 |
+
# local_dir instead of within the cache directory.
|
| 558 |
+
# See https://docs.python.org/3/library/os.path.html#os.path.relpath
|
| 559 |
+
relative_src = None
|
| 560 |
+
|
| 561 |
+
try:
|
| 562 |
+
commonpath = os.path.commonpath([abs_src, abs_dst])
|
| 563 |
+
_support_symlinks = are_symlinks_supported(commonpath)
|
| 564 |
+
except ValueError:
|
| 565 |
+
# Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos.
|
| 566 |
+
# See https://docs.python.org/3/library/os.path.html#os.path.commonpath
|
| 567 |
+
_support_symlinks = os.name != "nt"
|
| 568 |
+
except PermissionError:
|
| 569 |
+
# Permission error means src and dst are not in the same volume (e.g. destination path has been provided
|
| 570 |
+
# by the user via `local_dir`. Let's test symlink support there)
|
| 571 |
+
_support_symlinks = are_symlinks_supported(abs_dst_folder)
|
| 572 |
+
except OSError as e:
|
| 573 |
+
# OS error (errno=30) means that the commonpath is readonly on Linux/MacOS.
|
| 574 |
+
if e.errno == errno.EROFS:
|
| 575 |
+
_support_symlinks = are_symlinks_supported(abs_dst_folder)
|
| 576 |
+
else:
|
| 577 |
+
raise
|
| 578 |
+
|
| 579 |
+
# Symlinks are supported => let's create a symlink.
|
| 580 |
+
if _support_symlinks:
|
| 581 |
+
src_rel_or_abs = relative_src or abs_src
|
| 582 |
+
logger.debug(f"Creating pointer from {src_rel_or_abs} to {abs_dst}")
|
| 583 |
+
try:
|
| 584 |
+
os.symlink(src_rel_or_abs, abs_dst)
|
| 585 |
+
return
|
| 586 |
+
except FileExistsError:
|
| 587 |
+
if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src):
|
| 588 |
+
# `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has
|
| 589 |
+
# been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing.
|
| 590 |
+
return
|
| 591 |
+
else:
|
| 592 |
+
# Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and
|
| 593 |
+
# `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception.
|
| 594 |
+
raise
|
| 595 |
+
except PermissionError:
|
| 596 |
+
# Permission error means src and dst are not in the same volume (e.g. download to local dir) and symlink
|
| 597 |
+
# is supported on both volumes but not between them. Let's just make a hard copy in that case.
|
| 598 |
+
pass
|
| 599 |
+
|
| 600 |
+
# Symlinks are not supported => let's move or copy the file.
|
| 601 |
+
if new_blob:
|
| 602 |
+
logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
|
| 603 |
+
shutil.move(abs_src, abs_dst, copy_function=_copy_no_matter_what)
|
| 604 |
+
else:
|
| 605 |
+
logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
|
| 606 |
+
shutil.copyfile(abs_src, abs_dst)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None:
|
| 610 |
+
"""Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash.
|
| 611 |
+
|
| 612 |
+
Does nothing if `revision` is already a proper `commit_hash` or reference is already cached.
|
| 613 |
+
"""
|
| 614 |
+
if revision != commit_hash:
|
| 615 |
+
ref_path = Path(storage_folder) / "refs" / revision
|
| 616 |
+
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
| 617 |
+
if not ref_path.exists() or commit_hash != ref_path.read_text():
|
| 618 |
+
# Update ref only if has been updated. Could cause useless error in case
|
| 619 |
+
# repo is already cached and user doesn't have write access to cache folder.
|
| 620 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1216.
|
| 621 |
+
ref_path.write_text(commit_hash)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
@validate_hf_hub_args
|
| 625 |
+
def repo_folder_name(*, repo_id: str, repo_type: str) -> str:
|
| 626 |
+
"""Return a serialized version of a hf.co repo name and type, safe for disk storage
|
| 627 |
+
as a single non-nested folder.
|
| 628 |
+
|
| 629 |
+
Example: models--julien-c--EsperBERTo-small
|
| 630 |
+
"""
|
| 631 |
+
# remove all `/` occurrences to correctly convert repo to directory name
|
| 632 |
+
parts = [f"{repo_type}s", *repo_id.split("/")]
|
| 633 |
+
return constants.REPO_ID_SEPARATOR.join(parts)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None:
|
| 637 |
+
"""Check disk usage and log a warning if there is not enough disk space to download the file.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
expected_size (`int`):
|
| 641 |
+
The expected size of the file in bytes.
|
| 642 |
+
target_dir (`str`):
|
| 643 |
+
The directory where the file will be stored after downloading.
|
| 644 |
+
"""
|
| 645 |
+
|
| 646 |
+
target_dir = Path(target_dir) # format as `Path`
|
| 647 |
+
for path in [target_dir] + list(target_dir.parents): # first check target_dir, then each parents one by one
|
| 648 |
+
try:
|
| 649 |
+
target_dir_free = shutil.disk_usage(path).free
|
| 650 |
+
if target_dir_free < expected_size:
|
| 651 |
+
warnings.warn(
|
| 652 |
+
"Not enough free disk space to download the file. "
|
| 653 |
+
f"The expected file size is: {expected_size / 1e6:.2f} MB. "
|
| 654 |
+
f"The target location {target_dir} only has {target_dir_free / 1e6:.2f} MB free disk space."
|
| 655 |
+
)
|
| 656 |
+
return
|
| 657 |
+
except OSError: # raise on anything: file does not exist or space disk cannot be checked
|
| 658 |
+
pass
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
@validate_hf_hub_args
|
| 662 |
+
def hf_hub_download(
|
| 663 |
+
repo_id: str,
|
| 664 |
+
filename: str,
|
| 665 |
+
*,
|
| 666 |
+
subfolder: Optional[str] = None,
|
| 667 |
+
repo_type: Optional[str] = None,
|
| 668 |
+
revision: Optional[str] = None,
|
| 669 |
+
library_name: Optional[str] = None,
|
| 670 |
+
library_version: Optional[str] = None,
|
| 671 |
+
cache_dir: Union[str, Path, None] = None,
|
| 672 |
+
local_dir: Union[str, Path, None] = None,
|
| 673 |
+
user_agent: Union[Dict, str, None] = None,
|
| 674 |
+
force_download: bool = False,
|
| 675 |
+
proxies: Optional[Dict] = None,
|
| 676 |
+
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
|
| 677 |
+
token: Union[bool, str, None] = None,
|
| 678 |
+
local_files_only: bool = False,
|
| 679 |
+
headers: Optional[Dict[str, str]] = None,
|
| 680 |
+
endpoint: Optional[str] = None,
|
| 681 |
+
resume_download: Optional[bool] = None,
|
| 682 |
+
force_filename: Optional[str] = None,
|
| 683 |
+
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
| 684 |
+
) -> str:
|
| 685 |
+
"""Download a given file if it's not already present in the local cache.
|
| 686 |
+
|
| 687 |
+
The new cache file layout looks like this:
|
| 688 |
+
- The cache directory contains one subfolder per repo_id (namespaced by repo type)
|
| 689 |
+
- inside each repo folder:
|
| 690 |
+
- refs is a list of the latest known revision => commit_hash pairs
|
| 691 |
+
- blobs contains the actual file blobs (identified by their git-sha or sha256, depending on
|
| 692 |
+
whether they're LFS files or not)
|
| 693 |
+
- snapshots contains one subfolder per commit, each "commit" contains the subset of the files
|
| 694 |
+
that have been resolved at that particular commit. Each filename is a symlink to the blob
|
| 695 |
+
at that particular commit.
|
| 696 |
+
|
| 697 |
+
```
|
| 698 |
+
[ 96] .
|
| 699 |
+
└── [ 160] models--julien-c--EsperBERTo-small
|
| 700 |
+
├── [ 160] blobs
|
| 701 |
+
│ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
| 702 |
+
│ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
| 703 |
+
│ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812
|
| 704 |
+
├── [ 96] refs
|
| 705 |
+
│ └── [ 40] main
|
| 706 |
+
└── [ 128] snapshots
|
| 707 |
+
├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
|
| 708 |
+
│ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
|
| 709 |
+
│ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
| 710 |
+
└── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
|
| 711 |
+
├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
| 712 |
+
└── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
| 713 |
+
```
|
| 714 |
+
|
| 715 |
+
If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
|
| 716 |
+
option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
|
| 717 |
+
to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
|
| 718 |
+
cache-system, it's optimized for regularly pulling the latest version of a repository.
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
repo_id (`str`):
|
| 722 |
+
A user or an organization name and a repo name separated by a `/`.
|
| 723 |
+
filename (`str`):
|
| 724 |
+
The name of the file in the repo.
|
| 725 |
+
subfolder (`str`, *optional*):
|
| 726 |
+
An optional value corresponding to a folder inside the model repo.
|
| 727 |
+
repo_type (`str`, *optional*):
|
| 728 |
+
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
|
| 729 |
+
`None` or `"model"` if downloading from a model. Default is `None`.
|
| 730 |
+
revision (`str`, *optional*):
|
| 731 |
+
An optional Git revision id which can be a branch name, a tag, or a
|
| 732 |
+
commit hash.
|
| 733 |
+
library_name (`str`, *optional*):
|
| 734 |
+
The name of the library to which the object corresponds.
|
| 735 |
+
library_version (`str`, *optional*):
|
| 736 |
+
The version of the library.
|
| 737 |
+
cache_dir (`str`, `Path`, *optional*):
|
| 738 |
+
Path to the folder where cached files are stored.
|
| 739 |
+
local_dir (`str` or `Path`, *optional*):
|
| 740 |
+
If provided, the downloaded file will be placed under this directory.
|
| 741 |
+
user_agent (`dict`, `str`, *optional*):
|
| 742 |
+
The user-agent info in the form of a dictionary or a string.
|
| 743 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 744 |
+
Whether the file should be downloaded even if it already exists in
|
| 745 |
+
the local cache.
|
| 746 |
+
proxies (`dict`, *optional*):
|
| 747 |
+
Dictionary mapping protocol to the URL of the proxy passed to
|
| 748 |
+
`requests.request`.
|
| 749 |
+
etag_timeout (`float`, *optional*, defaults to `10`):
|
| 750 |
+
When fetching ETag, how many seconds to wait for the server to send
|
| 751 |
+
data before giving up which is passed to `requests.request`.
|
| 752 |
+
token (`str`, `bool`, *optional*):
|
| 753 |
+
A token to be used for the download.
|
| 754 |
+
- If `True`, the token is read from the HuggingFace config
|
| 755 |
+
folder.
|
| 756 |
+
- If a string, it's used as the authentication token.
|
| 757 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 758 |
+
If `True`, avoid downloading the file and return the path to the
|
| 759 |
+
local cached file if it exists.
|
| 760 |
+
headers (`dict`, *optional*):
|
| 761 |
+
Additional headers to be sent with the request.
|
| 762 |
+
|
| 763 |
+
Returns:
|
| 764 |
+
`str`: Local path of file or if networking is off, last version of file cached on disk.
|
| 765 |
+
|
| 766 |
+
Raises:
|
| 767 |
+
[`~utils.RepositoryNotFoundError`]
|
| 768 |
+
If the repository to download from cannot be found. This may be because it doesn't exist,
|
| 769 |
+
or because it is set to `private` and you do not have access.
|
| 770 |
+
[`~utils.RevisionNotFoundError`]
|
| 771 |
+
If the revision to download from cannot be found.
|
| 772 |
+
[`~utils.EntryNotFoundError`]
|
| 773 |
+
If the file to download cannot be found.
|
| 774 |
+
[`~utils.LocalEntryNotFoundError`]
|
| 775 |
+
If network is disabled or unavailable and file is not found in cache.
|
| 776 |
+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
| 777 |
+
If `token=True` but the token cannot be found.
|
| 778 |
+
[`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
| 779 |
+
If ETag cannot be determined.
|
| 780 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
| 781 |
+
If some parameter value is invalid.
|
| 782 |
+
|
| 783 |
+
"""
|
| 784 |
+
if constants.HF_HUB_ETAG_TIMEOUT != constants.DEFAULT_ETAG_TIMEOUT:
|
| 785 |
+
# Respect environment variable above user value
|
| 786 |
+
etag_timeout = constants.HF_HUB_ETAG_TIMEOUT
|
| 787 |
+
|
| 788 |
+
if force_filename is not None:
|
| 789 |
+
warnings.warn(
|
| 790 |
+
"The `force_filename` parameter is deprecated as a new caching system, "
|
| 791 |
+
"which keeps the filenames as they are on the Hub, is now in place.",
|
| 792 |
+
FutureWarning,
|
| 793 |
+
)
|
| 794 |
+
if resume_download is not None:
|
| 795 |
+
warnings.warn(
|
| 796 |
+
"`resume_download` is deprecated and will be removed in version 1.0.0. "
|
| 797 |
+
"Downloads always resume when possible. "
|
| 798 |
+
"If you want to force a new download, use `force_download=True`.",
|
| 799 |
+
FutureWarning,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
if cache_dir is None:
|
| 803 |
+
cache_dir = constants.HF_HUB_CACHE
|
| 804 |
+
if revision is None:
|
| 805 |
+
revision = constants.DEFAULT_REVISION
|
| 806 |
+
if isinstance(cache_dir, Path):
|
| 807 |
+
cache_dir = str(cache_dir)
|
| 808 |
+
if isinstance(local_dir, Path):
|
| 809 |
+
local_dir = str(local_dir)
|
| 810 |
+
|
| 811 |
+
if subfolder == "":
|
| 812 |
+
subfolder = None
|
| 813 |
+
if subfolder is not None:
|
| 814 |
+
# This is used to create a URL, and not a local path, hence the forward slash.
|
| 815 |
+
filename = f"{subfolder}/{filename}"
|
| 816 |
+
|
| 817 |
+
if repo_type is None:
|
| 818 |
+
repo_type = "model"
|
| 819 |
+
if repo_type not in constants.REPO_TYPES:
|
| 820 |
+
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
|
| 821 |
+
|
| 822 |
+
hf_headers = build_hf_headers(
|
| 823 |
+
token=token,
|
| 824 |
+
library_name=library_name,
|
| 825 |
+
library_version=library_version,
|
| 826 |
+
user_agent=user_agent,
|
| 827 |
+
headers=headers,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
if local_dir is not None:
|
| 831 |
+
if local_dir_use_symlinks != "auto":
|
| 832 |
+
warnings.warn(
|
| 833 |
+
"`local_dir_use_symlinks` parameter is deprecated and will be ignored. "
|
| 834 |
+
"The process to download files to a local folder has been updated and do "
|
| 835 |
+
"not rely on symlinks anymore. You only need to pass a destination folder "
|
| 836 |
+
"as`local_dir`.\n"
|
| 837 |
+
"For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder."
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
return _hf_hub_download_to_local_dir(
|
| 841 |
+
# Destination
|
| 842 |
+
local_dir=local_dir,
|
| 843 |
+
# File info
|
| 844 |
+
repo_id=repo_id,
|
| 845 |
+
repo_type=repo_type,
|
| 846 |
+
filename=filename,
|
| 847 |
+
revision=revision,
|
| 848 |
+
# HTTP info
|
| 849 |
+
endpoint=endpoint,
|
| 850 |
+
etag_timeout=etag_timeout,
|
| 851 |
+
headers=hf_headers,
|
| 852 |
+
proxies=proxies,
|
| 853 |
+
token=token,
|
| 854 |
+
# Additional options
|
| 855 |
+
cache_dir=cache_dir,
|
| 856 |
+
force_download=force_download,
|
| 857 |
+
local_files_only=local_files_only,
|
| 858 |
+
)
|
| 859 |
+
else:
|
| 860 |
+
return _hf_hub_download_to_cache_dir(
|
| 861 |
+
# Destination
|
| 862 |
+
cache_dir=cache_dir,
|
| 863 |
+
# File info
|
| 864 |
+
repo_id=repo_id,
|
| 865 |
+
filename=filename,
|
| 866 |
+
repo_type=repo_type,
|
| 867 |
+
revision=revision,
|
| 868 |
+
# HTTP info
|
| 869 |
+
endpoint=endpoint,
|
| 870 |
+
etag_timeout=etag_timeout,
|
| 871 |
+
headers=hf_headers,
|
| 872 |
+
proxies=proxies,
|
| 873 |
+
token=token,
|
| 874 |
+
# Additional options
|
| 875 |
+
local_files_only=local_files_only,
|
| 876 |
+
force_download=force_download,
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
def _hf_hub_download_to_cache_dir(
|
| 881 |
+
*,
|
| 882 |
+
# Destination
|
| 883 |
+
cache_dir: str,
|
| 884 |
+
# File info
|
| 885 |
+
repo_id: str,
|
| 886 |
+
filename: str,
|
| 887 |
+
repo_type: str,
|
| 888 |
+
revision: str,
|
| 889 |
+
# HTTP info
|
| 890 |
+
endpoint: Optional[str],
|
| 891 |
+
etag_timeout: float,
|
| 892 |
+
headers: Dict[str, str],
|
| 893 |
+
proxies: Optional[Dict],
|
| 894 |
+
token: Optional[Union[bool, str]],
|
| 895 |
+
# Additional options
|
| 896 |
+
local_files_only: bool,
|
| 897 |
+
force_download: bool,
|
| 898 |
+
) -> str:
|
| 899 |
+
"""Download a given file to a cache folder, if not already present.
|
| 900 |
+
|
| 901 |
+
Method should not be called directly. Please use `hf_hub_download` instead.
|
| 902 |
+
"""
|
| 903 |
+
locks_dir = os.path.join(cache_dir, ".locks")
|
| 904 |
+
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
|
| 905 |
+
|
| 906 |
+
# cross platform transcription of filename, to be used as a local file path.
|
| 907 |
+
relative_filename = os.path.join(*filename.split("/"))
|
| 908 |
+
if os.name == "nt":
|
| 909 |
+
if relative_filename.startswith("..\\") or "\\..\\" in relative_filename:
|
| 910 |
+
raise ValueError(
|
| 911 |
+
f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository"
|
| 912 |
+
" owner to rename this file."
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
# if user provides a commit_hash and they already have the file on disk, shortcut everything.
|
| 916 |
+
if REGEX_COMMIT_HASH.match(revision):
|
| 917 |
+
pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
|
| 918 |
+
if os.path.exists(pointer_path) and not force_download:
|
| 919 |
+
return pointer_path
|
| 920 |
+
|
| 921 |
+
# Try to get metadata (etag, commit_hash, url, size) from the server.
|
| 922 |
+
# If we can't, a HEAD request error is returned.
|
| 923 |
+
(url_to_download, etag, commit_hash, expected_size, head_call_error) = _get_metadata_or_catch_error(
|
| 924 |
+
repo_id=repo_id,
|
| 925 |
+
filename=filename,
|
| 926 |
+
repo_type=repo_type,
|
| 927 |
+
revision=revision,
|
| 928 |
+
endpoint=endpoint,
|
| 929 |
+
proxies=proxies,
|
| 930 |
+
etag_timeout=etag_timeout,
|
| 931 |
+
headers=headers,
|
| 932 |
+
token=token,
|
| 933 |
+
local_files_only=local_files_only,
|
| 934 |
+
storage_folder=storage_folder,
|
| 935 |
+
relative_filename=relative_filename,
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# etag can be None for several reasons:
|
| 939 |
+
# 1. we passed local_files_only.
|
| 940 |
+
# 2. we don't have a connection
|
| 941 |
+
# 3. Hub is down (HTTP 500, 503, 504)
|
| 942 |
+
# 4. repo is not found -for example private or gated- and invalid/missing token sent
|
| 943 |
+
# 5. Hub is blocked by a firewall or proxy is not set correctly.
|
| 944 |
+
# => Try to get the last downloaded one from the specified revision.
|
| 945 |
+
#
|
| 946 |
+
# If the specified revision is a commit hash, look inside "snapshots".
|
| 947 |
+
# If the specified revision is a branch or tag, look inside "refs".
|
| 948 |
+
if head_call_error is not None:
|
| 949 |
+
# Couldn't make a HEAD call => let's try to find a local file
|
| 950 |
+
if not force_download:
|
| 951 |
+
commit_hash = None
|
| 952 |
+
if REGEX_COMMIT_HASH.match(revision):
|
| 953 |
+
commit_hash = revision
|
| 954 |
+
else:
|
| 955 |
+
ref_path = os.path.join(storage_folder, "refs", revision)
|
| 956 |
+
if os.path.isfile(ref_path):
|
| 957 |
+
with open(ref_path) as f:
|
| 958 |
+
commit_hash = f.read()
|
| 959 |
+
|
| 960 |
+
# Return pointer file if exists
|
| 961 |
+
if commit_hash is not None:
|
| 962 |
+
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
|
| 963 |
+
if os.path.exists(pointer_path) and not force_download:
|
| 964 |
+
return pointer_path
|
| 965 |
+
|
| 966 |
+
# Otherwise, raise appropriate error
|
| 967 |
+
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
| 968 |
+
|
| 969 |
+
# From now on, etag, commit_hash, url and size are not None.
|
| 970 |
+
assert etag is not None, "etag must have been retrieved from server"
|
| 971 |
+
assert commit_hash is not None, "commit_hash must have been retrieved from server"
|
| 972 |
+
assert url_to_download is not None, "file location must have been retrieved from server"
|
| 973 |
+
assert expected_size is not None, "expected_size must have been retrieved from server"
|
| 974 |
+
blob_path = os.path.join(storage_folder, "blobs", etag)
|
| 975 |
+
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
|
| 976 |
+
|
| 977 |
+
os.makedirs(os.path.dirname(blob_path), exist_ok=True)
|
| 978 |
+
os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
|
| 979 |
+
|
| 980 |
+
# if passed revision is not identical to commit_hash
|
| 981 |
+
# then revision has to be a branch name or tag name.
|
| 982 |
+
# In that case store a ref.
|
| 983 |
+
_cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
|
| 984 |
+
|
| 985 |
+
# If file already exists, return it (except if force_download=True)
|
| 986 |
+
if not force_download:
|
| 987 |
+
if os.path.exists(pointer_path):
|
| 988 |
+
return pointer_path
|
| 989 |
+
|
| 990 |
+
if os.path.exists(blob_path):
|
| 991 |
+
# we have the blob already, but not the pointer
|
| 992 |
+
_create_symlink(blob_path, pointer_path, new_blob=False)
|
| 993 |
+
return pointer_path
|
| 994 |
+
|
| 995 |
+
# Prevent parallel downloads of the same file with a lock.
|
| 996 |
+
# etag could be duplicated across repos,
|
| 997 |
+
lock_path = os.path.join(locks_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type), f"{etag}.lock")
|
| 998 |
+
|
| 999 |
+
# Some Windows versions do not allow for paths longer than 255 characters.
|
| 1000 |
+
# In this case, we must specify it as an extended path by using the "\\?\" prefix.
|
| 1001 |
+
if os.name == "nt" and len(os.path.abspath(lock_path)) > 255:
|
| 1002 |
+
lock_path = "\\\\?\\" + os.path.abspath(lock_path)
|
| 1003 |
+
|
| 1004 |
+
if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
|
| 1005 |
+
blob_path = "\\\\?\\" + os.path.abspath(blob_path)
|
| 1006 |
+
|
| 1007 |
+
Path(lock_path).parent.mkdir(parents=True, exist_ok=True)
|
| 1008 |
+
with WeakFileLock(lock_path):
|
| 1009 |
+
_download_to_tmp_and_move(
|
| 1010 |
+
incomplete_path=Path(blob_path + ".incomplete"),
|
| 1011 |
+
destination_path=Path(blob_path),
|
| 1012 |
+
url_to_download=url_to_download,
|
| 1013 |
+
proxies=proxies,
|
| 1014 |
+
headers=headers,
|
| 1015 |
+
expected_size=expected_size,
|
| 1016 |
+
filename=filename,
|
| 1017 |
+
force_download=force_download,
|
| 1018 |
+
)
|
| 1019 |
+
if not os.path.exists(pointer_path):
|
| 1020 |
+
_create_symlink(blob_path, pointer_path, new_blob=True)
|
| 1021 |
+
|
| 1022 |
+
return pointer_path
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def _hf_hub_download_to_local_dir(
|
| 1026 |
+
*,
|
| 1027 |
+
# Destination
|
| 1028 |
+
local_dir: Union[str, Path],
|
| 1029 |
+
# File info
|
| 1030 |
+
repo_id: str,
|
| 1031 |
+
repo_type: str,
|
| 1032 |
+
filename: str,
|
| 1033 |
+
revision: str,
|
| 1034 |
+
# HTTP info
|
| 1035 |
+
endpoint: Optional[str],
|
| 1036 |
+
etag_timeout: float,
|
| 1037 |
+
headers: Dict[str, str],
|
| 1038 |
+
proxies: Optional[Dict],
|
| 1039 |
+
token: Union[bool, str, None],
|
| 1040 |
+
# Additional options
|
| 1041 |
+
cache_dir: str,
|
| 1042 |
+
force_download: bool,
|
| 1043 |
+
local_files_only: bool,
|
| 1044 |
+
) -> str:
|
| 1045 |
+
"""Download a given file to a local folder, if not already present.
|
| 1046 |
+
|
| 1047 |
+
Method should not be called directly. Please use `hf_hub_download` instead.
|
| 1048 |
+
"""
|
| 1049 |
+
# Some Windows versions do not allow for paths longer than 255 characters.
|
| 1050 |
+
# In this case, we must specify it as an extended path by using the "\\?\" prefix.
|
| 1051 |
+
if os.name == "nt" and len(os.path.abspath(local_dir)) > 255:
|
| 1052 |
+
local_dir = "\\\\?\\" + os.path.abspath(local_dir)
|
| 1053 |
+
local_dir = Path(local_dir)
|
| 1054 |
+
paths = get_local_download_paths(local_dir=local_dir, filename=filename)
|
| 1055 |
+
local_metadata = read_download_metadata(local_dir=local_dir, filename=filename)
|
| 1056 |
+
|
| 1057 |
+
# Local file exists + metadata exists + commit_hash matches => return file
|
| 1058 |
+
if (
|
| 1059 |
+
not force_download
|
| 1060 |
+
and REGEX_COMMIT_HASH.match(revision)
|
| 1061 |
+
and paths.file_path.is_file()
|
| 1062 |
+
and local_metadata is not None
|
| 1063 |
+
and local_metadata.commit_hash == revision
|
| 1064 |
+
):
|
| 1065 |
+
return str(paths.file_path)
|
| 1066 |
+
|
| 1067 |
+
# Local file doesn't exist or commit_hash doesn't match => we need the etag
|
| 1068 |
+
(url_to_download, etag, commit_hash, expected_size, head_call_error) = _get_metadata_or_catch_error(
|
| 1069 |
+
repo_id=repo_id,
|
| 1070 |
+
filename=filename,
|
| 1071 |
+
repo_type=repo_type,
|
| 1072 |
+
revision=revision,
|
| 1073 |
+
endpoint=endpoint,
|
| 1074 |
+
proxies=proxies,
|
| 1075 |
+
etag_timeout=etag_timeout,
|
| 1076 |
+
headers=headers,
|
| 1077 |
+
token=token,
|
| 1078 |
+
local_files_only=local_files_only,
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
if head_call_error is not None:
|
| 1082 |
+
# No HEAD call but local file exists => default to local file
|
| 1083 |
+
if not force_download and paths.file_path.is_file():
|
| 1084 |
+
logger.warning(
|
| 1085 |
+
f"Couldn't access the Hub to check for update but local file already exists. Defaulting to existing file. (error: {head_call_error})"
|
| 1086 |
+
)
|
| 1087 |
+
return str(paths.file_path)
|
| 1088 |
+
# Otherwise => raise
|
| 1089 |
+
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
| 1090 |
+
|
| 1091 |
+
# From now on, etag, commit_hash, url and size are not None.
|
| 1092 |
+
assert etag is not None, "etag must have been retrieved from server"
|
| 1093 |
+
assert commit_hash is not None, "commit_hash must have been retrieved from server"
|
| 1094 |
+
assert url_to_download is not None, "file location must have been retrieved from server"
|
| 1095 |
+
assert expected_size is not None, "expected_size must have been retrieved from server"
|
| 1096 |
+
|
| 1097 |
+
# Local file exists => check if it's up-to-date
|
| 1098 |
+
if not force_download and paths.file_path.is_file():
|
| 1099 |
+
# etag matches => update metadata and return file
|
| 1100 |
+
if local_metadata is not None and local_metadata.etag == etag:
|
| 1101 |
+
write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
|
| 1102 |
+
return str(paths.file_path)
|
| 1103 |
+
|
| 1104 |
+
# metadata is outdated + etag is a sha256
|
| 1105 |
+
# => means it's an LFS file (large)
|
| 1106 |
+
# => let's compute local hash and compare
|
| 1107 |
+
# => if match, update metadata and return file
|
| 1108 |
+
if local_metadata is None and REGEX_SHA256.match(etag) is not None:
|
| 1109 |
+
with open(paths.file_path, "rb") as f:
|
| 1110 |
+
file_hash = sha_fileobj(f).hex()
|
| 1111 |
+
if file_hash == etag:
|
| 1112 |
+
write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
|
| 1113 |
+
return str(paths.file_path)
|
| 1114 |
+
|
| 1115 |
+
# Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache)
|
| 1116 |
+
|
| 1117 |
+
# If we are lucky enough, the file is already in the cache => copy it
|
| 1118 |
+
if not force_download:
|
| 1119 |
+
cached_path = try_to_load_from_cache(
|
| 1120 |
+
repo_id=repo_id,
|
| 1121 |
+
filename=filename,
|
| 1122 |
+
cache_dir=cache_dir,
|
| 1123 |
+
revision=commit_hash,
|
| 1124 |
+
repo_type=repo_type,
|
| 1125 |
+
)
|
| 1126 |
+
if isinstance(cached_path, str):
|
| 1127 |
+
with WeakFileLock(paths.lock_path):
|
| 1128 |
+
paths.file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1129 |
+
shutil.copyfile(cached_path, paths.file_path)
|
| 1130 |
+
write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
|
| 1131 |
+
return str(paths.file_path)
|
| 1132 |
+
|
| 1133 |
+
# Otherwise, let's download the file!
|
| 1134 |
+
with WeakFileLock(paths.lock_path):
|
| 1135 |
+
paths.file_path.unlink(missing_ok=True) # delete outdated file first
|
| 1136 |
+
_download_to_tmp_and_move(
|
| 1137 |
+
incomplete_path=paths.incomplete_path(etag),
|
| 1138 |
+
destination_path=paths.file_path,
|
| 1139 |
+
url_to_download=url_to_download,
|
| 1140 |
+
proxies=proxies,
|
| 1141 |
+
headers=headers,
|
| 1142 |
+
expected_size=expected_size,
|
| 1143 |
+
filename=filename,
|
| 1144 |
+
force_download=force_download,
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
|
| 1148 |
+
return str(paths.file_path)
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
@validate_hf_hub_args
|
| 1152 |
+
def try_to_load_from_cache(
|
| 1153 |
+
repo_id: str,
|
| 1154 |
+
filename: str,
|
| 1155 |
+
cache_dir: Union[str, Path, None] = None,
|
| 1156 |
+
revision: Optional[str] = None,
|
| 1157 |
+
repo_type: Optional[str] = None,
|
| 1158 |
+
) -> Union[str, _CACHED_NO_EXIST_T, None]:
|
| 1159 |
+
"""
|
| 1160 |
+
Explores the cache to return the latest cached file for a given revision if found.
|
| 1161 |
+
|
| 1162 |
+
This function will not raise any exception if the file in not cached.
|
| 1163 |
+
|
| 1164 |
+
Args:
|
| 1165 |
+
cache_dir (`str` or `os.PathLike`):
|
| 1166 |
+
The folder where the cached files lie.
|
| 1167 |
+
repo_id (`str`):
|
| 1168 |
+
The ID of the repo on huggingface.co.
|
| 1169 |
+
filename (`str`):
|
| 1170 |
+
The filename to look for inside `repo_id`.
|
| 1171 |
+
revision (`str`, *optional*):
|
| 1172 |
+
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
| 1173 |
+
provided either.
|
| 1174 |
+
repo_type (`str`, *optional*):
|
| 1175 |
+
The type of the repository. Will default to `"model"`.
|
| 1176 |
+
|
| 1177 |
+
Returns:
|
| 1178 |
+
`Optional[str]` or `_CACHED_NO_EXIST`:
|
| 1179 |
+
Will return `None` if the file was not cached. Otherwise:
|
| 1180 |
+
- The exact path to the cached file if it's found in the cache
|
| 1181 |
+
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
|
| 1182 |
+
cached.
|
| 1183 |
+
|
| 1184 |
+
Example:
|
| 1185 |
+
|
| 1186 |
+
```python
|
| 1187 |
+
from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST
|
| 1188 |
+
|
| 1189 |
+
filepath = try_to_load_from_cache()
|
| 1190 |
+
if isinstance(filepath, str):
|
| 1191 |
+
# file exists and is cached
|
| 1192 |
+
...
|
| 1193 |
+
elif filepath is _CACHED_NO_EXIST:
|
| 1194 |
+
# non-existence of file is cached
|
| 1195 |
+
...
|
| 1196 |
+
else:
|
| 1197 |
+
# file is not cached
|
| 1198 |
+
...
|
| 1199 |
+
```
|
| 1200 |
+
"""
|
| 1201 |
+
if revision is None:
|
| 1202 |
+
revision = "main"
|
| 1203 |
+
if repo_type is None:
|
| 1204 |
+
repo_type = "model"
|
| 1205 |
+
if repo_type not in constants.REPO_TYPES:
|
| 1206 |
+
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
|
| 1207 |
+
if cache_dir is None:
|
| 1208 |
+
cache_dir = constants.HF_HUB_CACHE
|
| 1209 |
+
|
| 1210 |
+
object_id = repo_id.replace("/", "--")
|
| 1211 |
+
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
| 1212 |
+
if not os.path.isdir(repo_cache):
|
| 1213 |
+
# No cache for this model
|
| 1214 |
+
return None
|
| 1215 |
+
|
| 1216 |
+
refs_dir = os.path.join(repo_cache, "refs")
|
| 1217 |
+
snapshots_dir = os.path.join(repo_cache, "snapshots")
|
| 1218 |
+
no_exist_dir = os.path.join(repo_cache, ".no_exist")
|
| 1219 |
+
|
| 1220 |
+
# Resolve refs (for instance to convert main to the associated commit sha)
|
| 1221 |
+
if os.path.isdir(refs_dir):
|
| 1222 |
+
revision_file = os.path.join(refs_dir, revision)
|
| 1223 |
+
if os.path.isfile(revision_file):
|
| 1224 |
+
with open(revision_file) as f:
|
| 1225 |
+
revision = f.read()
|
| 1226 |
+
|
| 1227 |
+
# Check if file is cached as "no_exist"
|
| 1228 |
+
if os.path.isfile(os.path.join(no_exist_dir, revision, filename)):
|
| 1229 |
+
return _CACHED_NO_EXIST
|
| 1230 |
+
|
| 1231 |
+
# Check if revision folder exists
|
| 1232 |
+
if not os.path.exists(snapshots_dir):
|
| 1233 |
+
return None
|
| 1234 |
+
cached_shas = os.listdir(snapshots_dir)
|
| 1235 |
+
if revision not in cached_shas:
|
| 1236 |
+
# No cache for this revision and we won't try to return a random revision
|
| 1237 |
+
return None
|
| 1238 |
+
|
| 1239 |
+
# Check if file exists in cache
|
| 1240 |
+
cached_file = os.path.join(snapshots_dir, revision, filename)
|
| 1241 |
+
return cached_file if os.path.isfile(cached_file) else None
|
| 1242 |
+
|
| 1243 |
+
|
| 1244 |
+
@validate_hf_hub_args
|
| 1245 |
+
def get_hf_file_metadata(
|
| 1246 |
+
url: str,
|
| 1247 |
+
token: Union[bool, str, None] = None,
|
| 1248 |
+
proxies: Optional[Dict] = None,
|
| 1249 |
+
timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT,
|
| 1250 |
+
library_name: Optional[str] = None,
|
| 1251 |
+
library_version: Optional[str] = None,
|
| 1252 |
+
user_agent: Union[Dict, str, None] = None,
|
| 1253 |
+
headers: Optional[Dict[str, str]] = None,
|
| 1254 |
+
) -> HfFileMetadata:
|
| 1255 |
+
"""Fetch metadata of a file versioned on the Hub for a given url.
|
| 1256 |
+
|
| 1257 |
+
Args:
|
| 1258 |
+
url (`str`):
|
| 1259 |
+
File url, for example returned by [`hf_hub_url`].
|
| 1260 |
+
token (`str` or `bool`, *optional*):
|
| 1261 |
+
A token to be used for the download.
|
| 1262 |
+
- If `True`, the token is read from the HuggingFace config
|
| 1263 |
+
folder.
|
| 1264 |
+
- If `False` or `None`, no token is provided.
|
| 1265 |
+
- If a string, it's used as the authentication token.
|
| 1266 |
+
proxies (`dict`, *optional*):
|
| 1267 |
+
Dictionary mapping protocol to the URL of the proxy passed to
|
| 1268 |
+
`requests.request`.
|
| 1269 |
+
timeout (`float`, *optional*, defaults to 10):
|
| 1270 |
+
How many seconds to wait for the server to send metadata before giving up.
|
| 1271 |
+
library_name (`str`, *optional*):
|
| 1272 |
+
The name of the library to which the object corresponds.
|
| 1273 |
+
library_version (`str`, *optional*):
|
| 1274 |
+
The version of the library.
|
| 1275 |
+
user_agent (`dict`, `str`, *optional*):
|
| 1276 |
+
The user-agent info in the form of a dictionary or a string.
|
| 1277 |
+
headers (`dict`, *optional*):
|
| 1278 |
+
Additional headers to be sent with the request.
|
| 1279 |
+
|
| 1280 |
+
Returns:
|
| 1281 |
+
A [`HfFileMetadata`] object containing metadata such as location, etag, size and
|
| 1282 |
+
commit_hash.
|
| 1283 |
+
"""
|
| 1284 |
+
hf_headers = build_hf_headers(
|
| 1285 |
+
token=token,
|
| 1286 |
+
library_name=library_name,
|
| 1287 |
+
library_version=library_version,
|
| 1288 |
+
user_agent=user_agent,
|
| 1289 |
+
headers=headers,
|
| 1290 |
+
)
|
| 1291 |
+
hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
|
| 1292 |
+
|
| 1293 |
+
# Retrieve metadata
|
| 1294 |
+
r = _request_wrapper(
|
| 1295 |
+
method="HEAD",
|
| 1296 |
+
url=url,
|
| 1297 |
+
headers=hf_headers,
|
| 1298 |
+
allow_redirects=False,
|
| 1299 |
+
follow_relative_redirects=True,
|
| 1300 |
+
proxies=proxies,
|
| 1301 |
+
timeout=timeout,
|
| 1302 |
+
)
|
| 1303 |
+
hf_raise_for_status(r)
|
| 1304 |
+
|
| 1305 |
+
# Return
|
| 1306 |
+
return HfFileMetadata(
|
| 1307 |
+
commit_hash=r.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT),
|
| 1308 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
| 1309 |
+
# we fallback to the regular etag header.
|
| 1310 |
+
etag=_normalize_etag(r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")),
|
| 1311 |
+
# Either from response headers (if redirected) or defaults to request url
|
| 1312 |
+
# Do not use directly `url`, as `_request_wrapper` might have followed relative
|
| 1313 |
+
# redirects.
|
| 1314 |
+
location=r.headers.get("Location") or r.request.url, # type: ignore
|
| 1315 |
+
size=_int_or_none(
|
| 1316 |
+
r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")
|
| 1317 |
+
),
|
| 1318 |
+
)
|
| 1319 |
+
|
| 1320 |
+
|
| 1321 |
+
def _get_metadata_or_catch_error(
|
| 1322 |
+
*,
|
| 1323 |
+
repo_id: str,
|
| 1324 |
+
filename: str,
|
| 1325 |
+
repo_type: str,
|
| 1326 |
+
revision: str,
|
| 1327 |
+
endpoint: Optional[str],
|
| 1328 |
+
proxies: Optional[Dict],
|
| 1329 |
+
etag_timeout: Optional[float],
|
| 1330 |
+
headers: Dict[str, str], # mutated inplace!
|
| 1331 |
+
token: Union[bool, str, None],
|
| 1332 |
+
local_files_only: bool,
|
| 1333 |
+
relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache
|
| 1334 |
+
storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache
|
| 1335 |
+
) -> Union[
|
| 1336 |
+
# Either an exception is caught and returned
|
| 1337 |
+
Tuple[None, None, None, None, Exception],
|
| 1338 |
+
# Or the metadata is returned as
|
| 1339 |
+
# `(url_to_download, etag, commit_hash, expected_size, None)`
|
| 1340 |
+
Tuple[str, str, str, int, None],
|
| 1341 |
+
]:
|
| 1342 |
+
"""Get metadata for a file on the Hub, safely handling network issues.
|
| 1343 |
+
|
| 1344 |
+
Returns either the etag, commit_hash and expected size of the file, or the error
|
| 1345 |
+
raised while fetching the metadata.
|
| 1346 |
+
|
| 1347 |
+
NOTE: This function mutates `headers` inplace! It removes the `authorization` header
|
| 1348 |
+
if the file is a LFS blob and the domain of the url is different from the
|
| 1349 |
+
domain of the location (typically an S3 bucket).
|
| 1350 |
+
"""
|
| 1351 |
+
if local_files_only:
|
| 1352 |
+
return (
|
| 1353 |
+
None,
|
| 1354 |
+
None,
|
| 1355 |
+
None,
|
| 1356 |
+
None,
|
| 1357 |
+
OfflineModeIsEnabled(
|
| 1358 |
+
f"Cannot access file since 'local_files_only=True' as been set. (repo_id: {repo_id}, repo_type: {repo_type}, revision: {revision}, filename: {filename})"
|
| 1359 |
+
),
|
| 1360 |
+
)
|
| 1361 |
+
|
| 1362 |
+
url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint)
|
| 1363 |
+
url_to_download: str = url
|
| 1364 |
+
etag: Optional[str] = None
|
| 1365 |
+
commit_hash: Optional[str] = None
|
| 1366 |
+
expected_size: Optional[int] = None
|
| 1367 |
+
head_error_call: Optional[Exception] = None
|
| 1368 |
+
|
| 1369 |
+
# Try to get metadata from the server.
|
| 1370 |
+
# Do not raise yet if the file is not found or not accessible.
|
| 1371 |
+
if not local_files_only:
|
| 1372 |
+
try:
|
| 1373 |
+
try:
|
| 1374 |
+
metadata = get_hf_file_metadata(
|
| 1375 |
+
url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token
|
| 1376 |
+
)
|
| 1377 |
+
except EntryNotFoundError as http_error:
|
| 1378 |
+
if storage_folder is not None and relative_filename is not None:
|
| 1379 |
+
# Cache the non-existence of the file
|
| 1380 |
+
commit_hash = http_error.response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT)
|
| 1381 |
+
if commit_hash is not None:
|
| 1382 |
+
no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename
|
| 1383 |
+
try:
|
| 1384 |
+
no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1385 |
+
no_exist_file_path.touch()
|
| 1386 |
+
except OSError as e:
|
| 1387 |
+
logger.error(
|
| 1388 |
+
f"Could not cache non-existence of file. Will ignore error and continue. Error: {e}"
|
| 1389 |
+
)
|
| 1390 |
+
_cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
|
| 1391 |
+
raise
|
| 1392 |
+
|
| 1393 |
+
# Commit hash must exist
|
| 1394 |
+
commit_hash = metadata.commit_hash
|
| 1395 |
+
if commit_hash is None:
|
| 1396 |
+
raise FileMetadataError(
|
| 1397 |
+
"Distant resource does not seem to be on huggingface.co. It is possible that a configuration issue"
|
| 1398 |
+
" prevents you from downloading resources from https://huggingface.co. Please check your firewall"
|
| 1399 |
+
" and proxy settings and make sure your SSL certificates are updated."
|
| 1400 |
+
)
|
| 1401 |
+
|
| 1402 |
+
# Etag must exist
|
| 1403 |
+
# If we don't have any of those, raise an error.
|
| 1404 |
+
etag = metadata.etag
|
| 1405 |
+
if etag is None:
|
| 1406 |
+
raise FileMetadataError(
|
| 1407 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
| 1408 |
+
)
|
| 1409 |
+
|
| 1410 |
+
# Size must exist
|
| 1411 |
+
expected_size = metadata.size
|
| 1412 |
+
if expected_size is None:
|
| 1413 |
+
raise FileMetadataError("Distant resource does not have a Content-Length.")
|
| 1414 |
+
|
| 1415 |
+
# In case of a redirect, save an extra redirect on the request.get call,
|
| 1416 |
+
# and ensure we download the exact atomic version even if it changed
|
| 1417 |
+
# between the HEAD and the GET (unlikely, but hey).
|
| 1418 |
+
#
|
| 1419 |
+
# If url domain is different => we are downloading from a CDN => url is signed => don't send auth
|
| 1420 |
+
# If url domain is the same => redirect due to repo rename AND downloading a regular file => keep auth
|
| 1421 |
+
if url != metadata.location:
|
| 1422 |
+
url_to_download = metadata.location
|
| 1423 |
+
if urlparse(url).netloc != urlparse(metadata.location).netloc:
|
| 1424 |
+
# Remove authorization header when downloading a LFS blob
|
| 1425 |
+
headers.pop("authorization", None)
|
| 1426 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
| 1427 |
+
# Actually raise for those subclasses of ConnectionError
|
| 1428 |
+
raise
|
| 1429 |
+
except (
|
| 1430 |
+
requests.exceptions.ConnectionError,
|
| 1431 |
+
requests.exceptions.Timeout,
|
| 1432 |
+
OfflineModeIsEnabled,
|
| 1433 |
+
) as error:
|
| 1434 |
+
# Otherwise, our Internet connection is down.
|
| 1435 |
+
# etag is None
|
| 1436 |
+
head_error_call = error
|
| 1437 |
+
except (RevisionNotFoundError, EntryNotFoundError):
|
| 1438 |
+
# The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted)
|
| 1439 |
+
raise
|
| 1440 |
+
except requests.HTTPError as error:
|
| 1441 |
+
# Multiple reasons for an http error:
|
| 1442 |
+
# - Repository is private and invalid/missing token sent
|
| 1443 |
+
# - Repository is gated and invalid/missing token sent
|
| 1444 |
+
# - Hub is down (error 500 or 504)
|
| 1445 |
+
# => let's switch to 'local_files_only=True' to check if the files are already cached.
|
| 1446 |
+
# (if it's not the case, the error will be re-raised)
|
| 1447 |
+
head_error_call = error
|
| 1448 |
+
except FileMetadataError as error:
|
| 1449 |
+
# Multiple reasons for a FileMetadataError:
|
| 1450 |
+
# - Wrong network configuration (proxy, firewall, SSL certificates)
|
| 1451 |
+
# - Inconsistency on the Hub
|
| 1452 |
+
# => let's switch to 'local_files_only=True' to check if the files are already cached.
|
| 1453 |
+
# (if it's not the case, the error will be re-raised)
|
| 1454 |
+
head_error_call = error
|
| 1455 |
+
|
| 1456 |
+
if not (local_files_only or etag is not None or head_error_call is not None):
|
| 1457 |
+
raise RuntimeError("etag is empty due to uncovered problems")
|
| 1458 |
+
|
| 1459 |
+
return (url_to_download, etag, commit_hash, expected_size, head_error_call) # type: ignore [return-value]
|
| 1460 |
+
|
| 1461 |
+
|
| 1462 |
+
def _raise_on_head_call_error(head_call_error: Exception, force_download: bool, local_files_only: bool) -> NoReturn:
|
| 1463 |
+
"""Raise an appropriate error when the HEAD call failed and we cannot locate a local file."""
|
| 1464 |
+
|
| 1465 |
+
# No head call => we cannot force download.
|
| 1466 |
+
if force_download:
|
| 1467 |
+
if local_files_only:
|
| 1468 |
+
raise ValueError("Cannot pass 'force_download=True' and 'local_files_only=True' at the same time.")
|
| 1469 |
+
elif isinstance(head_call_error, OfflineModeIsEnabled):
|
| 1470 |
+
raise ValueError("Cannot pass 'force_download=True' when offline mode is enabled.") from head_call_error
|
| 1471 |
+
else:
|
| 1472 |
+
raise ValueError("Force download failed due to the above error.") from head_call_error
|
| 1473 |
+
|
| 1474 |
+
# No head call + couldn't find an appropriate file on disk => raise an error.
|
| 1475 |
+
if local_files_only:
|
| 1476 |
+
raise LocalEntryNotFoundError(
|
| 1477 |
+
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable"
|
| 1478 |
+
" hf.co look-ups and downloads online, set 'local_files_only' to False."
|
| 1479 |
+
)
|
| 1480 |
+
elif isinstance(head_call_error, RepositoryNotFoundError) or isinstance(head_call_error, GatedRepoError):
|
| 1481 |
+
# Repo not found or gated => let's raise the actual error
|
| 1482 |
+
raise head_call_error
|
| 1483 |
+
else:
|
| 1484 |
+
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
|
| 1485 |
+
raise LocalEntryNotFoundError(
|
| 1486 |
+
"An error happened while trying to locate the file on the Hub and we cannot find the requested files"
|
| 1487 |
+
" in the local cache. Please check your connection and try again or make sure your Internet connection"
|
| 1488 |
+
" is on."
|
| 1489 |
+
) from head_call_error
|
| 1490 |
+
|
| 1491 |
+
|
| 1492 |
+
def _download_to_tmp_and_move(
|
| 1493 |
+
incomplete_path: Path,
|
| 1494 |
+
destination_path: Path,
|
| 1495 |
+
url_to_download: str,
|
| 1496 |
+
proxies: Optional[Dict],
|
| 1497 |
+
headers: Dict[str, str],
|
| 1498 |
+
expected_size: Optional[int],
|
| 1499 |
+
filename: str,
|
| 1500 |
+
force_download: bool,
|
| 1501 |
+
) -> None:
|
| 1502 |
+
"""Download content from a URL to a destination path.
|
| 1503 |
+
|
| 1504 |
+
Internal logic:
|
| 1505 |
+
- return early if file is already downloaded
|
| 1506 |
+
- resume download if possible (from incomplete file)
|
| 1507 |
+
- do not resume download if `force_download=True` or `HF_HUB_ENABLE_HF_TRANSFER=True`
|
| 1508 |
+
- check disk space before downloading
|
| 1509 |
+
- download content to a temporary file
|
| 1510 |
+
- set correct permissions on temporary file
|
| 1511 |
+
- move the temporary file to the destination path
|
| 1512 |
+
|
| 1513 |
+
Both `incomplete_path` and `destination_path` must be on the same volume to avoid a local copy.
|
| 1514 |
+
"""
|
| 1515 |
+
if destination_path.exists() and not force_download:
|
| 1516 |
+
# Do nothing if already exists (except if force_download=True)
|
| 1517 |
+
return
|
| 1518 |
+
|
| 1519 |
+
if incomplete_path.exists() and (force_download or (constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies)):
|
| 1520 |
+
# By default, we will try to resume the download if possible.
|
| 1521 |
+
# However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should
|
| 1522 |
+
# not resume the download => delete the incomplete file.
|
| 1523 |
+
message = f"Removing incomplete file '{incomplete_path}'"
|
| 1524 |
+
if force_download:
|
| 1525 |
+
message += " (force_download=True)"
|
| 1526 |
+
elif constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies:
|
| 1527 |
+
message += " (hf_transfer=True)"
|
| 1528 |
+
logger.info(message)
|
| 1529 |
+
incomplete_path.unlink(missing_ok=True)
|
| 1530 |
+
|
| 1531 |
+
with incomplete_path.open("ab") as f:
|
| 1532 |
+
resume_size = f.tell()
|
| 1533 |
+
message = f"Downloading '{filename}' to '{incomplete_path}'"
|
| 1534 |
+
if resume_size > 0 and expected_size is not None:
|
| 1535 |
+
message += f" (resume from {resume_size}/{expected_size})"
|
| 1536 |
+
logger.info(message)
|
| 1537 |
+
|
| 1538 |
+
if expected_size is not None: # might be None if HTTP header not set correctly
|
| 1539 |
+
# Check disk space in both tmp and destination path
|
| 1540 |
+
_check_disk_space(expected_size, incomplete_path.parent)
|
| 1541 |
+
_check_disk_space(expected_size, destination_path.parent)
|
| 1542 |
+
|
| 1543 |
+
http_get(
|
| 1544 |
+
url_to_download,
|
| 1545 |
+
f,
|
| 1546 |
+
proxies=proxies,
|
| 1547 |
+
resume_size=resume_size,
|
| 1548 |
+
headers=headers,
|
| 1549 |
+
expected_size=expected_size,
|
| 1550 |
+
)
|
| 1551 |
+
|
| 1552 |
+
logger.info(f"Download complete. Moving file to {destination_path}")
|
| 1553 |
+
_chmod_and_move(incomplete_path, destination_path)
|
| 1554 |
+
|
| 1555 |
+
|
| 1556 |
+
def _int_or_none(value: Optional[str]) -> Optional[int]:
|
| 1557 |
+
try:
|
| 1558 |
+
return int(value) # type: ignore
|
| 1559 |
+
except (TypeError, ValueError):
|
| 1560 |
+
return None
|
| 1561 |
+
|
| 1562 |
+
|
| 1563 |
+
def _chmod_and_move(src: Path, dst: Path) -> None:
|
| 1564 |
+
"""Set correct permission before moving a blob from tmp directory to cache dir.
|
| 1565 |
+
|
| 1566 |
+
Do not take into account the `umask` from the process as there is no convenient way
|
| 1567 |
+
to get it that is thread-safe.
|
| 1568 |
+
|
| 1569 |
+
See:
|
| 1570 |
+
- About umask: https://docs.python.org/3/library/os.html#os.umask
|
| 1571 |
+
- Thread-safety: https://stackoverflow.com/a/70343066
|
| 1572 |
+
- About solution: https://github.com/huggingface/huggingface_hub/pull/1220#issuecomment-1326211591
|
| 1573 |
+
- Fix issue: https://github.com/huggingface/huggingface_hub/issues/1141
|
| 1574 |
+
- Fix issue: https://github.com/huggingface/huggingface_hub/issues/1215
|
| 1575 |
+
"""
|
| 1576 |
+
# Get umask by creating a temporary file in the cached repo folder.
|
| 1577 |
+
tmp_file = dst.parent.parent / f"tmp_{uuid.uuid4()}"
|
| 1578 |
+
try:
|
| 1579 |
+
tmp_file.touch()
|
| 1580 |
+
cache_dir_mode = Path(tmp_file).stat().st_mode
|
| 1581 |
+
os.chmod(str(src), stat.S_IMODE(cache_dir_mode))
|
| 1582 |
+
except OSError as e:
|
| 1583 |
+
logger.warning(
|
| 1584 |
+
f"Could not set the permissions on the file '{src}'. Error: {e}.\nContinuing without setting permissions."
|
| 1585 |
+
)
|
| 1586 |
+
finally:
|
| 1587 |
+
try:
|
| 1588 |
+
tmp_file.unlink()
|
| 1589 |
+
except OSError:
|
| 1590 |
+
# fails if `tmp_file.touch()` failed => do nothing
|
| 1591 |
+
# See https://github.com/huggingface/huggingface_hub/issues/2359
|
| 1592 |
+
pass
|
| 1593 |
+
|
| 1594 |
+
shutil.move(str(src), str(dst), copy_function=_copy_no_matter_what)
|
| 1595 |
+
|
| 1596 |
+
|
| 1597 |
+
def _copy_no_matter_what(src: str, dst: str) -> None:
|
| 1598 |
+
"""Copy file from src to dst.
|
| 1599 |
+
|
| 1600 |
+
If `shutil.copy2` fails, fallback to `shutil.copyfile`.
|
| 1601 |
+
"""
|
| 1602 |
+
try:
|
| 1603 |
+
# Copy file with metadata and permission
|
| 1604 |
+
# Can fail e.g. if dst is an S3 mount
|
| 1605 |
+
shutil.copy2(src, dst)
|
| 1606 |
+
except OSError:
|
| 1607 |
+
# Copy only file content
|
| 1608 |
+
shutil.copyfile(src, dst)
|
| 1609 |
+
|
| 1610 |
+
|
| 1611 |
+
def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str:
|
| 1612 |
+
# Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
|
| 1613 |
+
snapshot_path = os.path.join(storage_folder, "snapshots")
|
| 1614 |
+
pointer_path = os.path.join(snapshot_path, revision, relative_filename)
|
| 1615 |
+
if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents:
|
| 1616 |
+
raise ValueError(
|
| 1617 |
+
"Invalid pointer path: cannot create pointer path in snapshot folder if"
|
| 1618 |
+
f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and"
|
| 1619 |
+
f" `relative_filename='{relative_filename}'`."
|
| 1620 |
+
)
|
| 1621 |
+
return pointer_path
|
.venv/lib/python3.11/site-packages/huggingface_hub/hf_api.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/hf_file_system.py
ADDED
|
@@ -0,0 +1,1140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import tempfile
|
| 4 |
+
from collections import deque
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from itertools import chain
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Iterator, List, NoReturn, Optional, Tuple, Union
|
| 10 |
+
from urllib.parse import quote, unquote
|
| 11 |
+
|
| 12 |
+
import fsspec
|
| 13 |
+
from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
|
| 14 |
+
from fsspec.utils import isfilelike
|
| 15 |
+
from requests import Response
|
| 16 |
+
|
| 17 |
+
from . import constants
|
| 18 |
+
from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
| 19 |
+
from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
| 20 |
+
from .file_download import hf_hub_url, http_get
|
| 21 |
+
from .hf_api import HfApi, LastCommitInfo, RepoFile
|
| 22 |
+
from .utils import HFValidationError, hf_raise_for_status, http_backoff
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Regex used to match special revisions with "/" in them (see #1710)
|
| 26 |
+
SPECIAL_REFS_REVISION_REGEX = re.compile(
|
| 27 |
+
r"""
|
| 28 |
+
(^refs\/convert\/\w+) # `refs/convert/parquet` revisions
|
| 29 |
+
|
|
| 30 |
+
(^refs\/pr\/\d+) # PR revisions
|
| 31 |
+
""",
|
| 32 |
+
re.VERBOSE,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class HfFileSystemResolvedPath:
|
| 38 |
+
"""Data structure containing information about a resolved Hugging Face file system path."""
|
| 39 |
+
|
| 40 |
+
repo_type: str
|
| 41 |
+
repo_id: str
|
| 42 |
+
revision: str
|
| 43 |
+
path_in_repo: str
|
| 44 |
+
# The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision.
|
| 45 |
+
# Used to reconstruct the unresolved path to return to the user.
|
| 46 |
+
_raw_revision: Optional[str] = field(default=None, repr=False)
|
| 47 |
+
|
| 48 |
+
def unresolve(self) -> str:
|
| 49 |
+
repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id
|
| 50 |
+
if self._raw_revision:
|
| 51 |
+
return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/")
|
| 52 |
+
elif self.revision != constants.DEFAULT_REVISION:
|
| 53 |
+
return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/")
|
| 54 |
+
else:
|
| 55 |
+
return f"{repo_path}/{self.path_in_repo}".rstrip("/")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class HfFileSystem(fsspec.AbstractFileSystem):
|
| 59 |
+
"""
|
| 60 |
+
Access a remote Hugging Face Hub repository as if were a local file system.
|
| 61 |
+
|
| 62 |
+
<Tip warning={true}>
|
| 63 |
+
|
| 64 |
+
[`HfFileSystem`] provides fsspec compatibility, which is useful for libraries that require it (e.g., reading
|
| 65 |
+
Hugging Face datasets directly with `pandas`). However, it introduces additional overhead due to this compatibility
|
| 66 |
+
layer. For better performance and reliability, it's recommended to use `HfApi` methods when possible.
|
| 67 |
+
|
| 68 |
+
</Tip>
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
token (`str` or `bool`, *optional*):
|
| 72 |
+
A valid user access token (string). Defaults to the locally saved
|
| 73 |
+
token, which is the recommended method for authentication (see
|
| 74 |
+
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
|
| 75 |
+
To disable authentication, pass `False`.
|
| 76 |
+
endpoint (`str`, *optional*):
|
| 77 |
+
Endpoint of the Hub. Defaults to <https://huggingface.co>.
|
| 78 |
+
Usage:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
>>> from huggingface_hub import HfFileSystem
|
| 82 |
+
|
| 83 |
+
>>> fs = HfFileSystem()
|
| 84 |
+
|
| 85 |
+
>>> # List files
|
| 86 |
+
>>> fs.glob("my-username/my-model/*.bin")
|
| 87 |
+
['my-username/my-model/pytorch_model.bin']
|
| 88 |
+
>>> fs.ls("datasets/my-username/my-dataset", detail=False)
|
| 89 |
+
['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json']
|
| 90 |
+
|
| 91 |
+
>>> # Read/write files
|
| 92 |
+
>>> with fs.open("my-username/my-model/pytorch_model.bin") as f:
|
| 93 |
+
... data = f.read()
|
| 94 |
+
>>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
|
| 95 |
+
... f.write(data)
|
| 96 |
+
```
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
root_marker = ""
|
| 100 |
+
protocol = "hf"
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
*args,
|
| 105 |
+
endpoint: Optional[str] = None,
|
| 106 |
+
token: Union[bool, str, None] = None,
|
| 107 |
+
**storage_options,
|
| 108 |
+
):
|
| 109 |
+
super().__init__(*args, **storage_options)
|
| 110 |
+
self.endpoint = endpoint or constants.ENDPOINT
|
| 111 |
+
self.token = token
|
| 112 |
+
self._api = HfApi(endpoint=endpoint, token=token)
|
| 113 |
+
# Maps (repo_type, repo_id, revision) to a 2-tuple with:
|
| 114 |
+
# * the 1st element indicating whether the repositoy and the revision exist
|
| 115 |
+
# * the 2nd element being the exception raised if the repository or revision doesn't exist
|
| 116 |
+
self._repo_and_revision_exists_cache: Dict[
|
| 117 |
+
Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]]
|
| 118 |
+
] = {}
|
| 119 |
+
|
| 120 |
+
def _repo_and_revision_exist(
|
| 121 |
+
self, repo_type: str, repo_id: str, revision: Optional[str]
|
| 122 |
+
) -> Tuple[bool, Optional[Exception]]:
|
| 123 |
+
if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
|
| 124 |
+
try:
|
| 125 |
+
self._api.repo_info(
|
| 126 |
+
repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT
|
| 127 |
+
)
|
| 128 |
+
except (RepositoryNotFoundError, HFValidationError) as e:
|
| 129 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
|
| 130 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e
|
| 131 |
+
except RevisionNotFoundError as e:
|
| 132 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
|
| 133 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
|
| 134 |
+
else:
|
| 135 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None
|
| 136 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
|
| 137 |
+
return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)]
|
| 138 |
+
|
| 139 |
+
def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath:
|
| 140 |
+
"""
|
| 141 |
+
Resolve a Hugging Face file system path into its components.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
path (`str`):
|
| 145 |
+
Path to resolve.
|
| 146 |
+
revision (`str`, *optional*):
|
| 147 |
+
The revision of the repo to resolve. Defaults to the revision specified in the path.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
[`HfFileSystemResolvedPath`]: Resolved path information containing `repo_type`, `repo_id`, `revision` and `path_in_repo`.
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
`ValueError`:
|
| 154 |
+
If path contains conflicting revision information.
|
| 155 |
+
`NotImplementedError`:
|
| 156 |
+
If trying to list repositories.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def _align_revision_in_path_with_revision(
|
| 160 |
+
revision_in_path: Optional[str], revision: Optional[str]
|
| 161 |
+
) -> Optional[str]:
|
| 162 |
+
if revision is not None:
|
| 163 |
+
if revision_in_path is not None and revision_in_path != revision:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")'
|
| 166 |
+
" are not the same."
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
revision = revision_in_path
|
| 170 |
+
return revision
|
| 171 |
+
|
| 172 |
+
path = self._strip_protocol(path)
|
| 173 |
+
if not path:
|
| 174 |
+
# can't list repositories at root
|
| 175 |
+
raise NotImplementedError("Access to repositories lists is not implemented.")
|
| 176 |
+
elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values():
|
| 177 |
+
if "/" not in path:
|
| 178 |
+
# can't list repositories at the repository type level
|
| 179 |
+
raise NotImplementedError("Access to repositories lists is not implemented.")
|
| 180 |
+
repo_type, path = path.split("/", 1)
|
| 181 |
+
repo_type = constants.REPO_TYPES_MAPPING[repo_type]
|
| 182 |
+
else:
|
| 183 |
+
repo_type = constants.REPO_TYPE_MODEL
|
| 184 |
+
if path.count("/") > 0:
|
| 185 |
+
if "@" in path:
|
| 186 |
+
repo_id, revision_in_path = path.split("@", 1)
|
| 187 |
+
if "/" in revision_in_path:
|
| 188 |
+
match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path)
|
| 189 |
+
if match is not None and revision in (None, match.group()):
|
| 190 |
+
# Handle `refs/convert/parquet` and PR revisions separately
|
| 191 |
+
path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/")
|
| 192 |
+
revision_in_path = match.group()
|
| 193 |
+
else:
|
| 194 |
+
revision_in_path, path_in_repo = revision_in_path.split("/", 1)
|
| 195 |
+
else:
|
| 196 |
+
path_in_repo = ""
|
| 197 |
+
revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
|
| 198 |
+
repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
| 199 |
+
if not repo_and_revision_exist:
|
| 200 |
+
_raise_file_not_found(path, err)
|
| 201 |
+
else:
|
| 202 |
+
revision_in_path = None
|
| 203 |
+
repo_id_with_namespace = "/".join(path.split("/")[:2])
|
| 204 |
+
path_in_repo_with_namespace = "/".join(path.split("/")[2:])
|
| 205 |
+
repo_id_without_namespace = path.split("/")[0]
|
| 206 |
+
path_in_repo_without_namespace = "/".join(path.split("/")[1:])
|
| 207 |
+
repo_id = repo_id_with_namespace
|
| 208 |
+
path_in_repo = path_in_repo_with_namespace
|
| 209 |
+
repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
| 210 |
+
if not repo_and_revision_exist:
|
| 211 |
+
if isinstance(err, (RepositoryNotFoundError, HFValidationError)):
|
| 212 |
+
repo_id = repo_id_without_namespace
|
| 213 |
+
path_in_repo = path_in_repo_without_namespace
|
| 214 |
+
repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
| 215 |
+
if not repo_and_revision_exist:
|
| 216 |
+
_raise_file_not_found(path, err)
|
| 217 |
+
else:
|
| 218 |
+
_raise_file_not_found(path, err)
|
| 219 |
+
else:
|
| 220 |
+
repo_id = path
|
| 221 |
+
path_in_repo = ""
|
| 222 |
+
if "@" in path:
|
| 223 |
+
repo_id, revision_in_path = path.split("@", 1)
|
| 224 |
+
revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
|
| 225 |
+
else:
|
| 226 |
+
revision_in_path = None
|
| 227 |
+
repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
| 228 |
+
if not repo_and_revision_exist:
|
| 229 |
+
raise NotImplementedError("Access to repositories lists is not implemented.")
|
| 230 |
+
|
| 231 |
+
revision = revision if revision is not None else constants.DEFAULT_REVISION
|
| 232 |
+
return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path)
|
| 233 |
+
|
| 234 |
+
def invalidate_cache(self, path: Optional[str] = None) -> None:
|
| 235 |
+
"""
|
| 236 |
+
Clear the cache for a given path.
|
| 237 |
+
|
| 238 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.invalidate_cache).
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
path (`str`, *optional*):
|
| 242 |
+
Path to clear from cache. If not provided, clear the entire cache.
|
| 243 |
+
|
| 244 |
+
"""
|
| 245 |
+
if not path:
|
| 246 |
+
self.dircache.clear()
|
| 247 |
+
self._repo_and_revision_exists_cache.clear()
|
| 248 |
+
else:
|
| 249 |
+
resolved_path = self.resolve_path(path)
|
| 250 |
+
path = resolved_path.unresolve()
|
| 251 |
+
while path:
|
| 252 |
+
self.dircache.pop(path, None)
|
| 253 |
+
path = self._parent(path)
|
| 254 |
+
|
| 255 |
+
# Only clear repo cache if path is to repo root
|
| 256 |
+
if not resolved_path.path_in_repo:
|
| 257 |
+
self._repo_and_revision_exists_cache.pop((resolved_path.repo_type, resolved_path.repo_id, None), None)
|
| 258 |
+
self._repo_and_revision_exists_cache.pop(
|
| 259 |
+
(resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision), None
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def _open(
|
| 263 |
+
self,
|
| 264 |
+
path: str,
|
| 265 |
+
mode: str = "rb",
|
| 266 |
+
revision: Optional[str] = None,
|
| 267 |
+
block_size: Optional[int] = None,
|
| 268 |
+
**kwargs,
|
| 269 |
+
) -> "HfFileSystemFile":
|
| 270 |
+
if "a" in mode:
|
| 271 |
+
raise NotImplementedError("Appending to remote files is not yet supported.")
|
| 272 |
+
if block_size == 0:
|
| 273 |
+
return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
|
| 274 |
+
else:
|
| 275 |
+
return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
|
| 276 |
+
|
| 277 |
+
def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None:
|
| 278 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
| 279 |
+
self._api.delete_file(
|
| 280 |
+
path_in_repo=resolved_path.path_in_repo,
|
| 281 |
+
repo_id=resolved_path.repo_id,
|
| 282 |
+
token=self.token,
|
| 283 |
+
repo_type=resolved_path.repo_type,
|
| 284 |
+
revision=resolved_path.revision,
|
| 285 |
+
commit_message=kwargs.get("commit_message"),
|
| 286 |
+
commit_description=kwargs.get("commit_description"),
|
| 287 |
+
)
|
| 288 |
+
self.invalidate_cache(path=resolved_path.unresolve())
|
| 289 |
+
|
| 290 |
+
def rm(
|
| 291 |
+
self,
|
| 292 |
+
path: str,
|
| 293 |
+
recursive: bool = False,
|
| 294 |
+
maxdepth: Optional[int] = None,
|
| 295 |
+
revision: Optional[str] = None,
|
| 296 |
+
**kwargs,
|
| 297 |
+
) -> None:
|
| 298 |
+
"""
|
| 299 |
+
Delete files from a repository.
|
| 300 |
+
|
| 301 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.rm).
|
| 302 |
+
|
| 303 |
+
<Tip warning={true}>
|
| 304 |
+
|
| 305 |
+
Note: When possible, use `HfApi.delete_file()` for better performance.
|
| 306 |
+
|
| 307 |
+
</Tip>
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
path (`str`):
|
| 311 |
+
Path to delete.
|
| 312 |
+
recursive (`bool`, *optional*):
|
| 313 |
+
If True, delete directory and all its contents. Defaults to False.
|
| 314 |
+
maxdepth (`int`, *optional*):
|
| 315 |
+
Maximum number of subdirectories to visit when deleting recursively.
|
| 316 |
+
revision (`str`, *optional*):
|
| 317 |
+
The git revision to delete from.
|
| 318 |
+
|
| 319 |
+
"""
|
| 320 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
| 321 |
+
paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision)
|
| 322 |
+
paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)]
|
| 323 |
+
operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo]
|
| 324 |
+
commit_message = f"Delete {path} "
|
| 325 |
+
commit_message += "recursively " if recursive else ""
|
| 326 |
+
commit_message += f"up to depth {maxdepth} " if maxdepth is not None else ""
|
| 327 |
+
# TODO: use `commit_description` to list all the deleted paths?
|
| 328 |
+
self._api.create_commit(
|
| 329 |
+
repo_id=resolved_path.repo_id,
|
| 330 |
+
repo_type=resolved_path.repo_type,
|
| 331 |
+
token=self.token,
|
| 332 |
+
operations=operations,
|
| 333 |
+
revision=resolved_path.revision,
|
| 334 |
+
commit_message=kwargs.get("commit_message", commit_message),
|
| 335 |
+
commit_description=kwargs.get("commit_description"),
|
| 336 |
+
)
|
| 337 |
+
self.invalidate_cache(path=resolved_path.unresolve())
|
| 338 |
+
|
| 339 |
+
def ls(
|
| 340 |
+
self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
|
| 341 |
+
) -> List[Union[str, Dict[str, Any]]]:
|
| 342 |
+
"""
|
| 343 |
+
List the contents of a directory.
|
| 344 |
+
|
| 345 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.ls).
|
| 346 |
+
|
| 347 |
+
<Tip warning={true}>
|
| 348 |
+
|
| 349 |
+
Note: When possible, use `HfApi.list_repo_tree()` for better performance.
|
| 350 |
+
|
| 351 |
+
</Tip>
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
path (`str`):
|
| 355 |
+
Path to the directory.
|
| 356 |
+
detail (`bool`, *optional*):
|
| 357 |
+
If True, returns a list of dictionaries containing file information. If False,
|
| 358 |
+
returns a list of file paths. Defaults to True.
|
| 359 |
+
refresh (`bool`, *optional*):
|
| 360 |
+
If True, bypass the cache and fetch the latest data. Defaults to False.
|
| 361 |
+
revision (`str`, *optional*):
|
| 362 |
+
The git revision to list from.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
`List[Union[str, Dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
|
| 366 |
+
dictionaries (if detail=True).
|
| 367 |
+
"""
|
| 368 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
| 369 |
+
path = resolved_path.unresolve()
|
| 370 |
+
kwargs = {"expand_info": detail, **kwargs}
|
| 371 |
+
try:
|
| 372 |
+
out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs)
|
| 373 |
+
except EntryNotFoundError:
|
| 374 |
+
# Path could be a file
|
| 375 |
+
if not resolved_path.path_in_repo:
|
| 376 |
+
_raise_file_not_found(path, None)
|
| 377 |
+
out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs)
|
| 378 |
+
out = [o for o in out if o["name"] == path]
|
| 379 |
+
if len(out) == 0:
|
| 380 |
+
_raise_file_not_found(path, None)
|
| 381 |
+
return out if detail else [o["name"] for o in out]
|
| 382 |
+
|
| 383 |
+
def _ls_tree(
|
| 384 |
+
self,
|
| 385 |
+
path: str,
|
| 386 |
+
recursive: bool = False,
|
| 387 |
+
refresh: bool = False,
|
| 388 |
+
revision: Optional[str] = None,
|
| 389 |
+
expand_info: bool = True,
|
| 390 |
+
):
|
| 391 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
| 392 |
+
path = resolved_path.unresolve()
|
| 393 |
+
root_path = HfFileSystemResolvedPath(
|
| 394 |
+
resolved_path.repo_type,
|
| 395 |
+
resolved_path.repo_id,
|
| 396 |
+
resolved_path.revision,
|
| 397 |
+
path_in_repo="",
|
| 398 |
+
_raw_revision=resolved_path._raw_revision,
|
| 399 |
+
).unresolve()
|
| 400 |
+
|
| 401 |
+
out = []
|
| 402 |
+
if path in self.dircache and not refresh:
|
| 403 |
+
cached_path_infos = self.dircache[path]
|
| 404 |
+
out.extend(cached_path_infos)
|
| 405 |
+
dirs_not_in_dircache = []
|
| 406 |
+
if recursive:
|
| 407 |
+
# Use BFS to traverse the cache and build the "recursive "output
|
| 408 |
+
# (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same)
|
| 409 |
+
dirs_to_visit = deque(
|
| 410 |
+
[path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
|
| 411 |
+
)
|
| 412 |
+
while dirs_to_visit:
|
| 413 |
+
dir_info = dirs_to_visit.popleft()
|
| 414 |
+
if dir_info["name"] not in self.dircache:
|
| 415 |
+
dirs_not_in_dircache.append(dir_info["name"])
|
| 416 |
+
else:
|
| 417 |
+
cached_path_infos = self.dircache[dir_info["name"]]
|
| 418 |
+
out.extend(cached_path_infos)
|
| 419 |
+
dirs_to_visit.extend(
|
| 420 |
+
[path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
dirs_not_expanded = []
|
| 424 |
+
if expand_info:
|
| 425 |
+
# Check if there are directories with non-expanded entries
|
| 426 |
+
dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None]
|
| 427 |
+
|
| 428 |
+
if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded):
|
| 429 |
+
# If the dircache is incomplete, find the common path of the missing and non-expanded entries
|
| 430 |
+
# and extend the output with the result of `_ls_tree(common_path, recursive=True)`
|
| 431 |
+
common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded)
|
| 432 |
+
# Get the parent directory if the common prefix itself is not a directory
|
| 433 |
+
common_path = (
|
| 434 |
+
common_prefix.rstrip("/")
|
| 435 |
+
if common_prefix.endswith("/")
|
| 436 |
+
or common_prefix == root_path
|
| 437 |
+
or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded)
|
| 438 |
+
else self._parent(common_prefix)
|
| 439 |
+
)
|
| 440 |
+
out = [o for o in out if not o["name"].startswith(common_path + "/")]
|
| 441 |
+
for cached_path in self.dircache:
|
| 442 |
+
if cached_path.startswith(common_path + "/"):
|
| 443 |
+
self.dircache.pop(cached_path, None)
|
| 444 |
+
self.dircache.pop(common_path, None)
|
| 445 |
+
out.extend(
|
| 446 |
+
self._ls_tree(
|
| 447 |
+
common_path,
|
| 448 |
+
recursive=recursive,
|
| 449 |
+
refresh=True,
|
| 450 |
+
revision=revision,
|
| 451 |
+
expand_info=expand_info,
|
| 452 |
+
)
|
| 453 |
+
)
|
| 454 |
+
else:
|
| 455 |
+
tree = self._api.list_repo_tree(
|
| 456 |
+
resolved_path.repo_id,
|
| 457 |
+
resolved_path.path_in_repo,
|
| 458 |
+
recursive=recursive,
|
| 459 |
+
expand=expand_info,
|
| 460 |
+
revision=resolved_path.revision,
|
| 461 |
+
repo_type=resolved_path.repo_type,
|
| 462 |
+
)
|
| 463 |
+
for path_info in tree:
|
| 464 |
+
if isinstance(path_info, RepoFile):
|
| 465 |
+
cache_path_info = {
|
| 466 |
+
"name": root_path + "/" + path_info.path,
|
| 467 |
+
"size": path_info.size,
|
| 468 |
+
"type": "file",
|
| 469 |
+
"blob_id": path_info.blob_id,
|
| 470 |
+
"lfs": path_info.lfs,
|
| 471 |
+
"last_commit": path_info.last_commit,
|
| 472 |
+
"security": path_info.security,
|
| 473 |
+
}
|
| 474 |
+
else:
|
| 475 |
+
cache_path_info = {
|
| 476 |
+
"name": root_path + "/" + path_info.path,
|
| 477 |
+
"size": 0,
|
| 478 |
+
"type": "directory",
|
| 479 |
+
"tree_id": path_info.tree_id,
|
| 480 |
+
"last_commit": path_info.last_commit,
|
| 481 |
+
}
|
| 482 |
+
parent_path = self._parent(cache_path_info["name"])
|
| 483 |
+
self.dircache.setdefault(parent_path, []).append(cache_path_info)
|
| 484 |
+
out.append(cache_path_info)
|
| 485 |
+
return out
|
| 486 |
+
|
| 487 |
+
def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]:
|
| 488 |
+
"""
|
| 489 |
+
Return all files below the given path.
|
| 490 |
+
|
| 491 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.walk).
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
path (`str`):
|
| 495 |
+
Root path to list files from.
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
`Iterator[Tuple[str, List[str], List[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
|
| 499 |
+
"""
|
| 500 |
+
# Set expand_info=False by default to get a x10 speed boost
|
| 501 |
+
kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
|
| 502 |
+
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
| 503 |
+
yield from super().walk(path, *args, **kwargs)
|
| 504 |
+
|
| 505 |
+
def glob(self, path: str, **kwargs) -> List[str]:
|
| 506 |
+
"""
|
| 507 |
+
Find files by glob-matching.
|
| 508 |
+
|
| 509 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob).
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
path (`str`):
|
| 513 |
+
Path pattern to match.
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
`List[str]`: List of paths matching the pattern.
|
| 517 |
+
"""
|
| 518 |
+
# Set expand_info=False by default to get a x10 speed boost
|
| 519 |
+
kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
|
| 520 |
+
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
| 521 |
+
return super().glob(path, **kwargs)
|
| 522 |
+
|
| 523 |
+
def find(
|
| 524 |
+
self,
|
| 525 |
+
path: str,
|
| 526 |
+
maxdepth: Optional[int] = None,
|
| 527 |
+
withdirs: bool = False,
|
| 528 |
+
detail: bool = False,
|
| 529 |
+
refresh: bool = False,
|
| 530 |
+
revision: Optional[str] = None,
|
| 531 |
+
**kwargs,
|
| 532 |
+
) -> Union[List[str], Dict[str, Dict[str, Any]]]:
|
| 533 |
+
"""
|
| 534 |
+
List all files below path.
|
| 535 |
+
|
| 536 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.find).
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
path (`str`):
|
| 540 |
+
Root path to list files from.
|
| 541 |
+
maxdepth (`int`, *optional*):
|
| 542 |
+
Maximum depth to descend into subdirectories.
|
| 543 |
+
withdirs (`bool`, *optional*):
|
| 544 |
+
Include directory paths in the output. Defaults to False.
|
| 545 |
+
detail (`bool`, *optional*):
|
| 546 |
+
If True, returns a dict mapping paths to file information. Defaults to False.
|
| 547 |
+
refresh (`bool`, *optional*):
|
| 548 |
+
If True, bypass the cache and fetch the latest data. Defaults to False.
|
| 549 |
+
revision (`str`, *optional*):
|
| 550 |
+
The git revision to list from.
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
`Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information.
|
| 554 |
+
"""
|
| 555 |
+
if maxdepth:
|
| 556 |
+
return super().find(
|
| 557 |
+
path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs
|
| 558 |
+
)
|
| 559 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
| 560 |
+
path = resolved_path.unresolve()
|
| 561 |
+
kwargs = {"expand_info": detail, **kwargs}
|
| 562 |
+
try:
|
| 563 |
+
out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs)
|
| 564 |
+
except EntryNotFoundError:
|
| 565 |
+
# Path could be a file
|
| 566 |
+
if self.info(path, revision=revision, **kwargs)["type"] == "file":
|
| 567 |
+
out = {path: {}}
|
| 568 |
+
else:
|
| 569 |
+
out = {}
|
| 570 |
+
else:
|
| 571 |
+
if not withdirs:
|
| 572 |
+
out = [o for o in out if o["type"] != "directory"]
|
| 573 |
+
else:
|
| 574 |
+
# If `withdirs=True`, include the directory itself to be consistent with the spec
|
| 575 |
+
path_info = self.info(path, revision=resolved_path.revision, **kwargs)
|
| 576 |
+
out = [path_info] + out if path_info["type"] == "directory" else out
|
| 577 |
+
out = {o["name"]: o for o in out}
|
| 578 |
+
names = sorted(out)
|
| 579 |
+
if not detail:
|
| 580 |
+
return names
|
| 581 |
+
else:
|
| 582 |
+
return {name: out[name] for name in names}
|
| 583 |
+
|
| 584 |
+
def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None:
|
| 585 |
+
"""
|
| 586 |
+
Copy a file within or between repositories.
|
| 587 |
+
|
| 588 |
+
<Tip warning={true}>
|
| 589 |
+
|
| 590 |
+
Note: When possible, use `HfApi.upload_file()` for better performance.
|
| 591 |
+
|
| 592 |
+
</Tip>
|
| 593 |
+
|
| 594 |
+
Args:
|
| 595 |
+
path1 (`str`):
|
| 596 |
+
Source path to copy from.
|
| 597 |
+
path2 (`str`):
|
| 598 |
+
Destination path to copy to.
|
| 599 |
+
revision (`str`, *optional*):
|
| 600 |
+
The git revision to copy from.
|
| 601 |
+
|
| 602 |
+
"""
|
| 603 |
+
resolved_path1 = self.resolve_path(path1, revision=revision)
|
| 604 |
+
resolved_path2 = self.resolve_path(path2, revision=revision)
|
| 605 |
+
|
| 606 |
+
same_repo = (
|
| 607 |
+
resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if same_repo:
|
| 611 |
+
commit_message = f"Copy {path1} to {path2}"
|
| 612 |
+
self._api.create_commit(
|
| 613 |
+
repo_id=resolved_path1.repo_id,
|
| 614 |
+
repo_type=resolved_path1.repo_type,
|
| 615 |
+
revision=resolved_path2.revision,
|
| 616 |
+
commit_message=kwargs.get("commit_message", commit_message),
|
| 617 |
+
commit_description=kwargs.get("commit_description", ""),
|
| 618 |
+
operations=[
|
| 619 |
+
CommitOperationCopy(
|
| 620 |
+
src_path_in_repo=resolved_path1.path_in_repo,
|
| 621 |
+
path_in_repo=resolved_path2.path_in_repo,
|
| 622 |
+
src_revision=resolved_path1.revision,
|
| 623 |
+
)
|
| 624 |
+
],
|
| 625 |
+
)
|
| 626 |
+
else:
|
| 627 |
+
with self.open(path1, "rb", revision=resolved_path1.revision) as f:
|
| 628 |
+
content = f.read()
|
| 629 |
+
commit_message = f"Copy {path1} to {path2}"
|
| 630 |
+
self._api.upload_file(
|
| 631 |
+
path_or_fileobj=content,
|
| 632 |
+
path_in_repo=resolved_path2.path_in_repo,
|
| 633 |
+
repo_id=resolved_path2.repo_id,
|
| 634 |
+
token=self.token,
|
| 635 |
+
repo_type=resolved_path2.repo_type,
|
| 636 |
+
revision=resolved_path2.revision,
|
| 637 |
+
commit_message=kwargs.get("commit_message", commit_message),
|
| 638 |
+
commit_description=kwargs.get("commit_description"),
|
| 639 |
+
)
|
| 640 |
+
self.invalidate_cache(path=resolved_path1.unresolve())
|
| 641 |
+
self.invalidate_cache(path=resolved_path2.unresolve())
|
| 642 |
+
|
| 643 |
+
def modified(self, path: str, **kwargs) -> datetime:
|
| 644 |
+
"""
|
| 645 |
+
Get the last modified time of a file.
|
| 646 |
+
|
| 647 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.modified).
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
path (`str`):
|
| 651 |
+
Path to the file.
|
| 652 |
+
|
| 653 |
+
Returns:
|
| 654 |
+
`datetime`: Last commit date of the file.
|
| 655 |
+
"""
|
| 656 |
+
info = self.info(path, **kwargs)
|
| 657 |
+
return info["last_commit"]["date"]
|
| 658 |
+
|
| 659 |
+
def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]:
|
| 660 |
+
"""
|
| 661 |
+
Get information about a file or directory.
|
| 662 |
+
|
| 663 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.info).
|
| 664 |
+
|
| 665 |
+
<Tip warning={true}>
|
| 666 |
+
|
| 667 |
+
Note: When possible, use `HfApi.get_paths_info()` or `HfApi.repo_info()` for better performance.
|
| 668 |
+
|
| 669 |
+
</Tip>
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
path (`str`):
|
| 673 |
+
Path to get info for.
|
| 674 |
+
refresh (`bool`, *optional*):
|
| 675 |
+
If True, bypass the cache and fetch the latest data. Defaults to False.
|
| 676 |
+
revision (`str`, *optional*):
|
| 677 |
+
The git revision to get info from.
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
`Dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
|
| 681 |
+
|
| 682 |
+
"""
|
| 683 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
| 684 |
+
path = resolved_path.unresolve()
|
| 685 |
+
expand_info = kwargs.get(
|
| 686 |
+
"expand_info", True
|
| 687 |
+
) # don't expose it as a parameter in the public API to follow the spec
|
| 688 |
+
if not resolved_path.path_in_repo:
|
| 689 |
+
# Path is the root directory
|
| 690 |
+
out = {
|
| 691 |
+
"name": path,
|
| 692 |
+
"size": 0,
|
| 693 |
+
"type": "directory",
|
| 694 |
+
}
|
| 695 |
+
if expand_info:
|
| 696 |
+
last_commit = self._api.list_repo_commits(
|
| 697 |
+
resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision
|
| 698 |
+
)[-1]
|
| 699 |
+
out = {
|
| 700 |
+
**out,
|
| 701 |
+
"tree_id": None, # TODO: tree_id of the root directory?
|
| 702 |
+
"last_commit": LastCommitInfo(
|
| 703 |
+
oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at
|
| 704 |
+
),
|
| 705 |
+
}
|
| 706 |
+
else:
|
| 707 |
+
out = None
|
| 708 |
+
parent_path = self._parent(path)
|
| 709 |
+
if not expand_info and parent_path not in self.dircache:
|
| 710 |
+
# Fill the cache with cheap call
|
| 711 |
+
self.ls(parent_path, expand_info=False)
|
| 712 |
+
if parent_path in self.dircache:
|
| 713 |
+
# Check if the path is in the cache
|
| 714 |
+
out1 = [o for o in self.dircache[parent_path] if o["name"] == path]
|
| 715 |
+
if not out1:
|
| 716 |
+
_raise_file_not_found(path, None)
|
| 717 |
+
out = out1[0]
|
| 718 |
+
if refresh or out is None or (expand_info and out and out["last_commit"] is None):
|
| 719 |
+
paths_info = self._api.get_paths_info(
|
| 720 |
+
resolved_path.repo_id,
|
| 721 |
+
resolved_path.path_in_repo,
|
| 722 |
+
expand=expand_info,
|
| 723 |
+
revision=resolved_path.revision,
|
| 724 |
+
repo_type=resolved_path.repo_type,
|
| 725 |
+
)
|
| 726 |
+
if not paths_info:
|
| 727 |
+
_raise_file_not_found(path, None)
|
| 728 |
+
path_info = paths_info[0]
|
| 729 |
+
root_path = HfFileSystemResolvedPath(
|
| 730 |
+
resolved_path.repo_type,
|
| 731 |
+
resolved_path.repo_id,
|
| 732 |
+
resolved_path.revision,
|
| 733 |
+
path_in_repo="",
|
| 734 |
+
_raw_revision=resolved_path._raw_revision,
|
| 735 |
+
).unresolve()
|
| 736 |
+
if isinstance(path_info, RepoFile):
|
| 737 |
+
out = {
|
| 738 |
+
"name": root_path + "/" + path_info.path,
|
| 739 |
+
"size": path_info.size,
|
| 740 |
+
"type": "file",
|
| 741 |
+
"blob_id": path_info.blob_id,
|
| 742 |
+
"lfs": path_info.lfs,
|
| 743 |
+
"last_commit": path_info.last_commit,
|
| 744 |
+
"security": path_info.security,
|
| 745 |
+
}
|
| 746 |
+
else:
|
| 747 |
+
out = {
|
| 748 |
+
"name": root_path + "/" + path_info.path,
|
| 749 |
+
"size": 0,
|
| 750 |
+
"type": "directory",
|
| 751 |
+
"tree_id": path_info.tree_id,
|
| 752 |
+
"last_commit": path_info.last_commit,
|
| 753 |
+
}
|
| 754 |
+
if not expand_info:
|
| 755 |
+
out = {k: out[k] for k in ["name", "size", "type"]}
|
| 756 |
+
assert out is not None
|
| 757 |
+
return out
|
| 758 |
+
|
| 759 |
+
def exists(self, path, **kwargs):
|
| 760 |
+
"""
|
| 761 |
+
Check if a file exists.
|
| 762 |
+
|
| 763 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists).
|
| 764 |
+
|
| 765 |
+
<Tip warning={true}>
|
| 766 |
+
|
| 767 |
+
Note: When possible, use `HfApi.file_exists()` for better performance.
|
| 768 |
+
|
| 769 |
+
</Tip>
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
path (`str`):
|
| 773 |
+
Path to check.
|
| 774 |
+
|
| 775 |
+
Returns:
|
| 776 |
+
`bool`: True if file exists, False otherwise.
|
| 777 |
+
"""
|
| 778 |
+
try:
|
| 779 |
+
if kwargs.get("refresh", False):
|
| 780 |
+
self.invalidate_cache(path)
|
| 781 |
+
|
| 782 |
+
self.info(path, **{**kwargs, "expand_info": False})
|
| 783 |
+
return True
|
| 784 |
+
except: # noqa: E722
|
| 785 |
+
return False
|
| 786 |
+
|
| 787 |
+
def isdir(self, path):
|
| 788 |
+
"""
|
| 789 |
+
Check if a path is a directory.
|
| 790 |
+
|
| 791 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isdir).
|
| 792 |
+
|
| 793 |
+
Args:
|
| 794 |
+
path (`str`):
|
| 795 |
+
Path to check.
|
| 796 |
+
|
| 797 |
+
Returns:
|
| 798 |
+
`bool`: True if path is a directory, False otherwise.
|
| 799 |
+
"""
|
| 800 |
+
try:
|
| 801 |
+
return self.info(path, expand_info=False)["type"] == "directory"
|
| 802 |
+
except OSError:
|
| 803 |
+
return False
|
| 804 |
+
|
| 805 |
+
def isfile(self, path):
|
| 806 |
+
"""
|
| 807 |
+
Check if a path is a file.
|
| 808 |
+
|
| 809 |
+
For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isfile).
|
| 810 |
+
|
| 811 |
+
Args:
|
| 812 |
+
path (`str`):
|
| 813 |
+
Path to check.
|
| 814 |
+
|
| 815 |
+
Returns:
|
| 816 |
+
`bool`: True if path is a file, False otherwise.
|
| 817 |
+
"""
|
| 818 |
+
try:
|
| 819 |
+
return self.info(path, expand_info=False)["type"] == "file"
|
| 820 |
+
except: # noqa: E722
|
| 821 |
+
return False
|
| 822 |
+
|
| 823 |
+
def url(self, path: str) -> str:
|
| 824 |
+
"""
|
| 825 |
+
Get the HTTP URL of the given path.
|
| 826 |
+
|
| 827 |
+
Args:
|
| 828 |
+
path (`str`):
|
| 829 |
+
Path to get URL for.
|
| 830 |
+
|
| 831 |
+
Returns:
|
| 832 |
+
`str`: HTTP URL to access the file or directory on the Hub.
|
| 833 |
+
"""
|
| 834 |
+
resolved_path = self.resolve_path(path)
|
| 835 |
+
url = hf_hub_url(
|
| 836 |
+
resolved_path.repo_id,
|
| 837 |
+
resolved_path.path_in_repo,
|
| 838 |
+
repo_type=resolved_path.repo_type,
|
| 839 |
+
revision=resolved_path.revision,
|
| 840 |
+
endpoint=self.endpoint,
|
| 841 |
+
)
|
| 842 |
+
if self.isdir(path):
|
| 843 |
+
url = url.replace("/resolve/", "/tree/", 1)
|
| 844 |
+
return url
|
| 845 |
+
|
| 846 |
+
def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None:
|
| 847 |
+
"""
|
| 848 |
+
Copy single remote file to local.
|
| 849 |
+
|
| 850 |
+
<Tip warning={true}>
|
| 851 |
+
|
| 852 |
+
Note: When possible, use `HfApi.hf_hub_download()` for better performance.
|
| 853 |
+
|
| 854 |
+
</Tip>
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
rpath (`str`):
|
| 858 |
+
Remote path to download from.
|
| 859 |
+
lpath (`str`):
|
| 860 |
+
Local path to download to.
|
| 861 |
+
callback (`Callback`, *optional*):
|
| 862 |
+
Optional callback to track download progress. Defaults to no callback.
|
| 863 |
+
outfile (`IO`, *optional*):
|
| 864 |
+
Optional file-like object to write to. If provided, `lpath` is ignored.
|
| 865 |
+
|
| 866 |
+
"""
|
| 867 |
+
revision = kwargs.get("revision")
|
| 868 |
+
unhandled_kwargs = set(kwargs.keys()) - {"revision"}
|
| 869 |
+
if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0:
|
| 870 |
+
# for now, let's not handle custom callbacks
|
| 871 |
+
# and let's not handle custom kwargs
|
| 872 |
+
return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs)
|
| 873 |
+
|
| 874 |
+
# Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883
|
| 875 |
+
if isfilelike(lpath):
|
| 876 |
+
outfile = lpath
|
| 877 |
+
elif self.isdir(rpath):
|
| 878 |
+
os.makedirs(lpath, exist_ok=True)
|
| 879 |
+
return None
|
| 880 |
+
|
| 881 |
+
if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object
|
| 882 |
+
os.makedirs(os.path.dirname(lpath), exist_ok=True)
|
| 883 |
+
|
| 884 |
+
# Open file if not already open
|
| 885 |
+
close_file = False
|
| 886 |
+
if outfile is None:
|
| 887 |
+
outfile = open(lpath, "wb")
|
| 888 |
+
close_file = True
|
| 889 |
+
initial_pos = outfile.tell()
|
| 890 |
+
|
| 891 |
+
# Custom implementation of `get_file` to use `http_get`.
|
| 892 |
+
resolve_remote_path = self.resolve_path(rpath, revision=revision)
|
| 893 |
+
expected_size = self.info(rpath, revision=revision)["size"]
|
| 894 |
+
callback.set_size(expected_size)
|
| 895 |
+
try:
|
| 896 |
+
http_get(
|
| 897 |
+
url=hf_hub_url(
|
| 898 |
+
repo_id=resolve_remote_path.repo_id,
|
| 899 |
+
revision=resolve_remote_path.revision,
|
| 900 |
+
filename=resolve_remote_path.path_in_repo,
|
| 901 |
+
repo_type=resolve_remote_path.repo_type,
|
| 902 |
+
endpoint=self.endpoint,
|
| 903 |
+
),
|
| 904 |
+
temp_file=outfile,
|
| 905 |
+
displayed_filename=rpath,
|
| 906 |
+
expected_size=expected_size,
|
| 907 |
+
resume_size=0,
|
| 908 |
+
headers=self._api._build_hf_headers(),
|
| 909 |
+
_tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None,
|
| 910 |
+
)
|
| 911 |
+
outfile.seek(initial_pos)
|
| 912 |
+
finally:
|
| 913 |
+
# Close file only if we opened it ourselves
|
| 914 |
+
if close_file:
|
| 915 |
+
outfile.close()
|
| 916 |
+
|
| 917 |
+
@property
|
| 918 |
+
def transaction(self):
|
| 919 |
+
"""A context within which files are committed together upon exit
|
| 920 |
+
|
| 921 |
+
Requires the file class to implement `.commit()` and `.discard()`
|
| 922 |
+
for the normal and exception cases.
|
| 923 |
+
"""
|
| 924 |
+
# Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231
|
| 925 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1733
|
| 926 |
+
raise NotImplementedError("Transactional commits are not supported.")
|
| 927 |
+
|
| 928 |
+
def start_transaction(self):
|
| 929 |
+
"""Begin write transaction for deferring files, non-context version"""
|
| 930 |
+
# Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241
|
| 931 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1733
|
| 932 |
+
raise NotImplementedError("Transactional commits are not supported.")
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
| 936 |
+
def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
|
| 937 |
+
try:
|
| 938 |
+
self.resolved_path = fs.resolve_path(path, revision=revision)
|
| 939 |
+
except FileNotFoundError as e:
|
| 940 |
+
if "w" in kwargs.get("mode", ""):
|
| 941 |
+
raise FileNotFoundError(
|
| 942 |
+
f"{e}.\nMake sure the repository and revision exist before writing data."
|
| 943 |
+
) from e
|
| 944 |
+
raise
|
| 945 |
+
# avoid an unnecessary .info() call with expensive expand_info=True to instantiate .details
|
| 946 |
+
if kwargs.get("mode", "rb") == "rb":
|
| 947 |
+
self.details = fs.info(self.resolved_path.unresolve(), expand_info=False)
|
| 948 |
+
super().__init__(fs, self.resolved_path.unresolve(), **kwargs)
|
| 949 |
+
self.fs: HfFileSystem
|
| 950 |
+
|
| 951 |
+
def __del__(self):
|
| 952 |
+
if not hasattr(self, "resolved_path"):
|
| 953 |
+
# Means that the constructor failed. Nothing to do.
|
| 954 |
+
return
|
| 955 |
+
return super().__del__()
|
| 956 |
+
|
| 957 |
+
def _fetch_range(self, start: int, end: int) -> bytes:
|
| 958 |
+
headers = {
|
| 959 |
+
"range": f"bytes={start}-{end - 1}",
|
| 960 |
+
**self.fs._api._build_hf_headers(),
|
| 961 |
+
}
|
| 962 |
+
url = hf_hub_url(
|
| 963 |
+
repo_id=self.resolved_path.repo_id,
|
| 964 |
+
revision=self.resolved_path.revision,
|
| 965 |
+
filename=self.resolved_path.path_in_repo,
|
| 966 |
+
repo_type=self.resolved_path.repo_type,
|
| 967 |
+
endpoint=self.fs.endpoint,
|
| 968 |
+
)
|
| 969 |
+
r = http_backoff(
|
| 970 |
+
"GET",
|
| 971 |
+
url,
|
| 972 |
+
headers=headers,
|
| 973 |
+
retry_on_status_codes=(500, 502, 503, 504),
|
| 974 |
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
| 975 |
+
)
|
| 976 |
+
hf_raise_for_status(r)
|
| 977 |
+
return r.content
|
| 978 |
+
|
| 979 |
+
def _initiate_upload(self) -> None:
|
| 980 |
+
self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False)
|
| 981 |
+
|
| 982 |
+
def _upload_chunk(self, final: bool = False) -> None:
|
| 983 |
+
self.buffer.seek(0)
|
| 984 |
+
block = self.buffer.read()
|
| 985 |
+
self.temp_file.write(block)
|
| 986 |
+
if final:
|
| 987 |
+
self.temp_file.close()
|
| 988 |
+
self.fs._api.upload_file(
|
| 989 |
+
path_or_fileobj=self.temp_file.name,
|
| 990 |
+
path_in_repo=self.resolved_path.path_in_repo,
|
| 991 |
+
repo_id=self.resolved_path.repo_id,
|
| 992 |
+
token=self.fs.token,
|
| 993 |
+
repo_type=self.resolved_path.repo_type,
|
| 994 |
+
revision=self.resolved_path.revision,
|
| 995 |
+
commit_message=self.kwargs.get("commit_message"),
|
| 996 |
+
commit_description=self.kwargs.get("commit_description"),
|
| 997 |
+
)
|
| 998 |
+
os.remove(self.temp_file.name)
|
| 999 |
+
self.fs.invalidate_cache(
|
| 1000 |
+
path=self.resolved_path.unresolve(),
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
def read(self, length=-1):
|
| 1004 |
+
"""Read remote file.
|
| 1005 |
+
|
| 1006 |
+
If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if
|
| 1007 |
+
`hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a
|
| 1008 |
+
temporary file and read from there.
|
| 1009 |
+
"""
|
| 1010 |
+
if self.mode == "rb" and (length is None or length == -1) and self.loc == 0:
|
| 1011 |
+
with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming
|
| 1012 |
+
return f.read()
|
| 1013 |
+
return super().read(length)
|
| 1014 |
+
|
| 1015 |
+
def url(self) -> str:
|
| 1016 |
+
return self.fs.url(self.path)
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
| 1020 |
+
def __init__(
|
| 1021 |
+
self,
|
| 1022 |
+
fs: HfFileSystem,
|
| 1023 |
+
path: str,
|
| 1024 |
+
mode: str = "rb",
|
| 1025 |
+
revision: Optional[str] = None,
|
| 1026 |
+
block_size: int = 0,
|
| 1027 |
+
cache_type: str = "none",
|
| 1028 |
+
**kwargs,
|
| 1029 |
+
):
|
| 1030 |
+
if block_size != 0:
|
| 1031 |
+
raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}")
|
| 1032 |
+
if cache_type != "none":
|
| 1033 |
+
raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}")
|
| 1034 |
+
if "w" in mode:
|
| 1035 |
+
raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'")
|
| 1036 |
+
try:
|
| 1037 |
+
self.resolved_path = fs.resolve_path(path, revision=revision)
|
| 1038 |
+
except FileNotFoundError as e:
|
| 1039 |
+
if "w" in kwargs.get("mode", ""):
|
| 1040 |
+
raise FileNotFoundError(
|
| 1041 |
+
f"{e}.\nMake sure the repository and revision exist before writing data."
|
| 1042 |
+
) from e
|
| 1043 |
+
# avoid an unnecessary .info() call to instantiate .details
|
| 1044 |
+
self.details = {"name": self.resolved_path.unresolve(), "size": None}
|
| 1045 |
+
super().__init__(
|
| 1046 |
+
fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
|
| 1047 |
+
)
|
| 1048 |
+
self.response: Optional[Response] = None
|
| 1049 |
+
self.fs: HfFileSystem
|
| 1050 |
+
|
| 1051 |
+
def seek(self, loc: int, whence: int = 0):
|
| 1052 |
+
if loc == 0 and whence == 1:
|
| 1053 |
+
return
|
| 1054 |
+
if loc == self.loc and whence == 0:
|
| 1055 |
+
return
|
| 1056 |
+
raise ValueError("Cannot seek streaming HF file")
|
| 1057 |
+
|
| 1058 |
+
def read(self, length: int = -1):
|
| 1059 |
+
read_args = (length,) if length >= 0 else ()
|
| 1060 |
+
if self.response is None or self.response.raw.isclosed():
|
| 1061 |
+
url = hf_hub_url(
|
| 1062 |
+
repo_id=self.resolved_path.repo_id,
|
| 1063 |
+
revision=self.resolved_path.revision,
|
| 1064 |
+
filename=self.resolved_path.path_in_repo,
|
| 1065 |
+
repo_type=self.resolved_path.repo_type,
|
| 1066 |
+
endpoint=self.fs.endpoint,
|
| 1067 |
+
)
|
| 1068 |
+
self.response = http_backoff(
|
| 1069 |
+
"GET",
|
| 1070 |
+
url,
|
| 1071 |
+
headers=self.fs._api._build_hf_headers(),
|
| 1072 |
+
retry_on_status_codes=(500, 502, 503, 504),
|
| 1073 |
+
stream=True,
|
| 1074 |
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
| 1075 |
+
)
|
| 1076 |
+
hf_raise_for_status(self.response)
|
| 1077 |
+
try:
|
| 1078 |
+
out = self.response.raw.read(*read_args)
|
| 1079 |
+
except Exception:
|
| 1080 |
+
self.response.close()
|
| 1081 |
+
|
| 1082 |
+
# Retry by recreating the connection
|
| 1083 |
+
url = hf_hub_url(
|
| 1084 |
+
repo_id=self.resolved_path.repo_id,
|
| 1085 |
+
revision=self.resolved_path.revision,
|
| 1086 |
+
filename=self.resolved_path.path_in_repo,
|
| 1087 |
+
repo_type=self.resolved_path.repo_type,
|
| 1088 |
+
endpoint=self.fs.endpoint,
|
| 1089 |
+
)
|
| 1090 |
+
self.response = http_backoff(
|
| 1091 |
+
"GET",
|
| 1092 |
+
url,
|
| 1093 |
+
headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
|
| 1094 |
+
retry_on_status_codes=(500, 502, 503, 504),
|
| 1095 |
+
stream=True,
|
| 1096 |
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
| 1097 |
+
)
|
| 1098 |
+
hf_raise_for_status(self.response)
|
| 1099 |
+
try:
|
| 1100 |
+
out = self.response.raw.read(*read_args)
|
| 1101 |
+
except Exception:
|
| 1102 |
+
self.response.close()
|
| 1103 |
+
raise
|
| 1104 |
+
self.loc += len(out)
|
| 1105 |
+
return out
|
| 1106 |
+
|
| 1107 |
+
def url(self) -> str:
|
| 1108 |
+
return self.fs.url(self.path)
|
| 1109 |
+
|
| 1110 |
+
def __del__(self):
|
| 1111 |
+
if not hasattr(self, "resolved_path"):
|
| 1112 |
+
# Means that the constructor failed. Nothing to do.
|
| 1113 |
+
return
|
| 1114 |
+
return super().__del__()
|
| 1115 |
+
|
| 1116 |
+
def __reduce__(self):
|
| 1117 |
+
return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
def safe_revision(revision: str) -> str:
|
| 1121 |
+
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
def safe_quote(s: str) -> str:
|
| 1125 |
+
return quote(s, safe="")
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
|
| 1129 |
+
msg = path
|
| 1130 |
+
if isinstance(err, RepositoryNotFoundError):
|
| 1131 |
+
msg = f"{path} (repository not found)"
|
| 1132 |
+
elif isinstance(err, RevisionNotFoundError):
|
| 1133 |
+
msg = f"{path} (revision not found)"
|
| 1134 |
+
elif isinstance(err, HFValidationError):
|
| 1135 |
+
msg = f"{path} (invalid repository id)"
|
| 1136 |
+
raise FileNotFoundError(msg) from err
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
|
| 1140 |
+
return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
|
.venv/lib/python3.11/site-packages/huggingface_hub/hub_mixin.py
ADDED
|
@@ -0,0 +1,836 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import Field, asdict, dataclass, is_dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
import packaging.version
|
| 9 |
+
|
| 10 |
+
from . import constants
|
| 11 |
+
from .errors import EntryNotFoundError, HfHubHTTPError
|
| 12 |
+
from .file_download import hf_hub_download
|
| 13 |
+
from .hf_api import HfApi
|
| 14 |
+
from .repocard import ModelCard, ModelCardData
|
| 15 |
+
from .utils import (
|
| 16 |
+
SoftTemporaryDirectory,
|
| 17 |
+
is_jsonable,
|
| 18 |
+
is_safetensors_available,
|
| 19 |
+
is_simple_optional_type,
|
| 20 |
+
is_torch_available,
|
| 21 |
+
logging,
|
| 22 |
+
unwrap_simple_optional_type,
|
| 23 |
+
validate_hf_hub_args,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if is_torch_available():
|
| 28 |
+
import torch # type: ignore
|
| 29 |
+
|
| 30 |
+
if is_safetensors_available():
|
| 31 |
+
import safetensors
|
| 32 |
+
from safetensors.torch import load_model as load_model_as_safetensor
|
| 33 |
+
from safetensors.torch import save_model as save_model_as_safetensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
|
| 40 |
+
class DataclassInstance(Protocol):
|
| 41 |
+
__dataclass_fields__: ClassVar[Dict[str, Field]]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Generic variable that is either ModelHubMixin or a subclass thereof
|
| 45 |
+
T = TypeVar("T", bound="ModelHubMixin")
|
| 46 |
+
# Generic variable to represent an args type
|
| 47 |
+
ARGS_T = TypeVar("ARGS_T")
|
| 48 |
+
ENCODER_T = Callable[[ARGS_T], Any]
|
| 49 |
+
DECODER_T = Callable[[Any], ARGS_T]
|
| 50 |
+
CODER_T = Tuple[ENCODER_T, DECODER_T]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
DEFAULT_MODEL_CARD = """
|
| 54 |
+
---
|
| 55 |
+
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
| 56 |
+
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
| 57 |
+
{{ card_data }}
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
|
| 61 |
+
- Library: {{ repo_url | default("[More Information Needed]", true) }}
|
| 62 |
+
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class MixinInfo:
|
| 68 |
+
model_card_template: str
|
| 69 |
+
model_card_data: ModelCardData
|
| 70 |
+
repo_url: Optional[str] = None
|
| 71 |
+
docs_url: Optional[str] = None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ModelHubMixin:
|
| 75 |
+
"""
|
| 76 |
+
A generic mixin to integrate ANY machine learning framework with the Hub.
|
| 77 |
+
|
| 78 |
+
To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
|
| 79 |
+
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
|
| 80 |
+
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
|
| 81 |
+
|
| 82 |
+
When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
|
| 83 |
+
`__init__` but to the class definition itself. This is useful to define metadata about the library integrating
|
| 84 |
+
[`ModelHubMixin`].
|
| 85 |
+
|
| 86 |
+
For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
repo_url (`str`, *optional*):
|
| 90 |
+
URL of the library repository. Used to generate model card.
|
| 91 |
+
docs_url (`str`, *optional*):
|
| 92 |
+
URL of the library documentation. Used to generate model card.
|
| 93 |
+
model_card_template (`str`, *optional*):
|
| 94 |
+
Template of the model card. Used to generate model card. Defaults to a generic template.
|
| 95 |
+
language (`str` or `List[str]`, *optional*):
|
| 96 |
+
Language supported by the library. Used to generate model card.
|
| 97 |
+
library_name (`str`, *optional*):
|
| 98 |
+
Name of the library integrating ModelHubMixin. Used to generate model card.
|
| 99 |
+
license (`str`, *optional*):
|
| 100 |
+
License of the library integrating ModelHubMixin. Used to generate model card.
|
| 101 |
+
E.g: "apache-2.0"
|
| 102 |
+
license_name (`str`, *optional*):
|
| 103 |
+
Name of the library integrating ModelHubMixin. Used to generate model card.
|
| 104 |
+
Only used if `license` is set to `other`.
|
| 105 |
+
E.g: "coqui-public-model-license".
|
| 106 |
+
license_link (`str`, *optional*):
|
| 107 |
+
URL to the license of the library integrating ModelHubMixin. Used to generate model card.
|
| 108 |
+
Only used if `license` is set to `other` and `license_name` is set.
|
| 109 |
+
E.g: "https://coqui.ai/cpml".
|
| 110 |
+
pipeline_tag (`str`, *optional*):
|
| 111 |
+
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
| 112 |
+
tags (`List[str]`, *optional*):
|
| 113 |
+
Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
|
| 114 |
+
coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
|
| 115 |
+
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
| 116 |
+
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
| 117 |
+
|
| 118 |
+
Example:
|
| 119 |
+
|
| 120 |
+
```python
|
| 121 |
+
>>> from huggingface_hub import ModelHubMixin
|
| 122 |
+
|
| 123 |
+
# Inherit from ModelHubMixin
|
| 124 |
+
>>> class MyCustomModel(
|
| 125 |
+
... ModelHubMixin,
|
| 126 |
+
... library_name="my-library",
|
| 127 |
+
... tags=["x-custom-tag", "arxiv:2304.12244"],
|
| 128 |
+
... repo_url="https://github.com/huggingface/my-cool-library",
|
| 129 |
+
... docs_url="https://huggingface.co/docs/my-cool-library",
|
| 130 |
+
... # ^ optional metadata to generate model card
|
| 131 |
+
... ):
|
| 132 |
+
... def __init__(self, size: int = 512, device: str = "cpu"):
|
| 133 |
+
... # define how to initialize your model
|
| 134 |
+
... super().__init__()
|
| 135 |
+
... ...
|
| 136 |
+
...
|
| 137 |
+
... def _save_pretrained(self, save_directory: Path) -> None:
|
| 138 |
+
... # define how to serialize your model
|
| 139 |
+
... ...
|
| 140 |
+
...
|
| 141 |
+
... @classmethod
|
| 142 |
+
... def from_pretrained(
|
| 143 |
+
... cls: Type[T],
|
| 144 |
+
... pretrained_model_name_or_path: Union[str, Path],
|
| 145 |
+
... *,
|
| 146 |
+
... force_download: bool = False,
|
| 147 |
+
... resume_download: Optional[bool] = None,
|
| 148 |
+
... proxies: Optional[Dict] = None,
|
| 149 |
+
... token: Optional[Union[str, bool]] = None,
|
| 150 |
+
... cache_dir: Optional[Union[str, Path]] = None,
|
| 151 |
+
... local_files_only: bool = False,
|
| 152 |
+
... revision: Optional[str] = None,
|
| 153 |
+
... **model_kwargs,
|
| 154 |
+
... ) -> T:
|
| 155 |
+
... # define how to deserialize your model
|
| 156 |
+
... ...
|
| 157 |
+
|
| 158 |
+
>>> model = MyCustomModel(size=256, device="gpu")
|
| 159 |
+
|
| 160 |
+
# Save model weights to local directory
|
| 161 |
+
>>> model.save_pretrained("my-awesome-model")
|
| 162 |
+
|
| 163 |
+
# Push model weights to the Hub
|
| 164 |
+
>>> model.push_to_hub("my-awesome-model")
|
| 165 |
+
|
| 166 |
+
# Download and initialize weights from the Hub
|
| 167 |
+
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
|
| 168 |
+
>>> reloaded_model.size
|
| 169 |
+
256
|
| 170 |
+
|
| 171 |
+
# Model card has been correctly populated
|
| 172 |
+
>>> from huggingface_hub import ModelCard
|
| 173 |
+
>>> card = ModelCard.load("username/my-awesome-model")
|
| 174 |
+
>>> card.data.tags
|
| 175 |
+
["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
|
| 176 |
+
>>> card.data.library_name
|
| 177 |
+
"my-library"
|
| 178 |
+
```
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
_hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None
|
| 182 |
+
# ^ optional config attribute automatically set in `from_pretrained`
|
| 183 |
+
_hub_mixin_info: MixinInfo
|
| 184 |
+
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
| 185 |
+
_hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
|
| 186 |
+
_hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters
|
| 187 |
+
_hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters
|
| 188 |
+
_hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded
|
| 189 |
+
_hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types
|
| 190 |
+
# ^ internal values to handle config
|
| 191 |
+
|
| 192 |
+
def __init_subclass__(
|
| 193 |
+
cls,
|
| 194 |
+
*,
|
| 195 |
+
# Generic info for model card
|
| 196 |
+
repo_url: Optional[str] = None,
|
| 197 |
+
docs_url: Optional[str] = None,
|
| 198 |
+
# Model card template
|
| 199 |
+
model_card_template: str = DEFAULT_MODEL_CARD,
|
| 200 |
+
# Model card metadata
|
| 201 |
+
language: Optional[List[str]] = None,
|
| 202 |
+
library_name: Optional[str] = None,
|
| 203 |
+
license: Optional[str] = None,
|
| 204 |
+
license_name: Optional[str] = None,
|
| 205 |
+
license_link: Optional[str] = None,
|
| 206 |
+
pipeline_tag: Optional[str] = None,
|
| 207 |
+
tags: Optional[List[str]] = None,
|
| 208 |
+
# How to encode/decode arguments with custom type into a JSON config?
|
| 209 |
+
coders: Optional[
|
| 210 |
+
Dict[Type, CODER_T]
|
| 211 |
+
# Key is a type.
|
| 212 |
+
# Value is a tuple (encoder, decoder).
|
| 213 |
+
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
|
| 214 |
+
] = None,
|
| 215 |
+
) -> None:
|
| 216 |
+
"""Inspect __init__ signature only once when subclassing + handle modelcard."""
|
| 217 |
+
super().__init_subclass__()
|
| 218 |
+
|
| 219 |
+
# Will be reused when creating modelcard
|
| 220 |
+
tags = tags or []
|
| 221 |
+
tags.append("model_hub_mixin")
|
| 222 |
+
|
| 223 |
+
# Initialize MixinInfo if not existent
|
| 224 |
+
info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())
|
| 225 |
+
|
| 226 |
+
# If parent class has a MixinInfo, inherit from it as a copy
|
| 227 |
+
if hasattr(cls, "_hub_mixin_info"):
|
| 228 |
+
# Inherit model card template from parent class if not explicitly set
|
| 229 |
+
if model_card_template == DEFAULT_MODEL_CARD:
|
| 230 |
+
info.model_card_template = cls._hub_mixin_info.model_card_template
|
| 231 |
+
|
| 232 |
+
# Inherit from parent model card data
|
| 233 |
+
info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())
|
| 234 |
+
|
| 235 |
+
# Inherit other info
|
| 236 |
+
info.docs_url = cls._hub_mixin_info.docs_url
|
| 237 |
+
info.repo_url = cls._hub_mixin_info.repo_url
|
| 238 |
+
cls._hub_mixin_info = info
|
| 239 |
+
|
| 240 |
+
# Update MixinInfo with metadata
|
| 241 |
+
if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
|
| 242 |
+
info.model_card_template = model_card_template
|
| 243 |
+
if repo_url is not None:
|
| 244 |
+
info.repo_url = repo_url
|
| 245 |
+
if docs_url is not None:
|
| 246 |
+
info.docs_url = docs_url
|
| 247 |
+
if language is not None:
|
| 248 |
+
info.model_card_data.language = language
|
| 249 |
+
if library_name is not None:
|
| 250 |
+
info.model_card_data.library_name = library_name
|
| 251 |
+
if license is not None:
|
| 252 |
+
info.model_card_data.license = license
|
| 253 |
+
if license_name is not None:
|
| 254 |
+
info.model_card_data.license_name = license_name
|
| 255 |
+
if license_link is not None:
|
| 256 |
+
info.model_card_data.license_link = license_link
|
| 257 |
+
if pipeline_tag is not None:
|
| 258 |
+
info.model_card_data.pipeline_tag = pipeline_tag
|
| 259 |
+
if tags is not None:
|
| 260 |
+
if info.model_card_data.tags is not None:
|
| 261 |
+
info.model_card_data.tags.extend(tags)
|
| 262 |
+
else:
|
| 263 |
+
info.model_card_data.tags = tags
|
| 264 |
+
|
| 265 |
+
info.model_card_data.tags = sorted(set(info.model_card_data.tags))
|
| 266 |
+
|
| 267 |
+
# Handle encoders/decoders for args
|
| 268 |
+
cls._hub_mixin_coders = coders or {}
|
| 269 |
+
cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
|
| 270 |
+
|
| 271 |
+
# Inspect __init__ signature to handle config
|
| 272 |
+
cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
|
| 273 |
+
cls._hub_mixin_jsonable_default_values = {
|
| 274 |
+
param.name: cls._encode_arg(param.default)
|
| 275 |
+
for param in cls._hub_mixin_init_parameters.values()
|
| 276 |
+
if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
|
| 277 |
+
}
|
| 278 |
+
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
| 279 |
+
|
| 280 |
+
def __new__(cls: Type[T], *args, **kwargs) -> T:
|
| 281 |
+
"""Create a new instance of the class and handle config.
|
| 282 |
+
|
| 283 |
+
3 cases:
|
| 284 |
+
- If `self._hub_mixin_config` is already set, do nothing.
|
| 285 |
+
- If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
|
| 286 |
+
- Otherwise, build `self._hub_mixin_config` from default values and passed values.
|
| 287 |
+
"""
|
| 288 |
+
instance = super().__new__(cls)
|
| 289 |
+
|
| 290 |
+
# If `config` is already set, return early
|
| 291 |
+
if instance._hub_mixin_config is not None:
|
| 292 |
+
return instance
|
| 293 |
+
|
| 294 |
+
# Infer passed values
|
| 295 |
+
passed_values = {
|
| 296 |
+
**{
|
| 297 |
+
key: value
|
| 298 |
+
for key, value in zip(
|
| 299 |
+
# [1:] to skip `self` parameter
|
| 300 |
+
list(cls._hub_mixin_init_parameters)[1:],
|
| 301 |
+
args,
|
| 302 |
+
)
|
| 303 |
+
},
|
| 304 |
+
**kwargs,
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
# If config passed as dataclass => set it and return early
|
| 308 |
+
if is_dataclass(passed_values.get("config")):
|
| 309 |
+
instance._hub_mixin_config = passed_values["config"]
|
| 310 |
+
return instance
|
| 311 |
+
|
| 312 |
+
# Otherwise, build config from default + passed values
|
| 313 |
+
init_config = {
|
| 314 |
+
# default values
|
| 315 |
+
**cls._hub_mixin_jsonable_default_values,
|
| 316 |
+
# passed values
|
| 317 |
+
**{
|
| 318 |
+
key: cls._encode_arg(value) # Encode custom types as jsonable value
|
| 319 |
+
for key, value in passed_values.items()
|
| 320 |
+
if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
|
| 321 |
+
},
|
| 322 |
+
}
|
| 323 |
+
passed_config = init_config.pop("config", {})
|
| 324 |
+
|
| 325 |
+
# Populate `init_config` with provided config
|
| 326 |
+
if isinstance(passed_config, dict):
|
| 327 |
+
init_config.update(passed_config)
|
| 328 |
+
|
| 329 |
+
# Set `config` attribute and return
|
| 330 |
+
if init_config != {}:
|
| 331 |
+
instance._hub_mixin_config = init_config
|
| 332 |
+
return instance
|
| 333 |
+
|
| 334 |
+
@classmethod
|
| 335 |
+
def _is_jsonable(cls, value: Any) -> bool:
|
| 336 |
+
"""Check if a value is JSON serializable."""
|
| 337 |
+
if isinstance(value, cls._hub_mixin_jsonable_custom_types):
|
| 338 |
+
return True
|
| 339 |
+
return is_jsonable(value)
|
| 340 |
+
|
| 341 |
+
@classmethod
|
| 342 |
+
def _encode_arg(cls, arg: Any) -> Any:
|
| 343 |
+
"""Encode an argument into a JSON serializable format."""
|
| 344 |
+
for type_, (encoder, _) in cls._hub_mixin_coders.items():
|
| 345 |
+
if isinstance(arg, type_):
|
| 346 |
+
if arg is None:
|
| 347 |
+
return None
|
| 348 |
+
return encoder(arg)
|
| 349 |
+
return arg
|
| 350 |
+
|
| 351 |
+
@classmethod
|
| 352 |
+
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
|
| 353 |
+
"""Decode a JSON serializable value into an argument."""
|
| 354 |
+
if is_simple_optional_type(expected_type):
|
| 355 |
+
if value is None:
|
| 356 |
+
return None
|
| 357 |
+
expected_type = unwrap_simple_optional_type(expected_type)
|
| 358 |
+
# Dataclass => handle it
|
| 359 |
+
if is_dataclass(expected_type):
|
| 360 |
+
return _load_dataclass(expected_type, value) # type: ignore[return-value]
|
| 361 |
+
# Otherwise => check custom decoders
|
| 362 |
+
for type_, (_, decoder) in cls._hub_mixin_coders.items():
|
| 363 |
+
if inspect.isclass(expected_type) and issubclass(expected_type, type_):
|
| 364 |
+
return decoder(value)
|
| 365 |
+
# Otherwise => don't decode
|
| 366 |
+
return value
|
| 367 |
+
|
| 368 |
+
def save_pretrained(
|
| 369 |
+
self,
|
| 370 |
+
save_directory: Union[str, Path],
|
| 371 |
+
*,
|
| 372 |
+
config: Optional[Union[dict, DataclassInstance]] = None,
|
| 373 |
+
repo_id: Optional[str] = None,
|
| 374 |
+
push_to_hub: bool = False,
|
| 375 |
+
model_card_kwargs: Optional[Dict[str, Any]] = None,
|
| 376 |
+
**push_to_hub_kwargs,
|
| 377 |
+
) -> Optional[str]:
|
| 378 |
+
"""
|
| 379 |
+
Save weights in local directory.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
save_directory (`str` or `Path`):
|
| 383 |
+
Path to directory in which the model weights and configuration will be saved.
|
| 384 |
+
config (`dict` or `DataclassInstance`, *optional*):
|
| 385 |
+
Model configuration specified as a key/value dictionary or a dataclass instance.
|
| 386 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 387 |
+
Whether or not to push your model to the Huggingface Hub after saving it.
|
| 388 |
+
repo_id (`str`, *optional*):
|
| 389 |
+
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
| 390 |
+
not provided.
|
| 391 |
+
model_card_kwargs (`Dict[str, Any]`, *optional*):
|
| 392 |
+
Additional arguments passed to the model card template to customize the model card.
|
| 393 |
+
push_to_hub_kwargs:
|
| 394 |
+
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
| 395 |
+
Returns:
|
| 396 |
+
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
|
| 397 |
+
"""
|
| 398 |
+
save_directory = Path(save_directory)
|
| 399 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 400 |
+
|
| 401 |
+
# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
|
| 402 |
+
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
|
| 403 |
+
# an existing config.json if it was not saved by `_save_pretrained`.
|
| 404 |
+
config_path = save_directory / constants.CONFIG_NAME
|
| 405 |
+
config_path.unlink(missing_ok=True)
|
| 406 |
+
|
| 407 |
+
# save model weights/files (framework-specific)
|
| 408 |
+
self._save_pretrained(save_directory)
|
| 409 |
+
|
| 410 |
+
# save config (if provided and if not serialized yet in `_save_pretrained`)
|
| 411 |
+
if config is None:
|
| 412 |
+
config = self._hub_mixin_config
|
| 413 |
+
if config is not None:
|
| 414 |
+
if is_dataclass(config):
|
| 415 |
+
config = asdict(config) # type: ignore[arg-type]
|
| 416 |
+
if not config_path.exists():
|
| 417 |
+
config_str = json.dumps(config, sort_keys=True, indent=2)
|
| 418 |
+
config_path.write_text(config_str)
|
| 419 |
+
|
| 420 |
+
# save model card
|
| 421 |
+
model_card_path = save_directory / "README.md"
|
| 422 |
+
model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
|
| 423 |
+
if not model_card_path.exists(): # do not overwrite if already exists
|
| 424 |
+
self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")
|
| 425 |
+
|
| 426 |
+
# push to the Hub if required
|
| 427 |
+
if push_to_hub:
|
| 428 |
+
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
|
| 429 |
+
if config is not None: # kwarg for `push_to_hub`
|
| 430 |
+
kwargs["config"] = config
|
| 431 |
+
if repo_id is None:
|
| 432 |
+
repo_id = save_directory.name # Defaults to `save_directory` name
|
| 433 |
+
return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
|
| 434 |
+
return None
|
| 435 |
+
|
| 436 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
| 437 |
+
"""
|
| 438 |
+
Overwrite this method in subclass to define how to save your model.
|
| 439 |
+
Check out our [integration guide](../guides/integrations) for instructions.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
save_directory (`str` or `Path`):
|
| 443 |
+
Path to directory in which the model weights and configuration will be saved.
|
| 444 |
+
"""
|
| 445 |
+
raise NotImplementedError
|
| 446 |
+
|
| 447 |
+
@classmethod
|
| 448 |
+
@validate_hf_hub_args
|
| 449 |
+
def from_pretrained(
|
| 450 |
+
cls: Type[T],
|
| 451 |
+
pretrained_model_name_or_path: Union[str, Path],
|
| 452 |
+
*,
|
| 453 |
+
force_download: bool = False,
|
| 454 |
+
resume_download: Optional[bool] = None,
|
| 455 |
+
proxies: Optional[Dict] = None,
|
| 456 |
+
token: Optional[Union[str, bool]] = None,
|
| 457 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
| 458 |
+
local_files_only: bool = False,
|
| 459 |
+
revision: Optional[str] = None,
|
| 460 |
+
**model_kwargs,
|
| 461 |
+
) -> T:
|
| 462 |
+
"""
|
| 463 |
+
Download a model from the Huggingface Hub and instantiate it.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
pretrained_model_name_or_path (`str`, `Path`):
|
| 467 |
+
- Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
|
| 468 |
+
- Or a path to a `directory` containing model weights saved using
|
| 469 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
|
| 470 |
+
revision (`str`, *optional*):
|
| 471 |
+
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
|
| 472 |
+
Defaults to the latest commit on `main` branch.
|
| 473 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 474 |
+
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
| 475 |
+
the existing cache.
|
| 476 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 477 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 478 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
| 479 |
+
token (`str` or `bool`, *optional*):
|
| 480 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
| 481 |
+
cached when running `huggingface-cli login`.
|
| 482 |
+
cache_dir (`str`, `Path`, *optional*):
|
| 483 |
+
Path to the folder where cached files are stored.
|
| 484 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 485 |
+
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
| 486 |
+
model_kwargs (`Dict`, *optional*):
|
| 487 |
+
Additional kwargs to pass to the model during initialization.
|
| 488 |
+
"""
|
| 489 |
+
model_id = str(pretrained_model_name_or_path)
|
| 490 |
+
config_file: Optional[str] = None
|
| 491 |
+
if os.path.isdir(model_id):
|
| 492 |
+
if constants.CONFIG_NAME in os.listdir(model_id):
|
| 493 |
+
config_file = os.path.join(model_id, constants.CONFIG_NAME)
|
| 494 |
+
else:
|
| 495 |
+
logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
| 496 |
+
else:
|
| 497 |
+
try:
|
| 498 |
+
config_file = hf_hub_download(
|
| 499 |
+
repo_id=model_id,
|
| 500 |
+
filename=constants.CONFIG_NAME,
|
| 501 |
+
revision=revision,
|
| 502 |
+
cache_dir=cache_dir,
|
| 503 |
+
force_download=force_download,
|
| 504 |
+
proxies=proxies,
|
| 505 |
+
resume_download=resume_download,
|
| 506 |
+
token=token,
|
| 507 |
+
local_files_only=local_files_only,
|
| 508 |
+
)
|
| 509 |
+
except HfHubHTTPError as e:
|
| 510 |
+
logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
|
| 511 |
+
|
| 512 |
+
# Read config
|
| 513 |
+
config = None
|
| 514 |
+
if config_file is not None:
|
| 515 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 516 |
+
config = json.load(f)
|
| 517 |
+
|
| 518 |
+
# Decode custom types in config
|
| 519 |
+
for key, value in config.items():
|
| 520 |
+
if key in cls._hub_mixin_init_parameters:
|
| 521 |
+
expected_type = cls._hub_mixin_init_parameters[key].annotation
|
| 522 |
+
if expected_type is not inspect.Parameter.empty:
|
| 523 |
+
config[key] = cls._decode_arg(expected_type, value)
|
| 524 |
+
|
| 525 |
+
# Populate model_kwargs from config
|
| 526 |
+
for param in cls._hub_mixin_init_parameters.values():
|
| 527 |
+
if param.name not in model_kwargs and param.name in config:
|
| 528 |
+
model_kwargs[param.name] = config[param.name]
|
| 529 |
+
|
| 530 |
+
# Check if `config` argument was passed at init
|
| 531 |
+
if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
|
| 532 |
+
# Decode `config` argument if it was passed
|
| 533 |
+
config_annotation = cls._hub_mixin_init_parameters["config"].annotation
|
| 534 |
+
config = cls._decode_arg(config_annotation, config)
|
| 535 |
+
|
| 536 |
+
# Forward config to model initialization
|
| 537 |
+
model_kwargs["config"] = config
|
| 538 |
+
|
| 539 |
+
# Inject config if `**kwargs` are expected
|
| 540 |
+
if is_dataclass(cls):
|
| 541 |
+
for key in cls.__dataclass_fields__:
|
| 542 |
+
if key not in model_kwargs and key in config:
|
| 543 |
+
model_kwargs[key] = config[key]
|
| 544 |
+
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
|
| 545 |
+
for key, value in config.items():
|
| 546 |
+
if key not in model_kwargs:
|
| 547 |
+
model_kwargs[key] = value
|
| 548 |
+
|
| 549 |
+
# Finally, also inject if `_from_pretrained` expects it
|
| 550 |
+
if cls._hub_mixin_inject_config and "config" not in model_kwargs:
|
| 551 |
+
model_kwargs["config"] = config
|
| 552 |
+
|
| 553 |
+
instance = cls._from_pretrained(
|
| 554 |
+
model_id=str(model_id),
|
| 555 |
+
revision=revision,
|
| 556 |
+
cache_dir=cache_dir,
|
| 557 |
+
force_download=force_download,
|
| 558 |
+
proxies=proxies,
|
| 559 |
+
resume_download=resume_download,
|
| 560 |
+
local_files_only=local_files_only,
|
| 561 |
+
token=token,
|
| 562 |
+
**model_kwargs,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Implicitly set the config as instance attribute if not already set by the class
|
| 566 |
+
# This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
|
| 567 |
+
if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
|
| 568 |
+
instance._hub_mixin_config = config
|
| 569 |
+
|
| 570 |
+
return instance
|
| 571 |
+
|
| 572 |
+
@classmethod
|
| 573 |
+
def _from_pretrained(
|
| 574 |
+
cls: Type[T],
|
| 575 |
+
*,
|
| 576 |
+
model_id: str,
|
| 577 |
+
revision: Optional[str],
|
| 578 |
+
cache_dir: Optional[Union[str, Path]],
|
| 579 |
+
force_download: bool,
|
| 580 |
+
proxies: Optional[Dict],
|
| 581 |
+
resume_download: Optional[bool],
|
| 582 |
+
local_files_only: bool,
|
| 583 |
+
token: Optional[Union[str, bool]],
|
| 584 |
+
**model_kwargs,
|
| 585 |
+
) -> T:
|
| 586 |
+
"""Overwrite this method in subclass to define how to load your model from pretrained.
|
| 587 |
+
|
| 588 |
+
Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
|
| 589 |
+
args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
|
| 590 |
+
method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
|
| 591 |
+
parameter to set on which device the model should be loaded.
|
| 592 |
+
|
| 593 |
+
Check out our [integration guide](../guides/integrations) for more instructions.
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
model_id (`str`):
|
| 597 |
+
ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
|
| 598 |
+
revision (`str`, *optional*):
|
| 599 |
+
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
|
| 600 |
+
latest commit on `main` branch.
|
| 601 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 602 |
+
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
| 603 |
+
the existing cache.
|
| 604 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 605 |
+
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
|
| 606 |
+
'http://hostname': 'foo.bar:4012'}`).
|
| 607 |
+
token (`str` or `bool`, *optional*):
|
| 608 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
| 609 |
+
cached when running `huggingface-cli login`.
|
| 610 |
+
cache_dir (`str`, `Path`, *optional*):
|
| 611 |
+
Path to the folder where cached files are stored.
|
| 612 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 613 |
+
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
| 614 |
+
model_kwargs:
|
| 615 |
+
Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
|
| 616 |
+
"""
|
| 617 |
+
raise NotImplementedError
|
| 618 |
+
|
| 619 |
+
@validate_hf_hub_args
|
| 620 |
+
def push_to_hub(
|
| 621 |
+
self,
|
| 622 |
+
repo_id: str,
|
| 623 |
+
*,
|
| 624 |
+
config: Optional[Union[dict, DataclassInstance]] = None,
|
| 625 |
+
commit_message: str = "Push model using huggingface_hub.",
|
| 626 |
+
private: Optional[bool] = None,
|
| 627 |
+
token: Optional[str] = None,
|
| 628 |
+
branch: Optional[str] = None,
|
| 629 |
+
create_pr: Optional[bool] = None,
|
| 630 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
| 631 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
| 632 |
+
delete_patterns: Optional[Union[List[str], str]] = None,
|
| 633 |
+
model_card_kwargs: Optional[Dict[str, Any]] = None,
|
| 634 |
+
) -> str:
|
| 635 |
+
"""
|
| 636 |
+
Upload model checkpoint to the Hub.
|
| 637 |
+
|
| 638 |
+
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
| 639 |
+
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
| 640 |
+
details.
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
repo_id (`str`):
|
| 644 |
+
ID of the repository to push to (example: `"username/my-model"`).
|
| 645 |
+
config (`dict` or `DataclassInstance`, *optional*):
|
| 646 |
+
Model configuration specified as a key/value dictionary or a dataclass instance.
|
| 647 |
+
commit_message (`str`, *optional*):
|
| 648 |
+
Message to commit while pushing.
|
| 649 |
+
private (`bool`, *optional*):
|
| 650 |
+
Whether the repository created should be private.
|
| 651 |
+
If `None` (default), the repo will be public unless the organization's default is private.
|
| 652 |
+
token (`str`, *optional*):
|
| 653 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
| 654 |
+
cached when running `huggingface-cli login`.
|
| 655 |
+
branch (`str`, *optional*):
|
| 656 |
+
The git branch on which to push the model. This defaults to `"main"`.
|
| 657 |
+
create_pr (`boolean`, *optional*):
|
| 658 |
+
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
| 659 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
| 660 |
+
If provided, only files matching at least one pattern are pushed.
|
| 661 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
| 662 |
+
If provided, files matching any of the patterns are not pushed.
|
| 663 |
+
delete_patterns (`List[str]` or `str`, *optional*):
|
| 664 |
+
If provided, remote files matching any of the patterns will be deleted from the repo.
|
| 665 |
+
model_card_kwargs (`Dict[str, Any]`, *optional*):
|
| 666 |
+
Additional arguments passed to the model card template to customize the model card.
|
| 667 |
+
|
| 668 |
+
Returns:
|
| 669 |
+
The url of the commit of your model in the given repository.
|
| 670 |
+
"""
|
| 671 |
+
api = HfApi(token=token)
|
| 672 |
+
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
| 673 |
+
|
| 674 |
+
# Push the files to the repo in a single commit
|
| 675 |
+
with SoftTemporaryDirectory() as tmp:
|
| 676 |
+
saved_path = Path(tmp) / repo_id
|
| 677 |
+
self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
|
| 678 |
+
return api.upload_folder(
|
| 679 |
+
repo_id=repo_id,
|
| 680 |
+
repo_type="model",
|
| 681 |
+
folder_path=saved_path,
|
| 682 |
+
commit_message=commit_message,
|
| 683 |
+
revision=branch,
|
| 684 |
+
create_pr=create_pr,
|
| 685 |
+
allow_patterns=allow_patterns,
|
| 686 |
+
ignore_patterns=ignore_patterns,
|
| 687 |
+
delete_patterns=delete_patterns,
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
| 691 |
+
card = ModelCard.from_template(
|
| 692 |
+
card_data=self._hub_mixin_info.model_card_data,
|
| 693 |
+
template_str=self._hub_mixin_info.model_card_template,
|
| 694 |
+
repo_url=self._hub_mixin_info.repo_url,
|
| 695 |
+
docs_url=self._hub_mixin_info.docs_url,
|
| 696 |
+
**kwargs,
|
| 697 |
+
)
|
| 698 |
+
return card
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
class PyTorchModelHubMixin(ModelHubMixin):
|
| 702 |
+
"""
|
| 703 |
+
Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
|
| 704 |
+
is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
|
| 705 |
+
you should first set it back in training mode with `model.train()`.
|
| 706 |
+
|
| 707 |
+
See [`ModelHubMixin`] for more details on how to use the mixin.
|
| 708 |
+
|
| 709 |
+
Example:
|
| 710 |
+
|
| 711 |
+
```python
|
| 712 |
+
>>> import torch
|
| 713 |
+
>>> import torch.nn as nn
|
| 714 |
+
>>> from huggingface_hub import PyTorchModelHubMixin
|
| 715 |
+
|
| 716 |
+
>>> class MyModel(
|
| 717 |
+
... nn.Module,
|
| 718 |
+
... PyTorchModelHubMixin,
|
| 719 |
+
... library_name="keras-nlp",
|
| 720 |
+
... repo_url="https://github.com/keras-team/keras-nlp",
|
| 721 |
+
... docs_url="https://keras.io/keras_nlp/",
|
| 722 |
+
... # ^ optional metadata to generate model card
|
| 723 |
+
... ):
|
| 724 |
+
... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
|
| 725 |
+
... super().__init__()
|
| 726 |
+
... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
|
| 727 |
+
... self.linear = nn.Linear(output_size, vocab_size)
|
| 728 |
+
|
| 729 |
+
... def forward(self, x):
|
| 730 |
+
... return self.linear(x + self.param)
|
| 731 |
+
>>> model = MyModel(hidden_size=256)
|
| 732 |
+
|
| 733 |
+
# Save model weights to local directory
|
| 734 |
+
>>> model.save_pretrained("my-awesome-model")
|
| 735 |
+
|
| 736 |
+
# Push model weights to the Hub
|
| 737 |
+
>>> model.push_to_hub("my-awesome-model")
|
| 738 |
+
|
| 739 |
+
# Download and initialize weights from the Hub
|
| 740 |
+
>>> model = MyModel.from_pretrained("username/my-awesome-model")
|
| 741 |
+
>>> model.hidden_size
|
| 742 |
+
256
|
| 743 |
+
```
|
| 744 |
+
"""
|
| 745 |
+
|
| 746 |
+
def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
|
| 747 |
+
tags = tags or []
|
| 748 |
+
tags.append("pytorch_model_hub_mixin")
|
| 749 |
+
kwargs["tags"] = tags
|
| 750 |
+
return super().__init_subclass__(*args, **kwargs)
|
| 751 |
+
|
| 752 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
| 753 |
+
"""Save weights from a Pytorch model to a local directory."""
|
| 754 |
+
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
| 755 |
+
save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE))
|
| 756 |
+
|
| 757 |
+
@classmethod
|
| 758 |
+
def _from_pretrained(
|
| 759 |
+
cls,
|
| 760 |
+
*,
|
| 761 |
+
model_id: str,
|
| 762 |
+
revision: Optional[str],
|
| 763 |
+
cache_dir: Optional[Union[str, Path]],
|
| 764 |
+
force_download: bool,
|
| 765 |
+
proxies: Optional[Dict],
|
| 766 |
+
resume_download: Optional[bool],
|
| 767 |
+
local_files_only: bool,
|
| 768 |
+
token: Union[str, bool, None],
|
| 769 |
+
map_location: str = "cpu",
|
| 770 |
+
strict: bool = False,
|
| 771 |
+
**model_kwargs,
|
| 772 |
+
):
|
| 773 |
+
"""Load Pytorch pretrained weights and return the loaded model."""
|
| 774 |
+
model = cls(**model_kwargs)
|
| 775 |
+
if os.path.isdir(model_id):
|
| 776 |
+
print("Loading weights from local directory")
|
| 777 |
+
model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE)
|
| 778 |
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
| 779 |
+
else:
|
| 780 |
+
try:
|
| 781 |
+
model_file = hf_hub_download(
|
| 782 |
+
repo_id=model_id,
|
| 783 |
+
filename=constants.SAFETENSORS_SINGLE_FILE,
|
| 784 |
+
revision=revision,
|
| 785 |
+
cache_dir=cache_dir,
|
| 786 |
+
force_download=force_download,
|
| 787 |
+
proxies=proxies,
|
| 788 |
+
resume_download=resume_download,
|
| 789 |
+
token=token,
|
| 790 |
+
local_files_only=local_files_only,
|
| 791 |
+
)
|
| 792 |
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
| 793 |
+
except EntryNotFoundError:
|
| 794 |
+
model_file = hf_hub_download(
|
| 795 |
+
repo_id=model_id,
|
| 796 |
+
filename=constants.PYTORCH_WEIGHTS_NAME,
|
| 797 |
+
revision=revision,
|
| 798 |
+
cache_dir=cache_dir,
|
| 799 |
+
force_download=force_download,
|
| 800 |
+
proxies=proxies,
|
| 801 |
+
resume_download=resume_download,
|
| 802 |
+
token=token,
|
| 803 |
+
local_files_only=local_files_only,
|
| 804 |
+
)
|
| 805 |
+
return cls._load_as_pickle(model, model_file, map_location, strict)
|
| 806 |
+
|
| 807 |
+
@classmethod
|
| 808 |
+
def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
| 809 |
+
state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True)
|
| 810 |
+
model.load_state_dict(state_dict, strict=strict) # type: ignore
|
| 811 |
+
model.eval() # type: ignore
|
| 812 |
+
return model
|
| 813 |
+
|
| 814 |
+
@classmethod
|
| 815 |
+
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
| 816 |
+
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined]
|
| 817 |
+
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
|
| 818 |
+
if map_location != "cpu":
|
| 819 |
+
logger.warning(
|
| 820 |
+
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
| 821 |
+
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
| 822 |
+
" This leads to a slower loading time."
|
| 823 |
+
" Please update safetensors to version 0.4.3 or above for improved performance."
|
| 824 |
+
)
|
| 825 |
+
model.to(map_location) # type: ignore [attr-defined]
|
| 826 |
+
else:
|
| 827 |
+
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type]
|
| 828 |
+
return model
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
|
| 832 |
+
"""Load a dataclass instance from a dictionary.
|
| 833 |
+
|
| 834 |
+
Fields not expected by the dataclass are ignored.
|
| 835 |
+
"""
|
| 836 |
+
return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_client.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_common.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Contains utilities used by both the sync and async inference clients."""
|
| 16 |
+
|
| 17 |
+
import base64
|
| 18 |
+
import io
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
from abc import ABC, abstractmethod
|
| 22 |
+
from contextlib import contextmanager
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import (
|
| 26 |
+
TYPE_CHECKING,
|
| 27 |
+
Any,
|
| 28 |
+
AsyncIterable,
|
| 29 |
+
BinaryIO,
|
| 30 |
+
ContextManager,
|
| 31 |
+
Dict,
|
| 32 |
+
Generator,
|
| 33 |
+
Iterable,
|
| 34 |
+
List,
|
| 35 |
+
Literal,
|
| 36 |
+
NoReturn,
|
| 37 |
+
Optional,
|
| 38 |
+
Union,
|
| 39 |
+
overload,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
from requests import HTTPError
|
| 43 |
+
|
| 44 |
+
from huggingface_hub.errors import (
|
| 45 |
+
GenerationError,
|
| 46 |
+
IncompleteGenerationError,
|
| 47 |
+
OverloadedError,
|
| 48 |
+
TextGenerationError,
|
| 49 |
+
UnknownError,
|
| 50 |
+
ValidationError,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
from ..utils import (
|
| 54 |
+
get_session,
|
| 55 |
+
is_aiohttp_available,
|
| 56 |
+
is_numpy_available,
|
| 57 |
+
is_pillow_available,
|
| 58 |
+
)
|
| 59 |
+
from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if TYPE_CHECKING:
|
| 63 |
+
from aiohttp import ClientResponse, ClientSession
|
| 64 |
+
from PIL.Image import Image
|
| 65 |
+
|
| 66 |
+
# TYPES
|
| 67 |
+
UrlT = str
|
| 68 |
+
PathT = Union[str, Path]
|
| 69 |
+
BinaryT = Union[bytes, BinaryIO]
|
| 70 |
+
ContentT = Union[BinaryT, PathT, UrlT]
|
| 71 |
+
|
| 72 |
+
# Use to set a Accept: image/png header
|
| 73 |
+
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
|
| 74 |
+
|
| 75 |
+
logger = logging.getLogger(__name__)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class RequestParameters:
|
| 80 |
+
url: str
|
| 81 |
+
task: str
|
| 82 |
+
model: Optional[str]
|
| 83 |
+
json: Optional[Union[str, Dict, List]]
|
| 84 |
+
data: Optional[ContentT]
|
| 85 |
+
headers: Dict[str, Any]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TaskProviderHelper(ABC):
|
| 89 |
+
"""Protocol defining the interface for task-specific provider helpers."""
|
| 90 |
+
|
| 91 |
+
@abstractmethod
|
| 92 |
+
def prepare_request(
|
| 93 |
+
self,
|
| 94 |
+
*,
|
| 95 |
+
inputs: Any,
|
| 96 |
+
parameters: Dict[str, Any],
|
| 97 |
+
headers: Dict,
|
| 98 |
+
model: Optional[str],
|
| 99 |
+
api_key: Optional[str],
|
| 100 |
+
extra_payload: Optional[Dict[str, Any]] = None,
|
| 101 |
+
) -> RequestParameters: ...
|
| 102 |
+
@abstractmethod
|
| 103 |
+
def get_response(self, response: Union[bytes, Dict]) -> Any: ...
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
|
| 107 |
+
@dataclass
|
| 108 |
+
class ModelStatus:
|
| 109 |
+
"""
|
| 110 |
+
This Dataclass represents the model status in the Hugging Face Inference API.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
loaded (`bool`):
|
| 114 |
+
If the model is currently loaded into Hugging Face's InferenceAPI. Models
|
| 115 |
+
are loaded on-demand, leading to the user's first request taking longer.
|
| 116 |
+
If a model is loaded, you can be assured that it is in a healthy state.
|
| 117 |
+
state (`str`):
|
| 118 |
+
The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'.
|
| 119 |
+
If a model's state is 'Loadable', it's not too big and has a supported
|
| 120 |
+
backend. Loadable models are automatically loaded when the user first
|
| 121 |
+
requests inference on the endpoint. This means it is transparent for the
|
| 122 |
+
user to load a model, except that the first call takes longer to complete.
|
| 123 |
+
compute_type (`Dict`):
|
| 124 |
+
Information about the compute resource the model is using or will use, such as 'gpu' type and number of
|
| 125 |
+
replicas.
|
| 126 |
+
framework (`str`):
|
| 127 |
+
The name of the framework that the model was built with, such as 'transformers'
|
| 128 |
+
or 'text-generation-inference'.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
loaded: bool
|
| 132 |
+
state: str
|
| 133 |
+
compute_type: Dict
|
| 134 |
+
framework: str
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
## IMPORT UTILS
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _import_aiohttp():
|
| 141 |
+
# Make sure `aiohttp` is installed on the machine.
|
| 142 |
+
if not is_aiohttp_available():
|
| 143 |
+
raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).")
|
| 144 |
+
import aiohttp
|
| 145 |
+
|
| 146 |
+
return aiohttp
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _import_numpy():
|
| 150 |
+
"""Make sure `numpy` is installed on the machine."""
|
| 151 |
+
if not is_numpy_available():
|
| 152 |
+
raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).")
|
| 153 |
+
import numpy
|
| 154 |
+
|
| 155 |
+
return numpy
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _import_pil_image():
|
| 159 |
+
"""Make sure `PIL` is installed on the machine."""
|
| 160 |
+
if not is_pillow_available():
|
| 161 |
+
raise ImportError(
|
| 162 |
+
"Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be"
|
| 163 |
+
" post-processed, use `client.post(...)` and get the raw response from the server."
|
| 164 |
+
)
|
| 165 |
+
from PIL import Image
|
| 166 |
+
|
| 167 |
+
return Image
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
## ENCODING / DECODING UTILS
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@overload
|
| 174 |
+
def _open_as_binary(
|
| 175 |
+
content: ContentT,
|
| 176 |
+
) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@overload
|
| 180 |
+
def _open_as_binary(
|
| 181 |
+
content: Literal[None],
|
| 182 |
+
) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@contextmanager # type: ignore
|
| 186 |
+
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
|
| 187 |
+
"""Open `content` as a binary file, either from a URL, a local path, or raw bytes.
|
| 188 |
+
|
| 189 |
+
Do nothing if `content` is None,
|
| 190 |
+
|
| 191 |
+
TODO: handle a PIL.Image as input
|
| 192 |
+
TODO: handle base64 as input
|
| 193 |
+
"""
|
| 194 |
+
# If content is a string => must be either a URL or a path
|
| 195 |
+
if isinstance(content, str):
|
| 196 |
+
if content.startswith("https://") or content.startswith("http://"):
|
| 197 |
+
logger.debug(f"Downloading content from {content}")
|
| 198 |
+
yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ?
|
| 199 |
+
return
|
| 200 |
+
content = Path(content)
|
| 201 |
+
if not content.exists():
|
| 202 |
+
raise FileNotFoundError(
|
| 203 |
+
f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local"
|
| 204 |
+
" file. To pass raw content, please encode it as bytes first."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# If content is a Path => open it
|
| 208 |
+
if isinstance(content, Path):
|
| 209 |
+
logger.debug(f"Opening content from {content}")
|
| 210 |
+
with content.open("rb") as f:
|
| 211 |
+
yield f
|
| 212 |
+
else:
|
| 213 |
+
# Otherwise: already a file-like object or None
|
| 214 |
+
yield content
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _b64_encode(content: ContentT) -> str:
|
| 218 |
+
"""Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
|
| 219 |
+
with _open_as_binary(content) as data:
|
| 220 |
+
data_as_bytes = data if isinstance(data, bytes) else data.read()
|
| 221 |
+
return base64.b64encode(data_as_bytes).decode()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _b64_to_image(encoded_image: str) -> "Image":
|
| 225 |
+
"""Parse a base64-encoded string into a PIL Image."""
|
| 226 |
+
Image = _import_pil_image()
|
| 227 |
+
return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _bytes_to_list(content: bytes) -> List:
|
| 231 |
+
"""Parse bytes from a Response object into a Python list.
|
| 232 |
+
|
| 233 |
+
Expects the response body to be JSON-encoded data.
|
| 234 |
+
|
| 235 |
+
NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
|
| 236 |
+
dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
|
| 237 |
+
"""
|
| 238 |
+
return json.loads(content.decode())
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _bytes_to_dict(content: bytes) -> Dict:
|
| 242 |
+
"""Parse bytes from a Response object into a Python dictionary.
|
| 243 |
+
|
| 244 |
+
Expects the response body to be JSON-encoded data.
|
| 245 |
+
|
| 246 |
+
NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
|
| 247 |
+
list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
|
| 248 |
+
"""
|
| 249 |
+
return json.loads(content.decode())
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _bytes_to_image(content: bytes) -> "Image":
|
| 253 |
+
"""Parse bytes from a Response object into a PIL Image.
|
| 254 |
+
|
| 255 |
+
Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead.
|
| 256 |
+
"""
|
| 257 |
+
Image = _import_pil_image()
|
| 258 |
+
return Image.open(io.BytesIO(content))
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _as_dict(response: Union[bytes, Dict]) -> Dict:
|
| 262 |
+
return json.loads(response) if isinstance(response, bytes) else response
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
## PAYLOAD UTILS
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
## STREAMING UTILS
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _stream_text_generation_response(
|
| 272 |
+
bytes_output_as_lines: Iterable[bytes], details: bool
|
| 273 |
+
) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
|
| 274 |
+
"""Used in `InferenceClient.text_generation`."""
|
| 275 |
+
# Parse ServerSentEvents
|
| 276 |
+
for byte_payload in bytes_output_as_lines:
|
| 277 |
+
try:
|
| 278 |
+
output = _format_text_generation_stream_output(byte_payload, details)
|
| 279 |
+
except StopIteration:
|
| 280 |
+
break
|
| 281 |
+
if output is not None:
|
| 282 |
+
yield output
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
async def _async_stream_text_generation_response(
|
| 286 |
+
bytes_output_as_lines: AsyncIterable[bytes], details: bool
|
| 287 |
+
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
|
| 288 |
+
"""Used in `AsyncInferenceClient.text_generation`."""
|
| 289 |
+
# Parse ServerSentEvents
|
| 290 |
+
async for byte_payload in bytes_output_as_lines:
|
| 291 |
+
try:
|
| 292 |
+
output = _format_text_generation_stream_output(byte_payload, details)
|
| 293 |
+
except StopIteration:
|
| 294 |
+
break
|
| 295 |
+
if output is not None:
|
| 296 |
+
yield output
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def _format_text_generation_stream_output(
|
| 300 |
+
byte_payload: bytes, details: bool
|
| 301 |
+
) -> Optional[Union[str, TextGenerationStreamOutput]]:
|
| 302 |
+
if not byte_payload.startswith(b"data:"):
|
| 303 |
+
return None # empty line
|
| 304 |
+
|
| 305 |
+
if byte_payload.strip() == b"data: [DONE]":
|
| 306 |
+
raise StopIteration("[DONE] signal received.")
|
| 307 |
+
|
| 308 |
+
# Decode payload
|
| 309 |
+
payload = byte_payload.decode("utf-8")
|
| 310 |
+
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
| 311 |
+
|
| 312 |
+
# Either an error as being returned
|
| 313 |
+
if json_payload.get("error") is not None:
|
| 314 |
+
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
| 315 |
+
|
| 316 |
+
# Or parse token payload
|
| 317 |
+
output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload)
|
| 318 |
+
return output.token.text if not details else output
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _stream_chat_completion_response(
|
| 322 |
+
bytes_lines: Iterable[bytes],
|
| 323 |
+
) -> Iterable[ChatCompletionStreamOutput]:
|
| 324 |
+
"""Used in `InferenceClient.chat_completion` if model is served with TGI."""
|
| 325 |
+
for item in bytes_lines:
|
| 326 |
+
try:
|
| 327 |
+
output = _format_chat_completion_stream_output(item)
|
| 328 |
+
except StopIteration:
|
| 329 |
+
break
|
| 330 |
+
if output is not None:
|
| 331 |
+
yield output
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
async def _async_stream_chat_completion_response(
|
| 335 |
+
bytes_lines: AsyncIterable[bytes],
|
| 336 |
+
) -> AsyncIterable[ChatCompletionStreamOutput]:
|
| 337 |
+
"""Used in `AsyncInferenceClient.chat_completion`."""
|
| 338 |
+
async for item in bytes_lines:
|
| 339 |
+
try:
|
| 340 |
+
output = _format_chat_completion_stream_output(item)
|
| 341 |
+
except StopIteration:
|
| 342 |
+
break
|
| 343 |
+
if output is not None:
|
| 344 |
+
yield output
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def _format_chat_completion_stream_output(
|
| 348 |
+
byte_payload: bytes,
|
| 349 |
+
) -> Optional[ChatCompletionStreamOutput]:
|
| 350 |
+
if not byte_payload.startswith(b"data:"):
|
| 351 |
+
return None # empty line
|
| 352 |
+
|
| 353 |
+
if byte_payload.strip() == b"data: [DONE]":
|
| 354 |
+
raise StopIteration("[DONE] signal received.")
|
| 355 |
+
|
| 356 |
+
# Decode payload
|
| 357 |
+
payload = byte_payload.decode("utf-8")
|
| 358 |
+
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
| 359 |
+
|
| 360 |
+
# Either an error as being returned
|
| 361 |
+
if json_payload.get("error") is not None:
|
| 362 |
+
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
| 363 |
+
|
| 364 |
+
# Or parse token payload
|
| 365 |
+
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
|
| 369 |
+
async for byte_payload in response.content:
|
| 370 |
+
yield byte_payload.strip()
|
| 371 |
+
await client.close()
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# "TGI servers" are servers running with the `text-generation-inference` backend.
|
| 375 |
+
# This backend is the go-to solution to run large language models at scale. However,
|
| 376 |
+
# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference`
|
| 377 |
+
# solution is still in use.
|
| 378 |
+
#
|
| 379 |
+
# Both approaches have very similar APIs, but not exactly the same. What we do first in
|
| 380 |
+
# the `text_generation` method is to assume the model is served via TGI. If we realize
|
| 381 |
+
# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the
|
| 382 |
+
# default API with a warning message. When that's the case, We remember the unsupported
|
| 383 |
+
# attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable.
|
| 384 |
+
#
|
| 385 |
+
# In addition, TGI servers have a built-in API route for chat-completion, which is not
|
| 386 |
+
# available on the default API. We use this route to provide a more consistent behavior
|
| 387 |
+
# when available.
|
| 388 |
+
#
|
| 389 |
+
# For more details, see https://github.com/huggingface/text-generation-inference and
|
| 390 |
+
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
|
| 391 |
+
|
| 392 |
+
_UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {}
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None:
|
| 396 |
+
_UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
|
| 400 |
+
return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# TEXT GENERATION ERRORS
|
| 404 |
+
# ----------------------
|
| 405 |
+
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
|
| 406 |
+
# inference project (https://github.com/huggingface/text-generation-inference).
|
| 407 |
+
# ----------------------
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
|
| 411 |
+
"""
|
| 412 |
+
Try to parse text-generation-inference error message and raise HTTPError in any case.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
error (`HTTPError`):
|
| 416 |
+
The HTTPError that have been raised.
|
| 417 |
+
"""
|
| 418 |
+
# Try to parse a Text Generation Inference error
|
| 419 |
+
|
| 420 |
+
try:
|
| 421 |
+
# Hacky way to retrieve payload in case of aiohttp error
|
| 422 |
+
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
|
| 423 |
+
error = payload.get("error")
|
| 424 |
+
error_type = payload.get("error_type")
|
| 425 |
+
except Exception: # no payload
|
| 426 |
+
raise http_error
|
| 427 |
+
|
| 428 |
+
# If error_type => more information than `hf_raise_for_status`
|
| 429 |
+
if error_type is not None:
|
| 430 |
+
exception = _parse_text_generation_error(error, error_type)
|
| 431 |
+
raise exception from http_error
|
| 432 |
+
|
| 433 |
+
# Otherwise, fallback to default error
|
| 434 |
+
raise http_error
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
|
| 438 |
+
if error_type == "generation":
|
| 439 |
+
return GenerationError(error) # type: ignore
|
| 440 |
+
if error_type == "incomplete_generation":
|
| 441 |
+
return IncompleteGenerationError(error) # type: ignore
|
| 442 |
+
if error_type == "overloaded":
|
| 443 |
+
return OverloadedError(error) # type: ignore
|
| 444 |
+
if error_type == "validation":
|
| 445 |
+
return ValidationError(error) # type: ignore
|
| 446 |
+
return UnknownError(error) # type: ignore
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/audio_to_audio.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class AudioToAudioInput(BaseInferenceType):
|
| 14 |
+
"""Inputs for Audio to Audio inference"""
|
| 15 |
+
|
| 16 |
+
inputs: Any
|
| 17 |
+
"""The input audio data"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class AudioToAudioOutputElement(BaseInferenceType):
|
| 22 |
+
"""Outputs of inference for the Audio To Audio task
|
| 23 |
+
A generated audio file with its label.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
blob: Any
|
| 27 |
+
"""The generated audio file."""
|
| 28 |
+
content_type: str
|
| 29 |
+
"""The content type of audio file."""
|
| 30 |
+
label: str
|
| 31 |
+
"""The label of the audio file."""
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import List, Literal, Optional, Union
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType):
|
| 17 |
+
"""Parametrization of the text generation process"""
|
| 18 |
+
|
| 19 |
+
do_sample: Optional[bool] = None
|
| 20 |
+
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
|
| 21 |
+
early_stopping: Optional[Union[bool, "AutomaticSpeechRecognitionEarlyStoppingEnum"]] = None
|
| 22 |
+
"""Controls the stopping condition for beam-based methods."""
|
| 23 |
+
epsilon_cutoff: Optional[float] = None
|
| 24 |
+
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
|
| 25 |
+
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
|
| 26 |
+
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
|
| 27 |
+
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
|
| 28 |
+
"""
|
| 29 |
+
eta_cutoff: Optional[float] = None
|
| 30 |
+
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
|
| 31 |
+
float strictly between 0 and 1, a token is only considered if it is greater than either
|
| 32 |
+
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
|
| 33 |
+
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
|
| 34 |
+
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
|
| 35 |
+
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
|
| 36 |
+
for more details.
|
| 37 |
+
"""
|
| 38 |
+
max_length: Optional[int] = None
|
| 39 |
+
"""The maximum length (in tokens) of the generated text, including the input."""
|
| 40 |
+
max_new_tokens: Optional[int] = None
|
| 41 |
+
"""The maximum number of tokens to generate. Takes precedence over max_length."""
|
| 42 |
+
min_length: Optional[int] = None
|
| 43 |
+
"""The minimum length (in tokens) of the generated text, including the input."""
|
| 44 |
+
min_new_tokens: Optional[int] = None
|
| 45 |
+
"""The minimum number of tokens to generate. Takes precedence over min_length."""
|
| 46 |
+
num_beam_groups: Optional[int] = None
|
| 47 |
+
"""Number of groups to divide num_beams into in order to ensure diversity among different
|
| 48 |
+
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
|
| 49 |
+
"""
|
| 50 |
+
num_beams: Optional[int] = None
|
| 51 |
+
"""Number of beams to use for beam search."""
|
| 52 |
+
penalty_alpha: Optional[float] = None
|
| 53 |
+
"""The value balances the model confidence and the degeneration penalty in contrastive
|
| 54 |
+
search decoding.
|
| 55 |
+
"""
|
| 56 |
+
temperature: Optional[float] = None
|
| 57 |
+
"""The value used to modulate the next token probabilities."""
|
| 58 |
+
top_k: Optional[int] = None
|
| 59 |
+
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
| 60 |
+
top_p: Optional[float] = None
|
| 61 |
+
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
|
| 62 |
+
that add up to top_p or higher are kept for generation.
|
| 63 |
+
"""
|
| 64 |
+
typical_p: Optional[float] = None
|
| 65 |
+
"""Local typicality measures how similar the conditional probability of predicting a target
|
| 66 |
+
token next is to the expected conditional probability of predicting a random token next,
|
| 67 |
+
given the partial text already generated. If set to float < 1, the smallest set of the
|
| 68 |
+
most locally typical tokens with probabilities that add up to typical_p or higher are
|
| 69 |
+
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
|
| 70 |
+
"""
|
| 71 |
+
use_cache: Optional[bool] = None
|
| 72 |
+
"""Whether the model should use the past last key/values attentions to speed up decoding"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class AutomaticSpeechRecognitionParameters(BaseInferenceType):
|
| 77 |
+
"""Additional inference parameters for Automatic Speech Recognition"""
|
| 78 |
+
|
| 79 |
+
return_timestamps: Optional[bool] = None
|
| 80 |
+
"""Whether to output corresponding timestamps with the generated text"""
|
| 81 |
+
# Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
|
| 82 |
+
generate_kwargs: Optional[AutomaticSpeechRecognitionGenerationParameters] = None
|
| 83 |
+
"""Parametrization of the text generation process"""
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class AutomaticSpeechRecognitionInput(BaseInferenceType):
|
| 88 |
+
"""Inputs for Automatic Speech Recognition inference"""
|
| 89 |
+
|
| 90 |
+
inputs: str
|
| 91 |
+
"""The input audio data as a base64-encoded string. If no `parameters` are provided, you can
|
| 92 |
+
also provide the audio data as a raw bytes payload.
|
| 93 |
+
"""
|
| 94 |
+
parameters: Optional[AutomaticSpeechRecognitionParameters] = None
|
| 95 |
+
"""Additional inference parameters for Automatic Speech Recognition"""
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType):
|
| 100 |
+
text: str
|
| 101 |
+
"""A chunk of text identified by the model"""
|
| 102 |
+
timestamps: List[float]
|
| 103 |
+
"""The start and end timestamps corresponding with the text"""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class AutomaticSpeechRecognitionOutput(BaseInferenceType):
|
| 108 |
+
"""Outputs of inference for the Automatic Speech Recognition task"""
|
| 109 |
+
|
| 110 |
+
text: str
|
| 111 |
+
"""The recognized text."""
|
| 112 |
+
chunks: Optional[List[AutomaticSpeechRecognitionOutputChunk]] = None
|
| 113 |
+
"""When returnTimestamps is enabled, chunks contains a list of audio chunks identified by
|
| 114 |
+
the model.
|
| 115 |
+
"""
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/depth_estimation.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class DepthEstimationInput(BaseInferenceType):
|
| 14 |
+
"""Inputs for Depth Estimation inference"""
|
| 15 |
+
|
| 16 |
+
inputs: Any
|
| 17 |
+
"""The input image data"""
|
| 18 |
+
parameters: Optional[Dict[str, Any]] = None
|
| 19 |
+
"""Additional inference parameters for Depth Estimation"""
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class DepthEstimationOutput(BaseInferenceType):
|
| 24 |
+
"""Outputs of inference for the Depth Estimation task"""
|
| 25 |
+
|
| 26 |
+
depth: Any
|
| 27 |
+
"""The predicted depth as an image"""
|
| 28 |
+
predicted_depth: Any
|
| 29 |
+
"""The predicted depth as a tensor"""
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/fill_mask.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, List, Optional
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class FillMaskParameters(BaseInferenceType):
|
| 14 |
+
"""Additional inference parameters for Fill Mask"""
|
| 15 |
+
|
| 16 |
+
targets: Optional[List[str]] = None
|
| 17 |
+
"""When passed, the model will limit the scores to the passed targets instead of looking up
|
| 18 |
+
in the whole vocabulary. If the provided targets are not in the model vocab, they will be
|
| 19 |
+
tokenized and the first resulting token will be used (with a warning, and that might be
|
| 20 |
+
slower).
|
| 21 |
+
"""
|
| 22 |
+
top_k: Optional[int] = None
|
| 23 |
+
"""When passed, overrides the number of predictions to return."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class FillMaskInput(BaseInferenceType):
|
| 28 |
+
"""Inputs for Fill Mask inference"""
|
| 29 |
+
|
| 30 |
+
inputs: str
|
| 31 |
+
"""The text with masked tokens"""
|
| 32 |
+
parameters: Optional[FillMaskParameters] = None
|
| 33 |
+
"""Additional inference parameters for Fill Mask"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class FillMaskOutputElement(BaseInferenceType):
|
| 38 |
+
"""Outputs of inference for the Fill Mask task"""
|
| 39 |
+
|
| 40 |
+
score: float
|
| 41 |
+
"""The corresponding probability"""
|
| 42 |
+
sequence: str
|
| 43 |
+
"""The corresponding input with the mask token prediction."""
|
| 44 |
+
token: int
|
| 45 |
+
"""The predicted token id (to replace the masked one)."""
|
| 46 |
+
token_str: Any
|
| 47 |
+
fill_mask_output_token_str: Optional[str] = None
|
| 48 |
+
"""The predicted token (to replace the masked one)."""
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_segmentation.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Literal, Optional
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
ImageSegmentationSubtask = Literal["instance", "panoptic", "semantic"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ImageSegmentationParameters(BaseInferenceType):
|
| 17 |
+
"""Additional inference parameters for Image Segmentation"""
|
| 18 |
+
|
| 19 |
+
mask_threshold: Optional[float] = None
|
| 20 |
+
"""Threshold to use when turning the predicted masks into binary values."""
|
| 21 |
+
overlap_mask_area_threshold: Optional[float] = None
|
| 22 |
+
"""Mask overlap threshold to eliminate small, disconnected segments."""
|
| 23 |
+
subtask: Optional["ImageSegmentationSubtask"] = None
|
| 24 |
+
"""Segmentation task to be performed, depending on model capabilities."""
|
| 25 |
+
threshold: Optional[float] = None
|
| 26 |
+
"""Probability threshold to filter out predicted masks."""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ImageSegmentationInput(BaseInferenceType):
|
| 31 |
+
"""Inputs for Image Segmentation inference"""
|
| 32 |
+
|
| 33 |
+
inputs: str
|
| 34 |
+
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
|
| 35 |
+
also provide the image data as a raw bytes payload.
|
| 36 |
+
"""
|
| 37 |
+
parameters: Optional[ImageSegmentationParameters] = None
|
| 38 |
+
"""Additional inference parameters for Image Segmentation"""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class ImageSegmentationOutputElement(BaseInferenceType):
|
| 43 |
+
"""Outputs of inference for the Image Segmentation task
|
| 44 |
+
A predicted mask / segment
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
label: str
|
| 48 |
+
"""The label of the predicted segment."""
|
| 49 |
+
mask: str
|
| 50 |
+
"""The corresponding mask as a black-and-white image (base64-encoded)."""
|
| 51 |
+
score: Optional[float] = None
|
| 52 |
+
"""The score or confidence degree the model has."""
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_image.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ImageToImageTargetSize(BaseInferenceType):
|
| 14 |
+
"""The size in pixel of the output image."""
|
| 15 |
+
|
| 16 |
+
height: int
|
| 17 |
+
width: int
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class ImageToImageParameters(BaseInferenceType):
|
| 22 |
+
"""Additional inference parameters for Image To Image"""
|
| 23 |
+
|
| 24 |
+
guidance_scale: Optional[float] = None
|
| 25 |
+
"""For diffusion models. A higher guidance scale value encourages the model to generate
|
| 26 |
+
images closely linked to the text prompt at the expense of lower image quality.
|
| 27 |
+
"""
|
| 28 |
+
negative_prompt: Optional[str] = None
|
| 29 |
+
"""One prompt to guide what NOT to include in image generation."""
|
| 30 |
+
num_inference_steps: Optional[int] = None
|
| 31 |
+
"""For diffusion models. The number of denoising steps. More denoising steps usually lead to
|
| 32 |
+
a higher quality image at the expense of slower inference.
|
| 33 |
+
"""
|
| 34 |
+
target_size: Optional[ImageToImageTargetSize] = None
|
| 35 |
+
"""The size in pixel of the output image."""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class ImageToImageInput(BaseInferenceType):
|
| 40 |
+
"""Inputs for Image To Image inference"""
|
| 41 |
+
|
| 42 |
+
inputs: str
|
| 43 |
+
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
|
| 44 |
+
also provide the image data as a raw bytes payload.
|
| 45 |
+
"""
|
| 46 |
+
parameters: Optional[ImageToImageParameters] = None
|
| 47 |
+
"""Additional inference parameters for Image To Image"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class ImageToImageOutput(BaseInferenceType):
|
| 52 |
+
"""Outputs of inference for the Image To Image task"""
|
| 53 |
+
|
| 54 |
+
image: Any
|
| 55 |
+
"""The output image returned as raw bytes in the payload."""
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_text.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Literal, Optional, Union
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
ImageToTextEarlyStoppingEnum = Literal["never"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ImageToTextGenerationParameters(BaseInferenceType):
|
| 17 |
+
"""Parametrization of the text generation process"""
|
| 18 |
+
|
| 19 |
+
do_sample: Optional[bool] = None
|
| 20 |
+
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
|
| 21 |
+
early_stopping: Optional[Union[bool, "ImageToTextEarlyStoppingEnum"]] = None
|
| 22 |
+
"""Controls the stopping condition for beam-based methods."""
|
| 23 |
+
epsilon_cutoff: Optional[float] = None
|
| 24 |
+
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
|
| 25 |
+
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
|
| 26 |
+
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
|
| 27 |
+
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
|
| 28 |
+
"""
|
| 29 |
+
eta_cutoff: Optional[float] = None
|
| 30 |
+
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
|
| 31 |
+
float strictly between 0 and 1, a token is only considered if it is greater than either
|
| 32 |
+
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
|
| 33 |
+
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
|
| 34 |
+
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
|
| 35 |
+
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
|
| 36 |
+
for more details.
|
| 37 |
+
"""
|
| 38 |
+
max_length: Optional[int] = None
|
| 39 |
+
"""The maximum length (in tokens) of the generated text, including the input."""
|
| 40 |
+
max_new_tokens: Optional[int] = None
|
| 41 |
+
"""The maximum number of tokens to generate. Takes precedence over max_length."""
|
| 42 |
+
min_length: Optional[int] = None
|
| 43 |
+
"""The minimum length (in tokens) of the generated text, including the input."""
|
| 44 |
+
min_new_tokens: Optional[int] = None
|
| 45 |
+
"""The minimum number of tokens to generate. Takes precedence over min_length."""
|
| 46 |
+
num_beam_groups: Optional[int] = None
|
| 47 |
+
"""Number of groups to divide num_beams into in order to ensure diversity among different
|
| 48 |
+
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
|
| 49 |
+
"""
|
| 50 |
+
num_beams: Optional[int] = None
|
| 51 |
+
"""Number of beams to use for beam search."""
|
| 52 |
+
penalty_alpha: Optional[float] = None
|
| 53 |
+
"""The value balances the model confidence and the degeneration penalty in contrastive
|
| 54 |
+
search decoding.
|
| 55 |
+
"""
|
| 56 |
+
temperature: Optional[float] = None
|
| 57 |
+
"""The value used to modulate the next token probabilities."""
|
| 58 |
+
top_k: Optional[int] = None
|
| 59 |
+
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
| 60 |
+
top_p: Optional[float] = None
|
| 61 |
+
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
|
| 62 |
+
that add up to top_p or higher are kept for generation.
|
| 63 |
+
"""
|
| 64 |
+
typical_p: Optional[float] = None
|
| 65 |
+
"""Local typicality measures how similar the conditional probability of predicting a target
|
| 66 |
+
token next is to the expected conditional probability of predicting a random token next,
|
| 67 |
+
given the partial text already generated. If set to float < 1, the smallest set of the
|
| 68 |
+
most locally typical tokens with probabilities that add up to typical_p or higher are
|
| 69 |
+
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
|
| 70 |
+
"""
|
| 71 |
+
use_cache: Optional[bool] = None
|
| 72 |
+
"""Whether the model should use the past last key/values attentions to speed up decoding"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ImageToTextParameters(BaseInferenceType):
|
| 77 |
+
"""Additional inference parameters for Image To Text"""
|
| 78 |
+
|
| 79 |
+
max_new_tokens: Optional[int] = None
|
| 80 |
+
"""The amount of maximum tokens to generate."""
|
| 81 |
+
# Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
|
| 82 |
+
generate_kwargs: Optional[ImageToTextGenerationParameters] = None
|
| 83 |
+
"""Parametrization of the text generation process"""
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class ImageToTextInput(BaseInferenceType):
|
| 88 |
+
"""Inputs for Image To Text inference"""
|
| 89 |
+
|
| 90 |
+
inputs: Any
|
| 91 |
+
"""The input image data"""
|
| 92 |
+
parameters: Optional[ImageToTextParameters] = None
|
| 93 |
+
"""Additional inference parameters for Image To Text"""
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass
|
| 97 |
+
class ImageToTextOutput(BaseInferenceType):
|
| 98 |
+
"""Outputs of inference for the Image To Text task"""
|
| 99 |
+
|
| 100 |
+
generated_text: Any
|
| 101 |
+
image_to_text_output_generated_text: Optional[str] = None
|
| 102 |
+
"""The generated text."""
|
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/object_detection.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
| 2 |
+
#
|
| 3 |
+
# See:
|
| 4 |
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
| 5 |
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from .base import BaseInferenceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ObjectDetectionParameters(BaseInferenceType):
|
| 14 |
+
"""Additional inference parameters for Object Detection"""
|
| 15 |
+
|
| 16 |
+
threshold: Optional[float] = None
|
| 17 |
+
"""The probability necessary to make a prediction."""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class ObjectDetectionInput(BaseInferenceType):
|
| 22 |
+
"""Inputs for Object Detection inference"""
|
| 23 |
+
|
| 24 |
+
inputs: str
|
| 25 |
+
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
|
| 26 |
+
also provide the image data as a raw bytes payload.
|
| 27 |
+
"""
|
| 28 |
+
parameters: Optional[ObjectDetectionParameters] = None
|
| 29 |
+
"""Additional inference parameters for Object Detection"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class ObjectDetectionBoundingBox(BaseInferenceType):
|
| 34 |
+
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
| 35 |
+
image.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
xmax: int
|
| 39 |
+
"""The x-coordinate of the bottom-right corner of the bounding box."""
|
| 40 |
+
xmin: int
|
| 41 |
+
"""The x-coordinate of the top-left corner of the bounding box."""
|
| 42 |
+
ymax: int
|
| 43 |
+
"""The y-coordinate of the bottom-right corner of the bounding box."""
|
| 44 |
+
ymin: int
|
| 45 |
+
"""The y-coordinate of the top-left corner of the bounding box."""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class ObjectDetectionOutputElement(BaseInferenceType):
|
| 50 |
+
"""Outputs of inference for the Object Detection task"""
|
| 51 |
+
|
| 52 |
+
box: ObjectDetectionBoundingBox
|
| 53 |
+
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
| 54 |
+
image.
|
| 55 |
+
"""
|
| 56 |
+
label: str
|
| 57 |
+
"""The predicted label for the bounding box."""
|
| 58 |
+
score: float
|
| 59 |
+
"""The associated score / probability."""
|