koichi12 commited on
Commit
9a8eae1
·
verified ·
1 Parent(s): 80a73eb

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__main__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/api.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/cd.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/constant.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/legacy.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/md.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/models.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/utils.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/version.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/charset_normalizer/cli/__init__.py +8 -0
  13. .venv/lib/python3.11/site-packages/charset_normalizer/cli/__main__.py +321 -0
  14. .venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__main__.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/huggingface_hub/__init__.py +1434 -0
  17. .venv/lib/python3.11/site-packages/huggingface_hub/_commit_api.py +758 -0
  18. .venv/lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py +353 -0
  19. .venv/lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py +402 -0
  20. .venv/lib/python3.11/site-packages/huggingface_hub/_local_folder.py +432 -0
  21. .venv/lib/python3.11/site-packages/huggingface_hub/_login.py +520 -0
  22. .venv/lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py +307 -0
  23. .venv/lib/python3.11/site-packages/huggingface_hub/_space_api.py +160 -0
  24. .venv/lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py +194 -0
  25. .venv/lib/python3.11/site-packages/huggingface_hub/_upload_large_folder.py +621 -0
  26. .venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py +137 -0
  27. .venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py +386 -0
  28. .venv/lib/python3.11/site-packages/huggingface_hub/community.py +355 -0
  29. .venv/lib/python3.11/site-packages/huggingface_hub/constants.py +229 -0
  30. .venv/lib/python3.11/site-packages/huggingface_hub/errors.py +329 -0
  31. .venv/lib/python3.11/site-packages/huggingface_hub/fastai_utils.py +425 -0
  32. .venv/lib/python3.11/site-packages/huggingface_hub/file_download.py +1621 -0
  33. .venv/lib/python3.11/site-packages/huggingface_hub/hf_api.py +0 -0
  34. .venv/lib/python3.11/site-packages/huggingface_hub/hf_file_system.py +1140 -0
  35. .venv/lib/python3.11/site-packages/huggingface_hub/hub_mixin.py +836 -0
  36. .venv/lib/python3.11/site-packages/huggingface_hub/inference/__init__.py +0 -0
  37. .venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_client.py +0 -0
  40. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_common.py +446 -0
  41. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py +0 -0
  42. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py +0 -0
  43. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
  44. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +115 -0
  45. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
  46. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/fill_mask.py +48 -0
  47. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
  48. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
  49. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_text.py +102 -0
  50. .venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/object_detection.py +59 -0
.gitattributes CHANGED
@@ -122,3 +122,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
122
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
123
  .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
124
  .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
 
 
122
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
123
  .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
124
  .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
125
+ .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.9 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/__main__.cpython-311.pyc ADDED
Binary file (394 Bytes). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/api.cpython-311.pyc ADDED
Binary file (20.8 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/cd.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/constant.cpython-311.pyc ADDED
Binary file (43.6 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/legacy.cpython-311.pyc ADDED
Binary file (3.15 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/md.cpython-311.pyc ADDED
Binary file (27.6 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/models.cpython-311.pyc ADDED
Binary file (18.6 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/utils.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/__pycache__/version.cpython-311.pyc ADDED
Binary file (400 Bytes). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from .__main__ import cli_detect, query_yes_no
4
+
5
+ __all__ = (
6
+ "cli_detect",
7
+ "query_yes_no",
8
+ )
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__main__.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import sys
5
+ from json import dumps
6
+ from os.path import abspath, basename, dirname, join, realpath
7
+ from platform import python_version
8
+ from unicodedata import unidata_version
9
+
10
+ import charset_normalizer.md as md_module
11
+ from charset_normalizer import from_fp
12
+ from charset_normalizer.models import CliDetectionResult
13
+ from charset_normalizer.version import __version__
14
+
15
+
16
+ def query_yes_no(question: str, default: str = "yes") -> bool:
17
+ """Ask a yes/no question via input() and return their answer.
18
+
19
+ "question" is a string that is presented to the user.
20
+ "default" is the presumed answer if the user just hits <Enter>.
21
+ It must be "yes" (the default), "no" or None (meaning
22
+ an answer is required of the user).
23
+
24
+ The "answer" return value is True for "yes" or False for "no".
25
+
26
+ Credit goes to (c) https://stackoverflow.com/questions/3041986/apt-command-line-interface-like-yes-no-input
27
+ """
28
+ valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
29
+ if default is None:
30
+ prompt = " [y/n] "
31
+ elif default == "yes":
32
+ prompt = " [Y/n] "
33
+ elif default == "no":
34
+ prompt = " [y/N] "
35
+ else:
36
+ raise ValueError("invalid default answer: '%s'" % default)
37
+
38
+ while True:
39
+ sys.stdout.write(question + prompt)
40
+ choice = input().lower()
41
+ if default is not None and choice == "":
42
+ return valid[default]
43
+ elif choice in valid:
44
+ return valid[choice]
45
+ else:
46
+ sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
47
+
48
+
49
+ def cli_detect(argv: list[str] | None = None) -> int:
50
+ """
51
+ CLI assistant using ARGV and ArgumentParser
52
+ :param argv:
53
+ :return: 0 if everything is fine, anything else equal trouble
54
+ """
55
+ parser = argparse.ArgumentParser(
56
+ description="The Real First Universal Charset Detector. "
57
+ "Discover originating encoding used on text file. "
58
+ "Normalize text to unicode."
59
+ )
60
+
61
+ parser.add_argument(
62
+ "files", type=argparse.FileType("rb"), nargs="+", help="File(s) to be analysed"
63
+ )
64
+ parser.add_argument(
65
+ "-v",
66
+ "--verbose",
67
+ action="store_true",
68
+ default=False,
69
+ dest="verbose",
70
+ help="Display complementary information about file if any. "
71
+ "Stdout will contain logs about the detection process.",
72
+ )
73
+ parser.add_argument(
74
+ "-a",
75
+ "--with-alternative",
76
+ action="store_true",
77
+ default=False,
78
+ dest="alternatives",
79
+ help="Output complementary possibilities if any. Top-level JSON WILL be a list.",
80
+ )
81
+ parser.add_argument(
82
+ "-n",
83
+ "--normalize",
84
+ action="store_true",
85
+ default=False,
86
+ dest="normalize",
87
+ help="Permit to normalize input file. If not set, program does not write anything.",
88
+ )
89
+ parser.add_argument(
90
+ "-m",
91
+ "--minimal",
92
+ action="store_true",
93
+ default=False,
94
+ dest="minimal",
95
+ help="Only output the charset detected to STDOUT. Disabling JSON output.",
96
+ )
97
+ parser.add_argument(
98
+ "-r",
99
+ "--replace",
100
+ action="store_true",
101
+ default=False,
102
+ dest="replace",
103
+ help="Replace file when trying to normalize it instead of creating a new one.",
104
+ )
105
+ parser.add_argument(
106
+ "-f",
107
+ "--force",
108
+ action="store_true",
109
+ default=False,
110
+ dest="force",
111
+ help="Replace file without asking if you are sure, use this flag with caution.",
112
+ )
113
+ parser.add_argument(
114
+ "-i",
115
+ "--no-preemptive",
116
+ action="store_true",
117
+ default=False,
118
+ dest="no_preemptive",
119
+ help="Disable looking at a charset declaration to hint the detector.",
120
+ )
121
+ parser.add_argument(
122
+ "-t",
123
+ "--threshold",
124
+ action="store",
125
+ default=0.2,
126
+ type=float,
127
+ dest="threshold",
128
+ help="Define a custom maximum amount of noise allowed in decoded content. 0. <= noise <= 1.",
129
+ )
130
+ parser.add_argument(
131
+ "--version",
132
+ action="version",
133
+ version="Charset-Normalizer {} - Python {} - Unicode {} - SpeedUp {}".format(
134
+ __version__,
135
+ python_version(),
136
+ unidata_version,
137
+ "OFF" if md_module.__file__.lower().endswith(".py") else "ON",
138
+ ),
139
+ help="Show version information and exit.",
140
+ )
141
+
142
+ args = parser.parse_args(argv)
143
+
144
+ if args.replace is True and args.normalize is False:
145
+ if args.files:
146
+ for my_file in args.files:
147
+ my_file.close()
148
+ print("Use --replace in addition of --normalize only.", file=sys.stderr)
149
+ return 1
150
+
151
+ if args.force is True and args.replace is False:
152
+ if args.files:
153
+ for my_file in args.files:
154
+ my_file.close()
155
+ print("Use --force in addition of --replace only.", file=sys.stderr)
156
+ return 1
157
+
158
+ if args.threshold < 0.0 or args.threshold > 1.0:
159
+ if args.files:
160
+ for my_file in args.files:
161
+ my_file.close()
162
+ print("--threshold VALUE should be between 0. AND 1.", file=sys.stderr)
163
+ return 1
164
+
165
+ x_ = []
166
+
167
+ for my_file in args.files:
168
+ matches = from_fp(
169
+ my_file,
170
+ threshold=args.threshold,
171
+ explain=args.verbose,
172
+ preemptive_behaviour=args.no_preemptive is False,
173
+ )
174
+
175
+ best_guess = matches.best()
176
+
177
+ if best_guess is None:
178
+ print(
179
+ 'Unable to identify originating encoding for "{}". {}'.format(
180
+ my_file.name,
181
+ (
182
+ "Maybe try increasing maximum amount of chaos."
183
+ if args.threshold < 1.0
184
+ else ""
185
+ ),
186
+ ),
187
+ file=sys.stderr,
188
+ )
189
+ x_.append(
190
+ CliDetectionResult(
191
+ abspath(my_file.name),
192
+ None,
193
+ [],
194
+ [],
195
+ "Unknown",
196
+ [],
197
+ False,
198
+ 1.0,
199
+ 0.0,
200
+ None,
201
+ True,
202
+ )
203
+ )
204
+ else:
205
+ x_.append(
206
+ CliDetectionResult(
207
+ abspath(my_file.name),
208
+ best_guess.encoding,
209
+ best_guess.encoding_aliases,
210
+ [
211
+ cp
212
+ for cp in best_guess.could_be_from_charset
213
+ if cp != best_guess.encoding
214
+ ],
215
+ best_guess.language,
216
+ best_guess.alphabets,
217
+ best_guess.bom,
218
+ best_guess.percent_chaos,
219
+ best_guess.percent_coherence,
220
+ None,
221
+ True,
222
+ )
223
+ )
224
+
225
+ if len(matches) > 1 and args.alternatives:
226
+ for el in matches:
227
+ if el != best_guess:
228
+ x_.append(
229
+ CliDetectionResult(
230
+ abspath(my_file.name),
231
+ el.encoding,
232
+ el.encoding_aliases,
233
+ [
234
+ cp
235
+ for cp in el.could_be_from_charset
236
+ if cp != el.encoding
237
+ ],
238
+ el.language,
239
+ el.alphabets,
240
+ el.bom,
241
+ el.percent_chaos,
242
+ el.percent_coherence,
243
+ None,
244
+ False,
245
+ )
246
+ )
247
+
248
+ if args.normalize is True:
249
+ if best_guess.encoding.startswith("utf") is True:
250
+ print(
251
+ '"{}" file does not need to be normalized, as it already came from unicode.'.format(
252
+ my_file.name
253
+ ),
254
+ file=sys.stderr,
255
+ )
256
+ if my_file.closed is False:
257
+ my_file.close()
258
+ continue
259
+
260
+ dir_path = dirname(realpath(my_file.name))
261
+ file_name = basename(realpath(my_file.name))
262
+
263
+ o_: list[str] = file_name.split(".")
264
+
265
+ if args.replace is False:
266
+ o_.insert(-1, best_guess.encoding)
267
+ if my_file.closed is False:
268
+ my_file.close()
269
+ elif (
270
+ args.force is False
271
+ and query_yes_no(
272
+ 'Are you sure to normalize "{}" by replacing it ?'.format(
273
+ my_file.name
274
+ ),
275
+ "no",
276
+ )
277
+ is False
278
+ ):
279
+ if my_file.closed is False:
280
+ my_file.close()
281
+ continue
282
+
283
+ try:
284
+ x_[0].unicode_path = join(dir_path, ".".join(o_))
285
+
286
+ with open(x_[0].unicode_path, "wb") as fp:
287
+ fp.write(best_guess.output())
288
+ except OSError as e:
289
+ print(str(e), file=sys.stderr)
290
+ if my_file.closed is False:
291
+ my_file.close()
292
+ return 2
293
+
294
+ if my_file.closed is False:
295
+ my_file.close()
296
+
297
+ if args.minimal is False:
298
+ print(
299
+ dumps(
300
+ [el.__dict__ for el in x_] if len(x_) > 1 else x_[0].__dict__,
301
+ ensure_ascii=True,
302
+ indent=4,
303
+ )
304
+ )
305
+ else:
306
+ for my_file in args.files:
307
+ print(
308
+ ", ".join(
309
+ [
310
+ el.encoding or "undefined"
311
+ for el in x_
312
+ if el.path == abspath(my_file.name)
313
+ ]
314
+ )
315
+ )
316
+
317
+ return 0
318
+
319
+
320
+ if __name__ == "__main__":
321
+ cli_detect()
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (366 Bytes). View file
 
.venv/lib/python3.11/site-packages/charset_normalizer/cli/__pycache__/__main__.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/__init__.py ADDED
@@ -0,0 +1,1434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # ***********
16
+ # `huggingface_hub` init has 2 modes:
17
+ # - Normal usage:
18
+ # If imported to use it, all modules and functions are lazy-loaded. This means
19
+ # they exist at top level in module but are imported only the first time they are
20
+ # used. This way, `from huggingface_hub import something` will import `something`
21
+ # quickly without the hassle of importing all the features from `huggingface_hub`.
22
+ # - Static check:
23
+ # If statically analyzed, all modules and functions are loaded normally. This way
24
+ # static typing check works properly as well as autocomplete in text editors and
25
+ # IDEs.
26
+ #
27
+ # The static model imports are done inside the `if TYPE_CHECKING:` statement at
28
+ # the bottom of this file. Since module/functions imports are duplicated, it is
29
+ # mandatory to make sure to add them twice when adding one. This is checked in the
30
+ # `make quality` command.
31
+ #
32
+ # To update the static imports, please run the following command and commit the changes.
33
+ # ```
34
+ # # Use script
35
+ # python utils/check_static_imports.py --update-file
36
+ #
37
+ # # Or run style on codebase
38
+ # make style
39
+ # ```
40
+ #
41
+ # ***********
42
+ # Lazy loader vendored from https://github.com/scientific-python/lazy_loader
43
+ import importlib
44
+ import os
45
+ import sys
46
+ from typing import TYPE_CHECKING
47
+
48
+
49
+ __version__ = "0.28.1"
50
+
51
+ # Alphabetical order of definitions is ensured in tests
52
+ # WARNING: any comment added in this dictionary definition will be lost when
53
+ # re-generating the file !
54
+ _SUBMOD_ATTRS = {
55
+ "_commit_scheduler": [
56
+ "CommitScheduler",
57
+ ],
58
+ "_inference_endpoints": [
59
+ "InferenceEndpoint",
60
+ "InferenceEndpointError",
61
+ "InferenceEndpointStatus",
62
+ "InferenceEndpointTimeoutError",
63
+ "InferenceEndpointType",
64
+ ],
65
+ "_login": [
66
+ "auth_list",
67
+ "auth_switch",
68
+ "interpreter_login",
69
+ "login",
70
+ "logout",
71
+ "notebook_login",
72
+ ],
73
+ "_snapshot_download": [
74
+ "snapshot_download",
75
+ ],
76
+ "_space_api": [
77
+ "SpaceHardware",
78
+ "SpaceRuntime",
79
+ "SpaceStage",
80
+ "SpaceStorage",
81
+ "SpaceVariable",
82
+ ],
83
+ "_tensorboard_logger": [
84
+ "HFSummaryWriter",
85
+ ],
86
+ "_webhooks_payload": [
87
+ "WebhookPayload",
88
+ "WebhookPayloadComment",
89
+ "WebhookPayloadDiscussion",
90
+ "WebhookPayloadDiscussionChanges",
91
+ "WebhookPayloadEvent",
92
+ "WebhookPayloadMovedTo",
93
+ "WebhookPayloadRepo",
94
+ "WebhookPayloadUrl",
95
+ "WebhookPayloadWebhook",
96
+ ],
97
+ "_webhooks_server": [
98
+ "WebhooksServer",
99
+ "webhook_endpoint",
100
+ ],
101
+ "community": [
102
+ "Discussion",
103
+ "DiscussionComment",
104
+ "DiscussionCommit",
105
+ "DiscussionEvent",
106
+ "DiscussionStatusChange",
107
+ "DiscussionTitleChange",
108
+ "DiscussionWithDetails",
109
+ ],
110
+ "constants": [
111
+ "CONFIG_NAME",
112
+ "FLAX_WEIGHTS_NAME",
113
+ "HUGGINGFACE_CO_URL_HOME",
114
+ "HUGGINGFACE_CO_URL_TEMPLATE",
115
+ "PYTORCH_WEIGHTS_NAME",
116
+ "REPO_TYPE_DATASET",
117
+ "REPO_TYPE_MODEL",
118
+ "REPO_TYPE_SPACE",
119
+ "TF2_WEIGHTS_NAME",
120
+ "TF_WEIGHTS_NAME",
121
+ ],
122
+ "fastai_utils": [
123
+ "_save_pretrained_fastai",
124
+ "from_pretrained_fastai",
125
+ "push_to_hub_fastai",
126
+ ],
127
+ "file_download": [
128
+ "HfFileMetadata",
129
+ "_CACHED_NO_EXIST",
130
+ "get_hf_file_metadata",
131
+ "hf_hub_download",
132
+ "hf_hub_url",
133
+ "try_to_load_from_cache",
134
+ ],
135
+ "hf_api": [
136
+ "Collection",
137
+ "CollectionItem",
138
+ "CommitInfo",
139
+ "CommitOperation",
140
+ "CommitOperationAdd",
141
+ "CommitOperationCopy",
142
+ "CommitOperationDelete",
143
+ "DatasetInfo",
144
+ "GitCommitInfo",
145
+ "GitRefInfo",
146
+ "GitRefs",
147
+ "HfApi",
148
+ "ModelInfo",
149
+ "RepoUrl",
150
+ "SpaceInfo",
151
+ "User",
152
+ "UserLikes",
153
+ "WebhookInfo",
154
+ "WebhookWatchedItem",
155
+ "accept_access_request",
156
+ "add_collection_item",
157
+ "add_space_secret",
158
+ "add_space_variable",
159
+ "auth_check",
160
+ "cancel_access_request",
161
+ "change_discussion_status",
162
+ "comment_discussion",
163
+ "create_branch",
164
+ "create_collection",
165
+ "create_commit",
166
+ "create_discussion",
167
+ "create_inference_endpoint",
168
+ "create_pull_request",
169
+ "create_repo",
170
+ "create_tag",
171
+ "create_webhook",
172
+ "dataset_info",
173
+ "delete_branch",
174
+ "delete_collection",
175
+ "delete_collection_item",
176
+ "delete_file",
177
+ "delete_folder",
178
+ "delete_inference_endpoint",
179
+ "delete_repo",
180
+ "delete_space_secret",
181
+ "delete_space_storage",
182
+ "delete_space_variable",
183
+ "delete_tag",
184
+ "delete_webhook",
185
+ "disable_webhook",
186
+ "duplicate_space",
187
+ "edit_discussion_comment",
188
+ "enable_webhook",
189
+ "file_exists",
190
+ "get_collection",
191
+ "get_dataset_tags",
192
+ "get_discussion_details",
193
+ "get_full_repo_name",
194
+ "get_inference_endpoint",
195
+ "get_model_tags",
196
+ "get_paths_info",
197
+ "get_repo_discussions",
198
+ "get_safetensors_metadata",
199
+ "get_space_runtime",
200
+ "get_space_variables",
201
+ "get_token_permission",
202
+ "get_user_overview",
203
+ "get_webhook",
204
+ "grant_access",
205
+ "list_accepted_access_requests",
206
+ "list_collections",
207
+ "list_datasets",
208
+ "list_inference_endpoints",
209
+ "list_liked_repos",
210
+ "list_models",
211
+ "list_organization_members",
212
+ "list_papers",
213
+ "list_pending_access_requests",
214
+ "list_rejected_access_requests",
215
+ "list_repo_commits",
216
+ "list_repo_files",
217
+ "list_repo_likers",
218
+ "list_repo_refs",
219
+ "list_repo_tree",
220
+ "list_spaces",
221
+ "list_user_followers",
222
+ "list_user_following",
223
+ "list_webhooks",
224
+ "merge_pull_request",
225
+ "model_info",
226
+ "move_repo",
227
+ "paper_info",
228
+ "parse_safetensors_file_metadata",
229
+ "pause_inference_endpoint",
230
+ "pause_space",
231
+ "preupload_lfs_files",
232
+ "reject_access_request",
233
+ "rename_discussion",
234
+ "repo_exists",
235
+ "repo_info",
236
+ "repo_type_and_id_from_hf_id",
237
+ "request_space_hardware",
238
+ "request_space_storage",
239
+ "restart_space",
240
+ "resume_inference_endpoint",
241
+ "revision_exists",
242
+ "run_as_future",
243
+ "scale_to_zero_inference_endpoint",
244
+ "set_space_sleep_time",
245
+ "space_info",
246
+ "super_squash_history",
247
+ "unlike",
248
+ "update_collection_item",
249
+ "update_collection_metadata",
250
+ "update_inference_endpoint",
251
+ "update_repo_settings",
252
+ "update_repo_visibility",
253
+ "update_webhook",
254
+ "upload_file",
255
+ "upload_folder",
256
+ "upload_large_folder",
257
+ "whoami",
258
+ ],
259
+ "hf_file_system": [
260
+ "HfFileSystem",
261
+ "HfFileSystemFile",
262
+ "HfFileSystemResolvedPath",
263
+ "HfFileSystemStreamFile",
264
+ ],
265
+ "hub_mixin": [
266
+ "ModelHubMixin",
267
+ "PyTorchModelHubMixin",
268
+ ],
269
+ "inference._client": [
270
+ "InferenceClient",
271
+ "InferenceTimeoutError",
272
+ ],
273
+ "inference._generated._async_client": [
274
+ "AsyncInferenceClient",
275
+ ],
276
+ "inference._generated.types": [
277
+ "AudioClassificationInput",
278
+ "AudioClassificationOutputElement",
279
+ "AudioClassificationOutputTransform",
280
+ "AudioClassificationParameters",
281
+ "AudioToAudioInput",
282
+ "AudioToAudioOutputElement",
283
+ "AutomaticSpeechRecognitionEarlyStoppingEnum",
284
+ "AutomaticSpeechRecognitionGenerationParameters",
285
+ "AutomaticSpeechRecognitionInput",
286
+ "AutomaticSpeechRecognitionOutput",
287
+ "AutomaticSpeechRecognitionOutputChunk",
288
+ "AutomaticSpeechRecognitionParameters",
289
+ "ChatCompletionInput",
290
+ "ChatCompletionInputFunctionDefinition",
291
+ "ChatCompletionInputFunctionName",
292
+ "ChatCompletionInputGrammarType",
293
+ "ChatCompletionInputGrammarTypeType",
294
+ "ChatCompletionInputMessage",
295
+ "ChatCompletionInputMessageChunk",
296
+ "ChatCompletionInputMessageChunkType",
297
+ "ChatCompletionInputStreamOptions",
298
+ "ChatCompletionInputTool",
299
+ "ChatCompletionInputToolChoiceClass",
300
+ "ChatCompletionInputToolChoiceEnum",
301
+ "ChatCompletionInputURL",
302
+ "ChatCompletionOutput",
303
+ "ChatCompletionOutputComplete",
304
+ "ChatCompletionOutputFunctionDefinition",
305
+ "ChatCompletionOutputLogprob",
306
+ "ChatCompletionOutputLogprobs",
307
+ "ChatCompletionOutputMessage",
308
+ "ChatCompletionOutputToolCall",
309
+ "ChatCompletionOutputTopLogprob",
310
+ "ChatCompletionOutputUsage",
311
+ "ChatCompletionStreamOutput",
312
+ "ChatCompletionStreamOutputChoice",
313
+ "ChatCompletionStreamOutputDelta",
314
+ "ChatCompletionStreamOutputDeltaToolCall",
315
+ "ChatCompletionStreamOutputFunction",
316
+ "ChatCompletionStreamOutputLogprob",
317
+ "ChatCompletionStreamOutputLogprobs",
318
+ "ChatCompletionStreamOutputTopLogprob",
319
+ "ChatCompletionStreamOutputUsage",
320
+ "DepthEstimationInput",
321
+ "DepthEstimationOutput",
322
+ "DocumentQuestionAnsweringInput",
323
+ "DocumentQuestionAnsweringInputData",
324
+ "DocumentQuestionAnsweringOutputElement",
325
+ "DocumentQuestionAnsweringParameters",
326
+ "FeatureExtractionInput",
327
+ "FeatureExtractionInputTruncationDirection",
328
+ "FillMaskInput",
329
+ "FillMaskOutputElement",
330
+ "FillMaskParameters",
331
+ "ImageClassificationInput",
332
+ "ImageClassificationOutputElement",
333
+ "ImageClassificationOutputTransform",
334
+ "ImageClassificationParameters",
335
+ "ImageSegmentationInput",
336
+ "ImageSegmentationOutputElement",
337
+ "ImageSegmentationParameters",
338
+ "ImageSegmentationSubtask",
339
+ "ImageToImageInput",
340
+ "ImageToImageOutput",
341
+ "ImageToImageParameters",
342
+ "ImageToImageTargetSize",
343
+ "ImageToTextEarlyStoppingEnum",
344
+ "ImageToTextGenerationParameters",
345
+ "ImageToTextInput",
346
+ "ImageToTextOutput",
347
+ "ImageToTextParameters",
348
+ "ObjectDetectionBoundingBox",
349
+ "ObjectDetectionInput",
350
+ "ObjectDetectionOutputElement",
351
+ "ObjectDetectionParameters",
352
+ "Padding",
353
+ "QuestionAnsweringInput",
354
+ "QuestionAnsweringInputData",
355
+ "QuestionAnsweringOutputElement",
356
+ "QuestionAnsweringParameters",
357
+ "SentenceSimilarityInput",
358
+ "SentenceSimilarityInputData",
359
+ "SummarizationInput",
360
+ "SummarizationOutput",
361
+ "SummarizationParameters",
362
+ "SummarizationTruncationStrategy",
363
+ "TableQuestionAnsweringInput",
364
+ "TableQuestionAnsweringInputData",
365
+ "TableQuestionAnsweringOutputElement",
366
+ "TableQuestionAnsweringParameters",
367
+ "Text2TextGenerationInput",
368
+ "Text2TextGenerationOutput",
369
+ "Text2TextGenerationParameters",
370
+ "Text2TextGenerationTruncationStrategy",
371
+ "TextClassificationInput",
372
+ "TextClassificationOutputElement",
373
+ "TextClassificationOutputTransform",
374
+ "TextClassificationParameters",
375
+ "TextGenerationInput",
376
+ "TextGenerationInputGenerateParameters",
377
+ "TextGenerationInputGrammarType",
378
+ "TextGenerationOutput",
379
+ "TextGenerationOutputBestOfSequence",
380
+ "TextGenerationOutputDetails",
381
+ "TextGenerationOutputFinishReason",
382
+ "TextGenerationOutputPrefillToken",
383
+ "TextGenerationOutputToken",
384
+ "TextGenerationStreamOutput",
385
+ "TextGenerationStreamOutputStreamDetails",
386
+ "TextGenerationStreamOutputToken",
387
+ "TextToAudioEarlyStoppingEnum",
388
+ "TextToAudioGenerationParameters",
389
+ "TextToAudioInput",
390
+ "TextToAudioOutput",
391
+ "TextToAudioParameters",
392
+ "TextToImageInput",
393
+ "TextToImageOutput",
394
+ "TextToImageParameters",
395
+ "TextToImageTargetSize",
396
+ "TextToSpeechEarlyStoppingEnum",
397
+ "TextToSpeechGenerationParameters",
398
+ "TextToSpeechInput",
399
+ "TextToSpeechOutput",
400
+ "TextToSpeechParameters",
401
+ "TextToVideoInput",
402
+ "TextToVideoOutput",
403
+ "TextToVideoParameters",
404
+ "TokenClassificationAggregationStrategy",
405
+ "TokenClassificationInput",
406
+ "TokenClassificationOutputElement",
407
+ "TokenClassificationParameters",
408
+ "TranslationInput",
409
+ "TranslationOutput",
410
+ "TranslationParameters",
411
+ "TranslationTruncationStrategy",
412
+ "TypeEnum",
413
+ "VideoClassificationInput",
414
+ "VideoClassificationOutputElement",
415
+ "VideoClassificationOutputTransform",
416
+ "VideoClassificationParameters",
417
+ "VisualQuestionAnsweringInput",
418
+ "VisualQuestionAnsweringInputData",
419
+ "VisualQuestionAnsweringOutputElement",
420
+ "VisualQuestionAnsweringParameters",
421
+ "ZeroShotClassificationInput",
422
+ "ZeroShotClassificationOutputElement",
423
+ "ZeroShotClassificationParameters",
424
+ "ZeroShotImageClassificationInput",
425
+ "ZeroShotImageClassificationOutputElement",
426
+ "ZeroShotImageClassificationParameters",
427
+ "ZeroShotObjectDetectionBoundingBox",
428
+ "ZeroShotObjectDetectionInput",
429
+ "ZeroShotObjectDetectionOutputElement",
430
+ "ZeroShotObjectDetectionParameters",
431
+ ],
432
+ "inference_api": [
433
+ "InferenceApi",
434
+ ],
435
+ "keras_mixin": [
436
+ "KerasModelHubMixin",
437
+ "from_pretrained_keras",
438
+ "push_to_hub_keras",
439
+ "save_pretrained_keras",
440
+ ],
441
+ "repocard": [
442
+ "DatasetCard",
443
+ "ModelCard",
444
+ "RepoCard",
445
+ "SpaceCard",
446
+ "metadata_eval_result",
447
+ "metadata_load",
448
+ "metadata_save",
449
+ "metadata_update",
450
+ ],
451
+ "repocard_data": [
452
+ "CardData",
453
+ "DatasetCardData",
454
+ "EvalResult",
455
+ "ModelCardData",
456
+ "SpaceCardData",
457
+ ],
458
+ "repository": [
459
+ "Repository",
460
+ ],
461
+ "serialization": [
462
+ "StateDictSplit",
463
+ "get_tf_storage_size",
464
+ "get_torch_storage_id",
465
+ "get_torch_storage_size",
466
+ "load_state_dict_from_file",
467
+ "load_torch_model",
468
+ "save_torch_model",
469
+ "save_torch_state_dict",
470
+ "split_state_dict_into_shards_factory",
471
+ "split_tf_state_dict_into_shards",
472
+ "split_torch_state_dict_into_shards",
473
+ ],
474
+ "serialization._dduf": [
475
+ "DDUFEntry",
476
+ "export_entries_as_dduf",
477
+ "export_folder_as_dduf",
478
+ "read_dduf_file",
479
+ ],
480
+ "utils": [
481
+ "CacheNotFound",
482
+ "CachedFileInfo",
483
+ "CachedRepoInfo",
484
+ "CachedRevisionInfo",
485
+ "CorruptedCacheException",
486
+ "DeleteCacheStrategy",
487
+ "HFCacheInfo",
488
+ "HfFolder",
489
+ "cached_assets_path",
490
+ "configure_http_backend",
491
+ "dump_environment_info",
492
+ "get_session",
493
+ "get_token",
494
+ "logging",
495
+ "scan_cache_dir",
496
+ ],
497
+ }
498
+
499
+ # WARNING: __all__ is generated automatically, Any manual edit will be lost when re-generating this file !
500
+ #
501
+ # To update the static imports, please run the following command and commit the changes.
502
+ # ```
503
+ # # Use script
504
+ # python utils/check_all_variable.py --update
505
+ #
506
+ # # Or run style on codebase
507
+ # make style
508
+ # ```
509
+
510
+ __all__ = [
511
+ "AsyncInferenceClient",
512
+ "AudioClassificationInput",
513
+ "AudioClassificationOutputElement",
514
+ "AudioClassificationOutputTransform",
515
+ "AudioClassificationParameters",
516
+ "AudioToAudioInput",
517
+ "AudioToAudioOutputElement",
518
+ "AutomaticSpeechRecognitionEarlyStoppingEnum",
519
+ "AutomaticSpeechRecognitionGenerationParameters",
520
+ "AutomaticSpeechRecognitionInput",
521
+ "AutomaticSpeechRecognitionOutput",
522
+ "AutomaticSpeechRecognitionOutputChunk",
523
+ "AutomaticSpeechRecognitionParameters",
524
+ "CONFIG_NAME",
525
+ "CacheNotFound",
526
+ "CachedFileInfo",
527
+ "CachedRepoInfo",
528
+ "CachedRevisionInfo",
529
+ "CardData",
530
+ "ChatCompletionInput",
531
+ "ChatCompletionInputFunctionDefinition",
532
+ "ChatCompletionInputFunctionName",
533
+ "ChatCompletionInputGrammarType",
534
+ "ChatCompletionInputGrammarTypeType",
535
+ "ChatCompletionInputMessage",
536
+ "ChatCompletionInputMessageChunk",
537
+ "ChatCompletionInputMessageChunkType",
538
+ "ChatCompletionInputStreamOptions",
539
+ "ChatCompletionInputTool",
540
+ "ChatCompletionInputToolChoiceClass",
541
+ "ChatCompletionInputToolChoiceEnum",
542
+ "ChatCompletionInputURL",
543
+ "ChatCompletionOutput",
544
+ "ChatCompletionOutputComplete",
545
+ "ChatCompletionOutputFunctionDefinition",
546
+ "ChatCompletionOutputLogprob",
547
+ "ChatCompletionOutputLogprobs",
548
+ "ChatCompletionOutputMessage",
549
+ "ChatCompletionOutputToolCall",
550
+ "ChatCompletionOutputTopLogprob",
551
+ "ChatCompletionOutputUsage",
552
+ "ChatCompletionStreamOutput",
553
+ "ChatCompletionStreamOutputChoice",
554
+ "ChatCompletionStreamOutputDelta",
555
+ "ChatCompletionStreamOutputDeltaToolCall",
556
+ "ChatCompletionStreamOutputFunction",
557
+ "ChatCompletionStreamOutputLogprob",
558
+ "ChatCompletionStreamOutputLogprobs",
559
+ "ChatCompletionStreamOutputTopLogprob",
560
+ "ChatCompletionStreamOutputUsage",
561
+ "Collection",
562
+ "CollectionItem",
563
+ "CommitInfo",
564
+ "CommitOperation",
565
+ "CommitOperationAdd",
566
+ "CommitOperationCopy",
567
+ "CommitOperationDelete",
568
+ "CommitScheduler",
569
+ "CorruptedCacheException",
570
+ "DDUFEntry",
571
+ "DatasetCard",
572
+ "DatasetCardData",
573
+ "DatasetInfo",
574
+ "DeleteCacheStrategy",
575
+ "DepthEstimationInput",
576
+ "DepthEstimationOutput",
577
+ "Discussion",
578
+ "DiscussionComment",
579
+ "DiscussionCommit",
580
+ "DiscussionEvent",
581
+ "DiscussionStatusChange",
582
+ "DiscussionTitleChange",
583
+ "DiscussionWithDetails",
584
+ "DocumentQuestionAnsweringInput",
585
+ "DocumentQuestionAnsweringInputData",
586
+ "DocumentQuestionAnsweringOutputElement",
587
+ "DocumentQuestionAnsweringParameters",
588
+ "EvalResult",
589
+ "FLAX_WEIGHTS_NAME",
590
+ "FeatureExtractionInput",
591
+ "FeatureExtractionInputTruncationDirection",
592
+ "FillMaskInput",
593
+ "FillMaskOutputElement",
594
+ "FillMaskParameters",
595
+ "GitCommitInfo",
596
+ "GitRefInfo",
597
+ "GitRefs",
598
+ "HFCacheInfo",
599
+ "HFSummaryWriter",
600
+ "HUGGINGFACE_CO_URL_HOME",
601
+ "HUGGINGFACE_CO_URL_TEMPLATE",
602
+ "HfApi",
603
+ "HfFileMetadata",
604
+ "HfFileSystem",
605
+ "HfFileSystemFile",
606
+ "HfFileSystemResolvedPath",
607
+ "HfFileSystemStreamFile",
608
+ "HfFolder",
609
+ "ImageClassificationInput",
610
+ "ImageClassificationOutputElement",
611
+ "ImageClassificationOutputTransform",
612
+ "ImageClassificationParameters",
613
+ "ImageSegmentationInput",
614
+ "ImageSegmentationOutputElement",
615
+ "ImageSegmentationParameters",
616
+ "ImageSegmentationSubtask",
617
+ "ImageToImageInput",
618
+ "ImageToImageOutput",
619
+ "ImageToImageParameters",
620
+ "ImageToImageTargetSize",
621
+ "ImageToTextEarlyStoppingEnum",
622
+ "ImageToTextGenerationParameters",
623
+ "ImageToTextInput",
624
+ "ImageToTextOutput",
625
+ "ImageToTextParameters",
626
+ "InferenceApi",
627
+ "InferenceClient",
628
+ "InferenceEndpoint",
629
+ "InferenceEndpointError",
630
+ "InferenceEndpointStatus",
631
+ "InferenceEndpointTimeoutError",
632
+ "InferenceEndpointType",
633
+ "InferenceTimeoutError",
634
+ "KerasModelHubMixin",
635
+ "ModelCard",
636
+ "ModelCardData",
637
+ "ModelHubMixin",
638
+ "ModelInfo",
639
+ "ObjectDetectionBoundingBox",
640
+ "ObjectDetectionInput",
641
+ "ObjectDetectionOutputElement",
642
+ "ObjectDetectionParameters",
643
+ "PYTORCH_WEIGHTS_NAME",
644
+ "Padding",
645
+ "PyTorchModelHubMixin",
646
+ "QuestionAnsweringInput",
647
+ "QuestionAnsweringInputData",
648
+ "QuestionAnsweringOutputElement",
649
+ "QuestionAnsweringParameters",
650
+ "REPO_TYPE_DATASET",
651
+ "REPO_TYPE_MODEL",
652
+ "REPO_TYPE_SPACE",
653
+ "RepoCard",
654
+ "RepoUrl",
655
+ "Repository",
656
+ "SentenceSimilarityInput",
657
+ "SentenceSimilarityInputData",
658
+ "SpaceCard",
659
+ "SpaceCardData",
660
+ "SpaceHardware",
661
+ "SpaceInfo",
662
+ "SpaceRuntime",
663
+ "SpaceStage",
664
+ "SpaceStorage",
665
+ "SpaceVariable",
666
+ "StateDictSplit",
667
+ "SummarizationInput",
668
+ "SummarizationOutput",
669
+ "SummarizationParameters",
670
+ "SummarizationTruncationStrategy",
671
+ "TF2_WEIGHTS_NAME",
672
+ "TF_WEIGHTS_NAME",
673
+ "TableQuestionAnsweringInput",
674
+ "TableQuestionAnsweringInputData",
675
+ "TableQuestionAnsweringOutputElement",
676
+ "TableQuestionAnsweringParameters",
677
+ "Text2TextGenerationInput",
678
+ "Text2TextGenerationOutput",
679
+ "Text2TextGenerationParameters",
680
+ "Text2TextGenerationTruncationStrategy",
681
+ "TextClassificationInput",
682
+ "TextClassificationOutputElement",
683
+ "TextClassificationOutputTransform",
684
+ "TextClassificationParameters",
685
+ "TextGenerationInput",
686
+ "TextGenerationInputGenerateParameters",
687
+ "TextGenerationInputGrammarType",
688
+ "TextGenerationOutput",
689
+ "TextGenerationOutputBestOfSequence",
690
+ "TextGenerationOutputDetails",
691
+ "TextGenerationOutputFinishReason",
692
+ "TextGenerationOutputPrefillToken",
693
+ "TextGenerationOutputToken",
694
+ "TextGenerationStreamOutput",
695
+ "TextGenerationStreamOutputStreamDetails",
696
+ "TextGenerationStreamOutputToken",
697
+ "TextToAudioEarlyStoppingEnum",
698
+ "TextToAudioGenerationParameters",
699
+ "TextToAudioInput",
700
+ "TextToAudioOutput",
701
+ "TextToAudioParameters",
702
+ "TextToImageInput",
703
+ "TextToImageOutput",
704
+ "TextToImageParameters",
705
+ "TextToImageTargetSize",
706
+ "TextToSpeechEarlyStoppingEnum",
707
+ "TextToSpeechGenerationParameters",
708
+ "TextToSpeechInput",
709
+ "TextToSpeechOutput",
710
+ "TextToSpeechParameters",
711
+ "TextToVideoInput",
712
+ "TextToVideoOutput",
713
+ "TextToVideoParameters",
714
+ "TokenClassificationAggregationStrategy",
715
+ "TokenClassificationInput",
716
+ "TokenClassificationOutputElement",
717
+ "TokenClassificationParameters",
718
+ "TranslationInput",
719
+ "TranslationOutput",
720
+ "TranslationParameters",
721
+ "TranslationTruncationStrategy",
722
+ "TypeEnum",
723
+ "User",
724
+ "UserLikes",
725
+ "VideoClassificationInput",
726
+ "VideoClassificationOutputElement",
727
+ "VideoClassificationOutputTransform",
728
+ "VideoClassificationParameters",
729
+ "VisualQuestionAnsweringInput",
730
+ "VisualQuestionAnsweringInputData",
731
+ "VisualQuestionAnsweringOutputElement",
732
+ "VisualQuestionAnsweringParameters",
733
+ "WebhookInfo",
734
+ "WebhookPayload",
735
+ "WebhookPayloadComment",
736
+ "WebhookPayloadDiscussion",
737
+ "WebhookPayloadDiscussionChanges",
738
+ "WebhookPayloadEvent",
739
+ "WebhookPayloadMovedTo",
740
+ "WebhookPayloadRepo",
741
+ "WebhookPayloadUrl",
742
+ "WebhookPayloadWebhook",
743
+ "WebhookWatchedItem",
744
+ "WebhooksServer",
745
+ "ZeroShotClassificationInput",
746
+ "ZeroShotClassificationOutputElement",
747
+ "ZeroShotClassificationParameters",
748
+ "ZeroShotImageClassificationInput",
749
+ "ZeroShotImageClassificationOutputElement",
750
+ "ZeroShotImageClassificationParameters",
751
+ "ZeroShotObjectDetectionBoundingBox",
752
+ "ZeroShotObjectDetectionInput",
753
+ "ZeroShotObjectDetectionOutputElement",
754
+ "ZeroShotObjectDetectionParameters",
755
+ "_CACHED_NO_EXIST",
756
+ "_save_pretrained_fastai",
757
+ "accept_access_request",
758
+ "add_collection_item",
759
+ "add_space_secret",
760
+ "add_space_variable",
761
+ "auth_check",
762
+ "auth_list",
763
+ "auth_switch",
764
+ "cached_assets_path",
765
+ "cancel_access_request",
766
+ "change_discussion_status",
767
+ "comment_discussion",
768
+ "configure_http_backend",
769
+ "create_branch",
770
+ "create_collection",
771
+ "create_commit",
772
+ "create_discussion",
773
+ "create_inference_endpoint",
774
+ "create_pull_request",
775
+ "create_repo",
776
+ "create_tag",
777
+ "create_webhook",
778
+ "dataset_info",
779
+ "delete_branch",
780
+ "delete_collection",
781
+ "delete_collection_item",
782
+ "delete_file",
783
+ "delete_folder",
784
+ "delete_inference_endpoint",
785
+ "delete_repo",
786
+ "delete_space_secret",
787
+ "delete_space_storage",
788
+ "delete_space_variable",
789
+ "delete_tag",
790
+ "delete_webhook",
791
+ "disable_webhook",
792
+ "dump_environment_info",
793
+ "duplicate_space",
794
+ "edit_discussion_comment",
795
+ "enable_webhook",
796
+ "export_entries_as_dduf",
797
+ "export_folder_as_dduf",
798
+ "file_exists",
799
+ "from_pretrained_fastai",
800
+ "from_pretrained_keras",
801
+ "get_collection",
802
+ "get_dataset_tags",
803
+ "get_discussion_details",
804
+ "get_full_repo_name",
805
+ "get_hf_file_metadata",
806
+ "get_inference_endpoint",
807
+ "get_model_tags",
808
+ "get_paths_info",
809
+ "get_repo_discussions",
810
+ "get_safetensors_metadata",
811
+ "get_session",
812
+ "get_space_runtime",
813
+ "get_space_variables",
814
+ "get_tf_storage_size",
815
+ "get_token",
816
+ "get_token_permission",
817
+ "get_torch_storage_id",
818
+ "get_torch_storage_size",
819
+ "get_user_overview",
820
+ "get_webhook",
821
+ "grant_access",
822
+ "hf_hub_download",
823
+ "hf_hub_url",
824
+ "interpreter_login",
825
+ "list_accepted_access_requests",
826
+ "list_collections",
827
+ "list_datasets",
828
+ "list_inference_endpoints",
829
+ "list_liked_repos",
830
+ "list_models",
831
+ "list_organization_members",
832
+ "list_papers",
833
+ "list_pending_access_requests",
834
+ "list_rejected_access_requests",
835
+ "list_repo_commits",
836
+ "list_repo_files",
837
+ "list_repo_likers",
838
+ "list_repo_refs",
839
+ "list_repo_tree",
840
+ "list_spaces",
841
+ "list_user_followers",
842
+ "list_user_following",
843
+ "list_webhooks",
844
+ "load_state_dict_from_file",
845
+ "load_torch_model",
846
+ "logging",
847
+ "login",
848
+ "logout",
849
+ "merge_pull_request",
850
+ "metadata_eval_result",
851
+ "metadata_load",
852
+ "metadata_save",
853
+ "metadata_update",
854
+ "model_info",
855
+ "move_repo",
856
+ "notebook_login",
857
+ "paper_info",
858
+ "parse_safetensors_file_metadata",
859
+ "pause_inference_endpoint",
860
+ "pause_space",
861
+ "preupload_lfs_files",
862
+ "push_to_hub_fastai",
863
+ "push_to_hub_keras",
864
+ "read_dduf_file",
865
+ "reject_access_request",
866
+ "rename_discussion",
867
+ "repo_exists",
868
+ "repo_info",
869
+ "repo_type_and_id_from_hf_id",
870
+ "request_space_hardware",
871
+ "request_space_storage",
872
+ "restart_space",
873
+ "resume_inference_endpoint",
874
+ "revision_exists",
875
+ "run_as_future",
876
+ "save_pretrained_keras",
877
+ "save_torch_model",
878
+ "save_torch_state_dict",
879
+ "scale_to_zero_inference_endpoint",
880
+ "scan_cache_dir",
881
+ "set_space_sleep_time",
882
+ "snapshot_download",
883
+ "space_info",
884
+ "split_state_dict_into_shards_factory",
885
+ "split_tf_state_dict_into_shards",
886
+ "split_torch_state_dict_into_shards",
887
+ "super_squash_history",
888
+ "try_to_load_from_cache",
889
+ "unlike",
890
+ "update_collection_item",
891
+ "update_collection_metadata",
892
+ "update_inference_endpoint",
893
+ "update_repo_settings",
894
+ "update_repo_visibility",
895
+ "update_webhook",
896
+ "upload_file",
897
+ "upload_folder",
898
+ "upload_large_folder",
899
+ "webhook_endpoint",
900
+ "whoami",
901
+ ]
902
+
903
+
904
+ def _attach(package_name, submodules=None, submod_attrs=None):
905
+ """Attach lazily loaded submodules, functions, or other attributes.
906
+
907
+ Typically, modules import submodules and attributes as follows:
908
+
909
+ ```py
910
+ import mysubmodule
911
+ import anothersubmodule
912
+
913
+ from .foo import someattr
914
+ ```
915
+
916
+ The idea is to replace a package's `__getattr__`, `__dir__`, such that all imports
917
+ work exactly the way they would with normal imports, except that the import occurs
918
+ upon first use.
919
+
920
+ The typical way to call this function, replacing the above imports, is:
921
+
922
+ ```python
923
+ __getattr__, __dir__ = lazy.attach(
924
+ __name__,
925
+ ['mysubmodule', 'anothersubmodule'],
926
+ {'foo': ['someattr']}
927
+ )
928
+ ```
929
+ This functionality requires Python 3.7 or higher.
930
+
931
+ Args:
932
+ package_name (`str`):
933
+ Typically use `__name__`.
934
+ submodules (`set`):
935
+ List of submodules to attach.
936
+ submod_attrs (`dict`):
937
+ Dictionary of submodule -> list of attributes / functions.
938
+ These attributes are imported as they are used.
939
+
940
+ Returns:
941
+ __getattr__, __dir__, __all__
942
+
943
+ """
944
+ if submod_attrs is None:
945
+ submod_attrs = {}
946
+
947
+ if submodules is None:
948
+ submodules = set()
949
+ else:
950
+ submodules = set(submodules)
951
+
952
+ attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs}
953
+
954
+ def __getattr__(name):
955
+ if name in submodules:
956
+ try:
957
+ return importlib.import_module(f"{package_name}.{name}")
958
+ except Exception as e:
959
+ print(f"Error importing {package_name}.{name}: {e}")
960
+ raise
961
+ elif name in attr_to_modules:
962
+ submod_path = f"{package_name}.{attr_to_modules[name]}"
963
+ try:
964
+ submod = importlib.import_module(submod_path)
965
+ except Exception as e:
966
+ print(f"Error importing {submod_path}: {e}")
967
+ raise
968
+ attr = getattr(submod, name)
969
+
970
+ # If the attribute lives in a file (module) with the same
971
+ # name as the attribute, ensure that the attribute and *not*
972
+ # the module is accessible on the package.
973
+ if name == attr_to_modules[name]:
974
+ pkg = sys.modules[package_name]
975
+ pkg.__dict__[name] = attr
976
+
977
+ return attr
978
+ else:
979
+ raise AttributeError(f"No {package_name} attribute {name}")
980
+
981
+ def __dir__():
982
+ return __all__
983
+
984
+ return __getattr__, __dir__
985
+
986
+
987
+ __getattr__, __dir__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS)
988
+
989
+ if os.environ.get("EAGER_IMPORT", ""):
990
+ for attr in __all__:
991
+ __getattr__(attr)
992
+
993
+ # WARNING: any content below this statement is generated automatically. Any manual edit
994
+ # will be lost when re-generating this file !
995
+ #
996
+ # To update the static imports, please run the following command and commit the changes.
997
+ # ```
998
+ # # Use script
999
+ # python utils/check_static_imports.py --update
1000
+ #
1001
+ # # Or run style on codebase
1002
+ # make style
1003
+ # ```
1004
+ if TYPE_CHECKING: # pragma: no cover
1005
+ from ._commit_scheduler import CommitScheduler # noqa: F401
1006
+ from ._inference_endpoints import (
1007
+ InferenceEndpoint, # noqa: F401
1008
+ InferenceEndpointError, # noqa: F401
1009
+ InferenceEndpointStatus, # noqa: F401
1010
+ InferenceEndpointTimeoutError, # noqa: F401
1011
+ InferenceEndpointType, # noqa: F401
1012
+ )
1013
+ from ._login import (
1014
+ auth_list, # noqa: F401
1015
+ auth_switch, # noqa: F401
1016
+ interpreter_login, # noqa: F401
1017
+ login, # noqa: F401
1018
+ logout, # noqa: F401
1019
+ notebook_login, # noqa: F401
1020
+ )
1021
+ from ._snapshot_download import snapshot_download # noqa: F401
1022
+ from ._space_api import (
1023
+ SpaceHardware, # noqa: F401
1024
+ SpaceRuntime, # noqa: F401
1025
+ SpaceStage, # noqa: F401
1026
+ SpaceStorage, # noqa: F401
1027
+ SpaceVariable, # noqa: F401
1028
+ )
1029
+ from ._tensorboard_logger import HFSummaryWriter # noqa: F401
1030
+ from ._webhooks_payload import (
1031
+ WebhookPayload, # noqa: F401
1032
+ WebhookPayloadComment, # noqa: F401
1033
+ WebhookPayloadDiscussion, # noqa: F401
1034
+ WebhookPayloadDiscussionChanges, # noqa: F401
1035
+ WebhookPayloadEvent, # noqa: F401
1036
+ WebhookPayloadMovedTo, # noqa: F401
1037
+ WebhookPayloadRepo, # noqa: F401
1038
+ WebhookPayloadUrl, # noqa: F401
1039
+ WebhookPayloadWebhook, # noqa: F401
1040
+ )
1041
+ from ._webhooks_server import (
1042
+ WebhooksServer, # noqa: F401
1043
+ webhook_endpoint, # noqa: F401
1044
+ )
1045
+ from .community import (
1046
+ Discussion, # noqa: F401
1047
+ DiscussionComment, # noqa: F401
1048
+ DiscussionCommit, # noqa: F401
1049
+ DiscussionEvent, # noqa: F401
1050
+ DiscussionStatusChange, # noqa: F401
1051
+ DiscussionTitleChange, # noqa: F401
1052
+ DiscussionWithDetails, # noqa: F401
1053
+ )
1054
+ from .constants import (
1055
+ CONFIG_NAME, # noqa: F401
1056
+ FLAX_WEIGHTS_NAME, # noqa: F401
1057
+ HUGGINGFACE_CO_URL_HOME, # noqa: F401
1058
+ HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401
1059
+ PYTORCH_WEIGHTS_NAME, # noqa: F401
1060
+ REPO_TYPE_DATASET, # noqa: F401
1061
+ REPO_TYPE_MODEL, # noqa: F401
1062
+ REPO_TYPE_SPACE, # noqa: F401
1063
+ TF2_WEIGHTS_NAME, # noqa: F401
1064
+ TF_WEIGHTS_NAME, # noqa: F401
1065
+ )
1066
+ from .fastai_utils import (
1067
+ _save_pretrained_fastai, # noqa: F401
1068
+ from_pretrained_fastai, # noqa: F401
1069
+ push_to_hub_fastai, # noqa: F401
1070
+ )
1071
+ from .file_download import (
1072
+ _CACHED_NO_EXIST, # noqa: F401
1073
+ HfFileMetadata, # noqa: F401
1074
+ get_hf_file_metadata, # noqa: F401
1075
+ hf_hub_download, # noqa: F401
1076
+ hf_hub_url, # noqa: F401
1077
+ try_to_load_from_cache, # noqa: F401
1078
+ )
1079
+ from .hf_api import (
1080
+ Collection, # noqa: F401
1081
+ CollectionItem, # noqa: F401
1082
+ CommitInfo, # noqa: F401
1083
+ CommitOperation, # noqa: F401
1084
+ CommitOperationAdd, # noqa: F401
1085
+ CommitOperationCopy, # noqa: F401
1086
+ CommitOperationDelete, # noqa: F401
1087
+ DatasetInfo, # noqa: F401
1088
+ GitCommitInfo, # noqa: F401
1089
+ GitRefInfo, # noqa: F401
1090
+ GitRefs, # noqa: F401
1091
+ HfApi, # noqa: F401
1092
+ ModelInfo, # noqa: F401
1093
+ RepoUrl, # noqa: F401
1094
+ SpaceInfo, # noqa: F401
1095
+ User, # noqa: F401
1096
+ UserLikes, # noqa: F401
1097
+ WebhookInfo, # noqa: F401
1098
+ WebhookWatchedItem, # noqa: F401
1099
+ accept_access_request, # noqa: F401
1100
+ add_collection_item, # noqa: F401
1101
+ add_space_secret, # noqa: F401
1102
+ add_space_variable, # noqa: F401
1103
+ auth_check, # noqa: F401
1104
+ cancel_access_request, # noqa: F401
1105
+ change_discussion_status, # noqa: F401
1106
+ comment_discussion, # noqa: F401
1107
+ create_branch, # noqa: F401
1108
+ create_collection, # noqa: F401
1109
+ create_commit, # noqa: F401
1110
+ create_discussion, # noqa: F401
1111
+ create_inference_endpoint, # noqa: F401
1112
+ create_pull_request, # noqa: F401
1113
+ create_repo, # noqa: F401
1114
+ create_tag, # noqa: F401
1115
+ create_webhook, # noqa: F401
1116
+ dataset_info, # noqa: F401
1117
+ delete_branch, # noqa: F401
1118
+ delete_collection, # noqa: F401
1119
+ delete_collection_item, # noqa: F401
1120
+ delete_file, # noqa: F401
1121
+ delete_folder, # noqa: F401
1122
+ delete_inference_endpoint, # noqa: F401
1123
+ delete_repo, # noqa: F401
1124
+ delete_space_secret, # noqa: F401
1125
+ delete_space_storage, # noqa: F401
1126
+ delete_space_variable, # noqa: F401
1127
+ delete_tag, # noqa: F401
1128
+ delete_webhook, # noqa: F401
1129
+ disable_webhook, # noqa: F401
1130
+ duplicate_space, # noqa: F401
1131
+ edit_discussion_comment, # noqa: F401
1132
+ enable_webhook, # noqa: F401
1133
+ file_exists, # noqa: F401
1134
+ get_collection, # noqa: F401
1135
+ get_dataset_tags, # noqa: F401
1136
+ get_discussion_details, # noqa: F401
1137
+ get_full_repo_name, # noqa: F401
1138
+ get_inference_endpoint, # noqa: F401
1139
+ get_model_tags, # noqa: F401
1140
+ get_paths_info, # noqa: F401
1141
+ get_repo_discussions, # noqa: F401
1142
+ get_safetensors_metadata, # noqa: F401
1143
+ get_space_runtime, # noqa: F401
1144
+ get_space_variables, # noqa: F401
1145
+ get_token_permission, # noqa: F401
1146
+ get_user_overview, # noqa: F401
1147
+ get_webhook, # noqa: F401
1148
+ grant_access, # noqa: F401
1149
+ list_accepted_access_requests, # noqa: F401
1150
+ list_collections, # noqa: F401
1151
+ list_datasets, # noqa: F401
1152
+ list_inference_endpoints, # noqa: F401
1153
+ list_liked_repos, # noqa: F401
1154
+ list_models, # noqa: F401
1155
+ list_organization_members, # noqa: F401
1156
+ list_papers, # noqa: F401
1157
+ list_pending_access_requests, # noqa: F401
1158
+ list_rejected_access_requests, # noqa: F401
1159
+ list_repo_commits, # noqa: F401
1160
+ list_repo_files, # noqa: F401
1161
+ list_repo_likers, # noqa: F401
1162
+ list_repo_refs, # noqa: F401
1163
+ list_repo_tree, # noqa: F401
1164
+ list_spaces, # noqa: F401
1165
+ list_user_followers, # noqa: F401
1166
+ list_user_following, # noqa: F401
1167
+ list_webhooks, # noqa: F401
1168
+ merge_pull_request, # noqa: F401
1169
+ model_info, # noqa: F401
1170
+ move_repo, # noqa: F401
1171
+ paper_info, # noqa: F401
1172
+ parse_safetensors_file_metadata, # noqa: F401
1173
+ pause_inference_endpoint, # noqa: F401
1174
+ pause_space, # noqa: F401
1175
+ preupload_lfs_files, # noqa: F401
1176
+ reject_access_request, # noqa: F401
1177
+ rename_discussion, # noqa: F401
1178
+ repo_exists, # noqa: F401
1179
+ repo_info, # noqa: F401
1180
+ repo_type_and_id_from_hf_id, # noqa: F401
1181
+ request_space_hardware, # noqa: F401
1182
+ request_space_storage, # noqa: F401
1183
+ restart_space, # noqa: F401
1184
+ resume_inference_endpoint, # noqa: F401
1185
+ revision_exists, # noqa: F401
1186
+ run_as_future, # noqa: F401
1187
+ scale_to_zero_inference_endpoint, # noqa: F401
1188
+ set_space_sleep_time, # noqa: F401
1189
+ space_info, # noqa: F401
1190
+ super_squash_history, # noqa: F401
1191
+ unlike, # noqa: F401
1192
+ update_collection_item, # noqa: F401
1193
+ update_collection_metadata, # noqa: F401
1194
+ update_inference_endpoint, # noqa: F401
1195
+ update_repo_settings, # noqa: F401
1196
+ update_repo_visibility, # noqa: F401
1197
+ update_webhook, # noqa: F401
1198
+ upload_file, # noqa: F401
1199
+ upload_folder, # noqa: F401
1200
+ upload_large_folder, # noqa: F401
1201
+ whoami, # noqa: F401
1202
+ )
1203
+ from .hf_file_system import (
1204
+ HfFileSystem, # noqa: F401
1205
+ HfFileSystemFile, # noqa: F401
1206
+ HfFileSystemResolvedPath, # noqa: F401
1207
+ HfFileSystemStreamFile, # noqa: F401
1208
+ )
1209
+ from .hub_mixin import (
1210
+ ModelHubMixin, # noqa: F401
1211
+ PyTorchModelHubMixin, # noqa: F401
1212
+ )
1213
+ from .inference._client import (
1214
+ InferenceClient, # noqa: F401
1215
+ InferenceTimeoutError, # noqa: F401
1216
+ )
1217
+ from .inference._generated._async_client import AsyncInferenceClient # noqa: F401
1218
+ from .inference._generated.types import (
1219
+ AudioClassificationInput, # noqa: F401
1220
+ AudioClassificationOutputElement, # noqa: F401
1221
+ AudioClassificationOutputTransform, # noqa: F401
1222
+ AudioClassificationParameters, # noqa: F401
1223
+ AudioToAudioInput, # noqa: F401
1224
+ AudioToAudioOutputElement, # noqa: F401
1225
+ AutomaticSpeechRecognitionEarlyStoppingEnum, # noqa: F401
1226
+ AutomaticSpeechRecognitionGenerationParameters, # noqa: F401
1227
+ AutomaticSpeechRecognitionInput, # noqa: F401
1228
+ AutomaticSpeechRecognitionOutput, # noqa: F401
1229
+ AutomaticSpeechRecognitionOutputChunk, # noqa: F401
1230
+ AutomaticSpeechRecognitionParameters, # noqa: F401
1231
+ ChatCompletionInput, # noqa: F401
1232
+ ChatCompletionInputFunctionDefinition, # noqa: F401
1233
+ ChatCompletionInputFunctionName, # noqa: F401
1234
+ ChatCompletionInputGrammarType, # noqa: F401
1235
+ ChatCompletionInputGrammarTypeType, # noqa: F401
1236
+ ChatCompletionInputMessage, # noqa: F401
1237
+ ChatCompletionInputMessageChunk, # noqa: F401
1238
+ ChatCompletionInputMessageChunkType, # noqa: F401
1239
+ ChatCompletionInputStreamOptions, # noqa: F401
1240
+ ChatCompletionInputTool, # noqa: F401
1241
+ ChatCompletionInputToolChoiceClass, # noqa: F401
1242
+ ChatCompletionInputToolChoiceEnum, # noqa: F401
1243
+ ChatCompletionInputURL, # noqa: F401
1244
+ ChatCompletionOutput, # noqa: F401
1245
+ ChatCompletionOutputComplete, # noqa: F401
1246
+ ChatCompletionOutputFunctionDefinition, # noqa: F401
1247
+ ChatCompletionOutputLogprob, # noqa: F401
1248
+ ChatCompletionOutputLogprobs, # noqa: F401
1249
+ ChatCompletionOutputMessage, # noqa: F401
1250
+ ChatCompletionOutputToolCall, # noqa: F401
1251
+ ChatCompletionOutputTopLogprob, # noqa: F401
1252
+ ChatCompletionOutputUsage, # noqa: F401
1253
+ ChatCompletionStreamOutput, # noqa: F401
1254
+ ChatCompletionStreamOutputChoice, # noqa: F401
1255
+ ChatCompletionStreamOutputDelta, # noqa: F401
1256
+ ChatCompletionStreamOutputDeltaToolCall, # noqa: F401
1257
+ ChatCompletionStreamOutputFunction, # noqa: F401
1258
+ ChatCompletionStreamOutputLogprob, # noqa: F401
1259
+ ChatCompletionStreamOutputLogprobs, # noqa: F401
1260
+ ChatCompletionStreamOutputTopLogprob, # noqa: F401
1261
+ ChatCompletionStreamOutputUsage, # noqa: F401
1262
+ DepthEstimationInput, # noqa: F401
1263
+ DepthEstimationOutput, # noqa: F401
1264
+ DocumentQuestionAnsweringInput, # noqa: F401
1265
+ DocumentQuestionAnsweringInputData, # noqa: F401
1266
+ DocumentQuestionAnsweringOutputElement, # noqa: F401
1267
+ DocumentQuestionAnsweringParameters, # noqa: F401
1268
+ FeatureExtractionInput, # noqa: F401
1269
+ FeatureExtractionInputTruncationDirection, # noqa: F401
1270
+ FillMaskInput, # noqa: F401
1271
+ FillMaskOutputElement, # noqa: F401
1272
+ FillMaskParameters, # noqa: F401
1273
+ ImageClassificationInput, # noqa: F401
1274
+ ImageClassificationOutputElement, # noqa: F401
1275
+ ImageClassificationOutputTransform, # noqa: F401
1276
+ ImageClassificationParameters, # noqa: F401
1277
+ ImageSegmentationInput, # noqa: F401
1278
+ ImageSegmentationOutputElement, # noqa: F401
1279
+ ImageSegmentationParameters, # noqa: F401
1280
+ ImageSegmentationSubtask, # noqa: F401
1281
+ ImageToImageInput, # noqa: F401
1282
+ ImageToImageOutput, # noqa: F401
1283
+ ImageToImageParameters, # noqa: F401
1284
+ ImageToImageTargetSize, # noqa: F401
1285
+ ImageToTextEarlyStoppingEnum, # noqa: F401
1286
+ ImageToTextGenerationParameters, # noqa: F401
1287
+ ImageToTextInput, # noqa: F401
1288
+ ImageToTextOutput, # noqa: F401
1289
+ ImageToTextParameters, # noqa: F401
1290
+ ObjectDetectionBoundingBox, # noqa: F401
1291
+ ObjectDetectionInput, # noqa: F401
1292
+ ObjectDetectionOutputElement, # noqa: F401
1293
+ ObjectDetectionParameters, # noqa: F401
1294
+ Padding, # noqa: F401
1295
+ QuestionAnsweringInput, # noqa: F401
1296
+ QuestionAnsweringInputData, # noqa: F401
1297
+ QuestionAnsweringOutputElement, # noqa: F401
1298
+ QuestionAnsweringParameters, # noqa: F401
1299
+ SentenceSimilarityInput, # noqa: F401
1300
+ SentenceSimilarityInputData, # noqa: F401
1301
+ SummarizationInput, # noqa: F401
1302
+ SummarizationOutput, # noqa: F401
1303
+ SummarizationParameters, # noqa: F401
1304
+ SummarizationTruncationStrategy, # noqa: F401
1305
+ TableQuestionAnsweringInput, # noqa: F401
1306
+ TableQuestionAnsweringInputData, # noqa: F401
1307
+ TableQuestionAnsweringOutputElement, # noqa: F401
1308
+ TableQuestionAnsweringParameters, # noqa: F401
1309
+ Text2TextGenerationInput, # noqa: F401
1310
+ Text2TextGenerationOutput, # noqa: F401
1311
+ Text2TextGenerationParameters, # noqa: F401
1312
+ Text2TextGenerationTruncationStrategy, # noqa: F401
1313
+ TextClassificationInput, # noqa: F401
1314
+ TextClassificationOutputElement, # noqa: F401
1315
+ TextClassificationOutputTransform, # noqa: F401
1316
+ TextClassificationParameters, # noqa: F401
1317
+ TextGenerationInput, # noqa: F401
1318
+ TextGenerationInputGenerateParameters, # noqa: F401
1319
+ TextGenerationInputGrammarType, # noqa: F401
1320
+ TextGenerationOutput, # noqa: F401
1321
+ TextGenerationOutputBestOfSequence, # noqa: F401
1322
+ TextGenerationOutputDetails, # noqa: F401
1323
+ TextGenerationOutputFinishReason, # noqa: F401
1324
+ TextGenerationOutputPrefillToken, # noqa: F401
1325
+ TextGenerationOutputToken, # noqa: F401
1326
+ TextGenerationStreamOutput, # noqa: F401
1327
+ TextGenerationStreamOutputStreamDetails, # noqa: F401
1328
+ TextGenerationStreamOutputToken, # noqa: F401
1329
+ TextToAudioEarlyStoppingEnum, # noqa: F401
1330
+ TextToAudioGenerationParameters, # noqa: F401
1331
+ TextToAudioInput, # noqa: F401
1332
+ TextToAudioOutput, # noqa: F401
1333
+ TextToAudioParameters, # noqa: F401
1334
+ TextToImageInput, # noqa: F401
1335
+ TextToImageOutput, # noqa: F401
1336
+ TextToImageParameters, # noqa: F401
1337
+ TextToImageTargetSize, # noqa: F401
1338
+ TextToSpeechEarlyStoppingEnum, # noqa: F401
1339
+ TextToSpeechGenerationParameters, # noqa: F401
1340
+ TextToSpeechInput, # noqa: F401
1341
+ TextToSpeechOutput, # noqa: F401
1342
+ TextToSpeechParameters, # noqa: F401
1343
+ TextToVideoInput, # noqa: F401
1344
+ TextToVideoOutput, # noqa: F401
1345
+ TextToVideoParameters, # noqa: F401
1346
+ TokenClassificationAggregationStrategy, # noqa: F401
1347
+ TokenClassificationInput, # noqa: F401
1348
+ TokenClassificationOutputElement, # noqa: F401
1349
+ TokenClassificationParameters, # noqa: F401
1350
+ TranslationInput, # noqa: F401
1351
+ TranslationOutput, # noqa: F401
1352
+ TranslationParameters, # noqa: F401
1353
+ TranslationTruncationStrategy, # noqa: F401
1354
+ TypeEnum, # noqa: F401
1355
+ VideoClassificationInput, # noqa: F401
1356
+ VideoClassificationOutputElement, # noqa: F401
1357
+ VideoClassificationOutputTransform, # noqa: F401
1358
+ VideoClassificationParameters, # noqa: F401
1359
+ VisualQuestionAnsweringInput, # noqa: F401
1360
+ VisualQuestionAnsweringInputData, # noqa: F401
1361
+ VisualQuestionAnsweringOutputElement, # noqa: F401
1362
+ VisualQuestionAnsweringParameters, # noqa: F401
1363
+ ZeroShotClassificationInput, # noqa: F401
1364
+ ZeroShotClassificationOutputElement, # noqa: F401
1365
+ ZeroShotClassificationParameters, # noqa: F401
1366
+ ZeroShotImageClassificationInput, # noqa: F401
1367
+ ZeroShotImageClassificationOutputElement, # noqa: F401
1368
+ ZeroShotImageClassificationParameters, # noqa: F401
1369
+ ZeroShotObjectDetectionBoundingBox, # noqa: F401
1370
+ ZeroShotObjectDetectionInput, # noqa: F401
1371
+ ZeroShotObjectDetectionOutputElement, # noqa: F401
1372
+ ZeroShotObjectDetectionParameters, # noqa: F401
1373
+ )
1374
+ from .inference_api import InferenceApi # noqa: F401
1375
+ from .keras_mixin import (
1376
+ KerasModelHubMixin, # noqa: F401
1377
+ from_pretrained_keras, # noqa: F401
1378
+ push_to_hub_keras, # noqa: F401
1379
+ save_pretrained_keras, # noqa: F401
1380
+ )
1381
+ from .repocard import (
1382
+ DatasetCard, # noqa: F401
1383
+ ModelCard, # noqa: F401
1384
+ RepoCard, # noqa: F401
1385
+ SpaceCard, # noqa: F401
1386
+ metadata_eval_result, # noqa: F401
1387
+ metadata_load, # noqa: F401
1388
+ metadata_save, # noqa: F401
1389
+ metadata_update, # noqa: F401
1390
+ )
1391
+ from .repocard_data import (
1392
+ CardData, # noqa: F401
1393
+ DatasetCardData, # noqa: F401
1394
+ EvalResult, # noqa: F401
1395
+ ModelCardData, # noqa: F401
1396
+ SpaceCardData, # noqa: F401
1397
+ )
1398
+ from .repository import Repository # noqa: F401
1399
+ from .serialization import (
1400
+ StateDictSplit, # noqa: F401
1401
+ get_tf_storage_size, # noqa: F401
1402
+ get_torch_storage_id, # noqa: F401
1403
+ get_torch_storage_size, # noqa: F401
1404
+ load_state_dict_from_file, # noqa: F401
1405
+ load_torch_model, # noqa: F401
1406
+ save_torch_model, # noqa: F401
1407
+ save_torch_state_dict, # noqa: F401
1408
+ split_state_dict_into_shards_factory, # noqa: F401
1409
+ split_tf_state_dict_into_shards, # noqa: F401
1410
+ split_torch_state_dict_into_shards, # noqa: F401
1411
+ )
1412
+ from .serialization._dduf import (
1413
+ DDUFEntry, # noqa: F401
1414
+ export_entries_as_dduf, # noqa: F401
1415
+ export_folder_as_dduf, # noqa: F401
1416
+ read_dduf_file, # noqa: F401
1417
+ )
1418
+ from .utils import (
1419
+ CachedFileInfo, # noqa: F401
1420
+ CachedRepoInfo, # noqa: F401
1421
+ CachedRevisionInfo, # noqa: F401
1422
+ CacheNotFound, # noqa: F401
1423
+ CorruptedCacheException, # noqa: F401
1424
+ DeleteCacheStrategy, # noqa: F401
1425
+ HFCacheInfo, # noqa: F401
1426
+ HfFolder, # noqa: F401
1427
+ cached_assets_path, # noqa: F401
1428
+ configure_http_backend, # noqa: F401
1429
+ dump_environment_info, # noqa: F401
1430
+ get_session, # noqa: F401
1431
+ get_token, # noqa: F401
1432
+ logging, # noqa: F401
1433
+ scan_cache_dir, # noqa: F401
1434
+ )
.venv/lib/python3.11/site-packages/huggingface_hub/_commit_api.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Type definitions and utilities for the `create_commit` API
3
+ """
4
+
5
+ import base64
6
+ import io
7
+ import os
8
+ import warnings
9
+ from collections import defaultdict
10
+ from contextlib import contextmanager
11
+ from dataclasses import dataclass, field
12
+ from itertools import groupby
13
+ from pathlib import Path, PurePosixPath
14
+ from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union
15
+
16
+ from tqdm.contrib.concurrent import thread_map
17
+
18
+ from . import constants
19
+ from .errors import EntryNotFoundError
20
+ from .file_download import hf_hub_url
21
+ from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info
22
+ from .utils import (
23
+ FORBIDDEN_FOLDERS,
24
+ chunk_iterable,
25
+ get_session,
26
+ hf_raise_for_status,
27
+ logging,
28
+ sha,
29
+ tqdm_stream_file,
30
+ validate_hf_hub_args,
31
+ )
32
+ from .utils import tqdm as hf_tqdm
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from .hf_api import RepoFile
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ UploadMode = Literal["lfs", "regular"]
43
+
44
+ # Max is 1,000 per request on the Hub for HfApi.get_paths_info
45
+ # Otherwise we get:
46
+ # HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters
47
+ # See https://github.com/huggingface/huggingface_hub/issues/1503
48
+ FETCH_LFS_BATCH_SIZE = 500
49
+
50
+
51
+ @dataclass
52
+ class CommitOperationDelete:
53
+ """
54
+ Data structure holding necessary info to delete a file or a folder from a repository
55
+ on the Hub.
56
+
57
+ Args:
58
+ path_in_repo (`str`):
59
+ Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
60
+ for a file or `"checkpoints/1fec34a/"` for a folder.
61
+ is_folder (`bool` or `Literal["auto"]`, *optional*)
62
+ Whether the Delete Operation applies to a folder or not. If "auto", the path
63
+ type (file or folder) is guessed automatically by looking if path ends with
64
+ a "/" (folder) or not (file). To explicitly set the path type, you can set
65
+ `is_folder=True` or `is_folder=False`.
66
+ """
67
+
68
+ path_in_repo: str
69
+ is_folder: Union[bool, Literal["auto"]] = "auto"
70
+
71
+ def __post_init__(self):
72
+ self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
73
+
74
+ if self.is_folder == "auto":
75
+ self.is_folder = self.path_in_repo.endswith("/")
76
+ if not isinstance(self.is_folder, bool):
77
+ raise ValueError(
78
+ f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'."
79
+ )
80
+
81
+
82
+ @dataclass
83
+ class CommitOperationCopy:
84
+ """
85
+ Data structure holding necessary info to copy a file in a repository on the Hub.
86
+
87
+ Limitations:
88
+ - Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it
89
+ - Cross-repository copies are not supported.
90
+
91
+ Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub.
92
+
93
+ Args:
94
+ src_path_in_repo (`str`):
95
+ Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`.
96
+ path_in_repo (`str`):
97
+ Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`.
98
+ src_revision (`str`, *optional*):
99
+ The git revision of the file to be copied. Can be any valid git revision.
100
+ Default to the target commit revision.
101
+ """
102
+
103
+ src_path_in_repo: str
104
+ path_in_repo: str
105
+ src_revision: Optional[str] = None
106
+ # set to the OID of the file to be copied if it has already been uploaded
107
+ # useful to determine if a commit will be empty or not.
108
+ _src_oid: Optional[str] = None
109
+ # set to the OID of the file to copy to if it has already been uploaded
110
+ # useful to determine if a commit will be empty or not.
111
+ _dest_oid: Optional[str] = None
112
+
113
+ def __post_init__(self):
114
+ self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo)
115
+ self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
116
+
117
+
118
+ @dataclass
119
+ class CommitOperationAdd:
120
+ """
121
+ Data structure holding necessary info to upload a file to a repository on the Hub.
122
+
123
+ Args:
124
+ path_in_repo (`str`):
125
+ Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
126
+ path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`):
127
+ Either:
128
+ - a path to a local file (as `str` or `pathlib.Path`) to upload
129
+ - a buffer of bytes (`bytes`) holding the content of the file to upload
130
+ - a "file object" (subclass of `io.BufferedIOBase`), typically obtained
131
+ with `open(path, "rb")`. It must support `seek()` and `tell()` methods.
132
+
133
+ Raises:
134
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
135
+ If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`.
136
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
137
+ If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file.
138
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
139
+ If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both
140
+ `seek()` and `tell()`.
141
+ """
142
+
143
+ path_in_repo: str
144
+ path_or_fileobj: Union[str, Path, bytes, BinaryIO]
145
+ upload_info: UploadInfo = field(init=False, repr=False)
146
+
147
+ # Internal attributes
148
+
149
+ # set to "lfs" or "regular" once known
150
+ _upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None)
151
+
152
+ # set to True if .gitignore rules prevent the file from being uploaded as LFS
153
+ # (server-side check)
154
+ _should_ignore: Optional[bool] = field(init=False, repr=False, default=None)
155
+
156
+ # set to the remote OID of the file if it has already been uploaded
157
+ # useful to determine if a commit will be empty or not
158
+ _remote_oid: Optional[str] = field(init=False, repr=False, default=None)
159
+
160
+ # set to True once the file has been uploaded as LFS
161
+ _is_uploaded: bool = field(init=False, repr=False, default=False)
162
+
163
+ # set to True once the file has been committed
164
+ _is_committed: bool = field(init=False, repr=False, default=False)
165
+
166
+ def __post_init__(self) -> None:
167
+ """Validates `path_or_fileobj` and compute `upload_info`."""
168
+ self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
169
+
170
+ # Validate `path_or_fileobj` value
171
+ if isinstance(self.path_or_fileobj, Path):
172
+ self.path_or_fileobj = str(self.path_or_fileobj)
173
+ if isinstance(self.path_or_fileobj, str):
174
+ path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj))
175
+ if not os.path.isfile(path_or_fileobj):
176
+ raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system")
177
+ elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
178
+ # ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode
179
+ raise ValueError(
180
+ "path_or_fileobj must be either an instance of str, bytes or"
181
+ " io.BufferedIOBase. If you passed a file-like object, make sure it is"
182
+ " in binary mode."
183
+ )
184
+ if isinstance(self.path_or_fileobj, io.BufferedIOBase):
185
+ try:
186
+ self.path_or_fileobj.tell()
187
+ self.path_or_fileobj.seek(0, os.SEEK_CUR)
188
+ except (OSError, AttributeError) as exc:
189
+ raise ValueError(
190
+ "path_or_fileobj is a file-like object but does not implement seek() and tell()"
191
+ ) from exc
192
+
193
+ # Compute "upload_info" attribute
194
+ if isinstance(self.path_or_fileobj, str):
195
+ self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
196
+ elif isinstance(self.path_or_fileobj, bytes):
197
+ self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
198
+ else:
199
+ self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
200
+
201
+ @contextmanager
202
+ def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]:
203
+ """
204
+ A context manager that yields a file-like object allowing to read the underlying
205
+ data behind `path_or_fileobj`.
206
+
207
+ Args:
208
+ with_tqdm (`bool`, *optional*, defaults to `False`):
209
+ If True, iterating over the file object will display a progress bar. Only
210
+ works if the file-like object is a path to a file. Pure bytes and buffers
211
+ are not supported.
212
+
213
+ Example:
214
+
215
+ ```python
216
+ >>> operation = CommitOperationAdd(
217
+ ... path_in_repo="remote/dir/weights.h5",
218
+ ... path_or_fileobj="./local/weights.h5",
219
+ ... )
220
+ CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5')
221
+
222
+ >>> with operation.as_file() as file:
223
+ ... content = file.read()
224
+
225
+ >>> with operation.as_file(with_tqdm=True) as file:
226
+ ... while True:
227
+ ... data = file.read(1024)
228
+ ... if not data:
229
+ ... break
230
+ config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
231
+
232
+ >>> with operation.as_file(with_tqdm=True) as file:
233
+ ... requests.put(..., data=file)
234
+ config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
235
+ ```
236
+ """
237
+ if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path):
238
+ if with_tqdm:
239
+ with tqdm_stream_file(self.path_or_fileobj) as file:
240
+ yield file
241
+ else:
242
+ with open(self.path_or_fileobj, "rb") as file:
243
+ yield file
244
+ elif isinstance(self.path_or_fileobj, bytes):
245
+ yield io.BytesIO(self.path_or_fileobj)
246
+ elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
247
+ prev_pos = self.path_or_fileobj.tell()
248
+ yield self.path_or_fileobj
249
+ self.path_or_fileobj.seek(prev_pos, io.SEEK_SET)
250
+
251
+ def b64content(self) -> bytes:
252
+ """
253
+ The base64-encoded content of `path_or_fileobj`
254
+
255
+ Returns: `bytes`
256
+ """
257
+ with self.as_file() as file:
258
+ return base64.b64encode(file.read())
259
+
260
+ @property
261
+ def _local_oid(self) -> Optional[str]:
262
+ """Return the OID of the local file.
263
+
264
+ This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one.
265
+ If the file did not change, we won't upload it again to prevent empty commits.
266
+
267
+ For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref).
268
+ For regular files, the OID corresponds to the SHA1 of the file content.
269
+ Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the
270
+ pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes
271
+ and more convenient client-side.
272
+ """
273
+ if self._upload_mode is None:
274
+ return None
275
+ elif self._upload_mode == "lfs":
276
+ return self.upload_info.sha256.hex()
277
+ else:
278
+ # Regular file => compute sha1
279
+ # => no need to read by chunk since the file is guaranteed to be <=5MB.
280
+ with self.as_file() as file:
281
+ return sha.git_hash(file.read())
282
+
283
+
284
+ def _validate_path_in_repo(path_in_repo: str) -> str:
285
+ # Validate `path_in_repo` value to prevent a server-side issue
286
+ if path_in_repo.startswith("/"):
287
+ path_in_repo = path_in_repo[1:]
288
+ if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"):
289
+ raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'")
290
+ if path_in_repo.startswith("./"):
291
+ path_in_repo = path_in_repo[2:]
292
+ for forbidden in FORBIDDEN_FOLDERS:
293
+ if any(part == forbidden for part in path_in_repo.split("/")):
294
+ raise ValueError(
295
+ f"Invalid `path_in_repo` in CommitOperation: cannot update files under a '{forbidden}/' folder (path:"
296
+ f" '{path_in_repo}')."
297
+ )
298
+ return path_in_repo
299
+
300
+
301
+ CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete]
302
+
303
+
304
+ def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None:
305
+ """
306
+ Warn user when a list of operations is expected to overwrite itself in a single
307
+ commit.
308
+
309
+ Rules:
310
+ - If a filepath is updated by multiple `CommitOperationAdd` operations, a warning
311
+ message is triggered.
312
+ - If a filepath is updated at least once by a `CommitOperationAdd` and then deleted
313
+ by a `CommitOperationDelete`, a warning is triggered.
314
+ - If a `CommitOperationDelete` deletes a filepath that is then updated by a
315
+ `CommitOperationAdd`, no warning is triggered. This is usually useless (no need to
316
+ delete before upload) but can happen if a user deletes an entire folder and then
317
+ add new files to it.
318
+ """
319
+ nb_additions_per_path: Dict[str, int] = defaultdict(int)
320
+ for operation in operations:
321
+ path_in_repo = operation.path_in_repo
322
+ if isinstance(operation, CommitOperationAdd):
323
+ if nb_additions_per_path[path_in_repo] > 0:
324
+ warnings.warn(
325
+ "About to update multiple times the same file in the same commit:"
326
+ f" '{path_in_repo}'. This can cause undesired inconsistencies in"
327
+ " your repo."
328
+ )
329
+ nb_additions_per_path[path_in_repo] += 1
330
+ for parent in PurePosixPath(path_in_repo).parents:
331
+ # Also keep track of number of updated files per folder
332
+ # => warns if deleting a folder overwrite some contained files
333
+ nb_additions_per_path[str(parent)] += 1
334
+ if isinstance(operation, CommitOperationDelete):
335
+ if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0:
336
+ if operation.is_folder:
337
+ warnings.warn(
338
+ "About to delete a folder containing files that have just been"
339
+ f" updated within the same commit: '{path_in_repo}'. This can"
340
+ " cause undesired inconsistencies in your repo."
341
+ )
342
+ else:
343
+ warnings.warn(
344
+ "About to delete a file that have just been updated within the"
345
+ f" same commit: '{path_in_repo}'. This can cause undesired"
346
+ " inconsistencies in your repo."
347
+ )
348
+
349
+
350
+ @validate_hf_hub_args
351
+ def _upload_lfs_files(
352
+ *,
353
+ additions: List[CommitOperationAdd],
354
+ repo_type: str,
355
+ repo_id: str,
356
+ headers: Dict[str, str],
357
+ endpoint: Optional[str] = None,
358
+ num_threads: int = 5,
359
+ revision: Optional[str] = None,
360
+ ):
361
+ """
362
+ Uploads the content of `additions` to the Hub using the large file storage protocol.
363
+
364
+ Relevant external documentation:
365
+ - LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
366
+
367
+ Args:
368
+ additions (`List` of `CommitOperationAdd`):
369
+ The files to be uploaded
370
+ repo_type (`str`):
371
+ Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
372
+ repo_id (`str`):
373
+ A namespace (user or an organization) and a repo name separated
374
+ by a `/`.
375
+ headers (`Dict[str, str]`):
376
+ Headers to use for the request, including authorization headers and user agent.
377
+ num_threads (`int`, *optional*):
378
+ The number of concurrent threads to use when uploading. Defaults to 5.
379
+ revision (`str`, *optional*):
380
+ The git revision to upload to.
381
+
382
+ Raises:
383
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
384
+ If an upload failed for any reason
385
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
386
+ If the server returns malformed responses
387
+ [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
388
+ If the LFS batch endpoint returned an HTTP error.
389
+ """
390
+ # Step 1: retrieve upload instructions from the LFS batch endpoint.
391
+ # Upload instructions are retrieved by chunk of 256 files to avoid reaching
392
+ # the payload limit.
393
+ batch_actions: List[Dict] = []
394
+ for chunk in chunk_iterable(additions, chunk_size=256):
395
+ batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(
396
+ upload_infos=[op.upload_info for op in chunk],
397
+ repo_id=repo_id,
398
+ repo_type=repo_type,
399
+ revision=revision,
400
+ endpoint=endpoint,
401
+ headers=headers,
402
+ token=None, # already passed in 'headers'
403
+ )
404
+
405
+ # If at least 1 error, we do not retrieve information for other chunks
406
+ if batch_errors_chunk:
407
+ message = "\n".join(
408
+ [
409
+ f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}"
410
+ for err in batch_errors_chunk
411
+ ]
412
+ )
413
+ raise ValueError(f"LFS batch endpoint returned errors:\n{message}")
414
+
415
+ batch_actions += batch_actions_chunk
416
+ oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions}
417
+
418
+ # Step 2: ignore files that have already been uploaded
419
+ filtered_actions = []
420
+ for action in batch_actions:
421
+ if action.get("actions") is None:
422
+ logger.debug(
423
+ f"Content of file {oid2addop[action['oid']].path_in_repo} is already"
424
+ " present upstream - skipping upload."
425
+ )
426
+ else:
427
+ filtered_actions.append(action)
428
+
429
+ if len(filtered_actions) == 0:
430
+ logger.debug("No LFS files to upload.")
431
+ return
432
+
433
+ # Step 3: upload files concurrently according to these instructions
434
+ def _wrapped_lfs_upload(batch_action) -> None:
435
+ try:
436
+ operation = oid2addop[batch_action["oid"]]
437
+ lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers, endpoint=endpoint)
438
+ except Exception as exc:
439
+ raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc
440
+
441
+ if constants.HF_HUB_ENABLE_HF_TRANSFER:
442
+ logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.")
443
+ for action in hf_tqdm(filtered_actions, name="huggingface_hub.lfs_upload"):
444
+ _wrapped_lfs_upload(action)
445
+ elif len(filtered_actions) == 1:
446
+ logger.debug("Uploading 1 LFS file to the Hub")
447
+ _wrapped_lfs_upload(filtered_actions[0])
448
+ else:
449
+ logger.debug(
450
+ f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently"
451
+ )
452
+ thread_map(
453
+ _wrapped_lfs_upload,
454
+ filtered_actions,
455
+ desc=f"Upload {len(filtered_actions)} LFS files",
456
+ max_workers=num_threads,
457
+ tqdm_class=hf_tqdm,
458
+ )
459
+
460
+
461
+ def _validate_preupload_info(preupload_info: dict):
462
+ files = preupload_info.get("files")
463
+ if not isinstance(files, list):
464
+ raise ValueError("preupload_info is improperly formatted")
465
+ for file_info in files:
466
+ if not (
467
+ isinstance(file_info, dict)
468
+ and isinstance(file_info.get("path"), str)
469
+ and isinstance(file_info.get("uploadMode"), str)
470
+ and (file_info["uploadMode"] in ("lfs", "regular"))
471
+ ):
472
+ raise ValueError("preupload_info is improperly formatted:")
473
+ return preupload_info
474
+
475
+
476
+ @validate_hf_hub_args
477
+ def _fetch_upload_modes(
478
+ additions: Iterable[CommitOperationAdd],
479
+ repo_type: str,
480
+ repo_id: str,
481
+ headers: Dict[str, str],
482
+ revision: str,
483
+ endpoint: Optional[str] = None,
484
+ create_pr: bool = False,
485
+ gitignore_content: Optional[str] = None,
486
+ ) -> None:
487
+ """
488
+ Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob
489
+ or as git LFS blob. Input `additions` are mutated in-place with the upload mode.
490
+
491
+ Args:
492
+ additions (`Iterable` of :class:`CommitOperationAdd`):
493
+ Iterable of :class:`CommitOperationAdd` describing the files to
494
+ upload to the Hub.
495
+ repo_type (`str`):
496
+ Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
497
+ repo_id (`str`):
498
+ A namespace (user or an organization) and a repo name separated
499
+ by a `/`.
500
+ headers (`Dict[str, str]`):
501
+ Headers to use for the request, including authorization headers and user agent.
502
+ revision (`str`):
503
+ The git revision to upload the files to. Can be any valid git revision.
504
+ gitignore_content (`str`, *optional*):
505
+ The content of the `.gitignore` file to know which files should be ignored. The order of priority
506
+ is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present
507
+ in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub
508
+ (if any).
509
+ Raises:
510
+ [`~utils.HfHubHTTPError`]
511
+ If the Hub API returned an error.
512
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
513
+ If the Hub API response is improperly formatted.
514
+ """
515
+ endpoint = endpoint if endpoint is not None else constants.ENDPOINT
516
+
517
+ # Fetch upload mode (LFS or regular) chunk by chunk.
518
+ upload_modes: Dict[str, UploadMode] = {}
519
+ should_ignore_info: Dict[str, bool] = {}
520
+ oid_info: Dict[str, Optional[str]] = {}
521
+
522
+ for chunk in chunk_iterable(additions, 256):
523
+ payload: Dict = {
524
+ "files": [
525
+ {
526
+ "path": op.path_in_repo,
527
+ "sample": base64.b64encode(op.upload_info.sample).decode("ascii"),
528
+ "size": op.upload_info.size,
529
+ }
530
+ for op in chunk
531
+ ]
532
+ }
533
+ if gitignore_content is not None:
534
+ payload["gitIgnore"] = gitignore_content
535
+
536
+ resp = get_session().post(
537
+ f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
538
+ json=payload,
539
+ headers=headers,
540
+ params={"create_pr": "1"} if create_pr else None,
541
+ )
542
+ hf_raise_for_status(resp)
543
+ preupload_info = _validate_preupload_info(resp.json())
544
+ upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]})
545
+ should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]})
546
+ oid_info.update(**{file["path"]: file.get("oid") for file in preupload_info["files"]})
547
+
548
+ # Set upload mode for each addition operation
549
+ for addition in additions:
550
+ addition._upload_mode = upload_modes[addition.path_in_repo]
551
+ addition._should_ignore = should_ignore_info[addition.path_in_repo]
552
+ addition._remote_oid = oid_info[addition.path_in_repo]
553
+
554
+ # Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented)
555
+ # => empty files are uploaded as "regular" to still allow users to commit them.
556
+ for addition in additions:
557
+ if addition.upload_info.size == 0:
558
+ addition._upload_mode = "regular"
559
+
560
+
561
+ @validate_hf_hub_args
562
+ def _fetch_files_to_copy(
563
+ copies: Iterable[CommitOperationCopy],
564
+ repo_type: str,
565
+ repo_id: str,
566
+ headers: Dict[str, str],
567
+ revision: str,
568
+ endpoint: Optional[str] = None,
569
+ ) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]:
570
+ """
571
+ Fetch information about the files to copy.
572
+
573
+ For LFS files, we only need their metadata (file size and sha256) while for regular files
574
+ we need to download the raw content from the Hub.
575
+
576
+ Args:
577
+ copies (`Iterable` of :class:`CommitOperationCopy`):
578
+ Iterable of :class:`CommitOperationCopy` describing the files to
579
+ copy on the Hub.
580
+ repo_type (`str`):
581
+ Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
582
+ repo_id (`str`):
583
+ A namespace (user or an organization) and a repo name separated
584
+ by a `/`.
585
+ headers (`Dict[str, str]`):
586
+ Headers to use for the request, including authorization headers and user agent.
587
+ revision (`str`):
588
+ The git revision to upload the files to. Can be any valid git revision.
589
+
590
+ Returns: `Dict[Tuple[str, Optional[str]], Union[RepoFile, bytes]]]`
591
+ Key is the file path and revision of the file to copy.
592
+ Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files).
593
+
594
+ Raises:
595
+ [`~utils.HfHubHTTPError`]
596
+ If the Hub API returned an error.
597
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
598
+ If the Hub API response is improperly formatted.
599
+ """
600
+ from .hf_api import HfApi, RepoFolder
601
+
602
+ hf_api = HfApi(endpoint=endpoint, headers=headers)
603
+ files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {}
604
+ # Store (path, revision) -> oid mapping
605
+ oid_info: Dict[Tuple[str, Optional[str]], Optional[str]] = {}
606
+ # 1. Fetch OIDs for destination paths in batches.
607
+ dest_paths = [op.path_in_repo for op in copies]
608
+ for offset in range(0, len(dest_paths), FETCH_LFS_BATCH_SIZE):
609
+ dest_repo_files = hf_api.get_paths_info(
610
+ repo_id=repo_id,
611
+ paths=dest_paths[offset : offset + FETCH_LFS_BATCH_SIZE],
612
+ revision=revision,
613
+ repo_type=repo_type,
614
+ )
615
+ for file in dest_repo_files:
616
+ if not isinstance(file, RepoFolder):
617
+ oid_info[(file.path, revision)] = file.blob_id
618
+
619
+ # 2. Group by source revision and fetch source file info in batches.
620
+ for src_revision, operations in groupby(copies, key=lambda op: op.src_revision):
621
+ operations = list(operations) # type: ignore
622
+ src_paths = [op.src_path_in_repo for op in operations]
623
+ for offset in range(0, len(src_paths), FETCH_LFS_BATCH_SIZE):
624
+ src_repo_files = hf_api.get_paths_info(
625
+ repo_id=repo_id,
626
+ paths=src_paths[offset : offset + FETCH_LFS_BATCH_SIZE],
627
+ revision=src_revision or revision,
628
+ repo_type=repo_type,
629
+ )
630
+
631
+ for src_repo_file in src_repo_files:
632
+ if isinstance(src_repo_file, RepoFolder):
633
+ raise NotImplementedError("Copying a folder is not implemented.")
634
+ oid_info[(src_repo_file.path, src_revision)] = src_repo_file.blob_id
635
+ # If it's an LFS file, store the RepoFile object. Otherwise, download raw bytes.
636
+ if src_repo_file.lfs:
637
+ files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file
638
+ else:
639
+ # TODO: (optimization) download regular files to copy concurrently
640
+ url = hf_hub_url(
641
+ endpoint=endpoint,
642
+ repo_type=repo_type,
643
+ repo_id=repo_id,
644
+ revision=src_revision or revision,
645
+ filename=src_repo_file.path,
646
+ )
647
+ response = get_session().get(url, headers=headers)
648
+ hf_raise_for_status(response)
649
+ files_to_copy[(src_repo_file.path, src_revision)] = response.content
650
+ # 3. Ensure all operations found a corresponding file in the Hub
651
+ # and track src/dest OIDs for each operation.
652
+ for operation in operations:
653
+ if (operation.src_path_in_repo, src_revision) not in files_to_copy:
654
+ raise EntryNotFoundError(
655
+ f"Cannot copy {operation.src_path_in_repo} at revision "
656
+ f"{src_revision or revision}: file is missing on repo."
657
+ )
658
+ operation._src_oid = oid_info.get((operation.src_path_in_repo, operation.src_revision))
659
+ operation._dest_oid = oid_info.get((operation.path_in_repo, revision))
660
+ return files_to_copy
661
+
662
+
663
+ def _prepare_commit_payload(
664
+ operations: Iterable[CommitOperation],
665
+ files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]],
666
+ commit_message: str,
667
+ commit_description: Optional[str] = None,
668
+ parent_commit: Optional[str] = None,
669
+ ) -> Iterable[Dict[str, Any]]:
670
+ """
671
+ Builds the payload to POST to the `/commit` API of the Hub.
672
+
673
+ Payload is returned as an iterator so that it can be streamed as a ndjson in the
674
+ POST request.
675
+
676
+ For more information, see:
677
+ - https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073
678
+ - http://ndjson.org/
679
+ """
680
+ commit_description = commit_description if commit_description is not None else ""
681
+
682
+ # 1. Send a header item with the commit metadata
683
+ header_value = {"summary": commit_message, "description": commit_description}
684
+ if parent_commit is not None:
685
+ header_value["parentCommit"] = parent_commit
686
+ yield {"key": "header", "value": header_value}
687
+
688
+ nb_ignored_files = 0
689
+
690
+ # 2. Send operations, one per line
691
+ for operation in operations:
692
+ # Skip ignored files
693
+ if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
694
+ logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
695
+ nb_ignored_files += 1
696
+ continue
697
+
698
+ # 2.a. Case adding a regular file
699
+ if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular":
700
+ yield {
701
+ "key": "file",
702
+ "value": {
703
+ "content": operation.b64content().decode(),
704
+ "path": operation.path_in_repo,
705
+ "encoding": "base64",
706
+ },
707
+ }
708
+ # 2.b. Case adding an LFS file
709
+ elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs":
710
+ yield {
711
+ "key": "lfsFile",
712
+ "value": {
713
+ "path": operation.path_in_repo,
714
+ "algo": "sha256",
715
+ "oid": operation.upload_info.sha256.hex(),
716
+ "size": operation.upload_info.size,
717
+ },
718
+ }
719
+ # 2.c. Case deleting a file or folder
720
+ elif isinstance(operation, CommitOperationDelete):
721
+ yield {
722
+ "key": "deletedFolder" if operation.is_folder else "deletedFile",
723
+ "value": {"path": operation.path_in_repo},
724
+ }
725
+ # 2.d. Case copying a file or folder
726
+ elif isinstance(operation, CommitOperationCopy):
727
+ file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)]
728
+ if isinstance(file_to_copy, bytes):
729
+ yield {
730
+ "key": "file",
731
+ "value": {
732
+ "content": base64.b64encode(file_to_copy).decode(),
733
+ "path": operation.path_in_repo,
734
+ "encoding": "base64",
735
+ },
736
+ }
737
+ elif file_to_copy.lfs:
738
+ yield {
739
+ "key": "lfsFile",
740
+ "value": {
741
+ "path": operation.path_in_repo,
742
+ "algo": "sha256",
743
+ "oid": file_to_copy.lfs.sha256,
744
+ },
745
+ }
746
+ else:
747
+ raise ValueError(
748
+ "Malformed files_to_copy (should be raw file content as bytes or RepoFile objects with LFS info."
749
+ )
750
+ # 2.e. Never expected to happen
751
+ else:
752
+ raise ValueError(
753
+ f"Unknown operation to commit. Operation: {operation}. Upload mode:"
754
+ f" {getattr(operation, '_upload_mode', None)}"
755
+ )
756
+
757
+ if nb_ignored_files > 0:
758
+ logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).")
.venv/lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import logging
3
+ import os
4
+ import time
5
+ from concurrent.futures import Future
6
+ from dataclasses import dataclass
7
+ from io import SEEK_END, SEEK_SET, BytesIO
8
+ from pathlib import Path
9
+ from threading import Lock, Thread
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ from .hf_api import DEFAULT_IGNORE_PATTERNS, CommitInfo, CommitOperationAdd, HfApi
13
+ from .utils import filter_repo_objects
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class _FileToUpload:
21
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
22
+
23
+ local_path: Path
24
+ path_in_repo: str
25
+ size_limit: int
26
+ last_modified: float
27
+
28
+
29
+ class CommitScheduler:
30
+ """
31
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
32
+
33
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
34
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
35
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
36
+ to learn more about how to use it.
37
+
38
+ Args:
39
+ repo_id (`str`):
40
+ The id of the repo to commit to.
41
+ folder_path (`str` or `Path`):
42
+ Path to the local folder to upload regularly.
43
+ every (`int` or `float`, *optional*):
44
+ The number of minutes between each commit. Defaults to 5 minutes.
45
+ path_in_repo (`str`, *optional*):
46
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
47
+ of the repository.
48
+ repo_type (`str`, *optional*):
49
+ The type of the repo to commit to. Defaults to `model`.
50
+ revision (`str`, *optional*):
51
+ The revision of the repo to commit to. Defaults to `main`.
52
+ private (`bool`, *optional*):
53
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
54
+ token (`str`, *optional*):
55
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
56
+ allow_patterns (`List[str]` or `str`, *optional*):
57
+ If provided, only files matching at least one pattern are uploaded.
58
+ ignore_patterns (`List[str]` or `str`, *optional*):
59
+ If provided, files matching any of the patterns are not uploaded.
60
+ squash_history (`bool`, *optional*):
61
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
62
+ useful to avoid degraded performances on the repo when it grows too large.
63
+ hf_api (`HfApi`, *optional*):
64
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
65
+
66
+ Example:
67
+ ```py
68
+ >>> from pathlib import Path
69
+ >>> from huggingface_hub import CommitScheduler
70
+
71
+ # Scheduler uploads every 10 minutes
72
+ >>> csv_path = Path("watched_folder/data.csv")
73
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
74
+
75
+ >>> with csv_path.open("a") as f:
76
+ ... f.write("first line")
77
+
78
+ # Some time later (...)
79
+ >>> with csv_path.open("a") as f:
80
+ ... f.write("second line")
81
+ ```
82
+
83
+ Example using a context manager:
84
+ ```py
85
+ >>> from pathlib import Path
86
+ >>> from huggingface_hub import CommitScheduler
87
+
88
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
89
+ ... csv_path = Path("watched_folder/data.csv")
90
+ ... with csv_path.open("a") as f:
91
+ ... f.write("first line")
92
+ ... (...)
93
+ ... with csv_path.open("a") as f:
94
+ ... f.write("second line")
95
+
96
+ # Scheduler is now stopped and last commit have been triggered
97
+ ```
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ *,
103
+ repo_id: str,
104
+ folder_path: Union[str, Path],
105
+ every: Union[int, float] = 5,
106
+ path_in_repo: Optional[str] = None,
107
+ repo_type: Optional[str] = None,
108
+ revision: Optional[str] = None,
109
+ private: Optional[bool] = None,
110
+ token: Optional[str] = None,
111
+ allow_patterns: Optional[Union[List[str], str]] = None,
112
+ ignore_patterns: Optional[Union[List[str], str]] = None,
113
+ squash_history: bool = False,
114
+ hf_api: Optional["HfApi"] = None,
115
+ ) -> None:
116
+ self.api = hf_api or HfApi(token=token)
117
+
118
+ # Folder
119
+ self.folder_path = Path(folder_path).expanduser().resolve()
120
+ self.path_in_repo = path_in_repo or ""
121
+ self.allow_patterns = allow_patterns
122
+
123
+ if ignore_patterns is None:
124
+ ignore_patterns = []
125
+ elif isinstance(ignore_patterns, str):
126
+ ignore_patterns = [ignore_patterns]
127
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
128
+
129
+ if self.folder_path.is_file():
130
+ raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.")
131
+ self.folder_path.mkdir(parents=True, exist_ok=True)
132
+
133
+ # Repository
134
+ repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True)
135
+ self.repo_id = repo_url.repo_id
136
+ self.repo_type = repo_type
137
+ self.revision = revision
138
+ self.token = token
139
+
140
+ # Keep track of already uploaded files
141
+ self.last_uploaded: Dict[Path, float] = {} # key is local path, value is timestamp
142
+
143
+ # Scheduler
144
+ if not every > 0:
145
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
146
+ self.lock = Lock()
147
+ self.every = every
148
+ self.squash_history = squash_history
149
+
150
+ logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.")
151
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
152
+ self._scheduler_thread.start()
153
+ atexit.register(self._push_to_hub)
154
+
155
+ self.__stopped = False
156
+
157
+ def stop(self) -> None:
158
+ """Stop the scheduler.
159
+
160
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
161
+ """
162
+ self.__stopped = True
163
+
164
+ def __enter__(self) -> "CommitScheduler":
165
+ return self
166
+
167
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
168
+ # Upload last changes before exiting
169
+ self.trigger().result()
170
+ self.stop()
171
+ return
172
+
173
+ def _run_scheduler(self) -> None:
174
+ """Dumb thread waiting between each scheduled push to Hub."""
175
+ while True:
176
+ self.last_future = self.trigger()
177
+ time.sleep(self.every * 60)
178
+ if self.__stopped:
179
+ break
180
+
181
+ def trigger(self) -> Future:
182
+ """Trigger a `push_to_hub` and return a future.
183
+
184
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
185
+ immediately, without waiting for the next scheduled commit.
186
+ """
187
+ return self.api.run_as_future(self._push_to_hub)
188
+
189
+ def _push_to_hub(self) -> Optional[CommitInfo]:
190
+ if self.__stopped: # If stopped, already scheduled commits are ignored
191
+ return None
192
+
193
+ logger.info("(Background) scheduled commit triggered.")
194
+ try:
195
+ value = self.push_to_hub()
196
+ if self.squash_history:
197
+ logger.info("(Background) squashing repo history.")
198
+ self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision)
199
+ return value
200
+ except Exception as e:
201
+ logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced
202
+ raise
203
+
204
+ def push_to_hub(self) -> Optional[CommitInfo]:
205
+ """
206
+ Push folder to the Hub and return the commit info.
207
+
208
+ <Tip warning={true}>
209
+
210
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
211
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
212
+ issues.
213
+
214
+ </Tip>
215
+
216
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
217
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
218
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
219
+ for example to compress data together in a single file before committing. For more details and examples, check
220
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
221
+ """
222
+ # Check files to upload (with lock)
223
+ with self.lock:
224
+ logger.debug("Listing files to upload for scheduled commit.")
225
+
226
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
227
+ relpath_to_abspath = {
228
+ path.relative_to(self.folder_path).as_posix(): path
229
+ for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic
230
+ if path.is_file()
231
+ }
232
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
233
+
234
+ # Filter with pattern + filter out unchanged files + retrieve current file size
235
+ files_to_upload: List[_FileToUpload] = []
236
+ for relpath in filter_repo_objects(
237
+ relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns
238
+ ):
239
+ local_path = relpath_to_abspath[relpath]
240
+ stat = local_path.stat()
241
+ if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime:
242
+ files_to_upload.append(
243
+ _FileToUpload(
244
+ local_path=local_path,
245
+ path_in_repo=prefix + relpath,
246
+ size_limit=stat.st_size,
247
+ last_modified=stat.st_mtime,
248
+ )
249
+ )
250
+
251
+ # Return if nothing to upload
252
+ if len(files_to_upload) == 0:
253
+ logger.debug("Dropping schedule commit: no changed file to upload.")
254
+ return None
255
+
256
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
257
+ logger.debug("Removing unchanged files since previous scheduled commit.")
258
+ add_operations = [
259
+ CommitOperationAdd(
260
+ # Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
261
+ path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit),
262
+ path_in_repo=file_to_upload.path_in_repo,
263
+ )
264
+ for file_to_upload in files_to_upload
265
+ ]
266
+
267
+ # Upload files (append mode expected - no need for lock)
268
+ logger.debug("Uploading files for scheduled commit.")
269
+ commit_info = self.api.create_commit(
270
+ repo_id=self.repo_id,
271
+ repo_type=self.repo_type,
272
+ operations=add_operations,
273
+ commit_message="Scheduled Commit",
274
+ revision=self.revision,
275
+ )
276
+
277
+ # Successful commit: keep track of the latest "last_modified" for each file
278
+ for file in files_to_upload:
279
+ self.last_uploaded[file.local_path] = file.last_modified
280
+ return commit_info
281
+
282
+
283
+ class PartialFileIO(BytesIO):
284
+ """A file-like object that reads only the first part of a file.
285
+
286
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
287
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
288
+
289
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
290
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
291
+
292
+ Only supports `read`, `tell` and `seek` methods.
293
+
294
+ Args:
295
+ file_path (`str` or `Path`):
296
+ Path to the file to read.
297
+ size_limit (`int`):
298
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
299
+ will be read (and uploaded).
300
+ """
301
+
302
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
303
+ self._file_path = Path(file_path)
304
+ self._file = self._file_path.open("rb")
305
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
306
+
307
+ def __del__(self) -> None:
308
+ self._file.close()
309
+ return super().__del__()
310
+
311
+ def __repr__(self) -> str:
312
+ return f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
313
+
314
+ def __len__(self) -> int:
315
+ return self._size_limit
316
+
317
+ def __getattribute__(self, name: str):
318
+ if name.startswith("_") or name in ("read", "tell", "seek"): # only 3 public methods supported
319
+ return super().__getattribute__(name)
320
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
321
+
322
+ def tell(self) -> int:
323
+ """Return the current file position."""
324
+ return self._file.tell()
325
+
326
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
327
+ """Change the stream position to the given offset.
328
+
329
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
330
+ """
331
+ if __whence == SEEK_END:
332
+ # SEEK_END => set from the truncated end
333
+ __offset = len(self) + __offset
334
+ __whence = SEEK_SET
335
+
336
+ pos = self._file.seek(__offset, __whence)
337
+ if pos > self._size_limit:
338
+ return self._file.seek(self._size_limit)
339
+ return pos
340
+
341
+ def read(self, __size: Optional[int] = -1) -> bytes:
342
+ """Read at most `__size` bytes from the file.
343
+
344
+ Behavior is the same as a regular file, except that it is capped to the size limit.
345
+ """
346
+ current = self._file.tell()
347
+ if __size is None or __size < 0:
348
+ # Read until file limit
349
+ truncated_size = self._size_limit - current
350
+ else:
351
+ # Read until file limit or __size
352
+ truncated_size = min(__size, self._size_limit - current)
353
+ return self._file.read(truncated_size)
.venv/lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from dataclasses import dataclass, field
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from typing import TYPE_CHECKING, Dict, Optional, Union
6
+
7
+ from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError
8
+
9
+ from .inference._client import InferenceClient
10
+ from .inference._generated._async_client import AsyncInferenceClient
11
+ from .utils import get_session, logging, parse_datetime
12
+
13
+
14
+ if TYPE_CHECKING:
15
+ from .hf_api import HfApi
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class InferenceEndpointStatus(str, Enum):
22
+ PENDING = "pending"
23
+ INITIALIZING = "initializing"
24
+ UPDATING = "updating"
25
+ UPDATE_FAILED = "updateFailed"
26
+ RUNNING = "running"
27
+ PAUSED = "paused"
28
+ FAILED = "failed"
29
+ SCALED_TO_ZERO = "scaledToZero"
30
+
31
+
32
+ class InferenceEndpointType(str, Enum):
33
+ PUBlIC = "public"
34
+ PROTECTED = "protected"
35
+ PRIVATE = "private"
36
+
37
+
38
+ @dataclass
39
+ class InferenceEndpoint:
40
+ """
41
+ Contains information about a deployed Inference Endpoint.
42
+
43
+ Args:
44
+ name (`str`):
45
+ The unique name of the Inference Endpoint.
46
+ namespace (`str`):
47
+ The namespace where the Inference Endpoint is located.
48
+ repository (`str`):
49
+ The name of the model repository deployed on this Inference Endpoint.
50
+ status ([`InferenceEndpointStatus`]):
51
+ The current status of the Inference Endpoint.
52
+ url (`str`, *optional*):
53
+ The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL.
54
+ framework (`str`):
55
+ The machine learning framework used for the model.
56
+ revision (`str`):
57
+ The specific model revision deployed on the Inference Endpoint.
58
+ task (`str`):
59
+ The task associated with the deployed model.
60
+ created_at (`datetime.datetime`):
61
+ The timestamp when the Inference Endpoint was created.
62
+ updated_at (`datetime.datetime`):
63
+ The timestamp of the last update of the Inference Endpoint.
64
+ type ([`InferenceEndpointType`]):
65
+ The type of the Inference Endpoint (public, protected, private).
66
+ raw (`Dict`):
67
+ The raw dictionary data returned from the API.
68
+ token (`str` or `bool`, *optional*):
69
+ Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the
70
+ locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server.
71
+
72
+ Example:
73
+ ```python
74
+ >>> from huggingface_hub import get_inference_endpoint
75
+ >>> endpoint = get_inference_endpoint("my-text-to-image")
76
+ >>> endpoint
77
+ InferenceEndpoint(name='my-text-to-image', ...)
78
+
79
+ # Get status
80
+ >>> endpoint.status
81
+ 'running'
82
+ >>> endpoint.url
83
+ 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
84
+
85
+ # Run inference
86
+ >>> endpoint.client.text_to_image(...)
87
+
88
+ # Pause endpoint to save $$$
89
+ >>> endpoint.pause()
90
+
91
+ # ...
92
+ # Resume and wait for deployment
93
+ >>> endpoint.resume()
94
+ >>> endpoint.wait()
95
+ >>> endpoint.client.text_to_image(...)
96
+ ```
97
+ """
98
+
99
+ # Field in __repr__
100
+ name: str = field(init=False)
101
+ namespace: str
102
+ repository: str = field(init=False)
103
+ status: InferenceEndpointStatus = field(init=False)
104
+ url: Optional[str] = field(init=False)
105
+
106
+ # Other fields
107
+ framework: str = field(repr=False, init=False)
108
+ revision: str = field(repr=False, init=False)
109
+ task: str = field(repr=False, init=False)
110
+ created_at: datetime = field(repr=False, init=False)
111
+ updated_at: datetime = field(repr=False, init=False)
112
+ type: InferenceEndpointType = field(repr=False, init=False)
113
+
114
+ # Raw dict from the API
115
+ raw: Dict = field(repr=False)
116
+
117
+ # Internal fields
118
+ _token: Union[str, bool, None] = field(repr=False, compare=False)
119
+ _api: "HfApi" = field(repr=False, compare=False)
120
+
121
+ @classmethod
122
+ def from_raw(
123
+ cls, raw: Dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None
124
+ ) -> "InferenceEndpoint":
125
+ """Initialize object from raw dictionary."""
126
+ if api is None:
127
+ from .hf_api import HfApi
128
+
129
+ api = HfApi()
130
+ if token is None:
131
+ token = api.token
132
+
133
+ # All other fields are populated in __post_init__
134
+ return cls(raw=raw, namespace=namespace, _token=token, _api=api)
135
+
136
+ def __post_init__(self) -> None:
137
+ """Populate fields from raw dictionary."""
138
+ self._populate_from_raw()
139
+
140
+ @property
141
+ def client(self) -> InferenceClient:
142
+ """Returns a client to make predictions on this Inference Endpoint.
143
+
144
+ Returns:
145
+ [`InferenceClient`]: an inference client pointing to the deployed endpoint.
146
+
147
+ Raises:
148
+ [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
149
+ """
150
+ if self.url is None:
151
+ raise InferenceEndpointError(
152
+ "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
153
+ "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
154
+ )
155
+ return InferenceClient(
156
+ model=self.url,
157
+ token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
158
+ )
159
+
160
+ @property
161
+ def async_client(self) -> AsyncInferenceClient:
162
+ """Returns a client to make predictions on this Inference Endpoint.
163
+
164
+ Returns:
165
+ [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint.
166
+
167
+ Raises:
168
+ [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
169
+ """
170
+ if self.url is None:
171
+ raise InferenceEndpointError(
172
+ "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
173
+ "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
174
+ )
175
+ return AsyncInferenceClient(
176
+ model=self.url,
177
+ token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
178
+ )
179
+
180
+ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint":
181
+ """Wait for the Inference Endpoint to be deployed.
182
+
183
+ Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout`
184
+ seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest
185
+ data.
186
+
187
+ Args:
188
+ timeout (`int`, *optional*):
189
+ The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait
190
+ indefinitely.
191
+ refresh_every (`int`, *optional*):
192
+ The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s.
193
+
194
+ Returns:
195
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
196
+
197
+ Raises:
198
+ [`InferenceEndpointError`]
199
+ If the Inference Endpoint ended up in a failed state.
200
+ [`InferenceEndpointTimeoutError`]
201
+ If the Inference Endpoint is not deployed after `timeout` seconds.
202
+ """
203
+ if timeout is not None and timeout < 0:
204
+ raise ValueError("`timeout` cannot be negative.")
205
+ if refresh_every <= 0:
206
+ raise ValueError("`refresh_every` must be positive.")
207
+
208
+ start = time.time()
209
+ while True:
210
+ if self.url is not None:
211
+ # Means the URL is provisioned => check if the endpoint is reachable
212
+ response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token))
213
+ if response.status_code == 200:
214
+ logger.info("Inference Endpoint is ready to be used.")
215
+ return self
216
+ if self.status == InferenceEndpointStatus.FAILED:
217
+ raise InferenceEndpointError(
218
+ f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information."
219
+ )
220
+ if timeout is not None:
221
+ if time.time() - start > timeout:
222
+ raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")
223
+ logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...")
224
+ time.sleep(refresh_every)
225
+ self.fetch()
226
+
227
+ def fetch(self) -> "InferenceEndpoint":
228
+ """Fetch latest information about the Inference Endpoint.
229
+
230
+ Returns:
231
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
232
+ """
233
+ obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
234
+ self.raw = obj.raw
235
+ self._populate_from_raw()
236
+ return self
237
+
238
+ def update(
239
+ self,
240
+ *,
241
+ # Compute update
242
+ accelerator: Optional[str] = None,
243
+ instance_size: Optional[str] = None,
244
+ instance_type: Optional[str] = None,
245
+ min_replica: Optional[int] = None,
246
+ max_replica: Optional[int] = None,
247
+ scale_to_zero_timeout: Optional[int] = None,
248
+ # Model update
249
+ repository: Optional[str] = None,
250
+ framework: Optional[str] = None,
251
+ revision: Optional[str] = None,
252
+ task: Optional[str] = None,
253
+ custom_image: Optional[Dict] = None,
254
+ secrets: Optional[Dict[str, str]] = None,
255
+ ) -> "InferenceEndpoint":
256
+ """Update the Inference Endpoint.
257
+
258
+ This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
259
+ optional but at least one must be provided.
260
+
261
+ This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the
262
+ latest data from the server.
263
+
264
+ Args:
265
+ accelerator (`str`, *optional*):
266
+ The hardware accelerator to be used for inference (e.g. `"cpu"`).
267
+ instance_size (`str`, *optional*):
268
+ The size or type of the instance to be used for hosting the model (e.g. `"x4"`).
269
+ instance_type (`str`, *optional*):
270
+ The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`).
271
+ min_replica (`int`, *optional*):
272
+ The minimum number of replicas (instances) to keep running for the Inference Endpoint.
273
+ max_replica (`int`, *optional*):
274
+ The maximum number of replicas (instances) to scale to for the Inference Endpoint.
275
+ scale_to_zero_timeout (`int`, *optional*):
276
+ The duration in minutes before an inactive endpoint is scaled to zero.
277
+
278
+ repository (`str`, *optional*):
279
+ The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
280
+ framework (`str`, *optional*):
281
+ The machine learning framework used for the model (e.g. `"custom"`).
282
+ revision (`str`, *optional*):
283
+ The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
284
+ task (`str`, *optional*):
285
+ The task on which to deploy the model (e.g. `"text-classification"`).
286
+ custom_image (`Dict`, *optional*):
287
+ A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an
288
+ Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples).
289
+ secrets (`Dict[str, str]`, *optional*):
290
+ Secret values to inject in the container environment.
291
+ Returns:
292
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
293
+ """
294
+ # Make API call
295
+ obj = self._api.update_inference_endpoint(
296
+ name=self.name,
297
+ namespace=self.namespace,
298
+ accelerator=accelerator,
299
+ instance_size=instance_size,
300
+ instance_type=instance_type,
301
+ min_replica=min_replica,
302
+ max_replica=max_replica,
303
+ scale_to_zero_timeout=scale_to_zero_timeout,
304
+ repository=repository,
305
+ framework=framework,
306
+ revision=revision,
307
+ task=task,
308
+ custom_image=custom_image,
309
+ secrets=secrets,
310
+ token=self._token, # type: ignore [arg-type]
311
+ )
312
+
313
+ # Mutate current object
314
+ self.raw = obj.raw
315
+ self._populate_from_raw()
316
+ return self
317
+
318
+ def pause(self) -> "InferenceEndpoint":
319
+ """Pause the Inference Endpoint.
320
+
321
+ A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`].
322
+ This is different than scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which
323
+ would be automatically restarted when a request is made to it.
324
+
325
+ This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the
326
+ latest data from the server.
327
+
328
+ Returns:
329
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
330
+ """
331
+ obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
332
+ self.raw = obj.raw
333
+ self._populate_from_raw()
334
+ return self
335
+
336
+ def resume(self, running_ok: bool = True) -> "InferenceEndpoint":
337
+ """Resume the Inference Endpoint.
338
+
339
+ This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the
340
+ latest data from the server.
341
+
342
+ Args:
343
+ running_ok (`bool`, *optional*):
344
+ If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to
345
+ `True`.
346
+
347
+ Returns:
348
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
349
+ """
350
+ obj = self._api.resume_inference_endpoint(
351
+ name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token
352
+ ) # type: ignore [arg-type]
353
+ self.raw = obj.raw
354
+ self._populate_from_raw()
355
+ return self
356
+
357
+ def scale_to_zero(self) -> "InferenceEndpoint":
358
+ """Scale Inference Endpoint to zero.
359
+
360
+ An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a
361
+ cold start delay. This is different than pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which
362
+ would require a manual resume with [`InferenceEndpoint.resume`].
363
+
364
+ This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the
365
+ latest data from the server.
366
+
367
+ Returns:
368
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
369
+ """
370
+ obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
371
+ self.raw = obj.raw
372
+ self._populate_from_raw()
373
+ return self
374
+
375
+ def delete(self) -> None:
376
+ """Delete the Inference Endpoint.
377
+
378
+ This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
379
+ to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`].
380
+
381
+ This is an alias for [`HfApi.delete_inference_endpoint`].
382
+ """
383
+ self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
384
+
385
+ def _populate_from_raw(self) -> None:
386
+ """Populate fields from raw dictionary.
387
+
388
+ Called in __post_init__ + each time the Inference Endpoint is updated.
389
+ """
390
+ # Repr fields
391
+ self.name = self.raw["name"]
392
+ self.repository = self.raw["model"]["repository"]
393
+ self.status = self.raw["status"]["state"]
394
+ self.url = self.raw["status"].get("url")
395
+
396
+ # Other fields
397
+ self.framework = self.raw["model"]["framework"]
398
+ self.revision = self.raw["model"]["revision"]
399
+ self.task = self.raw["model"]["task"]
400
+ self.created_at = parse_datetime(self.raw["status"]["createdAt"])
401
+ self.updated_at = parse_datetime(self.raw["status"]["updatedAt"])
402
+ self.type = self.raw["type"]
.venv/lib/python3.11/site-packages/huggingface_hub/_local_folder.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains utilities to handle the `../.cache/huggingface` folder in local directories.
16
+
17
+ First discussed in https://github.com/huggingface/huggingface_hub/issues/1738 to store
18
+ download metadata when downloading files from the hub to a local directory (without
19
+ using the cache).
20
+
21
+ ./.cache/huggingface folder structure:
22
+ [4.0K] data
23
+ ├── [4.0K] .cache
24
+ │ └── [4.0K] huggingface
25
+ │ └── [4.0K] download
26
+ │ ├── [ 16] file.parquet.metadata
27
+ │ ├── [ 16] file.txt.metadata
28
+ │ └── [4.0K] folder
29
+ │ └── [ 16] file.parquet.metadata
30
+
31
+ ├── [6.5G] file.parquet
32
+ ├── [1.5K] file.txt
33
+ └── [4.0K] folder
34
+ └── [ 16] file.parquet
35
+
36
+
37
+ Download metadata file structure:
38
+ ```
39
+ # file.txt.metadata
40
+ 11c5a3d5811f50298f278a704980280950aedb10
41
+ a16a55fda99d2f2e7b69cce5cf93ff4ad3049930
42
+ 1712656091.123
43
+
44
+ # file.parquet.metadata
45
+ 11c5a3d5811f50298f278a704980280950aedb10
46
+ 7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421
47
+ 1712656091.123
48
+ }
49
+ ```
50
+ """
51
+
52
+ import base64
53
+ import hashlib
54
+ import logging
55
+ import os
56
+ import time
57
+ from dataclasses import dataclass
58
+ from pathlib import Path
59
+ from typing import Optional
60
+
61
+ from .utils import WeakFileLock
62
+
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+
67
+ @dataclass
68
+ class LocalDownloadFilePaths:
69
+ """
70
+ Paths to the files related to a download process in a local dir.
71
+
72
+ Returned by [`get_local_download_paths`].
73
+
74
+ Attributes:
75
+ file_path (`Path`):
76
+ Path where the file will be saved.
77
+ lock_path (`Path`):
78
+ Path to the lock file used to ensure atomicity when reading/writing metadata.
79
+ metadata_path (`Path`):
80
+ Path to the metadata file.
81
+ """
82
+
83
+ file_path: Path
84
+ lock_path: Path
85
+ metadata_path: Path
86
+
87
+ def incomplete_path(self, etag: str) -> Path:
88
+ """Return the path where a file will be temporarily downloaded before being moved to `file_path`."""
89
+ return self.metadata_path.parent / f"{_short_hash(self.metadata_path.name)}.{etag}.incomplete"
90
+
91
+
92
+ @dataclass(frozen=True)
93
+ class LocalUploadFilePaths:
94
+ """
95
+ Paths to the files related to an upload process in a local dir.
96
+
97
+ Returned by [`get_local_upload_paths`].
98
+
99
+ Attributes:
100
+ path_in_repo (`str`):
101
+ Path of the file in the repo.
102
+ file_path (`Path`):
103
+ Path where the file will be saved.
104
+ lock_path (`Path`):
105
+ Path to the lock file used to ensure atomicity when reading/writing metadata.
106
+ metadata_path (`Path`):
107
+ Path to the metadata file.
108
+ """
109
+
110
+ path_in_repo: str
111
+ file_path: Path
112
+ lock_path: Path
113
+ metadata_path: Path
114
+
115
+
116
+ @dataclass
117
+ class LocalDownloadFileMetadata:
118
+ """
119
+ Metadata about a file in the local directory related to a download process.
120
+
121
+ Attributes:
122
+ filename (`str`):
123
+ Path of the file in the repo.
124
+ commit_hash (`str`):
125
+ Commit hash of the file in the repo.
126
+ etag (`str`):
127
+ ETag of the file in the repo. Used to check if the file has changed.
128
+ For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash.
129
+ timestamp (`int`):
130
+ Unix timestamp of when the metadata was saved i.e. when the metadata was accurate.
131
+ """
132
+
133
+ filename: str
134
+ commit_hash: str
135
+ etag: str
136
+ timestamp: float
137
+
138
+
139
+ @dataclass
140
+ class LocalUploadFileMetadata:
141
+ """
142
+ Metadata about a file in the local directory related to an upload process.
143
+ """
144
+
145
+ size: int
146
+
147
+ # Default values correspond to "we don't know yet"
148
+ timestamp: Optional[float] = None
149
+ should_ignore: Optional[bool] = None
150
+ sha256: Optional[str] = None
151
+ upload_mode: Optional[str] = None
152
+ is_uploaded: bool = False
153
+ is_committed: bool = False
154
+
155
+ def save(self, paths: LocalUploadFilePaths) -> None:
156
+ """Save the metadata to disk."""
157
+ with WeakFileLock(paths.lock_path):
158
+ with paths.metadata_path.open("w") as f:
159
+ new_timestamp = time.time()
160
+ f.write(str(new_timestamp) + "\n")
161
+
162
+ f.write(str(self.size)) # never None
163
+ f.write("\n")
164
+
165
+ if self.should_ignore is not None:
166
+ f.write(str(int(self.should_ignore)))
167
+ f.write("\n")
168
+
169
+ if self.sha256 is not None:
170
+ f.write(self.sha256)
171
+ f.write("\n")
172
+
173
+ if self.upload_mode is not None:
174
+ f.write(self.upload_mode)
175
+ f.write("\n")
176
+
177
+ f.write(str(int(self.is_uploaded)) + "\n")
178
+ f.write(str(int(self.is_committed)) + "\n")
179
+
180
+ self.timestamp = new_timestamp
181
+
182
+
183
+ def get_local_download_paths(local_dir: Path, filename: str) -> LocalDownloadFilePaths:
184
+ """Compute paths to the files related to a download process.
185
+
186
+ Folders containing the paths are all guaranteed to exist.
187
+
188
+ Args:
189
+ local_dir (`Path`):
190
+ Path to the local directory in which files are downloaded.
191
+ filename (`str`):
192
+ Path of the file in the repo.
193
+
194
+ Return:
195
+ [`LocalDownloadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path, incomplete_path).
196
+ """
197
+ # filename is the path in the Hub repository (separated by '/')
198
+ # make sure to have a cross platform transcription
199
+ sanitized_filename = os.path.join(*filename.split("/"))
200
+ if os.name == "nt":
201
+ if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename:
202
+ raise ValueError(
203
+ f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository"
204
+ " owner to rename this file."
205
+ )
206
+ file_path = local_dir / sanitized_filename
207
+ metadata_path = _huggingface_dir(local_dir) / "download" / f"{sanitized_filename}.metadata"
208
+ lock_path = metadata_path.with_suffix(".lock")
209
+
210
+ # Some Windows versions do not allow for paths longer than 255 characters.
211
+ # In this case, we must specify it as an extended path by using the "\\?\" prefix
212
+ if os.name == "nt":
213
+ if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255:
214
+ file_path = Path("\\\\?\\" + os.path.abspath(file_path))
215
+ lock_path = Path("\\\\?\\" + os.path.abspath(lock_path))
216
+ metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path))
217
+
218
+ file_path.parent.mkdir(parents=True, exist_ok=True)
219
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
220
+ return LocalDownloadFilePaths(file_path=file_path, lock_path=lock_path, metadata_path=metadata_path)
221
+
222
+
223
+ def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePaths:
224
+ """Compute paths to the files related to an upload process.
225
+
226
+ Folders containing the paths are all guaranteed to exist.
227
+
228
+ Args:
229
+ local_dir (`Path`):
230
+ Path to the local directory that is uploaded.
231
+ filename (`str`):
232
+ Path of the file in the repo.
233
+
234
+ Return:
235
+ [`LocalUploadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path).
236
+ """
237
+ # filename is the path in the Hub repository (separated by '/')
238
+ # make sure to have a cross platform transcription
239
+ sanitized_filename = os.path.join(*filename.split("/"))
240
+ if os.name == "nt":
241
+ if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename:
242
+ raise ValueError(
243
+ f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository"
244
+ " owner to rename this file."
245
+ )
246
+ file_path = local_dir / sanitized_filename
247
+ metadata_path = _huggingface_dir(local_dir) / "upload" / f"{sanitized_filename}.metadata"
248
+ lock_path = metadata_path.with_suffix(".lock")
249
+
250
+ # Some Windows versions do not allow for paths longer than 255 characters.
251
+ # In this case, we must specify it as an extended path by using the "\\?\" prefix
252
+ if os.name == "nt":
253
+ if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255:
254
+ file_path = Path("\\\\?\\" + os.path.abspath(file_path))
255
+ lock_path = Path("\\\\?\\" + os.path.abspath(lock_path))
256
+ metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path))
257
+
258
+ file_path.parent.mkdir(parents=True, exist_ok=True)
259
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
260
+ return LocalUploadFilePaths(
261
+ path_in_repo=filename, file_path=file_path, lock_path=lock_path, metadata_path=metadata_path
262
+ )
263
+
264
+
265
+ def read_download_metadata(local_dir: Path, filename: str) -> Optional[LocalDownloadFileMetadata]:
266
+ """Read metadata about a file in the local directory related to a download process.
267
+
268
+ Args:
269
+ local_dir (`Path`):
270
+ Path to the local directory in which files are downloaded.
271
+ filename (`str`):
272
+ Path of the file in the repo.
273
+
274
+ Return:
275
+ `[LocalDownloadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise.
276
+ """
277
+ paths = get_local_download_paths(local_dir, filename)
278
+ with WeakFileLock(paths.lock_path):
279
+ if paths.metadata_path.exists():
280
+ try:
281
+ with paths.metadata_path.open() as f:
282
+ commit_hash = f.readline().strip()
283
+ etag = f.readline().strip()
284
+ timestamp = float(f.readline().strip())
285
+ metadata = LocalDownloadFileMetadata(
286
+ filename=filename,
287
+ commit_hash=commit_hash,
288
+ etag=etag,
289
+ timestamp=timestamp,
290
+ )
291
+ except Exception as e:
292
+ # remove the metadata file if it is corrupted / not the right format
293
+ logger.warning(
294
+ f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue."
295
+ )
296
+ try:
297
+ paths.metadata_path.unlink()
298
+ except Exception as e:
299
+ logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
300
+
301
+ try:
302
+ # check if the file exists and hasn't been modified since the metadata was saved
303
+ stat = paths.file_path.stat()
304
+ if (
305
+ stat.st_mtime - 1 <= metadata.timestamp
306
+ ): # allow 1s difference as stat.st_mtime might not be precise
307
+ return metadata
308
+ logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.")
309
+ except FileNotFoundError:
310
+ # file does not exist => metadata is outdated
311
+ return None
312
+ return None
313
+
314
+
315
+ def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetadata:
316
+ """Read metadata about a file in the local directory related to an upload process.
317
+
318
+ TODO: factorize logic with `read_download_metadata`.
319
+
320
+ Args:
321
+ local_dir (`Path`):
322
+ Path to the local directory in which files are downloaded.
323
+ filename (`str`):
324
+ Path of the file in the repo.
325
+
326
+ Return:
327
+ `[LocalUploadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise.
328
+ """
329
+ paths = get_local_upload_paths(local_dir, filename)
330
+ with WeakFileLock(paths.lock_path):
331
+ if paths.metadata_path.exists():
332
+ try:
333
+ with paths.metadata_path.open() as f:
334
+ timestamp = float(f.readline().strip())
335
+
336
+ size = int(f.readline().strip()) # never None
337
+
338
+ _should_ignore = f.readline().strip()
339
+ should_ignore = None if _should_ignore == "" else bool(int(_should_ignore))
340
+
341
+ _sha256 = f.readline().strip()
342
+ sha256 = None if _sha256 == "" else _sha256
343
+
344
+ _upload_mode = f.readline().strip()
345
+ upload_mode = None if _upload_mode == "" else _upload_mode
346
+ if upload_mode not in (None, "regular", "lfs"):
347
+ raise ValueError(f"Invalid upload mode in metadata {paths.path_in_repo}: {upload_mode}")
348
+
349
+ is_uploaded = bool(int(f.readline().strip()))
350
+ is_committed = bool(int(f.readline().strip()))
351
+
352
+ metadata = LocalUploadFileMetadata(
353
+ timestamp=timestamp,
354
+ size=size,
355
+ should_ignore=should_ignore,
356
+ sha256=sha256,
357
+ upload_mode=upload_mode,
358
+ is_uploaded=is_uploaded,
359
+ is_committed=is_committed,
360
+ )
361
+ except Exception as e:
362
+ # remove the metadata file if it is corrupted / not the right format
363
+ logger.warning(
364
+ f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue."
365
+ )
366
+ try:
367
+ paths.metadata_path.unlink()
368
+ except Exception as e:
369
+ logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
370
+
371
+ # TODO: can we do better?
372
+ if (
373
+ metadata.timestamp is not None
374
+ and metadata.is_uploaded # file was uploaded
375
+ and not metadata.is_committed # but not committed
376
+ and time.time() - metadata.timestamp > 20 * 3600 # and it's been more than 20 hours
377
+ ): # => we consider it as garbage-collected by S3
378
+ metadata.is_uploaded = False
379
+
380
+ # check if the file exists and hasn't been modified since the metadata was saved
381
+ try:
382
+ if metadata.timestamp is not None and paths.file_path.stat().st_mtime <= metadata.timestamp:
383
+ return metadata
384
+ logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.")
385
+ except FileNotFoundError:
386
+ # file does not exist => metadata is outdated
387
+ pass
388
+
389
+ # empty metadata => we don't know anything expect its size
390
+ return LocalUploadFileMetadata(size=paths.file_path.stat().st_size)
391
+
392
+
393
+ def write_download_metadata(local_dir: Path, filename: str, commit_hash: str, etag: str) -> None:
394
+ """Write metadata about a file in the local directory related to a download process.
395
+
396
+ Args:
397
+ local_dir (`Path`):
398
+ Path to the local directory in which files are downloaded.
399
+ """
400
+ paths = get_local_download_paths(local_dir, filename)
401
+ with WeakFileLock(paths.lock_path):
402
+ with paths.metadata_path.open("w") as f:
403
+ f.write(f"{commit_hash}\n{etag}\n{time.time()}\n")
404
+
405
+
406
+ def _huggingface_dir(local_dir: Path) -> Path:
407
+ """Return the path to the `.cache/huggingface` directory in a local directory."""
408
+ # Wrap in lru_cache to avoid overwriting the .gitignore file if called multiple times
409
+ path = local_dir / ".cache" / "huggingface"
410
+ path.mkdir(exist_ok=True, parents=True)
411
+
412
+ # Create a .gitignore file in the .cache/huggingface directory if it doesn't exist
413
+ # Should be thread-safe enough like this.
414
+ gitignore = path / ".gitignore"
415
+ gitignore_lock = path / ".gitignore.lock"
416
+ if not gitignore.exists():
417
+ try:
418
+ with WeakFileLock(gitignore_lock, timeout=0.1):
419
+ gitignore.write_text("*")
420
+ except IndexError:
421
+ pass
422
+ except OSError: # TimeoutError, FileNotFoundError, PermissionError, etc.
423
+ pass
424
+ try:
425
+ gitignore_lock.unlink()
426
+ except OSError:
427
+ pass
428
+ return path
429
+
430
+
431
+ def _short_hash(filename: str) -> str:
432
+ return base64.urlsafe_b64encode(hashlib.sha1(filename.encode()).digest()).decode()
.venv/lib/python3.11/site-packages/huggingface_hub/_login.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains methods to log in to the Hub."""
15
+
16
+ import os
17
+ import subprocess
18
+ from getpass import getpass
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ from . import constants
23
+ from .commands._cli_utils import ANSI
24
+ from .utils import (
25
+ capture_output,
26
+ get_token,
27
+ is_google_colab,
28
+ is_notebook,
29
+ list_credential_helpers,
30
+ logging,
31
+ run_subprocess,
32
+ set_git_credential,
33
+ unset_git_credential,
34
+ )
35
+ from .utils._auth import (
36
+ _get_token_by_name,
37
+ _get_token_from_environment,
38
+ _get_token_from_file,
39
+ _get_token_from_google_colab,
40
+ _save_stored_tokens,
41
+ _save_token,
42
+ get_stored_tokens,
43
+ )
44
+ from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _HF_LOGO_ASCII = """
50
+ _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
51
+ _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
52
+ _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
53
+ _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
54
+ _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
55
+ """
56
+
57
+
58
+ @_deprecate_arguments(
59
+ version="1.0",
60
+ deprecated_args="write_permission",
61
+ custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.",
62
+ )
63
+ @_deprecate_positional_args(version="1.0")
64
+ def login(
65
+ token: Optional[str] = None,
66
+ *,
67
+ add_to_git_credential: bool = False,
68
+ new_session: bool = True,
69
+ write_permission: bool = False,
70
+ ) -> None:
71
+ """Login the machine to access the Hub.
72
+
73
+ The `token` is persisted in cache and set as a git credential. Once done, the machine
74
+ is logged in and the access token will be available across all `huggingface_hub`
75
+ components. If `token` is not provided, it will be prompted to the user either with
76
+ a widget (in a notebook) or via the terminal.
77
+
78
+ To log in from outside of a script, one can also use `huggingface-cli login` which is
79
+ a cli command that wraps [`login`].
80
+
81
+ <Tip>
82
+
83
+ [`login`] is a drop-in replacement method for [`notebook_login`] as it wraps and
84
+ extends its capabilities.
85
+
86
+ </Tip>
87
+
88
+ <Tip>
89
+
90
+ When the token is not passed, [`login`] will automatically detect if the script runs
91
+ in a notebook or not. However, this detection might not be accurate due to the
92
+ variety of notebooks that exists nowadays. If that is the case, you can always force
93
+ the UI by using [`notebook_login`] or [`interpreter_login`].
94
+
95
+ </Tip>
96
+
97
+ Args:
98
+ token (`str`, *optional*):
99
+ User access token to generate from https://huggingface.co/settings/token.
100
+ add_to_git_credential (`bool`, defaults to `False`):
101
+ If `True`, token will be set as git credential. If no git credential helper
102
+ is configured, a warning will be displayed to the user. If `token` is `None`,
103
+ the value of `add_to_git_credential` is ignored and will be prompted again
104
+ to the end user.
105
+ new_session (`bool`, defaults to `True`):
106
+ If `True`, will request a token even if one is already saved on the machine.
107
+ write_permission (`bool`):
108
+ Ignored and deprecated argument.
109
+ Raises:
110
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
111
+ If an organization token is passed. Only personal account tokens are valid
112
+ to log in.
113
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
114
+ If token is invalid.
115
+ [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
116
+ If running in a notebook but `ipywidgets` is not installed.
117
+ """
118
+ if token is not None:
119
+ if not add_to_git_credential:
120
+ logger.info(
121
+ "The token has not been saved to the git credentials helper. Pass "
122
+ "`add_to_git_credential=True` in this function directly or "
123
+ "`--add-to-git-credential` if using via `huggingface-cli` if "
124
+ "you want to set the git credential as well."
125
+ )
126
+ _login(token, add_to_git_credential=add_to_git_credential)
127
+ elif is_notebook():
128
+ notebook_login(new_session=new_session)
129
+ else:
130
+ interpreter_login(new_session=new_session)
131
+
132
+
133
+ def logout(token_name: Optional[str] = None) -> None:
134
+ """Logout the machine from the Hub.
135
+
136
+ Token is deleted from the machine and removed from git credential.
137
+
138
+ Args:
139
+ token_name (`str`, *optional*):
140
+ Name of the access token to logout from. If `None`, will logout from all saved access tokens.
141
+ Raises:
142
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError):
143
+ If the access token name is not found.
144
+ """
145
+ if get_token() is None and not get_stored_tokens(): # No active token and no saved access tokens
146
+ logger.warning("Not logged in!")
147
+ return
148
+ if not token_name:
149
+ # Delete all saved access tokens and token
150
+ for file_path in (constants.HF_TOKEN_PATH, constants.HF_STORED_TOKENS_PATH):
151
+ try:
152
+ Path(file_path).unlink()
153
+ except FileNotFoundError:
154
+ pass
155
+ logger.info("Successfully logged out from all access tokens.")
156
+ else:
157
+ _logout_from_token(token_name)
158
+ logger.info(f"Successfully logged out from access token: {token_name}.")
159
+
160
+ unset_git_credential()
161
+
162
+ # Check if still logged in
163
+ if _get_token_from_google_colab() is not None:
164
+ raise EnvironmentError(
165
+ "You are automatically logged in using a Google Colab secret.\n"
166
+ "To log out, you must unset the `HF_TOKEN` secret in your Colab settings."
167
+ )
168
+ if _get_token_from_environment() is not None:
169
+ raise EnvironmentError(
170
+ "Token has been deleted from your machine but you are still logged in.\n"
171
+ "To log out, you must clear out both `HF_TOKEN` and `HUGGING_FACE_HUB_TOKEN` environment variables."
172
+ )
173
+
174
+
175
+ def auth_switch(token_name: str, add_to_git_credential: bool = False) -> None:
176
+ """Switch to a different access token.
177
+
178
+ Args:
179
+ token_name (`str`):
180
+ Name of the access token to switch to.
181
+ add_to_git_credential (`bool`, defaults to `False`):
182
+ If `True`, token will be set as git credential. If no git credential helper
183
+ is configured, a warning will be displayed to the user. If `token` is `None`,
184
+ the value of `add_to_git_credential` is ignored and will be prompted again
185
+ to the end user.
186
+
187
+ Raises:
188
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError):
189
+ If the access token name is not found.
190
+ """
191
+ token = _get_token_by_name(token_name)
192
+ if not token:
193
+ raise ValueError(f"Access token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}")
194
+ # Write token to HF_TOKEN_PATH
195
+ _set_active_token(token_name, add_to_git_credential)
196
+ logger.info(f"The current active token is: {token_name}")
197
+ token_from_environment = _get_token_from_environment()
198
+ if token_from_environment is not None and token_from_environment != token:
199
+ logger.warning(
200
+ "The environment variable `HF_TOKEN` is set and will override the access token you've just switched to."
201
+ )
202
+
203
+
204
+ def auth_list() -> None:
205
+ """List all stored access tokens."""
206
+ tokens = get_stored_tokens()
207
+
208
+ if not tokens:
209
+ logger.info("No access tokens found.")
210
+ return
211
+ # Find current token
212
+ current_token = get_token()
213
+ current_token_name = None
214
+ for token_name in tokens:
215
+ if tokens.get(token_name) == current_token:
216
+ current_token_name = token_name
217
+ # Print header
218
+ max_offset = max(len("token"), max(len(token) for token in tokens)) + 2
219
+ print(f" {{:<{max_offset}}}| {{:<15}}".format("name", "token"))
220
+ print("-" * (max_offset + 2) + "|" + "-" * 15)
221
+
222
+ # Print saved access tokens
223
+ for token_name in tokens:
224
+ token = tokens.get(token_name, "<not set>")
225
+ masked_token = f"{token[:3]}****{token[-4:]}" if token != "<not set>" else token
226
+ is_current = "*" if token == current_token else " "
227
+
228
+ print(f"{is_current} {{:<{max_offset}}}| {{:<15}}".format(token_name, masked_token))
229
+
230
+ if _get_token_from_environment():
231
+ logger.warning(
232
+ "\nNote: Environment variable `HF_TOKEN` is set and is the current active token independently from the stored tokens listed above."
233
+ )
234
+ elif current_token_name is None:
235
+ logger.warning(
236
+ "\nNote: No active token is set and no environment variable `HF_TOKEN` is found. Use `huggingface-cli login` to log in."
237
+ )
238
+
239
+
240
+ ###
241
+ # Interpreter-based login (text)
242
+ ###
243
+
244
+
245
+ @_deprecate_arguments(
246
+ version="1.0",
247
+ deprecated_args="write_permission",
248
+ custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.",
249
+ )
250
+ @_deprecate_positional_args(version="1.0")
251
+ def interpreter_login(*, new_session: bool = True, write_permission: bool = False) -> None:
252
+ """
253
+ Displays a prompt to log in to the HF website and store the token.
254
+
255
+ This is equivalent to [`login`] without passing a token when not run in a notebook.
256
+ [`interpreter_login`] is useful if you want to force the use of the terminal prompt
257
+ instead of a notebook widget.
258
+
259
+ For more details, see [`login`].
260
+
261
+ Args:
262
+ new_session (`bool`, defaults to `True`):
263
+ If `True`, will request a token even if one is already saved on the machine.
264
+ write_permission (`bool`):
265
+ Ignored and deprecated argument.
266
+ """
267
+ if not new_session and get_token() is not None:
268
+ logger.info("User is already logged in.")
269
+ return
270
+
271
+ from .commands.delete_cache import _ask_for_confirmation_no_tui
272
+
273
+ print(_HF_LOGO_ASCII)
274
+ if get_token() is not None:
275
+ logger.info(
276
+ " A token is already saved on your machine. Run `huggingface-cli"
277
+ " whoami` to get more information or `huggingface-cli logout` if you want"
278
+ " to log out."
279
+ )
280
+ logger.info(" Setting a new token will erase the existing one.")
281
+
282
+ logger.info(
283
+ " To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens ."
284
+ )
285
+ if os.name == "nt":
286
+ logger.info("Token can be pasted using 'Right-Click'.")
287
+ token = getpass("Enter your token (input will not be visible): ")
288
+ add_to_git_credential = _ask_for_confirmation_no_tui("Add token as git credential?")
289
+
290
+ _login(token=token, add_to_git_credential=add_to_git_credential)
291
+
292
+
293
+ ###
294
+ # Notebook-based login (widget)
295
+ ###
296
+
297
+ NOTEBOOK_LOGIN_PASSWORD_HTML = """<center> <img
298
+ src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
299
+ alt='Hugging Face'> <br> Immediately click login after typing your password or
300
+ it might be stored in plain text in this notebook file. </center>"""
301
+
302
+
303
+ NOTEBOOK_LOGIN_TOKEN_HTML_START = """<center> <img
304
+ src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
305
+ alt='Hugging Face'> <br> Copy a token from <a
306
+ href="https://huggingface.co/settings/tokens" target="_blank">your Hugging Face
307
+ tokens page</a> and paste it below. <br> Immediately click login after copying
308
+ your token or it might be stored in plain text in this notebook file. </center>"""
309
+
310
+
311
+ NOTEBOOK_LOGIN_TOKEN_HTML_END = """
312
+ <b>Pro Tip:</b> If you don't already have one, you can create a dedicated
313
+ 'notebooks' token with 'write' access, that you can then easily reuse for all
314
+ notebooks. </center>"""
315
+
316
+
317
+ @_deprecate_arguments(
318
+ version="1.0",
319
+ deprecated_args="write_permission",
320
+ custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.",
321
+ )
322
+ @_deprecate_positional_args(version="1.0")
323
+ def notebook_login(*, new_session: bool = True, write_permission: bool = False) -> None:
324
+ """
325
+ Displays a widget to log in to the HF website and store the token.
326
+
327
+ This is equivalent to [`login`] without passing a token when run in a notebook.
328
+ [`notebook_login`] is useful if you want to force the use of the notebook widget
329
+ instead of a prompt in the terminal.
330
+
331
+ For more details, see [`login`].
332
+
333
+ Args:
334
+ new_session (`bool`, defaults to `True`):
335
+ If `True`, will request a token even if one is already saved on the machine.
336
+ write_permission (`bool`):
337
+ Ignored and deprecated argument.
338
+ """
339
+ try:
340
+ import ipywidgets.widgets as widgets # type: ignore
341
+ from IPython.display import display # type: ignore
342
+ except ImportError:
343
+ raise ImportError(
344
+ "The `notebook_login` function can only be used in a notebook (Jupyter or"
345
+ " Colab) and you need the `ipywidgets` module: `pip install ipywidgets`."
346
+ )
347
+ if not new_session and get_token() is not None:
348
+ logger.info("User is already logged in.")
349
+ return
350
+
351
+ box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%")
352
+
353
+ token_widget = widgets.Password(description="Token:")
354
+ git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?")
355
+ token_finish_button = widgets.Button(description="Login")
356
+
357
+ login_token_widget = widgets.VBox(
358
+ [
359
+ widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_START),
360
+ token_widget,
361
+ git_checkbox_widget,
362
+ token_finish_button,
363
+ widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_END),
364
+ ],
365
+ layout=box_layout,
366
+ )
367
+ display(login_token_widget)
368
+
369
+ # On click events
370
+ def login_token_event(t):
371
+ """Event handler for the login button."""
372
+ token = token_widget.value
373
+ add_to_git_credential = git_checkbox_widget.value
374
+ # Erase token and clear value to make sure it's not saved in the notebook.
375
+ token_widget.value = ""
376
+ # Hide inputs
377
+ login_token_widget.children = [widgets.Label("Connecting...")]
378
+ try:
379
+ with capture_output() as captured:
380
+ _login(token, add_to_git_credential=add_to_git_credential)
381
+ message = captured.getvalue()
382
+ except Exception as error:
383
+ message = str(error)
384
+ # Print result (success message or error)
385
+ login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()]
386
+
387
+ token_finish_button.on_click(login_token_event)
388
+
389
+
390
+ ###
391
+ # Login private helpers
392
+ ###
393
+
394
+
395
+ def _login(
396
+ token: str,
397
+ add_to_git_credential: bool,
398
+ ) -> None:
399
+ from .hf_api import whoami # avoid circular import
400
+
401
+ if token.startswith("api_org"):
402
+ raise ValueError("You must use your personal account token, not an organization token.")
403
+
404
+ token_info = whoami(token)
405
+ permission = token_info["auth"]["accessToken"]["role"]
406
+ logger.info(f"Token is valid (permission: {permission}).")
407
+
408
+ token_name = token_info["auth"]["accessToken"]["displayName"]
409
+ # Store token locally
410
+ _save_token(token=token, token_name=token_name)
411
+ # Set active token
412
+ _set_active_token(token_name=token_name, add_to_git_credential=add_to_git_credential)
413
+ logger.info("Login successful.")
414
+ if _get_token_from_environment():
415
+ logger.warning(
416
+ "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured."
417
+ )
418
+ else:
419
+ logger.info(f"The current active token is: `{token_name}`")
420
+
421
+
422
+ def _logout_from_token(token_name: str) -> None:
423
+ """Logout from a specific access token.
424
+
425
+ Args:
426
+ token_name (`str`):
427
+ The name of the access token to logout from.
428
+ Raises:
429
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError):
430
+ If the access token name is not found.
431
+ """
432
+ stored_tokens = get_stored_tokens()
433
+ # If there is no access tokens saved or the access token name is not found, do nothing
434
+ if not stored_tokens or token_name not in stored_tokens:
435
+ return
436
+
437
+ token = stored_tokens.pop(token_name)
438
+ _save_stored_tokens(stored_tokens)
439
+
440
+ if token == _get_token_from_file():
441
+ logger.warning(f"Active token '{token_name}' has been deleted.")
442
+ Path(constants.HF_TOKEN_PATH).unlink(missing_ok=True)
443
+
444
+
445
+ def _set_active_token(
446
+ token_name: str,
447
+ add_to_git_credential: bool,
448
+ ) -> None:
449
+ """Set the active access token.
450
+
451
+ Args:
452
+ token_name (`str`):
453
+ The name of the token to set as active.
454
+ """
455
+ token = _get_token_by_name(token_name)
456
+ if not token:
457
+ raise ValueError(f"Token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}")
458
+ if add_to_git_credential:
459
+ if _is_git_credential_helper_configured():
460
+ set_git_credential(token)
461
+ logger.info(
462
+ "Your token has been saved in your configured git credential helpers"
463
+ + f" ({','.join(list_credential_helpers())})."
464
+ )
465
+ else:
466
+ logger.warning("Token has not been saved to git credential helper.")
467
+ # Write token to HF_TOKEN_PATH
468
+ path = Path(constants.HF_TOKEN_PATH)
469
+ path.parent.mkdir(parents=True, exist_ok=True)
470
+ path.write_text(token)
471
+ logger.info(f"Your token has been saved to {constants.HF_TOKEN_PATH}")
472
+
473
+
474
+ def _is_git_credential_helper_configured() -> bool:
475
+ """Check if a git credential helper is configured.
476
+
477
+ Warns user if not the case (except for Google Colab where "store" is set by default
478
+ by `huggingface_hub`).
479
+ """
480
+ helpers = list_credential_helpers()
481
+ if len(helpers) > 0:
482
+ return True # Do not warn: at least 1 helper is set
483
+
484
+ # Only in Google Colab to avoid the warning message
485
+ # See https://github.com/huggingface/huggingface_hub/issues/1043#issuecomment-1247010710
486
+ if is_google_colab():
487
+ _set_store_as_git_credential_helper_globally()
488
+ return True # Do not warn: "store" is used by default in Google Colab
489
+
490
+ # Otherwise, warn user
491
+ print(
492
+ ANSI.red(
493
+ "Cannot authenticate through git-credential as no helper is defined on your"
494
+ " machine.\nYou might have to re-authenticate when pushing to the Hugging"
495
+ " Face Hub.\nRun the following command in your terminal in case you want to"
496
+ " set the 'store' credential helper as default.\n\ngit config --global"
497
+ " credential.helper store\n\nRead"
498
+ " https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more"
499
+ " details."
500
+ )
501
+ )
502
+ return False
503
+
504
+
505
+ def _set_store_as_git_credential_helper_globally() -> None:
506
+ """Set globally the credential.helper to `store`.
507
+
508
+ To be used only in Google Colab as we assume the user doesn't care about the git
509
+ credential config. It is the only particular case where we don't want to display the
510
+ warning message in [`notebook_login()`].
511
+
512
+ Related:
513
+ - https://github.com/huggingface/huggingface_hub/issues/1043
514
+ - https://github.com/huggingface/huggingface_hub/issues/1051
515
+ - https://git-scm.com/docs/git-credential-store
516
+ """
517
+ try:
518
+ run_subprocess("git config --global credential.helper store")
519
+ except subprocess.CalledProcessError as exc:
520
+ raise EnvironmentError(exc.stderr)
.venv/lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Dict, List, Literal, Optional, Union
4
+
5
+ import requests
6
+ from tqdm.auto import tqdm as base_tqdm
7
+ from tqdm.contrib.concurrent import thread_map
8
+
9
+ from . import constants
10
+ from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
11
+ from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
12
+ from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
13
+ from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
14
+ from .utils import tqdm as hf_tqdm
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ @validate_hf_hub_args
21
+ def snapshot_download(
22
+ repo_id: str,
23
+ *,
24
+ repo_type: Optional[str] = None,
25
+ revision: Optional[str] = None,
26
+ cache_dir: Union[str, Path, None] = None,
27
+ local_dir: Union[str, Path, None] = None,
28
+ library_name: Optional[str] = None,
29
+ library_version: Optional[str] = None,
30
+ user_agent: Optional[Union[Dict, str]] = None,
31
+ proxies: Optional[Dict] = None,
32
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
33
+ force_download: bool = False,
34
+ token: Optional[Union[bool, str]] = None,
35
+ local_files_only: bool = False,
36
+ allow_patterns: Optional[Union[List[str], str]] = None,
37
+ ignore_patterns: Optional[Union[List[str], str]] = None,
38
+ max_workers: int = 8,
39
+ tqdm_class: Optional[base_tqdm] = None,
40
+ headers: Optional[Dict[str, str]] = None,
41
+ endpoint: Optional[str] = None,
42
+ # Deprecated args
43
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
44
+ resume_download: Optional[bool] = None,
45
+ ) -> str:
46
+ """Download repo files.
47
+
48
+ Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
49
+ a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
50
+ to keep their actual filename relative to that folder. You can also filter which files to download using
51
+ `allow_patterns` and `ignore_patterns`.
52
+
53
+ If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
54
+ option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
55
+ to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
56
+ cache-system, it's optimized for regularly pulling the latest version of a repository.
57
+
58
+ An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
59
+ configured. It is also not possible to filter which files to download when cloning a repository using git.
60
+
61
+ Args:
62
+ repo_id (`str`):
63
+ A user or an organization name and a repo name separated by a `/`.
64
+ repo_type (`str`, *optional*):
65
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
66
+ `None` or `"model"` if downloading from a model. Default is `None`.
67
+ revision (`str`, *optional*):
68
+ An optional Git revision id which can be a branch name, a tag, or a
69
+ commit hash.
70
+ cache_dir (`str`, `Path`, *optional*):
71
+ Path to the folder where cached files are stored.
72
+ local_dir (`str` or `Path`, *optional*):
73
+ If provided, the downloaded files will be placed under this directory.
74
+ library_name (`str`, *optional*):
75
+ The name of the library to which the object corresponds.
76
+ library_version (`str`, *optional*):
77
+ The version of the library.
78
+ user_agent (`str`, `dict`, *optional*):
79
+ The user-agent info in the form of a dictionary or a string.
80
+ proxies (`dict`, *optional*):
81
+ Dictionary mapping protocol to the URL of the proxy passed to
82
+ `requests.request`.
83
+ etag_timeout (`float`, *optional*, defaults to `10`):
84
+ When fetching ETag, how many seconds to wait for the server to send
85
+ data before giving up which is passed to `requests.request`.
86
+ force_download (`bool`, *optional*, defaults to `False`):
87
+ Whether the file should be downloaded even if it already exists in the local cache.
88
+ token (`str`, `bool`, *optional*):
89
+ A token to be used for the download.
90
+ - If `True`, the token is read from the HuggingFace config
91
+ folder.
92
+ - If a string, it's used as the authentication token.
93
+ headers (`dict`, *optional*):
94
+ Additional headers to include in the request. Those headers take precedence over the others.
95
+ local_files_only (`bool`, *optional*, defaults to `False`):
96
+ If `True`, avoid downloading the file and return the path to the
97
+ local cached file if it exists.
98
+ allow_patterns (`List[str]` or `str`, *optional*):
99
+ If provided, only files matching at least one pattern are downloaded.
100
+ ignore_patterns (`List[str]` or `str`, *optional*):
101
+ If provided, files matching any of the patterns are not downloaded.
102
+ max_workers (`int`, *optional*):
103
+ Number of concurrent threads to download files (1 thread = 1 file download).
104
+ Defaults to 8.
105
+ tqdm_class (`tqdm`, *optional*):
106
+ If provided, overwrites the default behavior for the progress bar. Passed
107
+ argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
108
+ Note that the `tqdm_class` is not passed to each individual download.
109
+ Defaults to the custom HF progress bar that can be disabled by setting
110
+ `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
111
+
112
+ Returns:
113
+ `str`: folder path of the repo snapshot.
114
+
115
+ Raises:
116
+ [`~utils.RepositoryNotFoundError`]
117
+ If the repository to download from cannot be found. This may be because it doesn't exist,
118
+ or because it is set to `private` and you do not have access.
119
+ [`~utils.RevisionNotFoundError`]
120
+ If the revision to download from cannot be found.
121
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
122
+ If `token=True` and the token cannot be found.
123
+ [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
124
+ ETag cannot be determined.
125
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
126
+ if some parameter value is invalid.
127
+ """
128
+ if cache_dir is None:
129
+ cache_dir = constants.HF_HUB_CACHE
130
+ if revision is None:
131
+ revision = constants.DEFAULT_REVISION
132
+ if isinstance(cache_dir, Path):
133
+ cache_dir = str(cache_dir)
134
+
135
+ if repo_type is None:
136
+ repo_type = "model"
137
+ if repo_type not in constants.REPO_TYPES:
138
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
139
+
140
+ storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
141
+
142
+ repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
143
+ api_call_error: Optional[Exception] = None
144
+ if not local_files_only:
145
+ # try/except logic to handle different errors => taken from `hf_hub_download`
146
+ try:
147
+ # if we have internet connection we want to list files to download
148
+ api = HfApi(
149
+ library_name=library_name,
150
+ library_version=library_version,
151
+ user_agent=user_agent,
152
+ endpoint=endpoint,
153
+ headers=headers,
154
+ )
155
+ repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
156
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
157
+ # Actually raise for those subclasses of ConnectionError
158
+ raise
159
+ except (
160
+ requests.exceptions.ConnectionError,
161
+ requests.exceptions.Timeout,
162
+ OfflineModeIsEnabled,
163
+ ) as error:
164
+ # Internet connection is down
165
+ # => will try to use local files only
166
+ api_call_error = error
167
+ pass
168
+ except RevisionNotFoundError:
169
+ # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
170
+ raise
171
+ except requests.HTTPError as error:
172
+ # Multiple reasons for an http error:
173
+ # - Repository is private and invalid/missing token sent
174
+ # - Repository is gated and invalid/missing token sent
175
+ # - Hub is down (error 500 or 504)
176
+ # => let's switch to 'local_files_only=True' to check if the files are already cached.
177
+ # (if it's not the case, the error will be re-raised)
178
+ api_call_error = error
179
+ pass
180
+
181
+ # At this stage, if `repo_info` is None it means either:
182
+ # - internet connection is down
183
+ # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
184
+ # - repo is private/gated and invalid/missing token sent
185
+ # - Hub is down
186
+ # => let's look if we can find the appropriate folder in the cache:
187
+ # - if the specified revision is a commit hash, look inside "snapshots".
188
+ # - f the specified revision is a branch or tag, look inside "refs".
189
+ # => if local_dir is not None, we will return the path to the local folder if it exists.
190
+ if repo_info is None:
191
+ # Try to get which commit hash corresponds to the specified revision
192
+ commit_hash = None
193
+ if REGEX_COMMIT_HASH.match(revision):
194
+ commit_hash = revision
195
+ else:
196
+ ref_path = os.path.join(storage_folder, "refs", revision)
197
+ if os.path.exists(ref_path):
198
+ # retrieve commit_hash from refs file
199
+ with open(ref_path) as f:
200
+ commit_hash = f.read()
201
+
202
+ # Try to locate snapshot folder for this commit hash
203
+ if commit_hash is not None:
204
+ snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
205
+ if os.path.exists(snapshot_folder):
206
+ # Snapshot folder exists => let's return it
207
+ # (but we can't check if all the files are actually there)
208
+ return snapshot_folder
209
+ # If local_dir is not None, return it if it exists and is not empty
210
+ if local_dir is not None:
211
+ local_dir = Path(local_dir)
212
+ if local_dir.is_dir() and any(local_dir.iterdir()):
213
+ logger.warning(
214
+ f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
215
+ )
216
+ return str(local_dir.resolve())
217
+ # If we couldn't find the appropriate folder on disk, raise an error.
218
+ if local_files_only:
219
+ raise LocalEntryNotFoundError(
220
+ "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
221
+ "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
222
+ "'local_files_only=False' as input."
223
+ )
224
+ elif isinstance(api_call_error, OfflineModeIsEnabled):
225
+ raise LocalEntryNotFoundError(
226
+ "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
227
+ "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
228
+ "'HF_HUB_OFFLINE=0' as environment variable."
229
+ ) from api_call_error
230
+ elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
231
+ # Repo not found => let's raise the actual error
232
+ raise api_call_error
233
+ else:
234
+ # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
235
+ raise LocalEntryNotFoundError(
236
+ "An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
237
+ " snapshot folder for the specified revision on the local disk. Please check your internet connection"
238
+ " and try again."
239
+ ) from api_call_error
240
+
241
+ # At this stage, internet connection is up and running
242
+ # => let's download the files!
243
+ assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
244
+ assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
245
+ filtered_repo_files = list(
246
+ filter_repo_objects(
247
+ items=[f.rfilename for f in repo_info.siblings],
248
+ allow_patterns=allow_patterns,
249
+ ignore_patterns=ignore_patterns,
250
+ )
251
+ )
252
+ commit_hash = repo_info.sha
253
+ snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
254
+ # if passed revision is not identical to commit_hash
255
+ # then revision has to be a branch name or tag name.
256
+ # In that case store a ref.
257
+ if revision != commit_hash:
258
+ ref_path = os.path.join(storage_folder, "refs", revision)
259
+ try:
260
+ os.makedirs(os.path.dirname(ref_path), exist_ok=True)
261
+ with open(ref_path, "w") as f:
262
+ f.write(commit_hash)
263
+ except OSError as e:
264
+ logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
265
+
266
+ # we pass the commit_hash to hf_hub_download
267
+ # so no network call happens if we already
268
+ # have the file locally.
269
+ def _inner_hf_hub_download(repo_file: str):
270
+ return hf_hub_download(
271
+ repo_id,
272
+ filename=repo_file,
273
+ repo_type=repo_type,
274
+ revision=commit_hash,
275
+ endpoint=endpoint,
276
+ cache_dir=cache_dir,
277
+ local_dir=local_dir,
278
+ local_dir_use_symlinks=local_dir_use_symlinks,
279
+ library_name=library_name,
280
+ library_version=library_version,
281
+ user_agent=user_agent,
282
+ proxies=proxies,
283
+ etag_timeout=etag_timeout,
284
+ resume_download=resume_download,
285
+ force_download=force_download,
286
+ token=token,
287
+ headers=headers,
288
+ )
289
+
290
+ if constants.HF_HUB_ENABLE_HF_TRANSFER:
291
+ # when using hf_transfer we don't want extra parallelism
292
+ # from the one hf_transfer provides
293
+ for file in filtered_repo_files:
294
+ _inner_hf_hub_download(file)
295
+ else:
296
+ thread_map(
297
+ _inner_hf_hub_download,
298
+ filtered_repo_files,
299
+ desc=f"Fetching {len(filtered_repo_files)} files",
300
+ max_workers=max_workers,
301
+ # User can use its own tqdm class or the default one from `huggingface_hub.utils`
302
+ tqdm_class=tqdm_class or hf_tqdm,
303
+ )
304
+
305
+ if local_dir is not None:
306
+ return str(os.path.realpath(local_dir))
307
+ return snapshot_folder
.venv/lib/python3.11/site-packages/huggingface_hub/_space_api.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from datetime import datetime
17
+ from enum import Enum
18
+ from typing import Dict, Optional
19
+
20
+ from huggingface_hub.utils import parse_datetime
21
+
22
+
23
+ class SpaceStage(str, Enum):
24
+ """
25
+ Enumeration of possible stage of a Space on the Hub.
26
+
27
+ Value can be compared to a string:
28
+ ```py
29
+ assert SpaceStage.BUILDING == "BUILDING"
30
+ ```
31
+
32
+ Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L61 (private url).
33
+ """
34
+
35
+ # Copied from moon-landing > server > repo_types > SpaceInfo.ts (private repo)
36
+ NO_APP_FILE = "NO_APP_FILE"
37
+ CONFIG_ERROR = "CONFIG_ERROR"
38
+ BUILDING = "BUILDING"
39
+ BUILD_ERROR = "BUILD_ERROR"
40
+ RUNNING = "RUNNING"
41
+ RUNNING_BUILDING = "RUNNING_BUILDING"
42
+ RUNTIME_ERROR = "RUNTIME_ERROR"
43
+ DELETING = "DELETING"
44
+ STOPPED = "STOPPED"
45
+ PAUSED = "PAUSED"
46
+
47
+
48
+ class SpaceHardware(str, Enum):
49
+ """
50
+ Enumeration of hardwares available to run your Space on the Hub.
51
+
52
+ Value can be compared to a string:
53
+ ```py
54
+ assert SpaceHardware.CPU_BASIC == "cpu-basic"
55
+ ```
56
+
57
+ Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L73 (private url).
58
+ """
59
+
60
+ CPU_BASIC = "cpu-basic"
61
+ CPU_UPGRADE = "cpu-upgrade"
62
+ T4_SMALL = "t4-small"
63
+ T4_MEDIUM = "t4-medium"
64
+ L4X1 = "l4x1"
65
+ L4X4 = "l4x4"
66
+ ZERO_A10G = "zero-a10g"
67
+ A10G_SMALL = "a10g-small"
68
+ A10G_LARGE = "a10g-large"
69
+ A10G_LARGEX2 = "a10g-largex2"
70
+ A10G_LARGEX4 = "a10g-largex4"
71
+ A100_LARGE = "a100-large"
72
+ V5E_1X1 = "v5e-1x1"
73
+ V5E_2X2 = "v5e-2x2"
74
+ V5E_2X4 = "v5e-2x4"
75
+
76
+
77
+ class SpaceStorage(str, Enum):
78
+ """
79
+ Enumeration of persistent storage available for your Space on the Hub.
80
+
81
+ Value can be compared to a string:
82
+ ```py
83
+ assert SpaceStorage.SMALL == "small"
84
+ ```
85
+
86
+ Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts#L24 (private url).
87
+ """
88
+
89
+ SMALL = "small"
90
+ MEDIUM = "medium"
91
+ LARGE = "large"
92
+
93
+
94
+ @dataclass
95
+ class SpaceRuntime:
96
+ """
97
+ Contains information about the current runtime of a Space.
98
+
99
+ Args:
100
+ stage (`str`):
101
+ Current stage of the space. Example: RUNNING.
102
+ hardware (`str` or `None`):
103
+ Current hardware of the space. Example: "cpu-basic". Can be `None` if Space
104
+ is `BUILDING` for the first time.
105
+ requested_hardware (`str` or `None`):
106
+ Requested hardware. Can be different than `hardware` especially if the request
107
+ has just been made. Example: "t4-medium". Can be `None` if no hardware has
108
+ been requested yet.
109
+ sleep_time (`int` or `None`):
110
+ Number of seconds the Space will be kept alive after the last request. By default (if value is `None`), the
111
+ Space will never go to sleep if it's running on an upgraded hardware, while it will go to sleep after 48
112
+ hours on a free 'cpu-basic' hardware. For more details, see https://huggingface.co/docs/hub/spaces-gpus#sleep-time.
113
+ raw (`dict`):
114
+ Raw response from the server. Contains more information about the Space
115
+ runtime like number of replicas, number of cpu, memory size,...
116
+ """
117
+
118
+ stage: SpaceStage
119
+ hardware: Optional[SpaceHardware]
120
+ requested_hardware: Optional[SpaceHardware]
121
+ sleep_time: Optional[int]
122
+ storage: Optional[SpaceStorage]
123
+ raw: Dict
124
+
125
+ def __init__(self, data: Dict) -> None:
126
+ self.stage = data["stage"]
127
+ self.hardware = data.get("hardware", {}).get("current")
128
+ self.requested_hardware = data.get("hardware", {}).get("requested")
129
+ self.sleep_time = data.get("gcTimeout")
130
+ self.storage = data.get("storage")
131
+ self.raw = data
132
+
133
+
134
+ @dataclass
135
+ class SpaceVariable:
136
+ """
137
+ Contains information about the current variables of a Space.
138
+
139
+ Args:
140
+ key (`str`):
141
+ Variable key. Example: `"MODEL_REPO_ID"`
142
+ value (`str`):
143
+ Variable value. Example: `"the_model_repo_id"`.
144
+ description (`str` or None):
145
+ Description of the variable. Example: `"Model Repo ID of the implemented model"`.
146
+ updatedAt (`datetime` or None):
147
+ datetime of the last update of the variable (if the variable has been updated at least once).
148
+ """
149
+
150
+ key: str
151
+ value: str
152
+ description: Optional[str]
153
+ updated_at: Optional[datetime]
154
+
155
+ def __init__(self, key: str, values: Dict) -> None:
156
+ self.key = key
157
+ self.value = values["value"]
158
+ self.description = values.get("description")
159
+ updated_at = values.get("updatedAt")
160
+ self.updated_at = parse_datetime(updated_at) if updated_at is not None else None
.venv/lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains a logger to push training logs to the Hub, using Tensorboard."""
15
+
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, List, Optional, Union
18
+
19
+ from ._commit_scheduler import CommitScheduler
20
+ from .errors import EntryNotFoundError
21
+ from .repocard import ModelCard
22
+ from .utils import experimental
23
+
24
+
25
+ # Depending on user's setup, SummaryWriter can come either from 'tensorboardX'
26
+ # or from 'torch.utils.tensorboard'. Both are compatible so let's try to load
27
+ # from either of them.
28
+ try:
29
+ from tensorboardX import SummaryWriter
30
+
31
+ is_summary_writer_available = True
32
+
33
+ except ImportError:
34
+ try:
35
+ from torch.utils.tensorboard import SummaryWriter
36
+
37
+ is_summary_writer_available = False
38
+ except ImportError:
39
+ # Dummy class to avoid failing at import. Will raise on instance creation.
40
+ SummaryWriter = object
41
+ is_summary_writer_available = False
42
+
43
+ if TYPE_CHECKING:
44
+ from tensorboardX import SummaryWriter
45
+
46
+
47
+ class HFSummaryWriter(SummaryWriter):
48
+ """
49
+ Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub.
50
+
51
+ Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate
52
+ thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection
53
+ issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
54
+ minutes (default to every 5 minutes).
55
+
56
+ <Tip warning={true}>
57
+
58
+ `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
59
+
60
+ </Tip>
61
+
62
+ Args:
63
+ repo_id (`str`):
64
+ The id of the repo to which the logs will be pushed.
65
+ logdir (`str`, *optional*):
66
+ The directory where the logs will be written. If not specified, a local directory will be created by the
67
+ underlying `SummaryWriter` object.
68
+ commit_every (`int` or `float`, *optional*):
69
+ The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes.
70
+ squash_history (`bool`, *optional*):
71
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
72
+ useful to avoid degraded performances on the repo when it grows too large.
73
+ repo_type (`str`, *optional*):
74
+ The type of the repo to which the logs will be pushed. Defaults to "model".
75
+ repo_revision (`str`, *optional*):
76
+ The revision of the repo to which the logs will be pushed. Defaults to "main".
77
+ repo_private (`bool`, *optional*):
78
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
79
+ path_in_repo (`str`, *optional*):
80
+ The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/".
81
+ repo_allow_patterns (`List[str]` or `str`, *optional*):
82
+ A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the
83
+ [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
84
+ repo_ignore_patterns (`List[str]` or `str`, *optional*):
85
+ A list of patterns to exclude in the upload. Check out the
86
+ [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
87
+ token (`str`, *optional*):
88
+ Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
89
+ details
90
+ kwargs:
91
+ Additional keyword arguments passed to `SummaryWriter`.
92
+
93
+ Examples:
94
+ ```diff
95
+ # Taken from https://pytorch.org/docs/stable/tensorboard.html
96
+ - from torch.utils.tensorboard import SummaryWriter
97
+ + from huggingface_hub import HFSummaryWriter
98
+
99
+ import numpy as np
100
+
101
+ - writer = SummaryWriter()
102
+ + writer = HFSummaryWriter(repo_id="username/my-trained-model")
103
+
104
+ for n_iter in range(100):
105
+ writer.add_scalar('Loss/train', np.random.random(), n_iter)
106
+ writer.add_scalar('Loss/test', np.random.random(), n_iter)
107
+ writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
108
+ writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
109
+ ```
110
+
111
+ ```py
112
+ >>> from huggingface_hub import HFSummaryWriter
113
+
114
+ # Logs are automatically pushed every 15 minutes (5 by default) + when exiting the context manager
115
+ >>> with HFSummaryWriter(repo_id="test_hf_logger", commit_every=15) as logger:
116
+ ... logger.add_scalar("a", 1)
117
+ ... logger.add_scalar("b", 2)
118
+ ```
119
+ """
120
+
121
+ @experimental
122
+ def __new__(cls, *args, **kwargs) -> "HFSummaryWriter":
123
+ if not is_summary_writer_available:
124
+ raise ImportError(
125
+ "You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade"
126
+ " tensorboardX` first."
127
+ )
128
+ return super().__new__(cls)
129
+
130
+ def __init__(
131
+ self,
132
+ repo_id: str,
133
+ *,
134
+ logdir: Optional[str] = None,
135
+ commit_every: Union[int, float] = 5,
136
+ squash_history: bool = False,
137
+ repo_type: Optional[str] = None,
138
+ repo_revision: Optional[str] = None,
139
+ repo_private: Optional[bool] = None,
140
+ path_in_repo: Optional[str] = "tensorboard",
141
+ repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*",
142
+ repo_ignore_patterns: Optional[Union[List[str], str]] = None,
143
+ token: Optional[str] = None,
144
+ **kwargs,
145
+ ):
146
+ # Initialize SummaryWriter
147
+ super().__init__(logdir=logdir, **kwargs)
148
+
149
+ # Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it.
150
+ if not isinstance(self.logdir, str):
151
+ raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.")
152
+
153
+ # Append logdir name to `path_in_repo`
154
+ if path_in_repo is None or path_in_repo == "":
155
+ path_in_repo = Path(self.logdir).name
156
+ else:
157
+ path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name
158
+
159
+ # Initialize scheduler
160
+ self.scheduler = CommitScheduler(
161
+ folder_path=self.logdir,
162
+ path_in_repo=path_in_repo,
163
+ repo_id=repo_id,
164
+ repo_type=repo_type,
165
+ revision=repo_revision,
166
+ private=repo_private,
167
+ token=token,
168
+ allow_patterns=repo_allow_patterns,
169
+ ignore_patterns=repo_ignore_patterns,
170
+ every=commit_every,
171
+ squash_history=squash_history,
172
+ )
173
+
174
+ # Exposing some high-level info at root level
175
+ self.repo_id = self.scheduler.repo_id
176
+ self.repo_type = self.scheduler.repo_type
177
+ self.repo_revision = self.scheduler.revision
178
+
179
+ # Add `hf-summary-writer` tag to the model card metadata
180
+ try:
181
+ card = ModelCard.load(repo_id_or_path=self.repo_id, repo_type=self.repo_type)
182
+ except EntryNotFoundError:
183
+ card = ModelCard("")
184
+ tags = card.data.get("tags", [])
185
+ if "hf-summary-writer" not in tags:
186
+ tags.append("hf-summary-writer")
187
+ card.data["tags"] = tags
188
+ card.push_to_hub(repo_id=self.repo_id, repo_type=self.repo_type)
189
+
190
+ def __exit__(self, exc_type, exc_val, exc_tb):
191
+ """Push to hub in a non-blocking way when exiting the logger's context manager."""
192
+ super().__exit__(exc_type, exc_val, exc_tb)
193
+ future = self.scheduler.trigger()
194
+ future.result()
.venv/lib/python3.11/site-packages/huggingface_hub/_upload_large_folder.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import enum
16
+ import logging
17
+ import os
18
+ import queue
19
+ import shutil
20
+ import sys
21
+ import threading
22
+ import time
23
+ import traceback
24
+ from datetime import datetime
25
+ from pathlib import Path
26
+ from threading import Lock
27
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
28
+
29
+ from . import constants
30
+ from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes
31
+ from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata
32
+ from .constants import DEFAULT_REVISION, REPO_TYPES
33
+ from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm
34
+ from .utils._cache_manager import _format_size
35
+ from .utils.sha import sha_fileobj
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from .hf_api import HfApi
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ WAITING_TIME_IF_NO_TASKS = 10 # seconds
44
+ MAX_NB_REGULAR_FILES_PER_COMMIT = 75
45
+ MAX_NB_LFS_FILES_PER_COMMIT = 150
46
+
47
+
48
+ def upload_large_folder_internal(
49
+ api: "HfApi",
50
+ repo_id: str,
51
+ folder_path: Union[str, Path],
52
+ *,
53
+ repo_type: str, # Repo type is required!
54
+ revision: Optional[str] = None,
55
+ private: Optional[bool] = None,
56
+ allow_patterns: Optional[Union[List[str], str]] = None,
57
+ ignore_patterns: Optional[Union[List[str], str]] = None,
58
+ num_workers: Optional[int] = None,
59
+ print_report: bool = True,
60
+ print_report_every: int = 60,
61
+ ):
62
+ """Upload a large folder to the Hub in the most resilient way possible.
63
+
64
+ See [`HfApi.upload_large_folder`] for the full documentation.
65
+ """
66
+ # 1. Check args and setup
67
+ if repo_type is None:
68
+ raise ValueError(
69
+ "For large uploads, `repo_type` is explicitly required. Please set it to `model`, `dataset` or `space`."
70
+ " If you are using the CLI, pass it as `--repo-type=model`."
71
+ )
72
+ if repo_type not in REPO_TYPES:
73
+ raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
74
+ if revision is None:
75
+ revision = DEFAULT_REVISION
76
+
77
+ folder_path = Path(folder_path).expanduser().resolve()
78
+ if not folder_path.is_dir():
79
+ raise ValueError(f"Provided path: '{folder_path}' is not a directory")
80
+
81
+ if ignore_patterns is None:
82
+ ignore_patterns = []
83
+ elif isinstance(ignore_patterns, str):
84
+ ignore_patterns = [ignore_patterns]
85
+ ignore_patterns += DEFAULT_IGNORE_PATTERNS
86
+
87
+ if num_workers is None:
88
+ nb_cores = os.cpu_count() or 1
89
+ num_workers = max(nb_cores - 2, 2) # Use all but 2 cores, or at least 2 cores
90
+
91
+ # 2. Create repo if missing
92
+ repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)
93
+ logger.info(f"Repo created: {repo_url}")
94
+ repo_id = repo_url.repo_id
95
+
96
+ # 3. List files to upload
97
+ filtered_paths_list = filter_repo_objects(
98
+ (path.relative_to(folder_path).as_posix() for path in folder_path.glob("**/*") if path.is_file()),
99
+ allow_patterns=allow_patterns,
100
+ ignore_patterns=ignore_patterns,
101
+ )
102
+ paths_list = [get_local_upload_paths(folder_path, relpath) for relpath in filtered_paths_list]
103
+ logger.info(f"Found {len(paths_list)} candidate files to upload")
104
+
105
+ # Read metadata for each file
106
+ items = [
107
+ (paths, read_upload_metadata(folder_path, paths.path_in_repo))
108
+ for paths in tqdm(paths_list, desc="Recovering from metadata files")
109
+ ]
110
+
111
+ # 4. Start workers
112
+ status = LargeUploadStatus(items)
113
+ threads = [
114
+ threading.Thread(
115
+ target=_worker_job,
116
+ kwargs={
117
+ "status": status,
118
+ "api": api,
119
+ "repo_id": repo_id,
120
+ "repo_type": repo_type,
121
+ "revision": revision,
122
+ },
123
+ )
124
+ for _ in range(num_workers)
125
+ ]
126
+
127
+ for thread in threads:
128
+ thread.start()
129
+
130
+ # 5. Print regular reports
131
+ if print_report:
132
+ print("\n\n" + status.current_report())
133
+ last_report_ts = time.time()
134
+ while True:
135
+ time.sleep(1)
136
+ if time.time() - last_report_ts >= print_report_every:
137
+ if print_report:
138
+ _print_overwrite(status.current_report())
139
+ last_report_ts = time.time()
140
+ if status.is_done():
141
+ logging.info("Is done: exiting main loop")
142
+ break
143
+
144
+ for thread in threads:
145
+ thread.join()
146
+
147
+ logger.info(status.current_report())
148
+ logging.info("Upload is complete!")
149
+
150
+
151
+ ####################
152
+ # Logic to manage workers and synchronize tasks
153
+ ####################
154
+
155
+
156
+ class WorkerJob(enum.Enum):
157
+ SHA256 = enum.auto()
158
+ GET_UPLOAD_MODE = enum.auto()
159
+ PREUPLOAD_LFS = enum.auto()
160
+ COMMIT = enum.auto()
161
+ WAIT = enum.auto() # if no tasks are available but we don't want to exit
162
+
163
+
164
+ JOB_ITEM_T = Tuple[LocalUploadFilePaths, LocalUploadFileMetadata]
165
+
166
+
167
+ class LargeUploadStatus:
168
+ """Contains information, queues and tasks for a large upload process."""
169
+
170
+ def __init__(self, items: List[JOB_ITEM_T]):
171
+ self.items = items
172
+ self.queue_sha256: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
173
+ self.queue_get_upload_mode: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
174
+ self.queue_preupload_lfs: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
175
+ self.queue_commit: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
176
+ self.lock = Lock()
177
+
178
+ self.nb_workers_sha256: int = 0
179
+ self.nb_workers_get_upload_mode: int = 0
180
+ self.nb_workers_preupload_lfs: int = 0
181
+ self.nb_workers_commit: int = 0
182
+ self.nb_workers_waiting: int = 0
183
+ self.last_commit_attempt: Optional[float] = None
184
+
185
+ self._started_at = datetime.now()
186
+
187
+ # Setup queues
188
+ for item in self.items:
189
+ paths, metadata = item
190
+ if metadata.sha256 is None:
191
+ self.queue_sha256.put(item)
192
+ elif metadata.upload_mode is None:
193
+ self.queue_get_upload_mode.put(item)
194
+ elif metadata.upload_mode == "lfs" and not metadata.is_uploaded:
195
+ self.queue_preupload_lfs.put(item)
196
+ elif not metadata.is_committed:
197
+ self.queue_commit.put(item)
198
+ else:
199
+ logger.debug(f"Skipping file {paths.path_in_repo} (already uploaded and committed)")
200
+
201
+ def current_report(self) -> str:
202
+ """Generate a report of the current status of the large upload."""
203
+ nb_hashed = 0
204
+ size_hashed = 0
205
+ nb_preuploaded = 0
206
+ nb_lfs = 0
207
+ nb_lfs_unsure = 0
208
+ size_preuploaded = 0
209
+ nb_committed = 0
210
+ size_committed = 0
211
+ total_size = 0
212
+ ignored_files = 0
213
+ total_files = 0
214
+
215
+ with self.lock:
216
+ for _, metadata in self.items:
217
+ if metadata.should_ignore:
218
+ ignored_files += 1
219
+ continue
220
+ total_size += metadata.size
221
+ total_files += 1
222
+ if metadata.sha256 is not None:
223
+ nb_hashed += 1
224
+ size_hashed += metadata.size
225
+ if metadata.upload_mode == "lfs":
226
+ nb_lfs += 1
227
+ if metadata.upload_mode is None:
228
+ nb_lfs_unsure += 1
229
+ if metadata.is_uploaded:
230
+ nb_preuploaded += 1
231
+ size_preuploaded += metadata.size
232
+ if metadata.is_committed:
233
+ nb_committed += 1
234
+ size_committed += metadata.size
235
+ total_size_str = _format_size(total_size)
236
+
237
+ now = datetime.now()
238
+ now_str = now.strftime("%Y-%m-%d %H:%M:%S")
239
+ elapsed = now - self._started_at
240
+ elapsed_str = str(elapsed).split(".")[0] # remove milliseconds
241
+
242
+ message = "\n" + "-" * 10
243
+ message += f" {now_str} ({elapsed_str}) "
244
+ message += "-" * 10 + "\n"
245
+
246
+ message += "Files: "
247
+ message += f"hashed {nb_hashed}/{total_files} ({_format_size(size_hashed)}/{total_size_str}) | "
248
+ message += f"pre-uploaded: {nb_preuploaded}/{nb_lfs} ({_format_size(size_preuploaded)}/{total_size_str})"
249
+ if nb_lfs_unsure > 0:
250
+ message += f" (+{nb_lfs_unsure} unsure)"
251
+ message += f" | committed: {nb_committed}/{total_files} ({_format_size(size_committed)}/{total_size_str})"
252
+ message += f" | ignored: {ignored_files}\n"
253
+
254
+ message += "Workers: "
255
+ message += f"hashing: {self.nb_workers_sha256} | "
256
+ message += f"get upload mode: {self.nb_workers_get_upload_mode} | "
257
+ message += f"pre-uploading: {self.nb_workers_preupload_lfs} | "
258
+ message += f"committing: {self.nb_workers_commit} | "
259
+ message += f"waiting: {self.nb_workers_waiting}\n"
260
+ message += "-" * 51
261
+
262
+ return message
263
+
264
+ def is_done(self) -> bool:
265
+ with self.lock:
266
+ return all(metadata.is_committed or metadata.should_ignore for _, metadata in self.items)
267
+
268
+
269
+ def _worker_job(
270
+ status: LargeUploadStatus,
271
+ api: "HfApi",
272
+ repo_id: str,
273
+ repo_type: str,
274
+ revision: str,
275
+ ):
276
+ """
277
+ Main process for a worker. The worker will perform tasks based on the priority list until all files are uploaded
278
+ and committed. If no tasks are available, the worker will wait for 10 seconds before checking again.
279
+
280
+ If a task fails for any reason, the item(s) are put back in the queue for another worker to pick up.
281
+
282
+ Read `upload_large_folder` docstring for more information on how tasks are prioritized.
283
+ """
284
+ while True:
285
+ next_job: Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]] = None
286
+
287
+ # Determine next task
288
+ next_job = _determine_next_job(status)
289
+ if next_job is None:
290
+ return
291
+ job, items = next_job
292
+
293
+ # Perform task
294
+ if job == WorkerJob.SHA256:
295
+ item = items[0] # single item
296
+ try:
297
+ _compute_sha256(item)
298
+ status.queue_get_upload_mode.put(item)
299
+ except KeyboardInterrupt:
300
+ raise
301
+ except Exception as e:
302
+ logger.error(f"Failed to compute sha256: {e}")
303
+ traceback.format_exc()
304
+ status.queue_sha256.put(item)
305
+
306
+ with status.lock:
307
+ status.nb_workers_sha256 -= 1
308
+
309
+ elif job == WorkerJob.GET_UPLOAD_MODE:
310
+ try:
311
+ _get_upload_mode(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision)
312
+ except KeyboardInterrupt:
313
+ raise
314
+ except Exception as e:
315
+ logger.error(f"Failed to get upload mode: {e}")
316
+ traceback.format_exc()
317
+
318
+ # Items are either:
319
+ # - dropped (if should_ignore)
320
+ # - put in LFS queue (if LFS)
321
+ # - put in commit queue (if regular)
322
+ # - or put back (if error occurred).
323
+ for item in items:
324
+ _, metadata = item
325
+ if metadata.should_ignore:
326
+ continue
327
+ if metadata.upload_mode == "lfs":
328
+ status.queue_preupload_lfs.put(item)
329
+ elif metadata.upload_mode == "regular":
330
+ status.queue_commit.put(item)
331
+ else:
332
+ status.queue_get_upload_mode.put(item)
333
+
334
+ with status.lock:
335
+ status.nb_workers_get_upload_mode -= 1
336
+
337
+ elif job == WorkerJob.PREUPLOAD_LFS:
338
+ item = items[0] # single item
339
+ try:
340
+ _preupload_lfs(item, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision)
341
+ status.queue_commit.put(item)
342
+ except KeyboardInterrupt:
343
+ raise
344
+ except Exception as e:
345
+ logger.error(f"Failed to preupload LFS: {e}")
346
+ traceback.format_exc()
347
+ status.queue_preupload_lfs.put(item)
348
+
349
+ with status.lock:
350
+ status.nb_workers_preupload_lfs -= 1
351
+
352
+ elif job == WorkerJob.COMMIT:
353
+ try:
354
+ _commit(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision)
355
+ except KeyboardInterrupt:
356
+ raise
357
+ except Exception as e:
358
+ logger.error(f"Failed to commit: {e}")
359
+ traceback.format_exc()
360
+ for item in items:
361
+ status.queue_commit.put(item)
362
+ with status.lock:
363
+ status.last_commit_attempt = time.time()
364
+ status.nb_workers_commit -= 1
365
+
366
+ elif job == WorkerJob.WAIT:
367
+ time.sleep(WAITING_TIME_IF_NO_TASKS)
368
+ with status.lock:
369
+ status.nb_workers_waiting -= 1
370
+
371
+
372
+ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]]:
373
+ with status.lock:
374
+ # 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file)
375
+ if (
376
+ status.nb_workers_commit == 0
377
+ and status.queue_commit.qsize() > 0
378
+ and status.last_commit_attempt is not None
379
+ and time.time() - status.last_commit_attempt > 5 * 60
380
+ ):
381
+ status.nb_workers_commit += 1
382
+ logger.debug("Job: commit (more than 5 minutes since last commit attempt)")
383
+ return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
384
+
385
+ # 2. Commit if at least 100 files are ready to commit
386
+ elif status.nb_workers_commit == 0 and status.queue_commit.qsize() >= 150:
387
+ status.nb_workers_commit += 1
388
+ logger.debug("Job: commit (>100 files ready)")
389
+ return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
390
+
391
+ # 3. Get upload mode if at least 10 files
392
+ elif status.queue_get_upload_mode.qsize() >= 10:
393
+ status.nb_workers_get_upload_mode += 1
394
+ logger.debug("Job: get upload mode (>10 files ready)")
395
+ return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, 50))
396
+
397
+ # 4. Preupload LFS file if at least 1 file and no worker is preuploading LFS
398
+ elif status.queue_preupload_lfs.qsize() > 0 and status.nb_workers_preupload_lfs == 0:
399
+ status.nb_workers_preupload_lfs += 1
400
+ logger.debug("Job: preupload LFS (no other worker preuploading LFS)")
401
+ return (WorkerJob.PREUPLOAD_LFS, _get_one(status.queue_preupload_lfs))
402
+
403
+ # 5. Compute sha256 if at least 1 file and no worker is computing sha256
404
+ elif status.queue_sha256.qsize() > 0 and status.nb_workers_sha256 == 0:
405
+ status.nb_workers_sha256 += 1
406
+ logger.debug("Job: sha256 (no other worker computing sha256)")
407
+ return (WorkerJob.SHA256, _get_one(status.queue_sha256))
408
+
409
+ # 6. Get upload mode if at least 1 file and no worker is getting upload mode
410
+ elif status.queue_get_upload_mode.qsize() > 0 and status.nb_workers_get_upload_mode == 0:
411
+ status.nb_workers_get_upload_mode += 1
412
+ logger.debug("Job: get upload mode (no other worker getting upload mode)")
413
+ return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, 50))
414
+
415
+ # 7. Preupload LFS file if at least 1 file
416
+ # Skip if hf_transfer is enabled and there is already a worker preuploading LFS
417
+ elif status.queue_preupload_lfs.qsize() > 0 and (
418
+ status.nb_workers_preupload_lfs == 0 or not constants.HF_HUB_ENABLE_HF_TRANSFER
419
+ ):
420
+ status.nb_workers_preupload_lfs += 1
421
+ logger.debug("Job: preupload LFS")
422
+ return (WorkerJob.PREUPLOAD_LFS, _get_one(status.queue_preupload_lfs))
423
+
424
+ # 8. Compute sha256 if at least 1 file
425
+ elif status.queue_sha256.qsize() > 0:
426
+ status.nb_workers_sha256 += 1
427
+ logger.debug("Job: sha256")
428
+ return (WorkerJob.SHA256, _get_one(status.queue_sha256))
429
+
430
+ # 9. Get upload mode if at least 1 file
431
+ elif status.queue_get_upload_mode.qsize() > 0:
432
+ status.nb_workers_get_upload_mode += 1
433
+ logger.debug("Job: get upload mode")
434
+ return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, 50))
435
+
436
+ # 10. Commit if at least 1 file and 1 min since last commit attempt
437
+ elif (
438
+ status.nb_workers_commit == 0
439
+ and status.queue_commit.qsize() > 0
440
+ and status.last_commit_attempt is not None
441
+ and time.time() - status.last_commit_attempt > 1 * 60
442
+ ):
443
+ status.nb_workers_commit += 1
444
+ logger.debug("Job: commit (1 min since last commit attempt)")
445
+ return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
446
+
447
+ # 11. Commit if at least 1 file all other queues are empty and all workers are waiting
448
+ # e.g. when it's the last commit
449
+ elif (
450
+ status.nb_workers_commit == 0
451
+ and status.queue_commit.qsize() > 0
452
+ and status.queue_sha256.qsize() == 0
453
+ and status.queue_get_upload_mode.qsize() == 0
454
+ and status.queue_preupload_lfs.qsize() == 0
455
+ and status.nb_workers_sha256 == 0
456
+ and status.nb_workers_get_upload_mode == 0
457
+ and status.nb_workers_preupload_lfs == 0
458
+ ):
459
+ status.nb_workers_commit += 1
460
+ logger.debug("Job: commit")
461
+ return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit))
462
+
463
+ # 12. If all queues are empty, exit
464
+ elif all(metadata.is_committed or metadata.should_ignore for _, metadata in status.items):
465
+ logger.info("All files have been processed! Exiting worker.")
466
+ return None
467
+
468
+ # 13. If no task is available, wait
469
+ else:
470
+ status.nb_workers_waiting += 1
471
+ logger.debug(f"No task available, waiting... ({WAITING_TIME_IF_NO_TASKS}s)")
472
+ return (WorkerJob.WAIT, [])
473
+
474
+
475
+ ####################
476
+ # Atomic jobs (sha256, get_upload_mode, preupload_lfs, commit)
477
+ ####################
478
+
479
+
480
+ def _compute_sha256(item: JOB_ITEM_T) -> None:
481
+ """Compute sha256 of a file and save it in metadata."""
482
+ paths, metadata = item
483
+ if metadata.sha256 is None:
484
+ with paths.file_path.open("rb") as f:
485
+ metadata.sha256 = sha_fileobj(f).hex()
486
+ metadata.save(paths)
487
+
488
+
489
+ def _get_upload_mode(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
490
+ """Get upload mode for each file and update metadata.
491
+
492
+ Also receive info if the file should be ignored.
493
+ """
494
+ additions = [_build_hacky_operation(item) for item in items]
495
+ _fetch_upload_modes(
496
+ additions=additions,
497
+ repo_type=repo_type,
498
+ repo_id=repo_id,
499
+ headers=api._build_hf_headers(),
500
+ revision=revision,
501
+ )
502
+ for item, addition in zip(items, additions):
503
+ paths, metadata = item
504
+ metadata.upload_mode = addition._upload_mode
505
+ metadata.should_ignore = addition._should_ignore
506
+ metadata.save(paths)
507
+
508
+
509
+ def _preupload_lfs(item: JOB_ITEM_T, api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
510
+ """Preupload LFS file and update metadata."""
511
+ paths, metadata = item
512
+ addition = _build_hacky_operation(item)
513
+ api.preupload_lfs_files(
514
+ repo_id=repo_id,
515
+ repo_type=repo_type,
516
+ revision=revision,
517
+ additions=[addition],
518
+ )
519
+
520
+ metadata.is_uploaded = True
521
+ metadata.save(paths)
522
+
523
+
524
+ def _commit(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
525
+ """Commit files to the repo."""
526
+ additions = [_build_hacky_operation(item) for item in items]
527
+ api.create_commit(
528
+ repo_id=repo_id,
529
+ repo_type=repo_type,
530
+ revision=revision,
531
+ operations=additions,
532
+ commit_message="Add files using upload-large-folder tool",
533
+ )
534
+ for paths, metadata in items:
535
+ metadata.is_committed = True
536
+ metadata.save(paths)
537
+
538
+
539
+ ####################
540
+ # Hacks with CommitOperationAdd to bypass checks/sha256 calculation
541
+ ####################
542
+
543
+
544
+ class HackyCommitOperationAdd(CommitOperationAdd):
545
+ def __post_init__(self) -> None:
546
+ if isinstance(self.path_or_fileobj, Path):
547
+ self.path_or_fileobj = str(self.path_or_fileobj)
548
+
549
+
550
+ def _build_hacky_operation(item: JOB_ITEM_T) -> HackyCommitOperationAdd:
551
+ paths, metadata = item
552
+ operation = HackyCommitOperationAdd(path_in_repo=paths.path_in_repo, path_or_fileobj=paths.file_path)
553
+ with paths.file_path.open("rb") as file:
554
+ sample = file.peek(512)[:512]
555
+ if metadata.sha256 is None:
556
+ raise ValueError("sha256 must have been computed by now!")
557
+ operation.upload_info = UploadInfo(sha256=bytes.fromhex(metadata.sha256), size=metadata.size, sample=sample)
558
+ return operation
559
+
560
+
561
+ ####################
562
+ # Misc helpers
563
+ ####################
564
+
565
+
566
+ def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]:
567
+ return [queue.get()]
568
+
569
+
570
+ def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> List[JOB_ITEM_T]:
571
+ return [queue.get() for _ in range(min(queue.qsize(), n))]
572
+
573
+
574
+ def _get_items_to_commit(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]:
575
+ """Special case for commit job: the number of items to commit depends on the type of files."""
576
+ # Can take at most 50 regular files and/or 100 LFS files in a single commit
577
+ items: List[JOB_ITEM_T] = []
578
+ nb_lfs, nb_regular = 0, 0
579
+ while True:
580
+ # If empty queue => commit everything
581
+ if queue.qsize() == 0:
582
+ return items
583
+
584
+ # If we have enough items => commit them
585
+ if nb_lfs >= MAX_NB_LFS_FILES_PER_COMMIT or nb_regular >= MAX_NB_REGULAR_FILES_PER_COMMIT:
586
+ return items
587
+
588
+ # Else, get a new item and increase counter
589
+ item = queue.get()
590
+ items.append(item)
591
+ _, metadata = item
592
+ if metadata.upload_mode == "lfs":
593
+ nb_lfs += 1
594
+ else:
595
+ nb_regular += 1
596
+
597
+
598
+ def _print_overwrite(report: str) -> None:
599
+ """Print a report, overwriting the previous lines.
600
+
601
+ Since tqdm in using `sys.stderr` to (re-)write progress bars, we need to use `sys.stdout`
602
+ to print the report.
603
+
604
+ Note: works well only if no other process is writing to `sys.stdout`!
605
+ """
606
+ report += "\n"
607
+ # Get terminal width
608
+ terminal_width = shutil.get_terminal_size().columns
609
+
610
+ # Count number of lines that should be cleared
611
+ nb_lines = sum(len(line) // terminal_width + 1 for line in report.splitlines())
612
+
613
+ # Clear previous lines based on the number of lines in the report
614
+ for _ in range(nb_lines):
615
+ sys.stdout.write("\r\033[K") # Clear line
616
+ sys.stdout.write("\033[F") # Move cursor up one line
617
+
618
+ # Print the new report, filling remaining space with whitespace
619
+ sys.stdout.write(report)
620
+ sys.stdout.write(" " * (terminal_width - len(report.splitlines()[-1])))
621
+ sys.stdout.flush()
.venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains data structures to parse the webhooks payload."""
16
+
17
+ from typing import List, Literal, Optional
18
+
19
+ from .utils import is_pydantic_available
20
+
21
+
22
+ if is_pydantic_available():
23
+ from pydantic import BaseModel
24
+ else:
25
+ # Define a dummy BaseModel to avoid import errors when pydantic is not installed
26
+ # Import error will be raised when trying to use the class
27
+
28
+ class BaseModel: # type: ignore [no-redef]
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ raise ImportError(
31
+ "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
32
+ " should be installed separately. Please run `pip install --upgrade pydantic` and retry."
33
+ )
34
+
35
+
36
+ # This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they
37
+ # are not in used anymore. To keep in sync when format is updated in
38
+ # https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link).
39
+
40
+
41
+ WebhookEvent_T = Literal[
42
+ "create",
43
+ "delete",
44
+ "move",
45
+ "update",
46
+ ]
47
+ RepoChangeEvent_T = Literal[
48
+ "add",
49
+ "move",
50
+ "remove",
51
+ "update",
52
+ ]
53
+ RepoType_T = Literal[
54
+ "dataset",
55
+ "model",
56
+ "space",
57
+ ]
58
+ DiscussionStatus_T = Literal[
59
+ "closed",
60
+ "draft",
61
+ "open",
62
+ "merged",
63
+ ]
64
+ SupportedWebhookVersion = Literal[3]
65
+
66
+
67
+ class ObjectId(BaseModel):
68
+ id: str
69
+
70
+
71
+ class WebhookPayloadUrl(BaseModel):
72
+ web: str
73
+ api: Optional[str] = None
74
+
75
+
76
+ class WebhookPayloadMovedTo(BaseModel):
77
+ name: str
78
+ owner: ObjectId
79
+
80
+
81
+ class WebhookPayloadWebhook(ObjectId):
82
+ version: SupportedWebhookVersion
83
+
84
+
85
+ class WebhookPayloadEvent(BaseModel):
86
+ action: WebhookEvent_T
87
+ scope: str
88
+
89
+
90
+ class WebhookPayloadDiscussionChanges(BaseModel):
91
+ base: str
92
+ mergeCommitId: Optional[str] = None
93
+
94
+
95
+ class WebhookPayloadComment(ObjectId):
96
+ author: ObjectId
97
+ hidden: bool
98
+ content: Optional[str] = None
99
+ url: WebhookPayloadUrl
100
+
101
+
102
+ class WebhookPayloadDiscussion(ObjectId):
103
+ num: int
104
+ author: ObjectId
105
+ url: WebhookPayloadUrl
106
+ title: str
107
+ isPullRequest: bool
108
+ status: DiscussionStatus_T
109
+ changes: Optional[WebhookPayloadDiscussionChanges] = None
110
+ pinned: Optional[bool] = None
111
+
112
+
113
+ class WebhookPayloadRepo(ObjectId):
114
+ owner: ObjectId
115
+ head_sha: Optional[str] = None
116
+ name: str
117
+ private: bool
118
+ subdomain: Optional[str] = None
119
+ tags: Optional[List[str]] = None
120
+ type: Literal["dataset", "model", "space"]
121
+ url: WebhookPayloadUrl
122
+
123
+
124
+ class WebhookPayloadUpdatedRef(BaseModel):
125
+ ref: str
126
+ oldSha: Optional[str] = None
127
+ newSha: Optional[str] = None
128
+
129
+
130
+ class WebhookPayload(BaseModel):
131
+ event: WebhookPayloadEvent
132
+ repo: WebhookPayloadRepo
133
+ discussion: Optional[WebhookPayloadDiscussion] = None
134
+ comment: Optional[WebhookPayloadComment] = None
135
+ webhook: WebhookPayloadWebhook
136
+ movedTo: Optional[WebhookPayloadMovedTo] = None
137
+ updatedRefs: Optional[List[WebhookPayloadUpdatedRef]] = None
.venv/lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
16
+
17
+ import atexit
18
+ import inspect
19
+ import os
20
+ from functools import wraps
21
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
22
+
23
+ from .utils import experimental, is_fastapi_available, is_gradio_available
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ import gradio as gr
28
+ from fastapi import Request
29
+
30
+ if is_fastapi_available():
31
+ from fastapi import FastAPI, Request
32
+ from fastapi.responses import JSONResponse
33
+ else:
34
+ # Will fail at runtime if FastAPI is not available
35
+ FastAPI = Request = JSONResponse = None # type: ignore [misc, assignment]
36
+
37
+
38
+ _global_app: Optional["WebhooksServer"] = None
39
+ _is_local = os.environ.get("SPACE_ID") is None
40
+
41
+
42
+ @experimental
43
+ class WebhooksServer:
44
+ """
45
+ The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
46
+ These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
47
+ the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `launch` method has to be
48
+ called to start the app.
49
+
50
+ It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
51
+ model that contains all the information about the webhook event. The data will be parsed automatically for you.
52
+
53
+ Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
54
+ WebhooksServer and deploy it on a Space.
55
+
56
+ <Tip warning={true}>
57
+
58
+ `WebhooksServer` is experimental. Its API is subject to change in the future.
59
+
60
+ </Tip>
61
+
62
+ <Tip warning={true}>
63
+
64
+ You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
65
+
66
+ </Tip>
67
+
68
+ Args:
69
+ ui (`gradio.Blocks`, optional):
70
+ A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
71
+ about the configured webhooks is created.
72
+ webhook_secret (`str`, optional):
73
+ A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
74
+ you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
75
+ can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
76
+ webhook endpoints are opened without any security.
77
+
78
+ Example:
79
+
80
+ ```python
81
+ import gradio as gr
82
+ from huggingface_hub import WebhooksServer, WebhookPayload
83
+
84
+ with gr.Blocks() as ui:
85
+ ...
86
+
87
+ app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
88
+
89
+ @app.add_webhook("/say_hello")
90
+ async def hello(payload: WebhookPayload):
91
+ return {"message": "hello"}
92
+
93
+ app.launch()
94
+ ```
95
+ """
96
+
97
+ def __new__(cls, *args, **kwargs) -> "WebhooksServer":
98
+ if not is_gradio_available():
99
+ raise ImportError(
100
+ "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`"
101
+ " first."
102
+ )
103
+ if not is_fastapi_available():
104
+ raise ImportError(
105
+ "You must have `fastapi` installed to use `WebhooksServer`. Please run `pip install --upgrade fastapi`"
106
+ " first."
107
+ )
108
+ return super().__new__(cls)
109
+
110
+ def __init__(
111
+ self,
112
+ ui: Optional["gr.Blocks"] = None,
113
+ webhook_secret: Optional[str] = None,
114
+ ) -> None:
115
+ self._ui = ui
116
+
117
+ self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
118
+ self.registered_webhooks: Dict[str, Callable] = {}
119
+ _warn_on_empty_secret(self.webhook_secret)
120
+
121
+ def add_webhook(self, path: Optional[str] = None) -> Callable:
122
+ """
123
+ Decorator to add a webhook to the [`WebhooksServer`] server.
124
+
125
+ Args:
126
+ path (`str`, optional):
127
+ The URL path to register the webhook function. If not provided, the function name will be used as the
128
+ path. In any case, all webhooks are registered under `/webhooks`.
129
+
130
+ Raises:
131
+ ValueError: If the provided path is already registered as a webhook.
132
+
133
+ Example:
134
+ ```python
135
+ from huggingface_hub import WebhooksServer, WebhookPayload
136
+
137
+ app = WebhooksServer()
138
+
139
+ @app.add_webhook
140
+ async def trigger_training(payload: WebhookPayload):
141
+ if payload.repo.type == "dataset" and payload.event.action == "update":
142
+ # Trigger a training job if a dataset is updated
143
+ ...
144
+
145
+ app.launch()
146
+ ```
147
+ """
148
+ # Usage: directly as decorator. Example: `@app.add_webhook`
149
+ if callable(path):
150
+ # If path is a function, it means it was used as a decorator without arguments
151
+ return self.add_webhook()(path)
152
+
153
+ # Usage: provide a path. Example: `@app.add_webhook(...)`
154
+ @wraps(FastAPI.post)
155
+ def _inner_post(*args, **kwargs):
156
+ func = args[0]
157
+ abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
158
+ if abs_path in self.registered_webhooks:
159
+ raise ValueError(f"Webhook {abs_path} already exists.")
160
+ self.registered_webhooks[abs_path] = func
161
+
162
+ return _inner_post
163
+
164
+ def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None:
165
+ """Launch the Gradio app and register webhooks to the underlying FastAPI server.
166
+
167
+ Input parameters are forwarded to Gradio when launching the app.
168
+ """
169
+ ui = self._ui or self._get_default_ui()
170
+
171
+ # Start Gradio App
172
+ # - as non-blocking so that webhooks can be added afterwards
173
+ # - as shared if launch locally (to debug webhooks)
174
+ launch_kwargs.setdefault("share", _is_local)
175
+ self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs)
176
+
177
+ # Register webhooks to FastAPI app
178
+ for path, func in self.registered_webhooks.items():
179
+ # Add secret check if required
180
+ if self.webhook_secret is not None:
181
+ func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
182
+
183
+ # Add route to FastAPI app
184
+ self.fastapi_app.post(path)(func)
185
+
186
+ # Print instructions and block main thread
187
+ space_host = os.environ.get("SPACE_HOST")
188
+ url = "https://" + space_host if space_host is not None else (ui.share_url or ui.local_url)
189
+ url = url.strip("/")
190
+ message = "\nWebhooks are correctly setup and ready to use:"
191
+ message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
192
+ message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
193
+ print(message)
194
+
195
+ if not prevent_thread_lock:
196
+ ui.block_thread()
197
+
198
+ def _get_default_ui(self) -> "gr.Blocks":
199
+ """Default UI if not provided (lists webhooks and provides basic instructions)."""
200
+ import gradio as gr
201
+
202
+ with gr.Blocks() as ui:
203
+ gr.Markdown("# This is an app to process 🤗 Webhooks")
204
+ gr.Markdown(
205
+ "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
206
+ " specific repos or to all repos belonging to particular set of users/organizations (not just your"
207
+ " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
208
+ " know more about webhooks on the Huggingface Hub."
209
+ )
210
+ gr.Markdown(
211
+ f"{len(self.registered_webhooks)} webhook(s) are registered:"
212
+ + "\n\n"
213
+ + "\n ".join(
214
+ f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
215
+ for webhook_path, webhook in self.registered_webhooks.items()
216
+ )
217
+ )
218
+ gr.Markdown(
219
+ "Go to https://huggingface.co/settings/webhooks to setup your webhooks."
220
+ + "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
221
+ if _is_local
222
+ else (
223
+ "\nThis app is running on a Space. You can find the corresponding URL in the options menu"
224
+ " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
225
+ )
226
+ )
227
+ return ui
228
+
229
+
230
+ @experimental
231
+ def webhook_endpoint(path: Optional[str] = None) -> Callable:
232
+ """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
233
+
234
+ This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
235
+ you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
236
+ this decorator multiple times.
237
+
238
+ Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
239
+ server and deploy it on a Space.
240
+
241
+ <Tip warning={true}>
242
+
243
+ `webhook_endpoint` is experimental. Its API is subject to change in the future.
244
+
245
+ </Tip>
246
+
247
+ <Tip warning={true}>
248
+
249
+ You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
250
+
251
+ </Tip>
252
+
253
+ Args:
254
+ path (`str`, optional):
255
+ The URL path to register the webhook function. If not provided, the function name will be used as the path.
256
+ In any case, all webhooks are registered under `/webhooks`.
257
+
258
+ Examples:
259
+ The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
260
+ The server will be started automatically at exit (i.e. at the end of the script).
261
+
262
+ ```python
263
+ from huggingface_hub import webhook_endpoint, WebhookPayload
264
+
265
+ @webhook_endpoint
266
+ async def trigger_training(payload: WebhookPayload):
267
+ if payload.repo.type == "dataset" and payload.event.action == "update":
268
+ # Trigger a training job if a dataset is updated
269
+ ...
270
+
271
+ # Server is automatically started at the end of the script.
272
+ ```
273
+
274
+ Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
275
+ are running it in a notebook.
276
+
277
+ ```python
278
+ from huggingface_hub import webhook_endpoint, WebhookPayload
279
+
280
+ @webhook_endpoint
281
+ async def trigger_training(payload: WebhookPayload):
282
+ if payload.repo.type == "dataset" and payload.event.action == "update":
283
+ # Trigger a training job if a dataset is updated
284
+ ...
285
+
286
+ # Start the server manually
287
+ trigger_training.launch()
288
+ ```
289
+ """
290
+ if callable(path):
291
+ # If path is a function, it means it was used as a decorator without arguments
292
+ return webhook_endpoint()(path)
293
+
294
+ @wraps(WebhooksServer.add_webhook)
295
+ def _inner(func: Callable) -> Callable:
296
+ app = _get_global_app()
297
+ app.add_webhook(path)(func)
298
+ if len(app.registered_webhooks) == 1:
299
+ # Register `app.launch` to run at exit (only once)
300
+ atexit.register(app.launch)
301
+
302
+ @wraps(app.launch)
303
+ def _launch_now():
304
+ # Run the app directly (without waiting atexit)
305
+ atexit.unregister(app.launch)
306
+ app.launch()
307
+
308
+ func.launch = _launch_now # type: ignore
309
+ return func
310
+
311
+ return _inner
312
+
313
+
314
+ def _get_global_app() -> WebhooksServer:
315
+ global _global_app
316
+ if _global_app is None:
317
+ _global_app = WebhooksServer()
318
+ return _global_app
319
+
320
+
321
+ def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None:
322
+ if webhook_secret is None:
323
+ print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
324
+ print(
325
+ "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
326
+ "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
327
+ )
328
+ print(
329
+ "For more details about webhook secrets, please refer to"
330
+ " https://huggingface.co/docs/hub/webhooks#webhook-secret."
331
+ )
332
+ else:
333
+ print("Webhook secret is correctly defined.")
334
+
335
+
336
+ def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
337
+ """Returns the anchor to a given webhook in the docs (experimental)"""
338
+ return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
339
+
340
+
341
+ def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
342
+ """Wraps a webhook function to check the webhook secret before calling the function.
343
+
344
+ This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
345
+ parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
346
+ object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
347
+ `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
348
+ Gradio internals (and not by us), we cannot add a middleware.
349
+
350
+ This method is called only when a secret has been defined by the user. If a request is sent without the
351
+ "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
352
+ the function will return a 403 error (forbidden).
353
+
354
+ Inspired by https://stackoverflow.com/a/33112180.
355
+ """
356
+ initial_sig = inspect.signature(func)
357
+
358
+ @wraps(func)
359
+ async def _protected_func(request: Request, **kwargs):
360
+ request_secret = request.headers.get("x-webhook-secret")
361
+ if request_secret is None:
362
+ return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
363
+ if request_secret != webhook_secret:
364
+ return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
365
+
366
+ # Inject `request` in kwargs if required
367
+ if "request" in initial_sig.parameters:
368
+ kwargs["request"] = request
369
+
370
+ # Handle both sync and async routes
371
+ if inspect.iscoroutinefunction(func):
372
+ return await func(**kwargs)
373
+ else:
374
+ return func(**kwargs)
375
+
376
+ # Update signature to include request
377
+ if "request" not in initial_sig.parameters:
378
+ _protected_func.__signature__ = initial_sig.replace( # type: ignore
379
+ parameters=(
380
+ inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
381
+ )
382
+ + tuple(initial_sig.parameters.values())
383
+ )
384
+
385
+ # Return protected route
386
+ return _protected_func
.venv/lib/python3.11/site-packages/huggingface_hub/community.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data structures to interact with Discussions and Pull Requests on the Hub.
3
+
4
+ See [the Discussions and Pull Requests guide](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)
5
+ for more information on Pull Requests, Discussions, and the community tab.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from typing import List, Literal, Optional, Union
11
+
12
+ from . import constants
13
+ from .utils import parse_datetime
14
+
15
+
16
+ DiscussionStatus = Literal["open", "closed", "merged", "draft"]
17
+
18
+
19
+ @dataclass
20
+ class Discussion:
21
+ """
22
+ A Discussion or Pull Request on the Hub.
23
+
24
+ This dataclass is not intended to be instantiated directly.
25
+
26
+ Attributes:
27
+ title (`str`):
28
+ The title of the Discussion / Pull Request
29
+ status (`str`):
30
+ The status of the Discussion / Pull Request.
31
+ It must be one of:
32
+ * `"open"`
33
+ * `"closed"`
34
+ * `"merged"` (only for Pull Requests )
35
+ * `"draft"` (only for Pull Requests )
36
+ num (`int`):
37
+ The number of the Discussion / Pull Request.
38
+ repo_id (`str`):
39
+ The id (`"{namespace}/{repo_name}"`) of the repo on which
40
+ the Discussion / Pull Request was open.
41
+ repo_type (`str`):
42
+ The type of the repo on which the Discussion / Pull Request was open.
43
+ Possible values are: `"model"`, `"dataset"`, `"space"`.
44
+ author (`str`):
45
+ The username of the Discussion / Pull Request author.
46
+ Can be `"deleted"` if the user has been deleted since.
47
+ is_pull_request (`bool`):
48
+ Whether or not this is a Pull Request.
49
+ created_at (`datetime`):
50
+ The `datetime` of creation of the Discussion / Pull Request.
51
+ endpoint (`str`):
52
+ Endpoint of the Hub. Default is https://huggingface.co.
53
+ git_reference (`str`, *optional*):
54
+ (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
55
+ url (`str`):
56
+ (property) URL of the discussion on the Hub.
57
+ """
58
+
59
+ title: str
60
+ status: DiscussionStatus
61
+ num: int
62
+ repo_id: str
63
+ repo_type: str
64
+ author: str
65
+ is_pull_request: bool
66
+ created_at: datetime
67
+ endpoint: str
68
+
69
+ @property
70
+ def git_reference(self) -> Optional[str]:
71
+ """
72
+ If this is a Pull Request , returns the git reference to which changes can be pushed.
73
+ Returns `None` otherwise.
74
+ """
75
+ if self.is_pull_request:
76
+ return f"refs/pr/{self.num}"
77
+ return None
78
+
79
+ @property
80
+ def url(self) -> str:
81
+ """Returns the URL of the discussion on the Hub."""
82
+ if self.repo_type is None or self.repo_type == constants.REPO_TYPE_MODEL:
83
+ return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}"
84
+ return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}"
85
+
86
+
87
+ @dataclass
88
+ class DiscussionWithDetails(Discussion):
89
+ """
90
+ Subclass of [`Discussion`].
91
+
92
+ Attributes:
93
+ title (`str`):
94
+ The title of the Discussion / Pull Request
95
+ status (`str`):
96
+ The status of the Discussion / Pull Request.
97
+ It can be one of:
98
+ * `"open"`
99
+ * `"closed"`
100
+ * `"merged"` (only for Pull Requests )
101
+ * `"draft"` (only for Pull Requests )
102
+ num (`int`):
103
+ The number of the Discussion / Pull Request.
104
+ repo_id (`str`):
105
+ The id (`"{namespace}/{repo_name}"`) of the repo on which
106
+ the Discussion / Pull Request was open.
107
+ repo_type (`str`):
108
+ The type of the repo on which the Discussion / Pull Request was open.
109
+ Possible values are: `"model"`, `"dataset"`, `"space"`.
110
+ author (`str`):
111
+ The username of the Discussion / Pull Request author.
112
+ Can be `"deleted"` if the user has been deleted since.
113
+ is_pull_request (`bool`):
114
+ Whether or not this is a Pull Request.
115
+ created_at (`datetime`):
116
+ The `datetime` of creation of the Discussion / Pull Request.
117
+ events (`list` of [`DiscussionEvent`])
118
+ The list of [`DiscussionEvents`] in this Discussion or Pull Request.
119
+ conflicting_files (`Union[List[str], bool, None]`, *optional*):
120
+ A list of conflicting files if this is a Pull Request.
121
+ `None` if `self.is_pull_request` is `False`.
122
+ `True` if there are conflicting files but the list can't be retrieved.
123
+ target_branch (`str`, *optional*):
124
+ The branch into which changes are to be merged if this is a
125
+ Pull Request . `None` if `self.is_pull_request` is `False`.
126
+ merge_commit_oid (`str`, *optional*):
127
+ If this is a merged Pull Request , this is set to the OID / SHA of
128
+ the merge commit, `None` otherwise.
129
+ diff (`str`, *optional*):
130
+ The git diff if this is a Pull Request , `None` otherwise.
131
+ endpoint (`str`):
132
+ Endpoint of the Hub. Default is https://huggingface.co.
133
+ git_reference (`str`, *optional*):
134
+ (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
135
+ url (`str`):
136
+ (property) URL of the discussion on the Hub.
137
+ """
138
+
139
+ events: List["DiscussionEvent"]
140
+ conflicting_files: Union[List[str], bool, None]
141
+ target_branch: Optional[str]
142
+ merge_commit_oid: Optional[str]
143
+ diff: Optional[str]
144
+
145
+
146
+ @dataclass
147
+ class DiscussionEvent:
148
+ """
149
+ An event in a Discussion or Pull Request.
150
+
151
+ Use concrete classes:
152
+ * [`DiscussionComment`]
153
+ * [`DiscussionStatusChange`]
154
+ * [`DiscussionCommit`]
155
+ * [`DiscussionTitleChange`]
156
+
157
+ Attributes:
158
+ id (`str`):
159
+ The ID of the event. An hexadecimal string.
160
+ type (`str`):
161
+ The type of the event.
162
+ created_at (`datetime`):
163
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
164
+ object holding the creation timestamp for the event.
165
+ author (`str`):
166
+ The username of the Discussion / Pull Request author.
167
+ Can be `"deleted"` if the user has been deleted since.
168
+ """
169
+
170
+ id: str
171
+ type: str
172
+ created_at: datetime
173
+ author: str
174
+
175
+ _event: dict
176
+ """Stores the original event data, in case we need to access it later."""
177
+
178
+
179
+ @dataclass
180
+ class DiscussionComment(DiscussionEvent):
181
+ """A comment in a Discussion / Pull Request.
182
+
183
+ Subclass of [`DiscussionEvent`].
184
+
185
+
186
+ Attributes:
187
+ id (`str`):
188
+ The ID of the event. An hexadecimal string.
189
+ type (`str`):
190
+ The type of the event.
191
+ created_at (`datetime`):
192
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
193
+ object holding the creation timestamp for the event.
194
+ author (`str`):
195
+ The username of the Discussion / Pull Request author.
196
+ Can be `"deleted"` if the user has been deleted since.
197
+ content (`str`):
198
+ The raw markdown content of the comment. Mentions, links and images are not rendered.
199
+ edited (`bool`):
200
+ Whether or not this comment has been edited.
201
+ hidden (`bool`):
202
+ Whether or not this comment has been hidden.
203
+ """
204
+
205
+ content: str
206
+ edited: bool
207
+ hidden: bool
208
+
209
+ @property
210
+ def rendered(self) -> str:
211
+ """The rendered comment, as a HTML string"""
212
+ return self._event["data"]["latest"]["html"]
213
+
214
+ @property
215
+ def last_edited_at(self) -> datetime:
216
+ """The last edit time, as a `datetime` object."""
217
+ return parse_datetime(self._event["data"]["latest"]["updatedAt"])
218
+
219
+ @property
220
+ def last_edited_by(self) -> str:
221
+ """The last edit time, as a `datetime` object."""
222
+ return self._event["data"]["latest"].get("author", {}).get("name", "deleted")
223
+
224
+ @property
225
+ def edit_history(self) -> List[dict]:
226
+ """The edit history of the comment"""
227
+ return self._event["data"]["history"]
228
+
229
+ @property
230
+ def number_of_edits(self) -> int:
231
+ return len(self.edit_history)
232
+
233
+
234
+ @dataclass
235
+ class DiscussionStatusChange(DiscussionEvent):
236
+ """A change of status in a Discussion / Pull Request.
237
+
238
+ Subclass of [`DiscussionEvent`].
239
+
240
+ Attributes:
241
+ id (`str`):
242
+ The ID of the event. An hexadecimal string.
243
+ type (`str`):
244
+ The type of the event.
245
+ created_at (`datetime`):
246
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
247
+ object holding the creation timestamp for the event.
248
+ author (`str`):
249
+ The username of the Discussion / Pull Request author.
250
+ Can be `"deleted"` if the user has been deleted since.
251
+ new_status (`str`):
252
+ The status of the Discussion / Pull Request after the change.
253
+ It can be one of:
254
+ * `"open"`
255
+ * `"closed"`
256
+ * `"merged"` (only for Pull Requests )
257
+ """
258
+
259
+ new_status: str
260
+
261
+
262
+ @dataclass
263
+ class DiscussionCommit(DiscussionEvent):
264
+ """A commit in a Pull Request.
265
+
266
+ Subclass of [`DiscussionEvent`].
267
+
268
+ Attributes:
269
+ id (`str`):
270
+ The ID of the event. An hexadecimal string.
271
+ type (`str`):
272
+ The type of the event.
273
+ created_at (`datetime`):
274
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
275
+ object holding the creation timestamp for the event.
276
+ author (`str`):
277
+ The username of the Discussion / Pull Request author.
278
+ Can be `"deleted"` if the user has been deleted since.
279
+ summary (`str`):
280
+ The summary of the commit.
281
+ oid (`str`):
282
+ The OID / SHA of the commit, as a hexadecimal string.
283
+ """
284
+
285
+ summary: str
286
+ oid: str
287
+
288
+
289
+ @dataclass
290
+ class DiscussionTitleChange(DiscussionEvent):
291
+ """A rename event in a Discussion / Pull Request.
292
+
293
+ Subclass of [`DiscussionEvent`].
294
+
295
+ Attributes:
296
+ id (`str`):
297
+ The ID of the event. An hexadecimal string.
298
+ type (`str`):
299
+ The type of the event.
300
+ created_at (`datetime`):
301
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
302
+ object holding the creation timestamp for the event.
303
+ author (`str`):
304
+ The username of the Discussion / Pull Request author.
305
+ Can be `"deleted"` if the user has been deleted since.
306
+ old_title (`str`):
307
+ The previous title for the Discussion / Pull Request.
308
+ new_title (`str`):
309
+ The new title.
310
+ """
311
+
312
+ old_title: str
313
+ new_title: str
314
+
315
+
316
+ def deserialize_event(event: dict) -> DiscussionEvent:
317
+ """Instantiates a [`DiscussionEvent`] from a dict"""
318
+ event_id: str = event["id"]
319
+ event_type: str = event["type"]
320
+ created_at = parse_datetime(event["createdAt"])
321
+
322
+ common_args = dict(
323
+ id=event_id,
324
+ type=event_type,
325
+ created_at=created_at,
326
+ author=event.get("author", {}).get("name", "deleted"),
327
+ _event=event,
328
+ )
329
+
330
+ if event_type == "comment":
331
+ return DiscussionComment(
332
+ **common_args,
333
+ edited=event["data"]["edited"],
334
+ hidden=event["data"]["hidden"],
335
+ content=event["data"]["latest"]["raw"],
336
+ )
337
+ if event_type == "status-change":
338
+ return DiscussionStatusChange(
339
+ **common_args,
340
+ new_status=event["data"]["status"],
341
+ )
342
+ if event_type == "commit":
343
+ return DiscussionCommit(
344
+ **common_args,
345
+ summary=event["data"]["subject"],
346
+ oid=event["data"]["oid"],
347
+ )
348
+ if event_type == "title-change":
349
+ return DiscussionTitleChange(
350
+ **common_args,
351
+ old_title=event["data"]["from"],
352
+ new_title=event["data"]["to"],
353
+ )
354
+
355
+ return DiscussionEvent(**common_args)
.venv/lib/python3.11/site-packages/huggingface_hub/constants.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import typing
4
+ from typing import Literal, Optional, Tuple
5
+
6
+
7
+ # Possible values for env variables
8
+
9
+
10
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
11
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
12
+
13
+
14
+ def _is_true(value: Optional[str]) -> bool:
15
+ if value is None:
16
+ return False
17
+ return value.upper() in ENV_VARS_TRUE_VALUES
18
+
19
+
20
+ def _as_int(value: Optional[str]) -> Optional[int]:
21
+ if value is None:
22
+ return None
23
+ return int(value)
24
+
25
+
26
+ # Constants for file downloads
27
+
28
+ PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
29
+ TF2_WEIGHTS_NAME = "tf_model.h5"
30
+ TF_WEIGHTS_NAME = "model.ckpt"
31
+ FLAX_WEIGHTS_NAME = "flax_model.msgpack"
32
+ CONFIG_NAME = "config.json"
33
+ REPOCARD_NAME = "README.md"
34
+ DEFAULT_ETAG_TIMEOUT = 10
35
+ DEFAULT_DOWNLOAD_TIMEOUT = 10
36
+ DEFAULT_REQUEST_TIMEOUT = 10
37
+ DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024
38
+ HF_TRANSFER_CONCURRENCY = 100
39
+
40
+ # Constants for serialization
41
+
42
+ PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead
43
+ SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors"
44
+ TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5"
45
+
46
+ # Constants for safetensors repos
47
+
48
+ SAFETENSORS_SINGLE_FILE = "model.safetensors"
49
+ SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"
50
+ SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000
51
+
52
+ # Timeout of aquiring file lock and logging the attempt
53
+ FILELOCK_LOG_EVERY_SECONDS = 10
54
+
55
+ # Git-related constants
56
+
57
+ DEFAULT_REVISION = "main"
58
+ REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}")
59
+
60
+ HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"
61
+
62
+ _staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING"))
63
+
64
+ _HF_DEFAULT_ENDPOINT = "https://huggingface.co"
65
+ _HF_DEFAULT_STAGING_ENDPOINT = "https://hub-ci.huggingface.co"
66
+ ENDPOINT = os.getenv("HF_ENDPOINT", "").rstrip("/") or (
67
+ _HF_DEFAULT_STAGING_ENDPOINT if _staging_mode else _HF_DEFAULT_ENDPOINT
68
+ )
69
+
70
+ HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
71
+ HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit"
72
+ HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag"
73
+ HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size"
74
+
75
+ INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co")
76
+
77
+ # See https://huggingface.co/docs/inference-endpoints/index
78
+ INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2"
79
+
80
+ # Proxy for third-party providers
81
+ INFERENCE_PROXY_TEMPLATE = ENDPOINT + "/api/inference-proxy/{provider}"
82
+
83
+ REPO_ID_SEPARATOR = "--"
84
+ # ^ this substring is not allowed in repo_ids on hf.co
85
+ # and is the canonical one we use for serialization of repo ids elsewhere.
86
+
87
+
88
+ REPO_TYPE_DATASET = "dataset"
89
+ REPO_TYPE_SPACE = "space"
90
+ REPO_TYPE_MODEL = "model"
91
+ REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
92
+ SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"]
93
+
94
+ REPO_TYPES_URL_PREFIXES = {
95
+ REPO_TYPE_DATASET: "datasets/",
96
+ REPO_TYPE_SPACE: "spaces/",
97
+ }
98
+ REPO_TYPES_MAPPING = {
99
+ "datasets": REPO_TYPE_DATASET,
100
+ "spaces": REPO_TYPE_SPACE,
101
+ "models": REPO_TYPE_MODEL,
102
+ }
103
+
104
+ DiscussionTypeFilter = Literal["all", "discussion", "pull_request"]
105
+ DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter)
106
+ DiscussionStatusFilter = Literal["all", "open", "closed"]
107
+ DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter)
108
+
109
+ # Webhook subscription types
110
+ WEBHOOK_DOMAIN_T = Literal["repo", "discussions"]
111
+
112
+ # default cache
113
+ default_home = os.path.join(os.path.expanduser("~"), ".cache")
114
+ HF_HOME = os.path.expanduser(
115
+ os.getenv(
116
+ "HF_HOME",
117
+ os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"),
118
+ )
119
+ )
120
+ hf_cache_home = HF_HOME # for backward compatibility. TODO: remove this in 1.0.0
121
+
122
+ default_cache_path = os.path.join(HF_HOME, "hub")
123
+ default_assets_cache_path = os.path.join(HF_HOME, "assets")
124
+
125
+ # Legacy env variables
126
+ HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
127
+ HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path)
128
+
129
+ # New env variables
130
+ HF_HUB_CACHE = os.getenv("HF_HUB_CACHE", HUGGINGFACE_HUB_CACHE)
131
+ HF_ASSETS_CACHE = os.getenv("HF_ASSETS_CACHE", HUGGINGFACE_ASSETS_CACHE)
132
+
133
+ HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE"))
134
+
135
+ # Opt-out from telemetry requests
136
+ HF_HUB_DISABLE_TELEMETRY = (
137
+ _is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable
138
+ or _is_true(os.environ.get("DISABLE_TELEMETRY"))
139
+ or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/
140
+ )
141
+
142
+ # In the past, token was stored in a hardcoded location
143
+ # `_OLD_HF_TOKEN_PATH` is deprecated and will be removed "at some point".
144
+ # See https://github.com/huggingface/huggingface_hub/issues/1232
145
+ _OLD_HF_TOKEN_PATH = os.path.expanduser("~/.huggingface/token")
146
+ HF_TOKEN_PATH = os.environ.get("HF_TOKEN_PATH", os.path.join(HF_HOME, "token"))
147
+ HF_STORED_TOKENS_PATH = os.path.join(os.path.dirname(HF_TOKEN_PATH), "stored_tokens")
148
+
149
+ if _staging_mode:
150
+ # In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens
151
+ _staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging")
152
+ HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub")
153
+ _OLD_HF_TOKEN_PATH = os.path.join(_staging_home, "_old_token")
154
+ HF_TOKEN_PATH = os.path.join(_staging_home, "token")
155
+
156
+ # Here, `True` will disable progress bars globally without possibility of enabling it
157
+ # programmatically. `False` will enable them without possibility of disabling them.
158
+ # If environment variable is not set (None), then the user is free to enable/disable
159
+ # them programmatically.
160
+ # TL;DR: env variable has priority over code
161
+ __HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
162
+ HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = (
163
+ _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None
164
+ )
165
+
166
+ # Disable warning on machines that do not support symlinks (e.g. Windows non-developer)
167
+ HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING"))
168
+
169
+ # Disable warning when using experimental features
170
+ HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING"))
171
+
172
+ # Disable sending the cached token by default is all HTTP requests to the Hub
173
+ HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"))
174
+
175
+ # Enable fast-download using external dependency "hf_transfer"
176
+ # See:
177
+ # - https://pypi.org/project/hf-transfer/
178
+ # - https://github.com/huggingface/hf_transfer (private)
179
+ HF_HUB_ENABLE_HF_TRANSFER: bool = _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER"))
180
+
181
+
182
+ # UNUSED
183
+ # We don't use symlinks in local dir anymore.
184
+ HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD: int = (
185
+ _as_int(os.environ.get("HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD")) or 5 * 1024 * 1024
186
+ )
187
+
188
+ # Used to override the etag timeout on a system level
189
+ HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT
190
+
191
+ # Used to override the get request timeout on a system level
192
+ HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT
193
+
194
+ # List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are
195
+ # deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by
196
+ # default. We still keep the full list of supported frameworks in case we want to scan all of them.
197
+ MAIN_INFERENCE_API_FRAMEWORKS = [
198
+ "diffusers",
199
+ "sentence-transformers",
200
+ "text-generation-inference",
201
+ "transformers",
202
+ ]
203
+
204
+ ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [
205
+ "adapter-transformers",
206
+ "allennlp",
207
+ "asteroid",
208
+ "bertopic",
209
+ "doctr",
210
+ "espnet",
211
+ "fairseq",
212
+ "fastai",
213
+ "fasttext",
214
+ "flair",
215
+ "k2",
216
+ "keras",
217
+ "mindspore",
218
+ "nemo",
219
+ "open_clip",
220
+ "paddlenlp",
221
+ "peft",
222
+ "pyannote-audio",
223
+ "sklearn",
224
+ "spacy",
225
+ "span-marker",
226
+ "speechbrain",
227
+ "stanza",
228
+ "timm",
229
+ ]
.venv/lib/python3.11/site-packages/huggingface_hub/errors.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains all custom errors."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ from requests import HTTPError, Response
7
+
8
+
9
+ # CACHE ERRORS
10
+
11
+
12
+ class CacheNotFound(Exception):
13
+ """Exception thrown when the Huggingface cache is not found."""
14
+
15
+ cache_dir: Union[str, Path]
16
+
17
+ def __init__(self, msg: str, cache_dir: Union[str, Path], *args, **kwargs):
18
+ super().__init__(msg, *args, **kwargs)
19
+ self.cache_dir = cache_dir
20
+
21
+
22
+ class CorruptedCacheException(Exception):
23
+ """Exception for any unexpected structure in the Huggingface cache-system."""
24
+
25
+
26
+ # HEADERS ERRORS
27
+
28
+
29
+ class LocalTokenNotFoundError(EnvironmentError):
30
+ """Raised if local token is required but not found."""
31
+
32
+
33
+ # HTTP ERRORS
34
+
35
+
36
+ class OfflineModeIsEnabled(ConnectionError):
37
+ """Raised when a request is made but `HF_HUB_OFFLINE=1` is set as environment variable."""
38
+
39
+
40
+ class HfHubHTTPError(HTTPError):
41
+ """
42
+ HTTPError to inherit from for any custom HTTP Error raised in HF Hub.
43
+
44
+ Any HTTPError is converted at least into a `HfHubHTTPError`. If some information is
45
+ sent back by the server, it will be added to the error message.
46
+
47
+ Added details:
48
+ - Request id from "X-Request-Id" header if exists. If not, fallback to "X-Amzn-Trace-Id" header if exists.
49
+ - Server error message from the header "X-Error-Message".
50
+ - Server error message if we can found one in the response body.
51
+
52
+ Example:
53
+ ```py
54
+ import requests
55
+ from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError
56
+
57
+ response = get_session().post(...)
58
+ try:
59
+ hf_raise_for_status(response)
60
+ except HfHubHTTPError as e:
61
+ print(str(e)) # formatted message
62
+ e.request_id, e.server_message # details returned by server
63
+
64
+ # Complete the error message with additional information once it's raised
65
+ e.append_to_message("\n`create_commit` expects the repository to exist.")
66
+ raise
67
+ ```
68
+ """
69
+
70
+ def __init__(self, message: str, response: Optional[Response] = None, *, server_message: Optional[str] = None):
71
+ self.request_id = (
72
+ response.headers.get("x-request-id") or response.headers.get("X-Amzn-Trace-Id")
73
+ if response is not None
74
+ else None
75
+ )
76
+ self.server_message = server_message
77
+
78
+ super().__init__(
79
+ message,
80
+ response=response, # type: ignore [arg-type]
81
+ request=response.request if response is not None else None, # type: ignore [arg-type]
82
+ )
83
+
84
+ def append_to_message(self, additional_message: str) -> None:
85
+ """Append additional information to the `HfHubHTTPError` initial message."""
86
+ self.args = (self.args[0] + additional_message,) + self.args[1:]
87
+
88
+
89
+ # INFERENCE CLIENT ERRORS
90
+
91
+
92
+ class InferenceTimeoutError(HTTPError, TimeoutError):
93
+ """Error raised when a model is unavailable or the request times out."""
94
+
95
+
96
+ # INFERENCE ENDPOINT ERRORS
97
+
98
+
99
+ class InferenceEndpointError(Exception):
100
+ """Generic exception when dealing with Inference Endpoints."""
101
+
102
+
103
+ class InferenceEndpointTimeoutError(InferenceEndpointError, TimeoutError):
104
+ """Exception for timeouts while waiting for Inference Endpoint."""
105
+
106
+
107
+ # SAFETENSORS ERRORS
108
+
109
+
110
+ class SafetensorsParsingError(Exception):
111
+ """Raised when failing to parse a safetensors file metadata.
112
+
113
+ This can be the case if the file is not a safetensors file or does not respect the specification.
114
+ """
115
+
116
+
117
+ class NotASafetensorsRepoError(Exception):
118
+ """Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a
119
+ `model.safetensors.index.json` file.
120
+ """
121
+
122
+
123
+ # TEXT GENERATION ERRORS
124
+
125
+
126
+ class TextGenerationError(HTTPError):
127
+ """Generic error raised if text-generation went wrong."""
128
+
129
+
130
+ # Text Generation Inference Errors
131
+ class ValidationError(TextGenerationError):
132
+ """Server-side validation error."""
133
+
134
+
135
+ class GenerationError(TextGenerationError):
136
+ pass
137
+
138
+
139
+ class OverloadedError(TextGenerationError):
140
+ pass
141
+
142
+
143
+ class IncompleteGenerationError(TextGenerationError):
144
+ pass
145
+
146
+
147
+ class UnknownError(TextGenerationError):
148
+ pass
149
+
150
+
151
+ # VALIDATION ERRORS
152
+
153
+
154
+ class HFValidationError(ValueError):
155
+ """Generic exception thrown by `huggingface_hub` validators.
156
+
157
+ Inherits from [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError).
158
+ """
159
+
160
+
161
+ # FILE METADATA ERRORS
162
+
163
+
164
+ class FileMetadataError(OSError):
165
+ """Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash).
166
+
167
+ Inherits from `OSError` for backward compatibility.
168
+ """
169
+
170
+
171
+ # REPOSITORY ERRORS
172
+
173
+
174
+ class RepositoryNotFoundError(HfHubHTTPError):
175
+ """
176
+ Raised when trying to access a hf.co URL with an invalid repository name, or
177
+ with a private repo name the user does not have access to.
178
+
179
+ Example:
180
+
181
+ ```py
182
+ >>> from huggingface_hub import model_info
183
+ >>> model_info("<non_existent_repository>")
184
+ (...)
185
+ huggingface_hub.utils._errors.RepositoryNotFoundError: 401 Client Error. (Request ID: PvMw_VjBMjVdMz53WKIzP)
186
+
187
+ Repository Not Found for url: https://huggingface.co/api/models/%3Cnon_existent_repository%3E.
188
+ Please make sure you specified the correct `repo_id` and `repo_type`.
189
+ If the repo is private, make sure you are authenticated.
190
+ Invalid username or password.
191
+ ```
192
+ """
193
+
194
+
195
+ class GatedRepoError(RepositoryNotFoundError):
196
+ """
197
+ Raised when trying to access a gated repository for which the user is not on the
198
+ authorized list.
199
+
200
+ Note: derives from `RepositoryNotFoundError` to ensure backward compatibility.
201
+
202
+ Example:
203
+
204
+ ```py
205
+ >>> from huggingface_hub import model_info
206
+ >>> model_info("<gated_repository>")
207
+ (...)
208
+ huggingface_hub.utils._errors.GatedRepoError: 403 Client Error. (Request ID: ViT1Bf7O_026LGSQuVqfa)
209
+
210
+ Cannot access gated repo for url https://huggingface.co/api/models/ardent-figment/gated-model.
211
+ Access to model ardent-figment/gated-model is restricted and you are not in the authorized list.
212
+ Visit https://huggingface.co/ardent-figment/gated-model to ask for access.
213
+ ```
214
+ """
215
+
216
+
217
+ class DisabledRepoError(HfHubHTTPError):
218
+ """
219
+ Raised when trying to access a repository that has been disabled by its author.
220
+
221
+ Example:
222
+
223
+ ```py
224
+ >>> from huggingface_hub import dataset_info
225
+ >>> dataset_info("laion/laion-art")
226
+ (...)
227
+ huggingface_hub.utils._errors.DisabledRepoError: 403 Client Error. (Request ID: Root=1-659fc3fa-3031673e0f92c71a2260dbe2;bc6f4dfb-b30a-4862-af0a-5cfe827610d8)
228
+
229
+ Cannot access repository for url https://huggingface.co/api/datasets/laion/laion-art.
230
+ Access to this resource is disabled.
231
+ ```
232
+ """
233
+
234
+
235
+ # REVISION ERROR
236
+
237
+
238
+ class RevisionNotFoundError(HfHubHTTPError):
239
+ """
240
+ Raised when trying to access a hf.co URL with a valid repository but an invalid
241
+ revision.
242
+
243
+ Example:
244
+
245
+ ```py
246
+ >>> from huggingface_hub import hf_hub_download
247
+ >>> hf_hub_download('bert-base-cased', 'config.json', revision='<non-existent-revision>')
248
+ (...)
249
+ huggingface_hub.utils._errors.RevisionNotFoundError: 404 Client Error. (Request ID: Mwhe_c3Kt650GcdKEFomX)
250
+
251
+ Revision Not Found for url: https://huggingface.co/bert-base-cased/resolve/%3Cnon-existent-revision%3E/config.json.
252
+ ```
253
+ """
254
+
255
+
256
+ # ENTRY ERRORS
257
+ class EntryNotFoundError(HfHubHTTPError):
258
+ """
259
+ Raised when trying to access a hf.co URL with a valid repository and revision
260
+ but an invalid filename.
261
+
262
+ Example:
263
+
264
+ ```py
265
+ >>> from huggingface_hub import hf_hub_download
266
+ >>> hf_hub_download('bert-base-cased', '<non-existent-file>')
267
+ (...)
268
+ huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: 53pNl6M0MxsnG5Sw8JA6x)
269
+
270
+ Entry Not Found for url: https://huggingface.co/bert-base-cased/resolve/main/%3Cnon-existent-file%3E.
271
+ ```
272
+ """
273
+
274
+
275
+ class LocalEntryNotFoundError(EntryNotFoundError, FileNotFoundError, ValueError):
276
+ """
277
+ Raised when trying to access a file or snapshot that is not on the disk when network is
278
+ disabled or unavailable (connection issue). The entry may exist on the Hub.
279
+
280
+ Note: `ValueError` type is to ensure backward compatibility.
281
+ Note: `LocalEntryNotFoundError` derives from `HTTPError` because of `EntryNotFoundError`
282
+ even when it is not a network issue.
283
+
284
+ Example:
285
+
286
+ ```py
287
+ >>> from huggingface_hub import hf_hub_download
288
+ >>> hf_hub_download('bert-base-cased', '<non-cached-file>', local_files_only=True)
289
+ (...)
290
+ huggingface_hub.utils._errors.LocalEntryNotFoundError: Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co look-ups and downloads online, set 'local_files_only' to False.
291
+ ```
292
+ """
293
+
294
+ def __init__(self, message: str):
295
+ super().__init__(message, response=None)
296
+
297
+
298
+ # REQUEST ERROR
299
+ class BadRequestError(HfHubHTTPError, ValueError):
300
+ """
301
+ Raised by `hf_raise_for_status` when the server returns a HTTP 400 error.
302
+
303
+ Example:
304
+
305
+ ```py
306
+ >>> resp = requests.post("hf.co/api/check", ...)
307
+ >>> hf_raise_for_status(resp, endpoint_name="check")
308
+ huggingface_hub.utils._errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX)
309
+ ```
310
+ """
311
+
312
+
313
+ # DDUF file format ERROR
314
+
315
+
316
+ class DDUFError(Exception):
317
+ """Base exception for errors related to the DDUF format."""
318
+
319
+
320
+ class DDUFCorruptedFileError(DDUFError):
321
+ """Exception thrown when the DDUF file is corrupted."""
322
+
323
+
324
+ class DDUFExportError(DDUFError):
325
+ """Base exception for errors during DDUF export."""
326
+
327
+
328
+ class DDUFInvalidEntryNameError(DDUFExportError):
329
+ """Exception thrown when the entry name is invalid."""
.venv/lib/python3.11/site-packages/huggingface_hub/fastai_utils.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from pickle import DEFAULT_PROTOCOL, PicklingError
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from packaging import version
8
+
9
+ from huggingface_hub import constants, snapshot_download
10
+ from huggingface_hub.hf_api import HfApi
11
+ from huggingface_hub.utils import (
12
+ SoftTemporaryDirectory,
13
+ get_fastai_version,
14
+ get_fastcore_version,
15
+ get_python_version,
16
+ )
17
+
18
+ from .utils import logging, validate_hf_hub_args
19
+ from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility...
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ def _check_fastai_fastcore_versions(
26
+ fastai_min_version: str = "2.4",
27
+ fastcore_min_version: str = "1.3.27",
28
+ ):
29
+ """
30
+ Checks that the installed fastai and fastcore versions are compatible for pickle serialization.
31
+
32
+ Args:
33
+ fastai_min_version (`str`, *optional*):
34
+ The minimum fastai version supported.
35
+ fastcore_min_version (`str`, *optional*):
36
+ The minimum fastcore version supported.
37
+
38
+ <Tip>
39
+ Raises the following error:
40
+
41
+ - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
42
+ if the fastai or fastcore libraries are not available or are of an invalid version.
43
+
44
+ </Tip>
45
+ """
46
+
47
+ if (get_fastcore_version() or get_fastai_version()) == "N/A":
48
+ raise ImportError(
49
+ f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are"
50
+ f" required. Currently using fastai=={get_fastai_version()} and"
51
+ f" fastcore=={get_fastcore_version()}."
52
+ )
53
+
54
+ current_fastai_version = version.Version(get_fastai_version())
55
+ current_fastcore_version = version.Version(get_fastcore_version())
56
+
57
+ if current_fastai_version < version.Version(fastai_min_version):
58
+ raise ImportError(
59
+ "`push_to_hub_fastai` and `from_pretrained_fastai` require a"
60
+ f" fastai>={fastai_min_version} version, but you are using fastai version"
61
+ f" {get_fastai_version()} which is incompatible. Upgrade with `pip install"
62
+ " fastai==2.5.6`."
63
+ )
64
+
65
+ if current_fastcore_version < version.Version(fastcore_min_version):
66
+ raise ImportError(
67
+ "`push_to_hub_fastai` and `from_pretrained_fastai` require a"
68
+ f" fastcore>={fastcore_min_version} version, but you are using fastcore"
69
+ f" version {get_fastcore_version()} which is incompatible. Upgrade with"
70
+ " `pip install fastcore==1.3.27`."
71
+ )
72
+
73
+
74
+ def _check_fastai_fastcore_pyproject_versions(
75
+ storage_folder: str,
76
+ fastai_min_version: str = "2.4",
77
+ fastcore_min_version: str = "1.3.27",
78
+ ):
79
+ """
80
+ Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions
81
+ that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist
82
+ or does not contain versions for fastai and fastcore, then it logs a warning.
83
+
84
+ Args:
85
+ storage_folder (`str`):
86
+ Folder to look for the `pyproject.toml` file.
87
+ fastai_min_version (`str`, *optional*):
88
+ The minimum fastai version supported.
89
+ fastcore_min_version (`str`, *optional*):
90
+ The minimum fastcore version supported.
91
+
92
+ <Tip>
93
+ Raises the following errors:
94
+
95
+ - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
96
+ if the `toml` module is not installed.
97
+ - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
98
+ if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore.
99
+
100
+ </Tip>
101
+ """
102
+
103
+ try:
104
+ import toml
105
+ except ModuleNotFoundError:
106
+ raise ImportError(
107
+ "`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module."
108
+ " Install it with `pip install toml`."
109
+ )
110
+
111
+ # Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages.
112
+ if not os.path.isfile(f"{storage_folder}/pyproject.toml"):
113
+ logger.warning(
114
+ "There is no `pyproject.toml` in the repository that contains the fastai"
115
+ " `Learner`. The `pyproject.toml` would allow us to verify that your fastai"
116
+ " and fastcore versions are compatible with those of the model you want to"
117
+ " load."
118
+ )
119
+ return
120
+ pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml")
121
+
122
+ if "build-system" not in pyproject_toml.keys():
123
+ logger.warning(
124
+ "There is no `build-system` section in the pyproject.toml of the repository"
125
+ " that contains the fastai `Learner`. The `build-system` would allow us to"
126
+ " verify that your fastai and fastcore versions are compatible with those"
127
+ " of the model you want to load."
128
+ )
129
+ return
130
+ build_system_toml = pyproject_toml["build-system"]
131
+
132
+ if "requires" not in build_system_toml.keys():
133
+ logger.warning(
134
+ "There is no `requires` section in the pyproject.toml of the repository"
135
+ " that contains the fastai `Learner`. The `requires` would allow us to"
136
+ " verify that your fastai and fastcore versions are compatible with those"
137
+ " of the model you want to load."
138
+ )
139
+ return
140
+ package_versions = build_system_toml["requires"]
141
+
142
+ # Extracts contains fastai and fastcore versions from `pyproject.toml` if available.
143
+ # If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest.
144
+ fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")]
145
+ if len(fastai_packages) == 0:
146
+ logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.")
147
+ # fastai_version is an empty string if not specified
148
+ else:
149
+ fastai_version = str(fastai_packages[0]).partition("=")[2]
150
+ if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version):
151
+ raise ImportError(
152
+ "`from_pretrained_fastai` requires"
153
+ f" fastai>={fastai_min_version} version but the model to load uses"
154
+ f" {fastai_version} which is incompatible."
155
+ )
156
+
157
+ fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")]
158
+ if len(fastcore_packages) == 0:
159
+ logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.")
160
+ # fastcore_version is an empty string if not specified
161
+ else:
162
+ fastcore_version = str(fastcore_packages[0]).partition("=")[2]
163
+ if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version):
164
+ raise ImportError(
165
+ "`from_pretrained_fastai` requires"
166
+ f" fastcore>={fastcore_min_version} version, but you are using fastcore"
167
+ f" version {fastcore_version} which is incompatible."
168
+ )
169
+
170
+
171
+ README_TEMPLATE = """---
172
+ tags:
173
+ - fastai
174
+ ---
175
+
176
+ # Amazing!
177
+
178
+ 🥳 Congratulations on hosting your fastai model on the Hugging Face Hub!
179
+
180
+ # Some next steps
181
+ 1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))!
182
+
183
+ 2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)).
184
+
185
+ 3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)!
186
+
187
+ Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card.
188
+
189
+
190
+ ---
191
+
192
+
193
+ # Model card
194
+
195
+ ## Model description
196
+ More information needed
197
+
198
+ ## Intended uses & limitations
199
+ More information needed
200
+
201
+ ## Training and evaluation data
202
+ More information needed
203
+ """
204
+
205
+ PYPROJECT_TEMPLATE = f"""[build-system]
206
+ requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"]
207
+ build-backend = "setuptools.build_meta:__legacy__"
208
+ """
209
+
210
+
211
+ def _create_model_card(repo_dir: Path):
212
+ """
213
+ Creates a model card for the repository.
214
+
215
+ Args:
216
+ repo_dir (`Path`):
217
+ Directory where model card is created.
218
+ """
219
+ readme_path = repo_dir / "README.md"
220
+
221
+ if not readme_path.exists():
222
+ with readme_path.open("w", encoding="utf-8") as f:
223
+ f.write(README_TEMPLATE)
224
+
225
+
226
+ def _create_model_pyproject(repo_dir: Path):
227
+ """
228
+ Creates a `pyproject.toml` for the repository.
229
+
230
+ Args:
231
+ repo_dir (`Path`):
232
+ Directory where `pyproject.toml` is created.
233
+ """
234
+ pyproject_path = repo_dir / "pyproject.toml"
235
+
236
+ if not pyproject_path.exists():
237
+ with pyproject_path.open("w", encoding="utf-8") as f:
238
+ f.write(PYPROJECT_TEMPLATE)
239
+
240
+
241
+ def _save_pretrained_fastai(
242
+ learner,
243
+ save_directory: Union[str, Path],
244
+ config: Optional[Dict[str, Any]] = None,
245
+ ):
246
+ """
247
+ Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used.
248
+
249
+ Args:
250
+ learner (`Learner`):
251
+ The `fastai.Learner` you'd like to save.
252
+ save_directory (`str` or `Path`):
253
+ Specific directory in which you want to save the fastai learner.
254
+ config (`dict`, *optional*):
255
+ Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'.
256
+
257
+ <Tip>
258
+
259
+ Raises the following error:
260
+
261
+ - [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError)
262
+ if the config file provided is not a dictionary.
263
+
264
+ </Tip>
265
+ """
266
+ _check_fastai_fastcore_versions()
267
+
268
+ os.makedirs(save_directory, exist_ok=True)
269
+
270
+ # if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE.
271
+ if config is not None:
272
+ if not isinstance(config, dict):
273
+ raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'")
274
+ path = os.path.join(save_directory, constants.CONFIG_NAME)
275
+ with open(path, "w") as f:
276
+ json.dump(config, f)
277
+
278
+ _create_model_card(Path(save_directory))
279
+ _create_model_pyproject(Path(save_directory))
280
+
281
+ # learner.export saves the model in `self.path`.
282
+ learner.path = Path(save_directory)
283
+ os.makedirs(save_directory, exist_ok=True)
284
+ try:
285
+ learner.export(
286
+ fname="model.pkl",
287
+ pickle_protocol=DEFAULT_PROTOCOL,
288
+ )
289
+ except PicklingError:
290
+ raise PicklingError(
291
+ "You are using a lambda function, i.e., an anonymous function. `pickle`"
292
+ " cannot pickle function objects and requires that all functions have"
293
+ " names. One possible solution is to name the function."
294
+ )
295
+
296
+
297
+ @validate_hf_hub_args
298
+ def from_pretrained_fastai(
299
+ repo_id: str,
300
+ revision: Optional[str] = None,
301
+ ):
302
+ """
303
+ Load pretrained fastai model from the Hub or from a local directory.
304
+
305
+ Args:
306
+ repo_id (`str`):
307
+ The location where the pickled fastai.Learner is. It can be either of the two:
308
+ - Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'.
309
+ You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`.
310
+ Revision is the specific model version to use. Since we use a git-based system for storing models and other
311
+ artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id.
312
+ - Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml
313
+ indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`.
314
+ revision (`str`, *optional*):
315
+ Revision at which the repo's files are downloaded. See documentation of `snapshot_download`.
316
+
317
+ Returns:
318
+ The `fastai.Learner` model in the `repo_id` repo.
319
+ """
320
+ _check_fastai_fastcore_versions()
321
+
322
+ # Load the `repo_id` repo.
323
+ # `snapshot_download` returns the folder where the model was stored.
324
+ # `cache_dir` will be the default '/root/.cache/huggingface/hub'
325
+ if not os.path.isdir(repo_id):
326
+ storage_folder = snapshot_download(
327
+ repo_id=repo_id,
328
+ revision=revision,
329
+ library_name="fastai",
330
+ library_version=get_fastai_version(),
331
+ )
332
+ else:
333
+ storage_folder = repo_id
334
+
335
+ _check_fastai_fastcore_pyproject_versions(storage_folder)
336
+
337
+ from fastai.learner import load_learner # type: ignore
338
+
339
+ return load_learner(os.path.join(storage_folder, "model.pkl"))
340
+
341
+
342
+ @validate_hf_hub_args
343
+ def push_to_hub_fastai(
344
+ learner,
345
+ *,
346
+ repo_id: str,
347
+ commit_message: str = "Push FastAI model using huggingface_hub.",
348
+ private: Optional[bool] = None,
349
+ token: Optional[str] = None,
350
+ config: Optional[dict] = None,
351
+ branch: Optional[str] = None,
352
+ create_pr: Optional[bool] = None,
353
+ allow_patterns: Optional[Union[List[str], str]] = None,
354
+ ignore_patterns: Optional[Union[List[str], str]] = None,
355
+ delete_patterns: Optional[Union[List[str], str]] = None,
356
+ api_endpoint: Optional[str] = None,
357
+ ):
358
+ """
359
+ Upload learner checkpoint files to the Hub.
360
+
361
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
362
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
363
+ details.
364
+
365
+ Args:
366
+ learner (`Learner`):
367
+ The `fastai.Learner' you'd like to push to the Hub.
368
+ repo_id (`str`):
369
+ The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de').
370
+ commit_message (`str`, *optional*):
371
+ Message to commit while pushing. Will default to :obj:`"add model"`.
372
+ private (`bool`, *optional*):
373
+ Whether or not the repository created should be private.
374
+ If `None` (default), will default to been public except if the organization's default is private.
375
+ token (`str`, *optional*):
376
+ The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt.
377
+ config (`dict`, *optional*):
378
+ Configuration object to be saved alongside the model weights.
379
+ branch (`str`, *optional*):
380
+ The git branch on which to push the model. This defaults to
381
+ the default branch as specified in your repository, which
382
+ defaults to `"main"`.
383
+ create_pr (`boolean`, *optional*):
384
+ Whether or not to create a Pull Request from `branch` with that commit.
385
+ Defaults to `False`.
386
+ api_endpoint (`str`, *optional*):
387
+ The API endpoint to use when pushing the model to the hub.
388
+ allow_patterns (`List[str]` or `str`, *optional*):
389
+ If provided, only files matching at least one pattern are pushed.
390
+ ignore_patterns (`List[str]` or `str`, *optional*):
391
+ If provided, files matching any of the patterns are not pushed.
392
+ delete_patterns (`List[str]` or `str`, *optional*):
393
+ If provided, remote files matching any of the patterns will be deleted from the repo.
394
+
395
+ Returns:
396
+ The url of the commit of your model in the given repository.
397
+
398
+ <Tip>
399
+
400
+ Raises the following error:
401
+
402
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
403
+ if the user is not log on to the Hugging Face Hub.
404
+
405
+ </Tip>
406
+ """
407
+ _check_fastai_fastcore_versions()
408
+ api = HfApi(endpoint=api_endpoint)
409
+ repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id
410
+
411
+ # Push the files to the repo in a single commit
412
+ with SoftTemporaryDirectory() as tmp:
413
+ saved_path = Path(tmp) / repo_id
414
+ _save_pretrained_fastai(learner, saved_path, config=config)
415
+ return api.upload_folder(
416
+ repo_id=repo_id,
417
+ token=token,
418
+ folder_path=saved_path,
419
+ commit_message=commit_message,
420
+ revision=branch,
421
+ create_pr=create_pr,
422
+ allow_patterns=allow_patterns,
423
+ ignore_patterns=ignore_patterns,
424
+ delete_patterns=delete_patterns,
425
+ )
.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py ADDED
@@ -0,0 +1,1621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import copy
3
+ import errno
4
+ import inspect
5
+ import os
6
+ import re
7
+ import shutil
8
+ import stat
9
+ import time
10
+ import uuid
11
+ import warnings
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union
15
+ from urllib.parse import quote, urlparse
16
+
17
+ import requests
18
+
19
+ from . import (
20
+ __version__, # noqa: F401 # for backward compatibility
21
+ constants,
22
+ )
23
+ from ._local_folder import get_local_download_paths, read_download_metadata, write_download_metadata
24
+ from .constants import (
25
+ HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 # for backward compatibility
26
+ HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility
27
+ )
28
+ from .errors import (
29
+ EntryNotFoundError,
30
+ FileMetadataError,
31
+ GatedRepoError,
32
+ LocalEntryNotFoundError,
33
+ RepositoryNotFoundError,
34
+ RevisionNotFoundError,
35
+ )
36
+ from .utils import (
37
+ OfflineModeIsEnabled,
38
+ SoftTemporaryDirectory,
39
+ WeakFileLock,
40
+ build_hf_headers,
41
+ get_fastai_version, # noqa: F401 # for backward compatibility
42
+ get_fastcore_version, # noqa: F401 # for backward compatibility
43
+ get_graphviz_version, # noqa: F401 # for backward compatibility
44
+ get_jinja_version, # noqa: F401 # for backward compatibility
45
+ get_pydot_version, # noqa: F401 # for backward compatibility
46
+ get_session,
47
+ get_tf_version, # noqa: F401 # for backward compatibility
48
+ get_torch_version, # noqa: F401 # for backward compatibility
49
+ hf_raise_for_status,
50
+ is_fastai_available, # noqa: F401 # for backward compatibility
51
+ is_fastcore_available, # noqa: F401 # for backward compatibility
52
+ is_graphviz_available, # noqa: F401 # for backward compatibility
53
+ is_jinja_available, # noqa: F401 # for backward compatibility
54
+ is_pydot_available, # noqa: F401 # for backward compatibility
55
+ is_tf_available, # noqa: F401 # for backward compatibility
56
+ is_torch_available, # noqa: F401 # for backward compatibility
57
+ logging,
58
+ reset_sessions,
59
+ tqdm,
60
+ validate_hf_hub_args,
61
+ )
62
+ from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility
63
+ from .utils._typing import HTTP_METHOD_T
64
+ from .utils.sha import sha_fileobj
65
+ from .utils.tqdm import is_tqdm_disabled
66
+
67
+
68
+ logger = logging.get_logger(__name__)
69
+
70
+ # Return value when trying to load a file from cache but the file does not exist in the distant repo.
71
+ _CACHED_NO_EXIST = object()
72
+ _CACHED_NO_EXIST_T = Any
73
+
74
+ # Regex to get filename from a "Content-Disposition" header for CDN-served files
75
+ HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P<filename>.*?)";')
76
+
77
+ # Regex to check if the revision IS directly a commit_hash
78
+ REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
79
+
80
+ # Regex to check if the file etag IS a valid sha256
81
+ REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$")
82
+
83
+ _are_symlinks_supported_in_dir: Dict[str, bool] = {}
84
+
85
+
86
+ def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
87
+ """Return whether the symlinks are supported on the machine.
88
+
89
+ Since symlinks support can change depending on the mounted disk, we need to check
90
+ on the precise cache folder. By default, the default HF cache directory is checked.
91
+
92
+ Args:
93
+ cache_dir (`str`, `Path`, *optional*):
94
+ Path to the folder where cached files are stored.
95
+
96
+ Returns: [bool] Whether symlinks are supported in the directory.
97
+ """
98
+ # Defaults to HF cache
99
+ if cache_dir is None:
100
+ cache_dir = constants.HF_HUB_CACHE
101
+ cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique
102
+
103
+ # Check symlink compatibility only once (per cache directory) at first time use
104
+ if cache_dir not in _are_symlinks_supported_in_dir:
105
+ _are_symlinks_supported_in_dir[cache_dir] = True
106
+
107
+ os.makedirs(cache_dir, exist_ok=True)
108
+ with SoftTemporaryDirectory(dir=cache_dir) as tmpdir:
109
+ src_path = Path(tmpdir) / "dummy_file_src"
110
+ src_path.touch()
111
+ dst_path = Path(tmpdir) / "dummy_file_dst"
112
+
113
+ # Relative source path as in `_create_symlink``
114
+ relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path))
115
+ try:
116
+ os.symlink(relative_src, dst_path)
117
+ except OSError:
118
+ # Likely running on Windows
119
+ _are_symlinks_supported_in_dir[cache_dir] = False
120
+
121
+ if not constants.HF_HUB_DISABLE_SYMLINKS_WARNING:
122
+ message = (
123
+ "`huggingface_hub` cache-system uses symlinks by default to"
124
+ " efficiently store duplicated files but your machine does not"
125
+ f" support them in {cache_dir}. Caching files will still work"
126
+ " but in a degraded version that might require more space on"
127
+ " your disk. This warning can be disabled by setting the"
128
+ " `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For"
129
+ " more details, see"
130
+ " https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations."
131
+ )
132
+ if os.name == "nt":
133
+ message += (
134
+ "\nTo support symlinks on Windows, you either need to"
135
+ " activate Developer Mode or to run Python as an"
136
+ " administrator. In order to activate developer mode,"
137
+ " see this article:"
138
+ " https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development"
139
+ )
140
+ warnings.warn(message)
141
+
142
+ return _are_symlinks_supported_in_dir[cache_dir]
143
+
144
+
145
+ @dataclass(frozen=True)
146
+ class HfFileMetadata:
147
+ """Data structure containing information about a file versioned on the Hub.
148
+
149
+ Returned by [`get_hf_file_metadata`] based on a URL.
150
+
151
+ Args:
152
+ commit_hash (`str`, *optional*):
153
+ The commit_hash related to the file.
154
+ etag (`str`, *optional*):
155
+ Etag of the file on the server.
156
+ location (`str`):
157
+ Location where to download the file. Can be a Hub url or not (CDN).
158
+ size (`size`):
159
+ Size of the file. In case of an LFS file, contains the size of the actual
160
+ LFS file, not the pointer.
161
+ """
162
+
163
+ commit_hash: Optional[str]
164
+ etag: Optional[str]
165
+ location: str
166
+ size: Optional[int]
167
+
168
+
169
+ @validate_hf_hub_args
170
+ def hf_hub_url(
171
+ repo_id: str,
172
+ filename: str,
173
+ *,
174
+ subfolder: Optional[str] = None,
175
+ repo_type: Optional[str] = None,
176
+ revision: Optional[str] = None,
177
+ endpoint: Optional[str] = None,
178
+ ) -> str:
179
+ """Construct the URL of a file from the given information.
180
+
181
+ The resolved address can either be a huggingface.co-hosted url, or a link to
182
+ Cloudfront (a Content Delivery Network, or CDN) for large files which are
183
+ more than a few MBs.
184
+
185
+ Args:
186
+ repo_id (`str`):
187
+ A namespace (user or an organization) name and a repo name separated
188
+ by a `/`.
189
+ filename (`str`):
190
+ The name of the file in the repo.
191
+ subfolder (`str`, *optional*):
192
+ An optional value corresponding to a folder inside the repo.
193
+ repo_type (`str`, *optional*):
194
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
195
+ `None` or `"model"` if downloading from a model. Default is `None`.
196
+ revision (`str`, *optional*):
197
+ An optional Git revision id which can be a branch name, a tag, or a
198
+ commit hash.
199
+
200
+ Example:
201
+
202
+ ```python
203
+ >>> from huggingface_hub import hf_hub_url
204
+
205
+ >>> hf_hub_url(
206
+ ... repo_id="julien-c/EsperBERTo-small", filename="pytorch_model.bin"
207
+ ... )
208
+ 'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin'
209
+ ```
210
+
211
+ <Tip>
212
+
213
+ Notes:
214
+
215
+ Cloudfront is replicated over the globe so downloads are way faster for
216
+ the end user (and it also lowers our bandwidth costs).
217
+
218
+ Cloudfront aggressively caches files by default (default TTL is 24
219
+ hours), however this is not an issue here because we implement a
220
+ git-based versioning system on huggingface.co, which means that we store
221
+ the files on S3/Cloudfront in a content-addressable way (i.e., the file
222
+ name is its hash). Using content-addressable filenames means cache can't
223
+ ever be stale.
224
+
225
+ In terms of client-side caching from this library, we base our caching
226
+ on the objects' entity tag (`ETag`), which is an identifier of a
227
+ specific version of a resource [1]_. An object's ETag is: its git-sha1
228
+ if stored in git, or its sha256 if stored in git-lfs.
229
+
230
+ </Tip>
231
+
232
+ References:
233
+
234
+ - [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag
235
+ """
236
+ if subfolder == "":
237
+ subfolder = None
238
+ if subfolder is not None:
239
+ filename = f"{subfolder}/{filename}"
240
+
241
+ if repo_type not in constants.REPO_TYPES:
242
+ raise ValueError("Invalid repo type")
243
+
244
+ if repo_type in constants.REPO_TYPES_URL_PREFIXES:
245
+ repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
246
+
247
+ if revision is None:
248
+ revision = constants.DEFAULT_REVISION
249
+ url = HUGGINGFACE_CO_URL_TEMPLATE.format(
250
+ repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename)
251
+ )
252
+ # Update endpoint if provided
253
+ if endpoint is not None and url.startswith(constants.ENDPOINT):
254
+ url = endpoint + url[len(constants.ENDPOINT) :]
255
+ return url
256
+
257
+
258
+ def _request_wrapper(
259
+ method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params
260
+ ) -> requests.Response:
261
+ """Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when
262
+ `allow_redirection=False`.
263
+
264
+ Args:
265
+ method (`str`):
266
+ HTTP method, such as 'GET' or 'HEAD'.
267
+ url (`str`):
268
+ The URL of the resource to fetch.
269
+ follow_relative_redirects (`bool`, *optional*, defaults to `False`)
270
+ If True, relative redirection (redirection to the same site) will be resolved even when `allow_redirection`
271
+ kwarg is set to False. Useful when we want to follow a redirection to a renamed repository without
272
+ following redirection to a CDN.
273
+ **params (`dict`, *optional*):
274
+ Params to pass to `requests.request`.
275
+ """
276
+ # Recursively follow relative redirects
277
+ if follow_relative_redirects:
278
+ response = _request_wrapper(
279
+ method=method,
280
+ url=url,
281
+ follow_relative_redirects=False,
282
+ **params,
283
+ )
284
+
285
+ # If redirection, we redirect only relative paths.
286
+ # This is useful in case of a renamed repository.
287
+ if 300 <= response.status_code <= 399:
288
+ parsed_target = urlparse(response.headers["Location"])
289
+ if parsed_target.netloc == "":
290
+ # This means it is a relative 'location' headers, as allowed by RFC 7231.
291
+ # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
292
+ # We want to follow this relative redirect !
293
+ #
294
+ # Highly inspired by `resolve_redirects` from requests library.
295
+ # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159
296
+ next_url = urlparse(url)._replace(path=parsed_target.path).geturl()
297
+ return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params)
298
+ return response
299
+
300
+ # Perform request and return if status_code is not in the retry list.
301
+ response = get_session().request(method=method, url=url, **params)
302
+ hf_raise_for_status(response)
303
+ return response
304
+
305
+
306
+ def http_get(
307
+ url: str,
308
+ temp_file: BinaryIO,
309
+ *,
310
+ proxies: Optional[Dict] = None,
311
+ resume_size: float = 0,
312
+ headers: Optional[Dict[str, str]] = None,
313
+ expected_size: Optional[int] = None,
314
+ displayed_filename: Optional[str] = None,
315
+ _nb_retries: int = 5,
316
+ _tqdm_bar: Optional[tqdm] = None,
317
+ ) -> None:
318
+ """
319
+ Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
320
+
321
+ If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely a
322
+ transient error (network outage?). We log a warning message and try to resume the download a few times before
323
+ giving up. The method gives up after 5 attempts if no new data has being received from the server.
324
+
325
+ Args:
326
+ url (`str`):
327
+ The URL of the file to download.
328
+ temp_file (`BinaryIO`):
329
+ The file-like object where to save the file.
330
+ proxies (`dict`, *optional*):
331
+ Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
332
+ resume_size (`float`, *optional*):
333
+ The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a
334
+ positive number, the download will resume at the given position.
335
+ headers (`dict`, *optional*):
336
+ Dictionary of HTTP Headers to send with the request.
337
+ expected_size (`int`, *optional*):
338
+ The expected size of the file to download. If set, the download will raise an error if the size of the
339
+ received content is different from the expected one.
340
+ displayed_filename (`str`, *optional*):
341
+ The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If
342
+ not set, the filename is guessed from the URL or the `Content-Disposition` header.
343
+ """
344
+ if expected_size is not None and resume_size == expected_size:
345
+ # If the file is already fully downloaded, we don't need to download it again.
346
+ return
347
+
348
+ hf_transfer = None
349
+ if constants.HF_HUB_ENABLE_HF_TRANSFER:
350
+ if resume_size != 0:
351
+ warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method")
352
+ elif proxies is not None:
353
+ warnings.warn("'hf_transfer' does not support `proxies`: falling back to regular download method")
354
+ else:
355
+ try:
356
+ import hf_transfer # type: ignore[no-redef]
357
+ except ImportError:
358
+ raise ValueError(
359
+ "Fast download using 'hf_transfer' is enabled"
360
+ " (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is not"
361
+ " available in your environment. Try `pip install hf_transfer`."
362
+ )
363
+
364
+ initial_headers = headers
365
+ headers = copy.deepcopy(headers) or {}
366
+ if resume_size > 0:
367
+ headers["Range"] = "bytes=%d-" % (resume_size,)
368
+
369
+ r = _request_wrapper(
370
+ method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT
371
+ )
372
+ hf_raise_for_status(r)
373
+ content_length = r.headers.get("Content-Length")
374
+
375
+ # NOTE: 'total' is the total number of bytes to download, not the number of bytes in the file.
376
+ # If the file is compressed, the number of bytes in the saved file will be higher than 'total'.
377
+ total = resume_size + int(content_length) if content_length is not None else None
378
+
379
+ if displayed_filename is None:
380
+ displayed_filename = url
381
+ content_disposition = r.headers.get("Content-Disposition")
382
+ if content_disposition is not None:
383
+ match = HEADER_FILENAME_PATTERN.search(content_disposition)
384
+ if match is not None:
385
+ # Means file is on CDN
386
+ displayed_filename = match.groupdict()["filename"]
387
+
388
+ # Truncate filename if too long to display
389
+ if len(displayed_filename) > 40:
390
+ displayed_filename = f"(…){displayed_filename[-40:]}"
391
+
392
+ consistency_error_message = (
393
+ f"Consistency check failed: file should be of size {expected_size} but has size"
394
+ f" {{actual_size}} ({displayed_filename}).\nThis is usually due to network issues while downloading the file."
395
+ " Please retry with `force_download=True`."
396
+ )
397
+
398
+ # Stream file to buffer
399
+ progress_cm: tqdm = (
400
+ tqdm( # type: ignore[assignment]
401
+ unit="B",
402
+ unit_scale=True,
403
+ total=total,
404
+ initial=resume_size,
405
+ desc=displayed_filename,
406
+ disable=is_tqdm_disabled(logger.getEffectiveLevel()),
407
+ name="huggingface_hub.http_get",
408
+ )
409
+ if _tqdm_bar is None
410
+ else contextlib.nullcontext(_tqdm_bar)
411
+ # ^ `contextlib.nullcontext` mimics a context manager that does nothing
412
+ # Makes it easier to use the same code path for both cases but in the later
413
+ # case, the progress bar is not closed when exiting the context manager.
414
+ )
415
+
416
+ with progress_cm as progress:
417
+ if hf_transfer and total is not None and total > 5 * constants.DOWNLOAD_CHUNK_SIZE:
418
+ supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
419
+ if not supports_callback:
420
+ warnings.warn(
421
+ "You are using an outdated version of `hf_transfer`. "
422
+ "Consider upgrading to latest version to enable progress bars "
423
+ "using `pip install -U hf_transfer`."
424
+ )
425
+ try:
426
+ hf_transfer.download(
427
+ url=url,
428
+ filename=temp_file.name,
429
+ max_files=constants.HF_TRANSFER_CONCURRENCY,
430
+ chunk_size=constants.DOWNLOAD_CHUNK_SIZE,
431
+ headers=headers,
432
+ parallel_failures=3,
433
+ max_retries=5,
434
+ **({"callback": progress.update} if supports_callback else {}),
435
+ )
436
+ except Exception as e:
437
+ raise RuntimeError(
438
+ "An error occurred while downloading using `hf_transfer`. Consider"
439
+ " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
440
+ ) from e
441
+ if not supports_callback:
442
+ progress.update(total)
443
+ if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
444
+ raise EnvironmentError(
445
+ consistency_error_message.format(
446
+ actual_size=os.path.getsize(temp_file.name),
447
+ )
448
+ )
449
+ return
450
+ new_resume_size = resume_size
451
+ try:
452
+ for chunk in r.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE):
453
+ if chunk: # filter out keep-alive new chunks
454
+ progress.update(len(chunk))
455
+ temp_file.write(chunk)
456
+ new_resume_size += len(chunk)
457
+ # Some data has been downloaded from the server so we reset the number of retries.
458
+ _nb_retries = 5
459
+ except (requests.ConnectionError, requests.ReadTimeout) as e:
460
+ # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
461
+ # a transient error (network outage?). We log a warning message and try to resume the download a few times
462
+ # before giving up. Tre retry mechanism is basic but should be enough in most cases.
463
+ if _nb_retries <= 0:
464
+ logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
465
+ raise
466
+ logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
467
+ time.sleep(1)
468
+ reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
469
+ return http_get(
470
+ url=url,
471
+ temp_file=temp_file,
472
+ proxies=proxies,
473
+ resume_size=new_resume_size,
474
+ headers=initial_headers,
475
+ expected_size=expected_size,
476
+ _nb_retries=_nb_retries - 1,
477
+ _tqdm_bar=_tqdm_bar,
478
+ )
479
+
480
+ if expected_size is not None and expected_size != temp_file.tell():
481
+ raise EnvironmentError(
482
+ consistency_error_message.format(
483
+ actual_size=temp_file.tell(),
484
+ )
485
+ )
486
+
487
+
488
+ def _normalize_etag(etag: Optional[str]) -> Optional[str]:
489
+ """Normalize ETag HTTP header, so it can be used to create nice filepaths.
490
+
491
+ The HTTP spec allows two forms of ETag:
492
+ ETag: W/"<etag_value>"
493
+ ETag: "<etag_value>"
494
+
495
+ For now, we only expect the second form from the server, but we want to be future-proof so we support both. For
496
+ more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428.
497
+
498
+ Args:
499
+ etag (`str`, *optional*): HTTP header
500
+
501
+ Returns:
502
+ `str` or `None`: string that can be used as a nice directory name.
503
+ Returns `None` if input is None.
504
+ """
505
+ if etag is None:
506
+ return None
507
+ return etag.lstrip("W/").strip('"')
508
+
509
+
510
+ def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None:
511
+ """Alias method used in `transformers` conversion script."""
512
+ return _create_symlink(src=src, dst=dst, new_blob=new_blob)
513
+
514
+
515
+ def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
516
+ """Create a symbolic link named dst pointing to src.
517
+
518
+ By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages:
519
+ - If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will
520
+ not break.
521
+ - Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when
522
+ changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398,
523
+ https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228.
524
+ NOTE: The issue with absolute paths doesn't happen on admin mode.
525
+ When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created.
526
+ This happens when paths are not on the same volume. In that case, we use absolute paths.
527
+
528
+
529
+ The result layout looks something like
530
+ └── [ 128] snapshots
531
+ ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
532
+ │ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
533
+ │ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
534
+
535
+ If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by
536
+ having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file
537
+ (`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing
538
+ cache, the file is duplicated on the disk.
539
+
540
+ In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`.
541
+ The warning message can be disabled with the `DISABLE_SYMLINKS_WARNING` environment variable.
542
+ """
543
+ try:
544
+ os.remove(dst)
545
+ except OSError:
546
+ pass
547
+
548
+ abs_src = os.path.abspath(os.path.expanduser(src))
549
+ abs_dst = os.path.abspath(os.path.expanduser(dst))
550
+ abs_dst_folder = os.path.dirname(abs_dst)
551
+
552
+ # Use relative_dst in priority
553
+ try:
554
+ relative_src = os.path.relpath(abs_src, abs_dst_folder)
555
+ except ValueError:
556
+ # Raised on Windows if src and dst are not on the same volume. This is the case when creating a symlink to a
557
+ # local_dir instead of within the cache directory.
558
+ # See https://docs.python.org/3/library/os.path.html#os.path.relpath
559
+ relative_src = None
560
+
561
+ try:
562
+ commonpath = os.path.commonpath([abs_src, abs_dst])
563
+ _support_symlinks = are_symlinks_supported(commonpath)
564
+ except ValueError:
565
+ # Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos.
566
+ # See https://docs.python.org/3/library/os.path.html#os.path.commonpath
567
+ _support_symlinks = os.name != "nt"
568
+ except PermissionError:
569
+ # Permission error means src and dst are not in the same volume (e.g. destination path has been provided
570
+ # by the user via `local_dir`. Let's test symlink support there)
571
+ _support_symlinks = are_symlinks_supported(abs_dst_folder)
572
+ except OSError as e:
573
+ # OS error (errno=30) means that the commonpath is readonly on Linux/MacOS.
574
+ if e.errno == errno.EROFS:
575
+ _support_symlinks = are_symlinks_supported(abs_dst_folder)
576
+ else:
577
+ raise
578
+
579
+ # Symlinks are supported => let's create a symlink.
580
+ if _support_symlinks:
581
+ src_rel_or_abs = relative_src or abs_src
582
+ logger.debug(f"Creating pointer from {src_rel_or_abs} to {abs_dst}")
583
+ try:
584
+ os.symlink(src_rel_or_abs, abs_dst)
585
+ return
586
+ except FileExistsError:
587
+ if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src):
588
+ # `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has
589
+ # been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing.
590
+ return
591
+ else:
592
+ # Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and
593
+ # `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception.
594
+ raise
595
+ except PermissionError:
596
+ # Permission error means src and dst are not in the same volume (e.g. download to local dir) and symlink
597
+ # is supported on both volumes but not between them. Let's just make a hard copy in that case.
598
+ pass
599
+
600
+ # Symlinks are not supported => let's move or copy the file.
601
+ if new_blob:
602
+ logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
603
+ shutil.move(abs_src, abs_dst, copy_function=_copy_no_matter_what)
604
+ else:
605
+ logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
606
+ shutil.copyfile(abs_src, abs_dst)
607
+
608
+
609
+ def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None:
610
+ """Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash.
611
+
612
+ Does nothing if `revision` is already a proper `commit_hash` or reference is already cached.
613
+ """
614
+ if revision != commit_hash:
615
+ ref_path = Path(storage_folder) / "refs" / revision
616
+ ref_path.parent.mkdir(parents=True, exist_ok=True)
617
+ if not ref_path.exists() or commit_hash != ref_path.read_text():
618
+ # Update ref only if has been updated. Could cause useless error in case
619
+ # repo is already cached and user doesn't have write access to cache folder.
620
+ # See https://github.com/huggingface/huggingface_hub/issues/1216.
621
+ ref_path.write_text(commit_hash)
622
+
623
+
624
+ @validate_hf_hub_args
625
+ def repo_folder_name(*, repo_id: str, repo_type: str) -> str:
626
+ """Return a serialized version of a hf.co repo name and type, safe for disk storage
627
+ as a single non-nested folder.
628
+
629
+ Example: models--julien-c--EsperBERTo-small
630
+ """
631
+ # remove all `/` occurrences to correctly convert repo to directory name
632
+ parts = [f"{repo_type}s", *repo_id.split("/")]
633
+ return constants.REPO_ID_SEPARATOR.join(parts)
634
+
635
+
636
+ def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None:
637
+ """Check disk usage and log a warning if there is not enough disk space to download the file.
638
+
639
+ Args:
640
+ expected_size (`int`):
641
+ The expected size of the file in bytes.
642
+ target_dir (`str`):
643
+ The directory where the file will be stored after downloading.
644
+ """
645
+
646
+ target_dir = Path(target_dir) # format as `Path`
647
+ for path in [target_dir] + list(target_dir.parents): # first check target_dir, then each parents one by one
648
+ try:
649
+ target_dir_free = shutil.disk_usage(path).free
650
+ if target_dir_free < expected_size:
651
+ warnings.warn(
652
+ "Not enough free disk space to download the file. "
653
+ f"The expected file size is: {expected_size / 1e6:.2f} MB. "
654
+ f"The target location {target_dir} only has {target_dir_free / 1e6:.2f} MB free disk space."
655
+ )
656
+ return
657
+ except OSError: # raise on anything: file does not exist or space disk cannot be checked
658
+ pass
659
+
660
+
661
+ @validate_hf_hub_args
662
+ def hf_hub_download(
663
+ repo_id: str,
664
+ filename: str,
665
+ *,
666
+ subfolder: Optional[str] = None,
667
+ repo_type: Optional[str] = None,
668
+ revision: Optional[str] = None,
669
+ library_name: Optional[str] = None,
670
+ library_version: Optional[str] = None,
671
+ cache_dir: Union[str, Path, None] = None,
672
+ local_dir: Union[str, Path, None] = None,
673
+ user_agent: Union[Dict, str, None] = None,
674
+ force_download: bool = False,
675
+ proxies: Optional[Dict] = None,
676
+ etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
677
+ token: Union[bool, str, None] = None,
678
+ local_files_only: bool = False,
679
+ headers: Optional[Dict[str, str]] = None,
680
+ endpoint: Optional[str] = None,
681
+ resume_download: Optional[bool] = None,
682
+ force_filename: Optional[str] = None,
683
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
684
+ ) -> str:
685
+ """Download a given file if it's not already present in the local cache.
686
+
687
+ The new cache file layout looks like this:
688
+ - The cache directory contains one subfolder per repo_id (namespaced by repo type)
689
+ - inside each repo folder:
690
+ - refs is a list of the latest known revision => commit_hash pairs
691
+ - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on
692
+ whether they're LFS files or not)
693
+ - snapshots contains one subfolder per commit, each "commit" contains the subset of the files
694
+ that have been resolved at that particular commit. Each filename is a symlink to the blob
695
+ at that particular commit.
696
+
697
+ ```
698
+ [ 96] .
699
+ └── [ 160] models--julien-c--EsperBERTo-small
700
+ ├── [ 160] blobs
701
+ │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
702
+ │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
703
+ │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812
704
+ ├── [ 96] refs
705
+ │ └── [ 40] main
706
+ └── [ 128] snapshots
707
+ ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
708
+ │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
709
+ │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
710
+ └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
711
+ ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
712
+ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
713
+ ```
714
+
715
+ If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
716
+ option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
717
+ to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
718
+ cache-system, it's optimized for regularly pulling the latest version of a repository.
719
+
720
+ Args:
721
+ repo_id (`str`):
722
+ A user or an organization name and a repo name separated by a `/`.
723
+ filename (`str`):
724
+ The name of the file in the repo.
725
+ subfolder (`str`, *optional*):
726
+ An optional value corresponding to a folder inside the model repo.
727
+ repo_type (`str`, *optional*):
728
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
729
+ `None` or `"model"` if downloading from a model. Default is `None`.
730
+ revision (`str`, *optional*):
731
+ An optional Git revision id which can be a branch name, a tag, or a
732
+ commit hash.
733
+ library_name (`str`, *optional*):
734
+ The name of the library to which the object corresponds.
735
+ library_version (`str`, *optional*):
736
+ The version of the library.
737
+ cache_dir (`str`, `Path`, *optional*):
738
+ Path to the folder where cached files are stored.
739
+ local_dir (`str` or `Path`, *optional*):
740
+ If provided, the downloaded file will be placed under this directory.
741
+ user_agent (`dict`, `str`, *optional*):
742
+ The user-agent info in the form of a dictionary or a string.
743
+ force_download (`bool`, *optional*, defaults to `False`):
744
+ Whether the file should be downloaded even if it already exists in
745
+ the local cache.
746
+ proxies (`dict`, *optional*):
747
+ Dictionary mapping protocol to the URL of the proxy passed to
748
+ `requests.request`.
749
+ etag_timeout (`float`, *optional*, defaults to `10`):
750
+ When fetching ETag, how many seconds to wait for the server to send
751
+ data before giving up which is passed to `requests.request`.
752
+ token (`str`, `bool`, *optional*):
753
+ A token to be used for the download.
754
+ - If `True`, the token is read from the HuggingFace config
755
+ folder.
756
+ - If a string, it's used as the authentication token.
757
+ local_files_only (`bool`, *optional*, defaults to `False`):
758
+ If `True`, avoid downloading the file and return the path to the
759
+ local cached file if it exists.
760
+ headers (`dict`, *optional*):
761
+ Additional headers to be sent with the request.
762
+
763
+ Returns:
764
+ `str`: Local path of file or if networking is off, last version of file cached on disk.
765
+
766
+ Raises:
767
+ [`~utils.RepositoryNotFoundError`]
768
+ If the repository to download from cannot be found. This may be because it doesn't exist,
769
+ or because it is set to `private` and you do not have access.
770
+ [`~utils.RevisionNotFoundError`]
771
+ If the revision to download from cannot be found.
772
+ [`~utils.EntryNotFoundError`]
773
+ If the file to download cannot be found.
774
+ [`~utils.LocalEntryNotFoundError`]
775
+ If network is disabled or unavailable and file is not found in cache.
776
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
777
+ If `token=True` but the token cannot be found.
778
+ [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
779
+ If ETag cannot be determined.
780
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
781
+ If some parameter value is invalid.
782
+
783
+ """
784
+ if constants.HF_HUB_ETAG_TIMEOUT != constants.DEFAULT_ETAG_TIMEOUT:
785
+ # Respect environment variable above user value
786
+ etag_timeout = constants.HF_HUB_ETAG_TIMEOUT
787
+
788
+ if force_filename is not None:
789
+ warnings.warn(
790
+ "The `force_filename` parameter is deprecated as a new caching system, "
791
+ "which keeps the filenames as they are on the Hub, is now in place.",
792
+ FutureWarning,
793
+ )
794
+ if resume_download is not None:
795
+ warnings.warn(
796
+ "`resume_download` is deprecated and will be removed in version 1.0.0. "
797
+ "Downloads always resume when possible. "
798
+ "If you want to force a new download, use `force_download=True`.",
799
+ FutureWarning,
800
+ )
801
+
802
+ if cache_dir is None:
803
+ cache_dir = constants.HF_HUB_CACHE
804
+ if revision is None:
805
+ revision = constants.DEFAULT_REVISION
806
+ if isinstance(cache_dir, Path):
807
+ cache_dir = str(cache_dir)
808
+ if isinstance(local_dir, Path):
809
+ local_dir = str(local_dir)
810
+
811
+ if subfolder == "":
812
+ subfolder = None
813
+ if subfolder is not None:
814
+ # This is used to create a URL, and not a local path, hence the forward slash.
815
+ filename = f"{subfolder}/{filename}"
816
+
817
+ if repo_type is None:
818
+ repo_type = "model"
819
+ if repo_type not in constants.REPO_TYPES:
820
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
821
+
822
+ hf_headers = build_hf_headers(
823
+ token=token,
824
+ library_name=library_name,
825
+ library_version=library_version,
826
+ user_agent=user_agent,
827
+ headers=headers,
828
+ )
829
+
830
+ if local_dir is not None:
831
+ if local_dir_use_symlinks != "auto":
832
+ warnings.warn(
833
+ "`local_dir_use_symlinks` parameter is deprecated and will be ignored. "
834
+ "The process to download files to a local folder has been updated and do "
835
+ "not rely on symlinks anymore. You only need to pass a destination folder "
836
+ "as`local_dir`.\n"
837
+ "For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder."
838
+ )
839
+
840
+ return _hf_hub_download_to_local_dir(
841
+ # Destination
842
+ local_dir=local_dir,
843
+ # File info
844
+ repo_id=repo_id,
845
+ repo_type=repo_type,
846
+ filename=filename,
847
+ revision=revision,
848
+ # HTTP info
849
+ endpoint=endpoint,
850
+ etag_timeout=etag_timeout,
851
+ headers=hf_headers,
852
+ proxies=proxies,
853
+ token=token,
854
+ # Additional options
855
+ cache_dir=cache_dir,
856
+ force_download=force_download,
857
+ local_files_only=local_files_only,
858
+ )
859
+ else:
860
+ return _hf_hub_download_to_cache_dir(
861
+ # Destination
862
+ cache_dir=cache_dir,
863
+ # File info
864
+ repo_id=repo_id,
865
+ filename=filename,
866
+ repo_type=repo_type,
867
+ revision=revision,
868
+ # HTTP info
869
+ endpoint=endpoint,
870
+ etag_timeout=etag_timeout,
871
+ headers=hf_headers,
872
+ proxies=proxies,
873
+ token=token,
874
+ # Additional options
875
+ local_files_only=local_files_only,
876
+ force_download=force_download,
877
+ )
878
+
879
+
880
+ def _hf_hub_download_to_cache_dir(
881
+ *,
882
+ # Destination
883
+ cache_dir: str,
884
+ # File info
885
+ repo_id: str,
886
+ filename: str,
887
+ repo_type: str,
888
+ revision: str,
889
+ # HTTP info
890
+ endpoint: Optional[str],
891
+ etag_timeout: float,
892
+ headers: Dict[str, str],
893
+ proxies: Optional[Dict],
894
+ token: Optional[Union[bool, str]],
895
+ # Additional options
896
+ local_files_only: bool,
897
+ force_download: bool,
898
+ ) -> str:
899
+ """Download a given file to a cache folder, if not already present.
900
+
901
+ Method should not be called directly. Please use `hf_hub_download` instead.
902
+ """
903
+ locks_dir = os.path.join(cache_dir, ".locks")
904
+ storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
905
+
906
+ # cross platform transcription of filename, to be used as a local file path.
907
+ relative_filename = os.path.join(*filename.split("/"))
908
+ if os.name == "nt":
909
+ if relative_filename.startswith("..\\") or "\\..\\" in relative_filename:
910
+ raise ValueError(
911
+ f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository"
912
+ " owner to rename this file."
913
+ )
914
+
915
+ # if user provides a commit_hash and they already have the file on disk, shortcut everything.
916
+ if REGEX_COMMIT_HASH.match(revision):
917
+ pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
918
+ if os.path.exists(pointer_path) and not force_download:
919
+ return pointer_path
920
+
921
+ # Try to get metadata (etag, commit_hash, url, size) from the server.
922
+ # If we can't, a HEAD request error is returned.
923
+ (url_to_download, etag, commit_hash, expected_size, head_call_error) = _get_metadata_or_catch_error(
924
+ repo_id=repo_id,
925
+ filename=filename,
926
+ repo_type=repo_type,
927
+ revision=revision,
928
+ endpoint=endpoint,
929
+ proxies=proxies,
930
+ etag_timeout=etag_timeout,
931
+ headers=headers,
932
+ token=token,
933
+ local_files_only=local_files_only,
934
+ storage_folder=storage_folder,
935
+ relative_filename=relative_filename,
936
+ )
937
+
938
+ # etag can be None for several reasons:
939
+ # 1. we passed local_files_only.
940
+ # 2. we don't have a connection
941
+ # 3. Hub is down (HTTP 500, 503, 504)
942
+ # 4. repo is not found -for example private or gated- and invalid/missing token sent
943
+ # 5. Hub is blocked by a firewall or proxy is not set correctly.
944
+ # => Try to get the last downloaded one from the specified revision.
945
+ #
946
+ # If the specified revision is a commit hash, look inside "snapshots".
947
+ # If the specified revision is a branch or tag, look inside "refs".
948
+ if head_call_error is not None:
949
+ # Couldn't make a HEAD call => let's try to find a local file
950
+ if not force_download:
951
+ commit_hash = None
952
+ if REGEX_COMMIT_HASH.match(revision):
953
+ commit_hash = revision
954
+ else:
955
+ ref_path = os.path.join(storage_folder, "refs", revision)
956
+ if os.path.isfile(ref_path):
957
+ with open(ref_path) as f:
958
+ commit_hash = f.read()
959
+
960
+ # Return pointer file if exists
961
+ if commit_hash is not None:
962
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
963
+ if os.path.exists(pointer_path) and not force_download:
964
+ return pointer_path
965
+
966
+ # Otherwise, raise appropriate error
967
+ _raise_on_head_call_error(head_call_error, force_download, local_files_only)
968
+
969
+ # From now on, etag, commit_hash, url and size are not None.
970
+ assert etag is not None, "etag must have been retrieved from server"
971
+ assert commit_hash is not None, "commit_hash must have been retrieved from server"
972
+ assert url_to_download is not None, "file location must have been retrieved from server"
973
+ assert expected_size is not None, "expected_size must have been retrieved from server"
974
+ blob_path = os.path.join(storage_folder, "blobs", etag)
975
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
976
+
977
+ os.makedirs(os.path.dirname(blob_path), exist_ok=True)
978
+ os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
979
+
980
+ # if passed revision is not identical to commit_hash
981
+ # then revision has to be a branch name or tag name.
982
+ # In that case store a ref.
983
+ _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
984
+
985
+ # If file already exists, return it (except if force_download=True)
986
+ if not force_download:
987
+ if os.path.exists(pointer_path):
988
+ return pointer_path
989
+
990
+ if os.path.exists(blob_path):
991
+ # we have the blob already, but not the pointer
992
+ _create_symlink(blob_path, pointer_path, new_blob=False)
993
+ return pointer_path
994
+
995
+ # Prevent parallel downloads of the same file with a lock.
996
+ # etag could be duplicated across repos,
997
+ lock_path = os.path.join(locks_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type), f"{etag}.lock")
998
+
999
+ # Some Windows versions do not allow for paths longer than 255 characters.
1000
+ # In this case, we must specify it as an extended path by using the "\\?\" prefix.
1001
+ if os.name == "nt" and len(os.path.abspath(lock_path)) > 255:
1002
+ lock_path = "\\\\?\\" + os.path.abspath(lock_path)
1003
+
1004
+ if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
1005
+ blob_path = "\\\\?\\" + os.path.abspath(blob_path)
1006
+
1007
+ Path(lock_path).parent.mkdir(parents=True, exist_ok=True)
1008
+ with WeakFileLock(lock_path):
1009
+ _download_to_tmp_and_move(
1010
+ incomplete_path=Path(blob_path + ".incomplete"),
1011
+ destination_path=Path(blob_path),
1012
+ url_to_download=url_to_download,
1013
+ proxies=proxies,
1014
+ headers=headers,
1015
+ expected_size=expected_size,
1016
+ filename=filename,
1017
+ force_download=force_download,
1018
+ )
1019
+ if not os.path.exists(pointer_path):
1020
+ _create_symlink(blob_path, pointer_path, new_blob=True)
1021
+
1022
+ return pointer_path
1023
+
1024
+
1025
+ def _hf_hub_download_to_local_dir(
1026
+ *,
1027
+ # Destination
1028
+ local_dir: Union[str, Path],
1029
+ # File info
1030
+ repo_id: str,
1031
+ repo_type: str,
1032
+ filename: str,
1033
+ revision: str,
1034
+ # HTTP info
1035
+ endpoint: Optional[str],
1036
+ etag_timeout: float,
1037
+ headers: Dict[str, str],
1038
+ proxies: Optional[Dict],
1039
+ token: Union[bool, str, None],
1040
+ # Additional options
1041
+ cache_dir: str,
1042
+ force_download: bool,
1043
+ local_files_only: bool,
1044
+ ) -> str:
1045
+ """Download a given file to a local folder, if not already present.
1046
+
1047
+ Method should not be called directly. Please use `hf_hub_download` instead.
1048
+ """
1049
+ # Some Windows versions do not allow for paths longer than 255 characters.
1050
+ # In this case, we must specify it as an extended path by using the "\\?\" prefix.
1051
+ if os.name == "nt" and len(os.path.abspath(local_dir)) > 255:
1052
+ local_dir = "\\\\?\\" + os.path.abspath(local_dir)
1053
+ local_dir = Path(local_dir)
1054
+ paths = get_local_download_paths(local_dir=local_dir, filename=filename)
1055
+ local_metadata = read_download_metadata(local_dir=local_dir, filename=filename)
1056
+
1057
+ # Local file exists + metadata exists + commit_hash matches => return file
1058
+ if (
1059
+ not force_download
1060
+ and REGEX_COMMIT_HASH.match(revision)
1061
+ and paths.file_path.is_file()
1062
+ and local_metadata is not None
1063
+ and local_metadata.commit_hash == revision
1064
+ ):
1065
+ return str(paths.file_path)
1066
+
1067
+ # Local file doesn't exist or commit_hash doesn't match => we need the etag
1068
+ (url_to_download, etag, commit_hash, expected_size, head_call_error) = _get_metadata_or_catch_error(
1069
+ repo_id=repo_id,
1070
+ filename=filename,
1071
+ repo_type=repo_type,
1072
+ revision=revision,
1073
+ endpoint=endpoint,
1074
+ proxies=proxies,
1075
+ etag_timeout=etag_timeout,
1076
+ headers=headers,
1077
+ token=token,
1078
+ local_files_only=local_files_only,
1079
+ )
1080
+
1081
+ if head_call_error is not None:
1082
+ # No HEAD call but local file exists => default to local file
1083
+ if not force_download and paths.file_path.is_file():
1084
+ logger.warning(
1085
+ f"Couldn't access the Hub to check for update but local file already exists. Defaulting to existing file. (error: {head_call_error})"
1086
+ )
1087
+ return str(paths.file_path)
1088
+ # Otherwise => raise
1089
+ _raise_on_head_call_error(head_call_error, force_download, local_files_only)
1090
+
1091
+ # From now on, etag, commit_hash, url and size are not None.
1092
+ assert etag is not None, "etag must have been retrieved from server"
1093
+ assert commit_hash is not None, "commit_hash must have been retrieved from server"
1094
+ assert url_to_download is not None, "file location must have been retrieved from server"
1095
+ assert expected_size is not None, "expected_size must have been retrieved from server"
1096
+
1097
+ # Local file exists => check if it's up-to-date
1098
+ if not force_download and paths.file_path.is_file():
1099
+ # etag matches => update metadata and return file
1100
+ if local_metadata is not None and local_metadata.etag == etag:
1101
+ write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
1102
+ return str(paths.file_path)
1103
+
1104
+ # metadata is outdated + etag is a sha256
1105
+ # => means it's an LFS file (large)
1106
+ # => let's compute local hash and compare
1107
+ # => if match, update metadata and return file
1108
+ if local_metadata is None and REGEX_SHA256.match(etag) is not None:
1109
+ with open(paths.file_path, "rb") as f:
1110
+ file_hash = sha_fileobj(f).hex()
1111
+ if file_hash == etag:
1112
+ write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
1113
+ return str(paths.file_path)
1114
+
1115
+ # Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache)
1116
+
1117
+ # If we are lucky enough, the file is already in the cache => copy it
1118
+ if not force_download:
1119
+ cached_path = try_to_load_from_cache(
1120
+ repo_id=repo_id,
1121
+ filename=filename,
1122
+ cache_dir=cache_dir,
1123
+ revision=commit_hash,
1124
+ repo_type=repo_type,
1125
+ )
1126
+ if isinstance(cached_path, str):
1127
+ with WeakFileLock(paths.lock_path):
1128
+ paths.file_path.parent.mkdir(parents=True, exist_ok=True)
1129
+ shutil.copyfile(cached_path, paths.file_path)
1130
+ write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
1131
+ return str(paths.file_path)
1132
+
1133
+ # Otherwise, let's download the file!
1134
+ with WeakFileLock(paths.lock_path):
1135
+ paths.file_path.unlink(missing_ok=True) # delete outdated file first
1136
+ _download_to_tmp_and_move(
1137
+ incomplete_path=paths.incomplete_path(etag),
1138
+ destination_path=paths.file_path,
1139
+ url_to_download=url_to_download,
1140
+ proxies=proxies,
1141
+ headers=headers,
1142
+ expected_size=expected_size,
1143
+ filename=filename,
1144
+ force_download=force_download,
1145
+ )
1146
+
1147
+ write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag)
1148
+ return str(paths.file_path)
1149
+
1150
+
1151
+ @validate_hf_hub_args
1152
+ def try_to_load_from_cache(
1153
+ repo_id: str,
1154
+ filename: str,
1155
+ cache_dir: Union[str, Path, None] = None,
1156
+ revision: Optional[str] = None,
1157
+ repo_type: Optional[str] = None,
1158
+ ) -> Union[str, _CACHED_NO_EXIST_T, None]:
1159
+ """
1160
+ Explores the cache to return the latest cached file for a given revision if found.
1161
+
1162
+ This function will not raise any exception if the file in not cached.
1163
+
1164
+ Args:
1165
+ cache_dir (`str` or `os.PathLike`):
1166
+ The folder where the cached files lie.
1167
+ repo_id (`str`):
1168
+ The ID of the repo on huggingface.co.
1169
+ filename (`str`):
1170
+ The filename to look for inside `repo_id`.
1171
+ revision (`str`, *optional*):
1172
+ The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
1173
+ provided either.
1174
+ repo_type (`str`, *optional*):
1175
+ The type of the repository. Will default to `"model"`.
1176
+
1177
+ Returns:
1178
+ `Optional[str]` or `_CACHED_NO_EXIST`:
1179
+ Will return `None` if the file was not cached. Otherwise:
1180
+ - The exact path to the cached file if it's found in the cache
1181
+ - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
1182
+ cached.
1183
+
1184
+ Example:
1185
+
1186
+ ```python
1187
+ from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST
1188
+
1189
+ filepath = try_to_load_from_cache()
1190
+ if isinstance(filepath, str):
1191
+ # file exists and is cached
1192
+ ...
1193
+ elif filepath is _CACHED_NO_EXIST:
1194
+ # non-existence of file is cached
1195
+ ...
1196
+ else:
1197
+ # file is not cached
1198
+ ...
1199
+ ```
1200
+ """
1201
+ if revision is None:
1202
+ revision = "main"
1203
+ if repo_type is None:
1204
+ repo_type = "model"
1205
+ if repo_type not in constants.REPO_TYPES:
1206
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
1207
+ if cache_dir is None:
1208
+ cache_dir = constants.HF_HUB_CACHE
1209
+
1210
+ object_id = repo_id.replace("/", "--")
1211
+ repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
1212
+ if not os.path.isdir(repo_cache):
1213
+ # No cache for this model
1214
+ return None
1215
+
1216
+ refs_dir = os.path.join(repo_cache, "refs")
1217
+ snapshots_dir = os.path.join(repo_cache, "snapshots")
1218
+ no_exist_dir = os.path.join(repo_cache, ".no_exist")
1219
+
1220
+ # Resolve refs (for instance to convert main to the associated commit sha)
1221
+ if os.path.isdir(refs_dir):
1222
+ revision_file = os.path.join(refs_dir, revision)
1223
+ if os.path.isfile(revision_file):
1224
+ with open(revision_file) as f:
1225
+ revision = f.read()
1226
+
1227
+ # Check if file is cached as "no_exist"
1228
+ if os.path.isfile(os.path.join(no_exist_dir, revision, filename)):
1229
+ return _CACHED_NO_EXIST
1230
+
1231
+ # Check if revision folder exists
1232
+ if not os.path.exists(snapshots_dir):
1233
+ return None
1234
+ cached_shas = os.listdir(snapshots_dir)
1235
+ if revision not in cached_shas:
1236
+ # No cache for this revision and we won't try to return a random revision
1237
+ return None
1238
+
1239
+ # Check if file exists in cache
1240
+ cached_file = os.path.join(snapshots_dir, revision, filename)
1241
+ return cached_file if os.path.isfile(cached_file) else None
1242
+
1243
+
1244
+ @validate_hf_hub_args
1245
+ def get_hf_file_metadata(
1246
+ url: str,
1247
+ token: Union[bool, str, None] = None,
1248
+ proxies: Optional[Dict] = None,
1249
+ timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT,
1250
+ library_name: Optional[str] = None,
1251
+ library_version: Optional[str] = None,
1252
+ user_agent: Union[Dict, str, None] = None,
1253
+ headers: Optional[Dict[str, str]] = None,
1254
+ ) -> HfFileMetadata:
1255
+ """Fetch metadata of a file versioned on the Hub for a given url.
1256
+
1257
+ Args:
1258
+ url (`str`):
1259
+ File url, for example returned by [`hf_hub_url`].
1260
+ token (`str` or `bool`, *optional*):
1261
+ A token to be used for the download.
1262
+ - If `True`, the token is read from the HuggingFace config
1263
+ folder.
1264
+ - If `False` or `None`, no token is provided.
1265
+ - If a string, it's used as the authentication token.
1266
+ proxies (`dict`, *optional*):
1267
+ Dictionary mapping protocol to the URL of the proxy passed to
1268
+ `requests.request`.
1269
+ timeout (`float`, *optional*, defaults to 10):
1270
+ How many seconds to wait for the server to send metadata before giving up.
1271
+ library_name (`str`, *optional*):
1272
+ The name of the library to which the object corresponds.
1273
+ library_version (`str`, *optional*):
1274
+ The version of the library.
1275
+ user_agent (`dict`, `str`, *optional*):
1276
+ The user-agent info in the form of a dictionary or a string.
1277
+ headers (`dict`, *optional*):
1278
+ Additional headers to be sent with the request.
1279
+
1280
+ Returns:
1281
+ A [`HfFileMetadata`] object containing metadata such as location, etag, size and
1282
+ commit_hash.
1283
+ """
1284
+ hf_headers = build_hf_headers(
1285
+ token=token,
1286
+ library_name=library_name,
1287
+ library_version=library_version,
1288
+ user_agent=user_agent,
1289
+ headers=headers,
1290
+ )
1291
+ hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
1292
+
1293
+ # Retrieve metadata
1294
+ r = _request_wrapper(
1295
+ method="HEAD",
1296
+ url=url,
1297
+ headers=hf_headers,
1298
+ allow_redirects=False,
1299
+ follow_relative_redirects=True,
1300
+ proxies=proxies,
1301
+ timeout=timeout,
1302
+ )
1303
+ hf_raise_for_status(r)
1304
+
1305
+ # Return
1306
+ return HfFileMetadata(
1307
+ commit_hash=r.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT),
1308
+ # We favor a custom header indicating the etag of the linked resource, and
1309
+ # we fallback to the regular etag header.
1310
+ etag=_normalize_etag(r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")),
1311
+ # Either from response headers (if redirected) or defaults to request url
1312
+ # Do not use directly `url`, as `_request_wrapper` might have followed relative
1313
+ # redirects.
1314
+ location=r.headers.get("Location") or r.request.url, # type: ignore
1315
+ size=_int_or_none(
1316
+ r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")
1317
+ ),
1318
+ )
1319
+
1320
+
1321
+ def _get_metadata_or_catch_error(
1322
+ *,
1323
+ repo_id: str,
1324
+ filename: str,
1325
+ repo_type: str,
1326
+ revision: str,
1327
+ endpoint: Optional[str],
1328
+ proxies: Optional[Dict],
1329
+ etag_timeout: Optional[float],
1330
+ headers: Dict[str, str], # mutated inplace!
1331
+ token: Union[bool, str, None],
1332
+ local_files_only: bool,
1333
+ relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache
1334
+ storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache
1335
+ ) -> Union[
1336
+ # Either an exception is caught and returned
1337
+ Tuple[None, None, None, None, Exception],
1338
+ # Or the metadata is returned as
1339
+ # `(url_to_download, etag, commit_hash, expected_size, None)`
1340
+ Tuple[str, str, str, int, None],
1341
+ ]:
1342
+ """Get metadata for a file on the Hub, safely handling network issues.
1343
+
1344
+ Returns either the etag, commit_hash and expected size of the file, or the error
1345
+ raised while fetching the metadata.
1346
+
1347
+ NOTE: This function mutates `headers` inplace! It removes the `authorization` header
1348
+ if the file is a LFS blob and the domain of the url is different from the
1349
+ domain of the location (typically an S3 bucket).
1350
+ """
1351
+ if local_files_only:
1352
+ return (
1353
+ None,
1354
+ None,
1355
+ None,
1356
+ None,
1357
+ OfflineModeIsEnabled(
1358
+ f"Cannot access file since 'local_files_only=True' as been set. (repo_id: {repo_id}, repo_type: {repo_type}, revision: {revision}, filename: {filename})"
1359
+ ),
1360
+ )
1361
+
1362
+ url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint)
1363
+ url_to_download: str = url
1364
+ etag: Optional[str] = None
1365
+ commit_hash: Optional[str] = None
1366
+ expected_size: Optional[int] = None
1367
+ head_error_call: Optional[Exception] = None
1368
+
1369
+ # Try to get metadata from the server.
1370
+ # Do not raise yet if the file is not found or not accessible.
1371
+ if not local_files_only:
1372
+ try:
1373
+ try:
1374
+ metadata = get_hf_file_metadata(
1375
+ url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token
1376
+ )
1377
+ except EntryNotFoundError as http_error:
1378
+ if storage_folder is not None and relative_filename is not None:
1379
+ # Cache the non-existence of the file
1380
+ commit_hash = http_error.response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT)
1381
+ if commit_hash is not None:
1382
+ no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename
1383
+ try:
1384
+ no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
1385
+ no_exist_file_path.touch()
1386
+ except OSError as e:
1387
+ logger.error(
1388
+ f"Could not cache non-existence of file. Will ignore error and continue. Error: {e}"
1389
+ )
1390
+ _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
1391
+ raise
1392
+
1393
+ # Commit hash must exist
1394
+ commit_hash = metadata.commit_hash
1395
+ if commit_hash is None:
1396
+ raise FileMetadataError(
1397
+ "Distant resource does not seem to be on huggingface.co. It is possible that a configuration issue"
1398
+ " prevents you from downloading resources from https://huggingface.co. Please check your firewall"
1399
+ " and proxy settings and make sure your SSL certificates are updated."
1400
+ )
1401
+
1402
+ # Etag must exist
1403
+ # If we don't have any of those, raise an error.
1404
+ etag = metadata.etag
1405
+ if etag is None:
1406
+ raise FileMetadataError(
1407
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
1408
+ )
1409
+
1410
+ # Size must exist
1411
+ expected_size = metadata.size
1412
+ if expected_size is None:
1413
+ raise FileMetadataError("Distant resource does not have a Content-Length.")
1414
+
1415
+ # In case of a redirect, save an extra redirect on the request.get call,
1416
+ # and ensure we download the exact atomic version even if it changed
1417
+ # between the HEAD and the GET (unlikely, but hey).
1418
+ #
1419
+ # If url domain is different => we are downloading from a CDN => url is signed => don't send auth
1420
+ # If url domain is the same => redirect due to repo rename AND downloading a regular file => keep auth
1421
+ if url != metadata.location:
1422
+ url_to_download = metadata.location
1423
+ if urlparse(url).netloc != urlparse(metadata.location).netloc:
1424
+ # Remove authorization header when downloading a LFS blob
1425
+ headers.pop("authorization", None)
1426
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
1427
+ # Actually raise for those subclasses of ConnectionError
1428
+ raise
1429
+ except (
1430
+ requests.exceptions.ConnectionError,
1431
+ requests.exceptions.Timeout,
1432
+ OfflineModeIsEnabled,
1433
+ ) as error:
1434
+ # Otherwise, our Internet connection is down.
1435
+ # etag is None
1436
+ head_error_call = error
1437
+ except (RevisionNotFoundError, EntryNotFoundError):
1438
+ # The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted)
1439
+ raise
1440
+ except requests.HTTPError as error:
1441
+ # Multiple reasons for an http error:
1442
+ # - Repository is private and invalid/missing token sent
1443
+ # - Repository is gated and invalid/missing token sent
1444
+ # - Hub is down (error 500 or 504)
1445
+ # => let's switch to 'local_files_only=True' to check if the files are already cached.
1446
+ # (if it's not the case, the error will be re-raised)
1447
+ head_error_call = error
1448
+ except FileMetadataError as error:
1449
+ # Multiple reasons for a FileMetadataError:
1450
+ # - Wrong network configuration (proxy, firewall, SSL certificates)
1451
+ # - Inconsistency on the Hub
1452
+ # => let's switch to 'local_files_only=True' to check if the files are already cached.
1453
+ # (if it's not the case, the error will be re-raised)
1454
+ head_error_call = error
1455
+
1456
+ if not (local_files_only or etag is not None or head_error_call is not None):
1457
+ raise RuntimeError("etag is empty due to uncovered problems")
1458
+
1459
+ return (url_to_download, etag, commit_hash, expected_size, head_error_call) # type: ignore [return-value]
1460
+
1461
+
1462
+ def _raise_on_head_call_error(head_call_error: Exception, force_download: bool, local_files_only: bool) -> NoReturn:
1463
+ """Raise an appropriate error when the HEAD call failed and we cannot locate a local file."""
1464
+
1465
+ # No head call => we cannot force download.
1466
+ if force_download:
1467
+ if local_files_only:
1468
+ raise ValueError("Cannot pass 'force_download=True' and 'local_files_only=True' at the same time.")
1469
+ elif isinstance(head_call_error, OfflineModeIsEnabled):
1470
+ raise ValueError("Cannot pass 'force_download=True' when offline mode is enabled.") from head_call_error
1471
+ else:
1472
+ raise ValueError("Force download failed due to the above error.") from head_call_error
1473
+
1474
+ # No head call + couldn't find an appropriate file on disk => raise an error.
1475
+ if local_files_only:
1476
+ raise LocalEntryNotFoundError(
1477
+ "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable"
1478
+ " hf.co look-ups and downloads online, set 'local_files_only' to False."
1479
+ )
1480
+ elif isinstance(head_call_error, RepositoryNotFoundError) or isinstance(head_call_error, GatedRepoError):
1481
+ # Repo not found or gated => let's raise the actual error
1482
+ raise head_call_error
1483
+ else:
1484
+ # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
1485
+ raise LocalEntryNotFoundError(
1486
+ "An error happened while trying to locate the file on the Hub and we cannot find the requested files"
1487
+ " in the local cache. Please check your connection and try again or make sure your Internet connection"
1488
+ " is on."
1489
+ ) from head_call_error
1490
+
1491
+
1492
+ def _download_to_tmp_and_move(
1493
+ incomplete_path: Path,
1494
+ destination_path: Path,
1495
+ url_to_download: str,
1496
+ proxies: Optional[Dict],
1497
+ headers: Dict[str, str],
1498
+ expected_size: Optional[int],
1499
+ filename: str,
1500
+ force_download: bool,
1501
+ ) -> None:
1502
+ """Download content from a URL to a destination path.
1503
+
1504
+ Internal logic:
1505
+ - return early if file is already downloaded
1506
+ - resume download if possible (from incomplete file)
1507
+ - do not resume download if `force_download=True` or `HF_HUB_ENABLE_HF_TRANSFER=True`
1508
+ - check disk space before downloading
1509
+ - download content to a temporary file
1510
+ - set correct permissions on temporary file
1511
+ - move the temporary file to the destination path
1512
+
1513
+ Both `incomplete_path` and `destination_path` must be on the same volume to avoid a local copy.
1514
+ """
1515
+ if destination_path.exists() and not force_download:
1516
+ # Do nothing if already exists (except if force_download=True)
1517
+ return
1518
+
1519
+ if incomplete_path.exists() and (force_download or (constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies)):
1520
+ # By default, we will try to resume the download if possible.
1521
+ # However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should
1522
+ # not resume the download => delete the incomplete file.
1523
+ message = f"Removing incomplete file '{incomplete_path}'"
1524
+ if force_download:
1525
+ message += " (force_download=True)"
1526
+ elif constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies:
1527
+ message += " (hf_transfer=True)"
1528
+ logger.info(message)
1529
+ incomplete_path.unlink(missing_ok=True)
1530
+
1531
+ with incomplete_path.open("ab") as f:
1532
+ resume_size = f.tell()
1533
+ message = f"Downloading '{filename}' to '{incomplete_path}'"
1534
+ if resume_size > 0 and expected_size is not None:
1535
+ message += f" (resume from {resume_size}/{expected_size})"
1536
+ logger.info(message)
1537
+
1538
+ if expected_size is not None: # might be None if HTTP header not set correctly
1539
+ # Check disk space in both tmp and destination path
1540
+ _check_disk_space(expected_size, incomplete_path.parent)
1541
+ _check_disk_space(expected_size, destination_path.parent)
1542
+
1543
+ http_get(
1544
+ url_to_download,
1545
+ f,
1546
+ proxies=proxies,
1547
+ resume_size=resume_size,
1548
+ headers=headers,
1549
+ expected_size=expected_size,
1550
+ )
1551
+
1552
+ logger.info(f"Download complete. Moving file to {destination_path}")
1553
+ _chmod_and_move(incomplete_path, destination_path)
1554
+
1555
+
1556
+ def _int_or_none(value: Optional[str]) -> Optional[int]:
1557
+ try:
1558
+ return int(value) # type: ignore
1559
+ except (TypeError, ValueError):
1560
+ return None
1561
+
1562
+
1563
+ def _chmod_and_move(src: Path, dst: Path) -> None:
1564
+ """Set correct permission before moving a blob from tmp directory to cache dir.
1565
+
1566
+ Do not take into account the `umask` from the process as there is no convenient way
1567
+ to get it that is thread-safe.
1568
+
1569
+ See:
1570
+ - About umask: https://docs.python.org/3/library/os.html#os.umask
1571
+ - Thread-safety: https://stackoverflow.com/a/70343066
1572
+ - About solution: https://github.com/huggingface/huggingface_hub/pull/1220#issuecomment-1326211591
1573
+ - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1141
1574
+ - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1215
1575
+ """
1576
+ # Get umask by creating a temporary file in the cached repo folder.
1577
+ tmp_file = dst.parent.parent / f"tmp_{uuid.uuid4()}"
1578
+ try:
1579
+ tmp_file.touch()
1580
+ cache_dir_mode = Path(tmp_file).stat().st_mode
1581
+ os.chmod(str(src), stat.S_IMODE(cache_dir_mode))
1582
+ except OSError as e:
1583
+ logger.warning(
1584
+ f"Could not set the permissions on the file '{src}'. Error: {e}.\nContinuing without setting permissions."
1585
+ )
1586
+ finally:
1587
+ try:
1588
+ tmp_file.unlink()
1589
+ except OSError:
1590
+ # fails if `tmp_file.touch()` failed => do nothing
1591
+ # See https://github.com/huggingface/huggingface_hub/issues/2359
1592
+ pass
1593
+
1594
+ shutil.move(str(src), str(dst), copy_function=_copy_no_matter_what)
1595
+
1596
+
1597
+ def _copy_no_matter_what(src: str, dst: str) -> None:
1598
+ """Copy file from src to dst.
1599
+
1600
+ If `shutil.copy2` fails, fallback to `shutil.copyfile`.
1601
+ """
1602
+ try:
1603
+ # Copy file with metadata and permission
1604
+ # Can fail e.g. if dst is an S3 mount
1605
+ shutil.copy2(src, dst)
1606
+ except OSError:
1607
+ # Copy only file content
1608
+ shutil.copyfile(src, dst)
1609
+
1610
+
1611
+ def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str:
1612
+ # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
1613
+ snapshot_path = os.path.join(storage_folder, "snapshots")
1614
+ pointer_path = os.path.join(snapshot_path, revision, relative_filename)
1615
+ if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents:
1616
+ raise ValueError(
1617
+ "Invalid pointer path: cannot create pointer path in snapshot folder if"
1618
+ f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and"
1619
+ f" `relative_filename='{relative_filename}'`."
1620
+ )
1621
+ return pointer_path
.venv/lib/python3.11/site-packages/huggingface_hub/hf_api.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/huggingface_hub/hf_file_system.py ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import tempfile
4
+ from collections import deque
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from itertools import chain
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Iterator, List, NoReturn, Optional, Tuple, Union
10
+ from urllib.parse import quote, unquote
11
+
12
+ import fsspec
13
+ from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
14
+ from fsspec.utils import isfilelike
15
+ from requests import Response
16
+
17
+ from . import constants
18
+ from ._commit_api import CommitOperationCopy, CommitOperationDelete
19
+ from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
20
+ from .file_download import hf_hub_url, http_get
21
+ from .hf_api import HfApi, LastCommitInfo, RepoFile
22
+ from .utils import HFValidationError, hf_raise_for_status, http_backoff
23
+
24
+
25
+ # Regex used to match special revisions with "/" in them (see #1710)
26
+ SPECIAL_REFS_REVISION_REGEX = re.compile(
27
+ r"""
28
+ (^refs\/convert\/\w+) # `refs/convert/parquet` revisions
29
+ |
30
+ (^refs\/pr\/\d+) # PR revisions
31
+ """,
32
+ re.VERBOSE,
33
+ )
34
+
35
+
36
+ @dataclass
37
+ class HfFileSystemResolvedPath:
38
+ """Data structure containing information about a resolved Hugging Face file system path."""
39
+
40
+ repo_type: str
41
+ repo_id: str
42
+ revision: str
43
+ path_in_repo: str
44
+ # The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision.
45
+ # Used to reconstruct the unresolved path to return to the user.
46
+ _raw_revision: Optional[str] = field(default=None, repr=False)
47
+
48
+ def unresolve(self) -> str:
49
+ repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id
50
+ if self._raw_revision:
51
+ return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/")
52
+ elif self.revision != constants.DEFAULT_REVISION:
53
+ return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/")
54
+ else:
55
+ return f"{repo_path}/{self.path_in_repo}".rstrip("/")
56
+
57
+
58
+ class HfFileSystem(fsspec.AbstractFileSystem):
59
+ """
60
+ Access a remote Hugging Face Hub repository as if were a local file system.
61
+
62
+ <Tip warning={true}>
63
+
64
+ [`HfFileSystem`] provides fsspec compatibility, which is useful for libraries that require it (e.g., reading
65
+ Hugging Face datasets directly with `pandas`). However, it introduces additional overhead due to this compatibility
66
+ layer. For better performance and reliability, it's recommended to use `HfApi` methods when possible.
67
+
68
+ </Tip>
69
+
70
+ Args:
71
+ token (`str` or `bool`, *optional*):
72
+ A valid user access token (string). Defaults to the locally saved
73
+ token, which is the recommended method for authentication (see
74
+ https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
75
+ To disable authentication, pass `False`.
76
+ endpoint (`str`, *optional*):
77
+ Endpoint of the Hub. Defaults to <https://huggingface.co>.
78
+ Usage:
79
+
80
+ ```python
81
+ >>> from huggingface_hub import HfFileSystem
82
+
83
+ >>> fs = HfFileSystem()
84
+
85
+ >>> # List files
86
+ >>> fs.glob("my-username/my-model/*.bin")
87
+ ['my-username/my-model/pytorch_model.bin']
88
+ >>> fs.ls("datasets/my-username/my-dataset", detail=False)
89
+ ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json']
90
+
91
+ >>> # Read/write files
92
+ >>> with fs.open("my-username/my-model/pytorch_model.bin") as f:
93
+ ... data = f.read()
94
+ >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
95
+ ... f.write(data)
96
+ ```
97
+ """
98
+
99
+ root_marker = ""
100
+ protocol = "hf"
101
+
102
+ def __init__(
103
+ self,
104
+ *args,
105
+ endpoint: Optional[str] = None,
106
+ token: Union[bool, str, None] = None,
107
+ **storage_options,
108
+ ):
109
+ super().__init__(*args, **storage_options)
110
+ self.endpoint = endpoint or constants.ENDPOINT
111
+ self.token = token
112
+ self._api = HfApi(endpoint=endpoint, token=token)
113
+ # Maps (repo_type, repo_id, revision) to a 2-tuple with:
114
+ # * the 1st element indicating whether the repositoy and the revision exist
115
+ # * the 2nd element being the exception raised if the repository or revision doesn't exist
116
+ self._repo_and_revision_exists_cache: Dict[
117
+ Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]]
118
+ ] = {}
119
+
120
+ def _repo_and_revision_exist(
121
+ self, repo_type: str, repo_id: str, revision: Optional[str]
122
+ ) -> Tuple[bool, Optional[Exception]]:
123
+ if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
124
+ try:
125
+ self._api.repo_info(
126
+ repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT
127
+ )
128
+ except (RepositoryNotFoundError, HFValidationError) as e:
129
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
130
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e
131
+ except RevisionNotFoundError as e:
132
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
133
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
134
+ else:
135
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None
136
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
137
+ return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)]
138
+
139
+ def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath:
140
+ """
141
+ Resolve a Hugging Face file system path into its components.
142
+
143
+ Args:
144
+ path (`str`):
145
+ Path to resolve.
146
+ revision (`str`, *optional*):
147
+ The revision of the repo to resolve. Defaults to the revision specified in the path.
148
+
149
+ Returns:
150
+ [`HfFileSystemResolvedPath`]: Resolved path information containing `repo_type`, `repo_id`, `revision` and `path_in_repo`.
151
+
152
+ Raises:
153
+ `ValueError`:
154
+ If path contains conflicting revision information.
155
+ `NotImplementedError`:
156
+ If trying to list repositories.
157
+ """
158
+
159
+ def _align_revision_in_path_with_revision(
160
+ revision_in_path: Optional[str], revision: Optional[str]
161
+ ) -> Optional[str]:
162
+ if revision is not None:
163
+ if revision_in_path is not None and revision_in_path != revision:
164
+ raise ValueError(
165
+ f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")'
166
+ " are not the same."
167
+ )
168
+ else:
169
+ revision = revision_in_path
170
+ return revision
171
+
172
+ path = self._strip_protocol(path)
173
+ if not path:
174
+ # can't list repositories at root
175
+ raise NotImplementedError("Access to repositories lists is not implemented.")
176
+ elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values():
177
+ if "/" not in path:
178
+ # can't list repositories at the repository type level
179
+ raise NotImplementedError("Access to repositories lists is not implemented.")
180
+ repo_type, path = path.split("/", 1)
181
+ repo_type = constants.REPO_TYPES_MAPPING[repo_type]
182
+ else:
183
+ repo_type = constants.REPO_TYPE_MODEL
184
+ if path.count("/") > 0:
185
+ if "@" in path:
186
+ repo_id, revision_in_path = path.split("@", 1)
187
+ if "/" in revision_in_path:
188
+ match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path)
189
+ if match is not None and revision in (None, match.group()):
190
+ # Handle `refs/convert/parquet` and PR revisions separately
191
+ path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/")
192
+ revision_in_path = match.group()
193
+ else:
194
+ revision_in_path, path_in_repo = revision_in_path.split("/", 1)
195
+ else:
196
+ path_in_repo = ""
197
+ revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
198
+ repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
199
+ if not repo_and_revision_exist:
200
+ _raise_file_not_found(path, err)
201
+ else:
202
+ revision_in_path = None
203
+ repo_id_with_namespace = "/".join(path.split("/")[:2])
204
+ path_in_repo_with_namespace = "/".join(path.split("/")[2:])
205
+ repo_id_without_namespace = path.split("/")[0]
206
+ path_in_repo_without_namespace = "/".join(path.split("/")[1:])
207
+ repo_id = repo_id_with_namespace
208
+ path_in_repo = path_in_repo_with_namespace
209
+ repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
210
+ if not repo_and_revision_exist:
211
+ if isinstance(err, (RepositoryNotFoundError, HFValidationError)):
212
+ repo_id = repo_id_without_namespace
213
+ path_in_repo = path_in_repo_without_namespace
214
+ repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
215
+ if not repo_and_revision_exist:
216
+ _raise_file_not_found(path, err)
217
+ else:
218
+ _raise_file_not_found(path, err)
219
+ else:
220
+ repo_id = path
221
+ path_in_repo = ""
222
+ if "@" in path:
223
+ repo_id, revision_in_path = path.split("@", 1)
224
+ revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
225
+ else:
226
+ revision_in_path = None
227
+ repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
228
+ if not repo_and_revision_exist:
229
+ raise NotImplementedError("Access to repositories lists is not implemented.")
230
+
231
+ revision = revision if revision is not None else constants.DEFAULT_REVISION
232
+ return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path)
233
+
234
+ def invalidate_cache(self, path: Optional[str] = None) -> None:
235
+ """
236
+ Clear the cache for a given path.
237
+
238
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.invalidate_cache).
239
+
240
+ Args:
241
+ path (`str`, *optional*):
242
+ Path to clear from cache. If not provided, clear the entire cache.
243
+
244
+ """
245
+ if not path:
246
+ self.dircache.clear()
247
+ self._repo_and_revision_exists_cache.clear()
248
+ else:
249
+ resolved_path = self.resolve_path(path)
250
+ path = resolved_path.unresolve()
251
+ while path:
252
+ self.dircache.pop(path, None)
253
+ path = self._parent(path)
254
+
255
+ # Only clear repo cache if path is to repo root
256
+ if not resolved_path.path_in_repo:
257
+ self._repo_and_revision_exists_cache.pop((resolved_path.repo_type, resolved_path.repo_id, None), None)
258
+ self._repo_and_revision_exists_cache.pop(
259
+ (resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision), None
260
+ )
261
+
262
+ def _open(
263
+ self,
264
+ path: str,
265
+ mode: str = "rb",
266
+ revision: Optional[str] = None,
267
+ block_size: Optional[int] = None,
268
+ **kwargs,
269
+ ) -> "HfFileSystemFile":
270
+ if "a" in mode:
271
+ raise NotImplementedError("Appending to remote files is not yet supported.")
272
+ if block_size == 0:
273
+ return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
274
+ else:
275
+ return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs)
276
+
277
+ def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None:
278
+ resolved_path = self.resolve_path(path, revision=revision)
279
+ self._api.delete_file(
280
+ path_in_repo=resolved_path.path_in_repo,
281
+ repo_id=resolved_path.repo_id,
282
+ token=self.token,
283
+ repo_type=resolved_path.repo_type,
284
+ revision=resolved_path.revision,
285
+ commit_message=kwargs.get("commit_message"),
286
+ commit_description=kwargs.get("commit_description"),
287
+ )
288
+ self.invalidate_cache(path=resolved_path.unresolve())
289
+
290
+ def rm(
291
+ self,
292
+ path: str,
293
+ recursive: bool = False,
294
+ maxdepth: Optional[int] = None,
295
+ revision: Optional[str] = None,
296
+ **kwargs,
297
+ ) -> None:
298
+ """
299
+ Delete files from a repository.
300
+
301
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.rm).
302
+
303
+ <Tip warning={true}>
304
+
305
+ Note: When possible, use `HfApi.delete_file()` for better performance.
306
+
307
+ </Tip>
308
+
309
+ Args:
310
+ path (`str`):
311
+ Path to delete.
312
+ recursive (`bool`, *optional*):
313
+ If True, delete directory and all its contents. Defaults to False.
314
+ maxdepth (`int`, *optional*):
315
+ Maximum number of subdirectories to visit when deleting recursively.
316
+ revision (`str`, *optional*):
317
+ The git revision to delete from.
318
+
319
+ """
320
+ resolved_path = self.resolve_path(path, revision=revision)
321
+ paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision)
322
+ paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)]
323
+ operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo]
324
+ commit_message = f"Delete {path} "
325
+ commit_message += "recursively " if recursive else ""
326
+ commit_message += f"up to depth {maxdepth} " if maxdepth is not None else ""
327
+ # TODO: use `commit_description` to list all the deleted paths?
328
+ self._api.create_commit(
329
+ repo_id=resolved_path.repo_id,
330
+ repo_type=resolved_path.repo_type,
331
+ token=self.token,
332
+ operations=operations,
333
+ revision=resolved_path.revision,
334
+ commit_message=kwargs.get("commit_message", commit_message),
335
+ commit_description=kwargs.get("commit_description"),
336
+ )
337
+ self.invalidate_cache(path=resolved_path.unresolve())
338
+
339
+ def ls(
340
+ self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
341
+ ) -> List[Union[str, Dict[str, Any]]]:
342
+ """
343
+ List the contents of a directory.
344
+
345
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.ls).
346
+
347
+ <Tip warning={true}>
348
+
349
+ Note: When possible, use `HfApi.list_repo_tree()` for better performance.
350
+
351
+ </Tip>
352
+
353
+ Args:
354
+ path (`str`):
355
+ Path to the directory.
356
+ detail (`bool`, *optional*):
357
+ If True, returns a list of dictionaries containing file information. If False,
358
+ returns a list of file paths. Defaults to True.
359
+ refresh (`bool`, *optional*):
360
+ If True, bypass the cache and fetch the latest data. Defaults to False.
361
+ revision (`str`, *optional*):
362
+ The git revision to list from.
363
+
364
+ Returns:
365
+ `List[Union[str, Dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
366
+ dictionaries (if detail=True).
367
+ """
368
+ resolved_path = self.resolve_path(path, revision=revision)
369
+ path = resolved_path.unresolve()
370
+ kwargs = {"expand_info": detail, **kwargs}
371
+ try:
372
+ out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs)
373
+ except EntryNotFoundError:
374
+ # Path could be a file
375
+ if not resolved_path.path_in_repo:
376
+ _raise_file_not_found(path, None)
377
+ out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs)
378
+ out = [o for o in out if o["name"] == path]
379
+ if len(out) == 0:
380
+ _raise_file_not_found(path, None)
381
+ return out if detail else [o["name"] for o in out]
382
+
383
+ def _ls_tree(
384
+ self,
385
+ path: str,
386
+ recursive: bool = False,
387
+ refresh: bool = False,
388
+ revision: Optional[str] = None,
389
+ expand_info: bool = True,
390
+ ):
391
+ resolved_path = self.resolve_path(path, revision=revision)
392
+ path = resolved_path.unresolve()
393
+ root_path = HfFileSystemResolvedPath(
394
+ resolved_path.repo_type,
395
+ resolved_path.repo_id,
396
+ resolved_path.revision,
397
+ path_in_repo="",
398
+ _raw_revision=resolved_path._raw_revision,
399
+ ).unresolve()
400
+
401
+ out = []
402
+ if path in self.dircache and not refresh:
403
+ cached_path_infos = self.dircache[path]
404
+ out.extend(cached_path_infos)
405
+ dirs_not_in_dircache = []
406
+ if recursive:
407
+ # Use BFS to traverse the cache and build the "recursive "output
408
+ # (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same)
409
+ dirs_to_visit = deque(
410
+ [path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
411
+ )
412
+ while dirs_to_visit:
413
+ dir_info = dirs_to_visit.popleft()
414
+ if dir_info["name"] not in self.dircache:
415
+ dirs_not_in_dircache.append(dir_info["name"])
416
+ else:
417
+ cached_path_infos = self.dircache[dir_info["name"]]
418
+ out.extend(cached_path_infos)
419
+ dirs_to_visit.extend(
420
+ [path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
421
+ )
422
+
423
+ dirs_not_expanded = []
424
+ if expand_info:
425
+ # Check if there are directories with non-expanded entries
426
+ dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None]
427
+
428
+ if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded):
429
+ # If the dircache is incomplete, find the common path of the missing and non-expanded entries
430
+ # and extend the output with the result of `_ls_tree(common_path, recursive=True)`
431
+ common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded)
432
+ # Get the parent directory if the common prefix itself is not a directory
433
+ common_path = (
434
+ common_prefix.rstrip("/")
435
+ if common_prefix.endswith("/")
436
+ or common_prefix == root_path
437
+ or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded)
438
+ else self._parent(common_prefix)
439
+ )
440
+ out = [o for o in out if not o["name"].startswith(common_path + "/")]
441
+ for cached_path in self.dircache:
442
+ if cached_path.startswith(common_path + "/"):
443
+ self.dircache.pop(cached_path, None)
444
+ self.dircache.pop(common_path, None)
445
+ out.extend(
446
+ self._ls_tree(
447
+ common_path,
448
+ recursive=recursive,
449
+ refresh=True,
450
+ revision=revision,
451
+ expand_info=expand_info,
452
+ )
453
+ )
454
+ else:
455
+ tree = self._api.list_repo_tree(
456
+ resolved_path.repo_id,
457
+ resolved_path.path_in_repo,
458
+ recursive=recursive,
459
+ expand=expand_info,
460
+ revision=resolved_path.revision,
461
+ repo_type=resolved_path.repo_type,
462
+ )
463
+ for path_info in tree:
464
+ if isinstance(path_info, RepoFile):
465
+ cache_path_info = {
466
+ "name": root_path + "/" + path_info.path,
467
+ "size": path_info.size,
468
+ "type": "file",
469
+ "blob_id": path_info.blob_id,
470
+ "lfs": path_info.lfs,
471
+ "last_commit": path_info.last_commit,
472
+ "security": path_info.security,
473
+ }
474
+ else:
475
+ cache_path_info = {
476
+ "name": root_path + "/" + path_info.path,
477
+ "size": 0,
478
+ "type": "directory",
479
+ "tree_id": path_info.tree_id,
480
+ "last_commit": path_info.last_commit,
481
+ }
482
+ parent_path = self._parent(cache_path_info["name"])
483
+ self.dircache.setdefault(parent_path, []).append(cache_path_info)
484
+ out.append(cache_path_info)
485
+ return out
486
+
487
+ def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]:
488
+ """
489
+ Return all files below the given path.
490
+
491
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.walk).
492
+
493
+ Args:
494
+ path (`str`):
495
+ Root path to list files from.
496
+
497
+ Returns:
498
+ `Iterator[Tuple[str, List[str], List[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
499
+ """
500
+ # Set expand_info=False by default to get a x10 speed boost
501
+ kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
502
+ path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
503
+ yield from super().walk(path, *args, **kwargs)
504
+
505
+ def glob(self, path: str, **kwargs) -> List[str]:
506
+ """
507
+ Find files by glob-matching.
508
+
509
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob).
510
+
511
+ Args:
512
+ path (`str`):
513
+ Path pattern to match.
514
+
515
+ Returns:
516
+ `List[str]`: List of paths matching the pattern.
517
+ """
518
+ # Set expand_info=False by default to get a x10 speed boost
519
+ kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
520
+ path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
521
+ return super().glob(path, **kwargs)
522
+
523
+ def find(
524
+ self,
525
+ path: str,
526
+ maxdepth: Optional[int] = None,
527
+ withdirs: bool = False,
528
+ detail: bool = False,
529
+ refresh: bool = False,
530
+ revision: Optional[str] = None,
531
+ **kwargs,
532
+ ) -> Union[List[str], Dict[str, Dict[str, Any]]]:
533
+ """
534
+ List all files below path.
535
+
536
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.find).
537
+
538
+ Args:
539
+ path (`str`):
540
+ Root path to list files from.
541
+ maxdepth (`int`, *optional*):
542
+ Maximum depth to descend into subdirectories.
543
+ withdirs (`bool`, *optional*):
544
+ Include directory paths in the output. Defaults to False.
545
+ detail (`bool`, *optional*):
546
+ If True, returns a dict mapping paths to file information. Defaults to False.
547
+ refresh (`bool`, *optional*):
548
+ If True, bypass the cache and fetch the latest data. Defaults to False.
549
+ revision (`str`, *optional*):
550
+ The git revision to list from.
551
+
552
+ Returns:
553
+ `Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information.
554
+ """
555
+ if maxdepth:
556
+ return super().find(
557
+ path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs
558
+ )
559
+ resolved_path = self.resolve_path(path, revision=revision)
560
+ path = resolved_path.unresolve()
561
+ kwargs = {"expand_info": detail, **kwargs}
562
+ try:
563
+ out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs)
564
+ except EntryNotFoundError:
565
+ # Path could be a file
566
+ if self.info(path, revision=revision, **kwargs)["type"] == "file":
567
+ out = {path: {}}
568
+ else:
569
+ out = {}
570
+ else:
571
+ if not withdirs:
572
+ out = [o for o in out if o["type"] != "directory"]
573
+ else:
574
+ # If `withdirs=True`, include the directory itself to be consistent with the spec
575
+ path_info = self.info(path, revision=resolved_path.revision, **kwargs)
576
+ out = [path_info] + out if path_info["type"] == "directory" else out
577
+ out = {o["name"]: o for o in out}
578
+ names = sorted(out)
579
+ if not detail:
580
+ return names
581
+ else:
582
+ return {name: out[name] for name in names}
583
+
584
+ def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None:
585
+ """
586
+ Copy a file within or between repositories.
587
+
588
+ <Tip warning={true}>
589
+
590
+ Note: When possible, use `HfApi.upload_file()` for better performance.
591
+
592
+ </Tip>
593
+
594
+ Args:
595
+ path1 (`str`):
596
+ Source path to copy from.
597
+ path2 (`str`):
598
+ Destination path to copy to.
599
+ revision (`str`, *optional*):
600
+ The git revision to copy from.
601
+
602
+ """
603
+ resolved_path1 = self.resolve_path(path1, revision=revision)
604
+ resolved_path2 = self.resolve_path(path2, revision=revision)
605
+
606
+ same_repo = (
607
+ resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id
608
+ )
609
+
610
+ if same_repo:
611
+ commit_message = f"Copy {path1} to {path2}"
612
+ self._api.create_commit(
613
+ repo_id=resolved_path1.repo_id,
614
+ repo_type=resolved_path1.repo_type,
615
+ revision=resolved_path2.revision,
616
+ commit_message=kwargs.get("commit_message", commit_message),
617
+ commit_description=kwargs.get("commit_description", ""),
618
+ operations=[
619
+ CommitOperationCopy(
620
+ src_path_in_repo=resolved_path1.path_in_repo,
621
+ path_in_repo=resolved_path2.path_in_repo,
622
+ src_revision=resolved_path1.revision,
623
+ )
624
+ ],
625
+ )
626
+ else:
627
+ with self.open(path1, "rb", revision=resolved_path1.revision) as f:
628
+ content = f.read()
629
+ commit_message = f"Copy {path1} to {path2}"
630
+ self._api.upload_file(
631
+ path_or_fileobj=content,
632
+ path_in_repo=resolved_path2.path_in_repo,
633
+ repo_id=resolved_path2.repo_id,
634
+ token=self.token,
635
+ repo_type=resolved_path2.repo_type,
636
+ revision=resolved_path2.revision,
637
+ commit_message=kwargs.get("commit_message", commit_message),
638
+ commit_description=kwargs.get("commit_description"),
639
+ )
640
+ self.invalidate_cache(path=resolved_path1.unresolve())
641
+ self.invalidate_cache(path=resolved_path2.unresolve())
642
+
643
+ def modified(self, path: str, **kwargs) -> datetime:
644
+ """
645
+ Get the last modified time of a file.
646
+
647
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.modified).
648
+
649
+ Args:
650
+ path (`str`):
651
+ Path to the file.
652
+
653
+ Returns:
654
+ `datetime`: Last commit date of the file.
655
+ """
656
+ info = self.info(path, **kwargs)
657
+ return info["last_commit"]["date"]
658
+
659
+ def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]:
660
+ """
661
+ Get information about a file or directory.
662
+
663
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.info).
664
+
665
+ <Tip warning={true}>
666
+
667
+ Note: When possible, use `HfApi.get_paths_info()` or `HfApi.repo_info()` for better performance.
668
+
669
+ </Tip>
670
+
671
+ Args:
672
+ path (`str`):
673
+ Path to get info for.
674
+ refresh (`bool`, *optional*):
675
+ If True, bypass the cache and fetch the latest data. Defaults to False.
676
+ revision (`str`, *optional*):
677
+ The git revision to get info from.
678
+
679
+ Returns:
680
+ `Dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
681
+
682
+ """
683
+ resolved_path = self.resolve_path(path, revision=revision)
684
+ path = resolved_path.unresolve()
685
+ expand_info = kwargs.get(
686
+ "expand_info", True
687
+ ) # don't expose it as a parameter in the public API to follow the spec
688
+ if not resolved_path.path_in_repo:
689
+ # Path is the root directory
690
+ out = {
691
+ "name": path,
692
+ "size": 0,
693
+ "type": "directory",
694
+ }
695
+ if expand_info:
696
+ last_commit = self._api.list_repo_commits(
697
+ resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision
698
+ )[-1]
699
+ out = {
700
+ **out,
701
+ "tree_id": None, # TODO: tree_id of the root directory?
702
+ "last_commit": LastCommitInfo(
703
+ oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at
704
+ ),
705
+ }
706
+ else:
707
+ out = None
708
+ parent_path = self._parent(path)
709
+ if not expand_info and parent_path not in self.dircache:
710
+ # Fill the cache with cheap call
711
+ self.ls(parent_path, expand_info=False)
712
+ if parent_path in self.dircache:
713
+ # Check if the path is in the cache
714
+ out1 = [o for o in self.dircache[parent_path] if o["name"] == path]
715
+ if not out1:
716
+ _raise_file_not_found(path, None)
717
+ out = out1[0]
718
+ if refresh or out is None or (expand_info and out and out["last_commit"] is None):
719
+ paths_info = self._api.get_paths_info(
720
+ resolved_path.repo_id,
721
+ resolved_path.path_in_repo,
722
+ expand=expand_info,
723
+ revision=resolved_path.revision,
724
+ repo_type=resolved_path.repo_type,
725
+ )
726
+ if not paths_info:
727
+ _raise_file_not_found(path, None)
728
+ path_info = paths_info[0]
729
+ root_path = HfFileSystemResolvedPath(
730
+ resolved_path.repo_type,
731
+ resolved_path.repo_id,
732
+ resolved_path.revision,
733
+ path_in_repo="",
734
+ _raw_revision=resolved_path._raw_revision,
735
+ ).unresolve()
736
+ if isinstance(path_info, RepoFile):
737
+ out = {
738
+ "name": root_path + "/" + path_info.path,
739
+ "size": path_info.size,
740
+ "type": "file",
741
+ "blob_id": path_info.blob_id,
742
+ "lfs": path_info.lfs,
743
+ "last_commit": path_info.last_commit,
744
+ "security": path_info.security,
745
+ }
746
+ else:
747
+ out = {
748
+ "name": root_path + "/" + path_info.path,
749
+ "size": 0,
750
+ "type": "directory",
751
+ "tree_id": path_info.tree_id,
752
+ "last_commit": path_info.last_commit,
753
+ }
754
+ if not expand_info:
755
+ out = {k: out[k] for k in ["name", "size", "type"]}
756
+ assert out is not None
757
+ return out
758
+
759
+ def exists(self, path, **kwargs):
760
+ """
761
+ Check if a file exists.
762
+
763
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists).
764
+
765
+ <Tip warning={true}>
766
+
767
+ Note: When possible, use `HfApi.file_exists()` for better performance.
768
+
769
+ </Tip>
770
+
771
+ Args:
772
+ path (`str`):
773
+ Path to check.
774
+
775
+ Returns:
776
+ `bool`: True if file exists, False otherwise.
777
+ """
778
+ try:
779
+ if kwargs.get("refresh", False):
780
+ self.invalidate_cache(path)
781
+
782
+ self.info(path, **{**kwargs, "expand_info": False})
783
+ return True
784
+ except: # noqa: E722
785
+ return False
786
+
787
+ def isdir(self, path):
788
+ """
789
+ Check if a path is a directory.
790
+
791
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isdir).
792
+
793
+ Args:
794
+ path (`str`):
795
+ Path to check.
796
+
797
+ Returns:
798
+ `bool`: True if path is a directory, False otherwise.
799
+ """
800
+ try:
801
+ return self.info(path, expand_info=False)["type"] == "directory"
802
+ except OSError:
803
+ return False
804
+
805
+ def isfile(self, path):
806
+ """
807
+ Check if a path is a file.
808
+
809
+ For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isfile).
810
+
811
+ Args:
812
+ path (`str`):
813
+ Path to check.
814
+
815
+ Returns:
816
+ `bool`: True if path is a file, False otherwise.
817
+ """
818
+ try:
819
+ return self.info(path, expand_info=False)["type"] == "file"
820
+ except: # noqa: E722
821
+ return False
822
+
823
+ def url(self, path: str) -> str:
824
+ """
825
+ Get the HTTP URL of the given path.
826
+
827
+ Args:
828
+ path (`str`):
829
+ Path to get URL for.
830
+
831
+ Returns:
832
+ `str`: HTTP URL to access the file or directory on the Hub.
833
+ """
834
+ resolved_path = self.resolve_path(path)
835
+ url = hf_hub_url(
836
+ resolved_path.repo_id,
837
+ resolved_path.path_in_repo,
838
+ repo_type=resolved_path.repo_type,
839
+ revision=resolved_path.revision,
840
+ endpoint=self.endpoint,
841
+ )
842
+ if self.isdir(path):
843
+ url = url.replace("/resolve/", "/tree/", 1)
844
+ return url
845
+
846
+ def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None:
847
+ """
848
+ Copy single remote file to local.
849
+
850
+ <Tip warning={true}>
851
+
852
+ Note: When possible, use `HfApi.hf_hub_download()` for better performance.
853
+
854
+ </Tip>
855
+
856
+ Args:
857
+ rpath (`str`):
858
+ Remote path to download from.
859
+ lpath (`str`):
860
+ Local path to download to.
861
+ callback (`Callback`, *optional*):
862
+ Optional callback to track download progress. Defaults to no callback.
863
+ outfile (`IO`, *optional*):
864
+ Optional file-like object to write to. If provided, `lpath` is ignored.
865
+
866
+ """
867
+ revision = kwargs.get("revision")
868
+ unhandled_kwargs = set(kwargs.keys()) - {"revision"}
869
+ if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0:
870
+ # for now, let's not handle custom callbacks
871
+ # and let's not handle custom kwargs
872
+ return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs)
873
+
874
+ # Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883
875
+ if isfilelike(lpath):
876
+ outfile = lpath
877
+ elif self.isdir(rpath):
878
+ os.makedirs(lpath, exist_ok=True)
879
+ return None
880
+
881
+ if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object
882
+ os.makedirs(os.path.dirname(lpath), exist_ok=True)
883
+
884
+ # Open file if not already open
885
+ close_file = False
886
+ if outfile is None:
887
+ outfile = open(lpath, "wb")
888
+ close_file = True
889
+ initial_pos = outfile.tell()
890
+
891
+ # Custom implementation of `get_file` to use `http_get`.
892
+ resolve_remote_path = self.resolve_path(rpath, revision=revision)
893
+ expected_size = self.info(rpath, revision=revision)["size"]
894
+ callback.set_size(expected_size)
895
+ try:
896
+ http_get(
897
+ url=hf_hub_url(
898
+ repo_id=resolve_remote_path.repo_id,
899
+ revision=resolve_remote_path.revision,
900
+ filename=resolve_remote_path.path_in_repo,
901
+ repo_type=resolve_remote_path.repo_type,
902
+ endpoint=self.endpoint,
903
+ ),
904
+ temp_file=outfile,
905
+ displayed_filename=rpath,
906
+ expected_size=expected_size,
907
+ resume_size=0,
908
+ headers=self._api._build_hf_headers(),
909
+ _tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None,
910
+ )
911
+ outfile.seek(initial_pos)
912
+ finally:
913
+ # Close file only if we opened it ourselves
914
+ if close_file:
915
+ outfile.close()
916
+
917
+ @property
918
+ def transaction(self):
919
+ """A context within which files are committed together upon exit
920
+
921
+ Requires the file class to implement `.commit()` and `.discard()`
922
+ for the normal and exception cases.
923
+ """
924
+ # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231
925
+ # See https://github.com/huggingface/huggingface_hub/issues/1733
926
+ raise NotImplementedError("Transactional commits are not supported.")
927
+
928
+ def start_transaction(self):
929
+ """Begin write transaction for deferring files, non-context version"""
930
+ # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241
931
+ # See https://github.com/huggingface/huggingface_hub/issues/1733
932
+ raise NotImplementedError("Transactional commits are not supported.")
933
+
934
+
935
+ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
936
+ def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
937
+ try:
938
+ self.resolved_path = fs.resolve_path(path, revision=revision)
939
+ except FileNotFoundError as e:
940
+ if "w" in kwargs.get("mode", ""):
941
+ raise FileNotFoundError(
942
+ f"{e}.\nMake sure the repository and revision exist before writing data."
943
+ ) from e
944
+ raise
945
+ # avoid an unnecessary .info() call with expensive expand_info=True to instantiate .details
946
+ if kwargs.get("mode", "rb") == "rb":
947
+ self.details = fs.info(self.resolved_path.unresolve(), expand_info=False)
948
+ super().__init__(fs, self.resolved_path.unresolve(), **kwargs)
949
+ self.fs: HfFileSystem
950
+
951
+ def __del__(self):
952
+ if not hasattr(self, "resolved_path"):
953
+ # Means that the constructor failed. Nothing to do.
954
+ return
955
+ return super().__del__()
956
+
957
+ def _fetch_range(self, start: int, end: int) -> bytes:
958
+ headers = {
959
+ "range": f"bytes={start}-{end - 1}",
960
+ **self.fs._api._build_hf_headers(),
961
+ }
962
+ url = hf_hub_url(
963
+ repo_id=self.resolved_path.repo_id,
964
+ revision=self.resolved_path.revision,
965
+ filename=self.resolved_path.path_in_repo,
966
+ repo_type=self.resolved_path.repo_type,
967
+ endpoint=self.fs.endpoint,
968
+ )
969
+ r = http_backoff(
970
+ "GET",
971
+ url,
972
+ headers=headers,
973
+ retry_on_status_codes=(500, 502, 503, 504),
974
+ timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
975
+ )
976
+ hf_raise_for_status(r)
977
+ return r.content
978
+
979
+ def _initiate_upload(self) -> None:
980
+ self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False)
981
+
982
+ def _upload_chunk(self, final: bool = False) -> None:
983
+ self.buffer.seek(0)
984
+ block = self.buffer.read()
985
+ self.temp_file.write(block)
986
+ if final:
987
+ self.temp_file.close()
988
+ self.fs._api.upload_file(
989
+ path_or_fileobj=self.temp_file.name,
990
+ path_in_repo=self.resolved_path.path_in_repo,
991
+ repo_id=self.resolved_path.repo_id,
992
+ token=self.fs.token,
993
+ repo_type=self.resolved_path.repo_type,
994
+ revision=self.resolved_path.revision,
995
+ commit_message=self.kwargs.get("commit_message"),
996
+ commit_description=self.kwargs.get("commit_description"),
997
+ )
998
+ os.remove(self.temp_file.name)
999
+ self.fs.invalidate_cache(
1000
+ path=self.resolved_path.unresolve(),
1001
+ )
1002
+
1003
+ def read(self, length=-1):
1004
+ """Read remote file.
1005
+
1006
+ If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if
1007
+ `hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a
1008
+ temporary file and read from there.
1009
+ """
1010
+ if self.mode == "rb" and (length is None or length == -1) and self.loc == 0:
1011
+ with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming
1012
+ return f.read()
1013
+ return super().read(length)
1014
+
1015
+ def url(self) -> str:
1016
+ return self.fs.url(self.path)
1017
+
1018
+
1019
+ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
1020
+ def __init__(
1021
+ self,
1022
+ fs: HfFileSystem,
1023
+ path: str,
1024
+ mode: str = "rb",
1025
+ revision: Optional[str] = None,
1026
+ block_size: int = 0,
1027
+ cache_type: str = "none",
1028
+ **kwargs,
1029
+ ):
1030
+ if block_size != 0:
1031
+ raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}")
1032
+ if cache_type != "none":
1033
+ raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}")
1034
+ if "w" in mode:
1035
+ raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'")
1036
+ try:
1037
+ self.resolved_path = fs.resolve_path(path, revision=revision)
1038
+ except FileNotFoundError as e:
1039
+ if "w" in kwargs.get("mode", ""):
1040
+ raise FileNotFoundError(
1041
+ f"{e}.\nMake sure the repository and revision exist before writing data."
1042
+ ) from e
1043
+ # avoid an unnecessary .info() call to instantiate .details
1044
+ self.details = {"name": self.resolved_path.unresolve(), "size": None}
1045
+ super().__init__(
1046
+ fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
1047
+ )
1048
+ self.response: Optional[Response] = None
1049
+ self.fs: HfFileSystem
1050
+
1051
+ def seek(self, loc: int, whence: int = 0):
1052
+ if loc == 0 and whence == 1:
1053
+ return
1054
+ if loc == self.loc and whence == 0:
1055
+ return
1056
+ raise ValueError("Cannot seek streaming HF file")
1057
+
1058
+ def read(self, length: int = -1):
1059
+ read_args = (length,) if length >= 0 else ()
1060
+ if self.response is None or self.response.raw.isclosed():
1061
+ url = hf_hub_url(
1062
+ repo_id=self.resolved_path.repo_id,
1063
+ revision=self.resolved_path.revision,
1064
+ filename=self.resolved_path.path_in_repo,
1065
+ repo_type=self.resolved_path.repo_type,
1066
+ endpoint=self.fs.endpoint,
1067
+ )
1068
+ self.response = http_backoff(
1069
+ "GET",
1070
+ url,
1071
+ headers=self.fs._api._build_hf_headers(),
1072
+ retry_on_status_codes=(500, 502, 503, 504),
1073
+ stream=True,
1074
+ timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1075
+ )
1076
+ hf_raise_for_status(self.response)
1077
+ try:
1078
+ out = self.response.raw.read(*read_args)
1079
+ except Exception:
1080
+ self.response.close()
1081
+
1082
+ # Retry by recreating the connection
1083
+ url = hf_hub_url(
1084
+ repo_id=self.resolved_path.repo_id,
1085
+ revision=self.resolved_path.revision,
1086
+ filename=self.resolved_path.path_in_repo,
1087
+ repo_type=self.resolved_path.repo_type,
1088
+ endpoint=self.fs.endpoint,
1089
+ )
1090
+ self.response = http_backoff(
1091
+ "GET",
1092
+ url,
1093
+ headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
1094
+ retry_on_status_codes=(500, 502, 503, 504),
1095
+ stream=True,
1096
+ timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1097
+ )
1098
+ hf_raise_for_status(self.response)
1099
+ try:
1100
+ out = self.response.raw.read(*read_args)
1101
+ except Exception:
1102
+ self.response.close()
1103
+ raise
1104
+ self.loc += len(out)
1105
+ return out
1106
+
1107
+ def url(self) -> str:
1108
+ return self.fs.url(self.path)
1109
+
1110
+ def __del__(self):
1111
+ if not hasattr(self, "resolved_path"):
1112
+ # Means that the constructor failed. Nothing to do.
1113
+ return
1114
+ return super().__del__()
1115
+
1116
+ def __reduce__(self):
1117
+ return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
1118
+
1119
+
1120
+ def safe_revision(revision: str) -> str:
1121
+ return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
1122
+
1123
+
1124
+ def safe_quote(s: str) -> str:
1125
+ return quote(s, safe="")
1126
+
1127
+
1128
+ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
1129
+ msg = path
1130
+ if isinstance(err, RepositoryNotFoundError):
1131
+ msg = f"{path} (repository not found)"
1132
+ elif isinstance(err, RevisionNotFoundError):
1133
+ msg = f"{path} (revision not found)"
1134
+ elif isinstance(err, HFValidationError):
1135
+ msg = f"{path} (invalid repository id)"
1136
+ raise FileNotFoundError(msg) from err
1137
+
1138
+
1139
+ def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
1140
+ return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
.venv/lib/python3.11/site-packages/huggingface_hub/hub_mixin.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import json
3
+ import os
4
+ from dataclasses import Field, asdict, dataclass, is_dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
7
+
8
+ import packaging.version
9
+
10
+ from . import constants
11
+ from .errors import EntryNotFoundError, HfHubHTTPError
12
+ from .file_download import hf_hub_download
13
+ from .hf_api import HfApi
14
+ from .repocard import ModelCard, ModelCardData
15
+ from .utils import (
16
+ SoftTemporaryDirectory,
17
+ is_jsonable,
18
+ is_safetensors_available,
19
+ is_simple_optional_type,
20
+ is_torch_available,
21
+ logging,
22
+ unwrap_simple_optional_type,
23
+ validate_hf_hub_args,
24
+ )
25
+
26
+
27
+ if is_torch_available():
28
+ import torch # type: ignore
29
+
30
+ if is_safetensors_available():
31
+ import safetensors
32
+ from safetensors.torch import load_model as load_model_as_safetensor
33
+ from safetensors.torch import save_model as save_model_as_safetensor
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ # Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
40
+ class DataclassInstance(Protocol):
41
+ __dataclass_fields__: ClassVar[Dict[str, Field]]
42
+
43
+
44
+ # Generic variable that is either ModelHubMixin or a subclass thereof
45
+ T = TypeVar("T", bound="ModelHubMixin")
46
+ # Generic variable to represent an args type
47
+ ARGS_T = TypeVar("ARGS_T")
48
+ ENCODER_T = Callable[[ARGS_T], Any]
49
+ DECODER_T = Callable[[Any], ARGS_T]
50
+ CODER_T = Tuple[ENCODER_T, DECODER_T]
51
+
52
+
53
+ DEFAULT_MODEL_CARD = """
54
+ ---
55
+ # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
56
+ # Doc / guide: https://huggingface.co/docs/hub/model-cards
57
+ {{ card_data }}
58
+ ---
59
+
60
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
61
+ - Library: {{ repo_url | default("[More Information Needed]", true) }}
62
+ - Docs: {{ docs_url | default("[More Information Needed]", true) }}
63
+ """
64
+
65
+
66
+ @dataclass
67
+ class MixinInfo:
68
+ model_card_template: str
69
+ model_card_data: ModelCardData
70
+ repo_url: Optional[str] = None
71
+ docs_url: Optional[str] = None
72
+
73
+
74
+ class ModelHubMixin:
75
+ """
76
+ A generic mixin to integrate ANY machine learning framework with the Hub.
77
+
78
+ To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
79
+ have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
80
+ of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
81
+
82
+ When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
83
+ `__init__` but to the class definition itself. This is useful to define metadata about the library integrating
84
+ [`ModelHubMixin`].
85
+
86
+ For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).
87
+
88
+ Args:
89
+ repo_url (`str`, *optional*):
90
+ URL of the library repository. Used to generate model card.
91
+ docs_url (`str`, *optional*):
92
+ URL of the library documentation. Used to generate model card.
93
+ model_card_template (`str`, *optional*):
94
+ Template of the model card. Used to generate model card. Defaults to a generic template.
95
+ language (`str` or `List[str]`, *optional*):
96
+ Language supported by the library. Used to generate model card.
97
+ library_name (`str`, *optional*):
98
+ Name of the library integrating ModelHubMixin. Used to generate model card.
99
+ license (`str`, *optional*):
100
+ License of the library integrating ModelHubMixin. Used to generate model card.
101
+ E.g: "apache-2.0"
102
+ license_name (`str`, *optional*):
103
+ Name of the library integrating ModelHubMixin. Used to generate model card.
104
+ Only used if `license` is set to `other`.
105
+ E.g: "coqui-public-model-license".
106
+ license_link (`str`, *optional*):
107
+ URL to the license of the library integrating ModelHubMixin. Used to generate model card.
108
+ Only used if `license` is set to `other` and `license_name` is set.
109
+ E.g: "https://coqui.ai/cpml".
110
+ pipeline_tag (`str`, *optional*):
111
+ Tag of the pipeline. Used to generate model card. E.g. "text-classification".
112
+ tags (`List[str]`, *optional*):
113
+ Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
114
+ coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
115
+ Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
116
+ jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
117
+
118
+ Example:
119
+
120
+ ```python
121
+ >>> from huggingface_hub import ModelHubMixin
122
+
123
+ # Inherit from ModelHubMixin
124
+ >>> class MyCustomModel(
125
+ ... ModelHubMixin,
126
+ ... library_name="my-library",
127
+ ... tags=["x-custom-tag", "arxiv:2304.12244"],
128
+ ... repo_url="https://github.com/huggingface/my-cool-library",
129
+ ... docs_url="https://huggingface.co/docs/my-cool-library",
130
+ ... # ^ optional metadata to generate model card
131
+ ... ):
132
+ ... def __init__(self, size: int = 512, device: str = "cpu"):
133
+ ... # define how to initialize your model
134
+ ... super().__init__()
135
+ ... ...
136
+ ...
137
+ ... def _save_pretrained(self, save_directory: Path) -> None:
138
+ ... # define how to serialize your model
139
+ ... ...
140
+ ...
141
+ ... @classmethod
142
+ ... def from_pretrained(
143
+ ... cls: Type[T],
144
+ ... pretrained_model_name_or_path: Union[str, Path],
145
+ ... *,
146
+ ... force_download: bool = False,
147
+ ... resume_download: Optional[bool] = None,
148
+ ... proxies: Optional[Dict] = None,
149
+ ... token: Optional[Union[str, bool]] = None,
150
+ ... cache_dir: Optional[Union[str, Path]] = None,
151
+ ... local_files_only: bool = False,
152
+ ... revision: Optional[str] = None,
153
+ ... **model_kwargs,
154
+ ... ) -> T:
155
+ ... # define how to deserialize your model
156
+ ... ...
157
+
158
+ >>> model = MyCustomModel(size=256, device="gpu")
159
+
160
+ # Save model weights to local directory
161
+ >>> model.save_pretrained("my-awesome-model")
162
+
163
+ # Push model weights to the Hub
164
+ >>> model.push_to_hub("my-awesome-model")
165
+
166
+ # Download and initialize weights from the Hub
167
+ >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
168
+ >>> reloaded_model.size
169
+ 256
170
+
171
+ # Model card has been correctly populated
172
+ >>> from huggingface_hub import ModelCard
173
+ >>> card = ModelCard.load("username/my-awesome-model")
174
+ >>> card.data.tags
175
+ ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
176
+ >>> card.data.library_name
177
+ "my-library"
178
+ ```
179
+ """
180
+
181
+ _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None
182
+ # ^ optional config attribute automatically set in `from_pretrained`
183
+ _hub_mixin_info: MixinInfo
184
+ # ^ information about the library integrating ModelHubMixin (used to generate model card)
185
+ _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
186
+ _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters
187
+ _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters
188
+ _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded
189
+ _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types
190
+ # ^ internal values to handle config
191
+
192
+ def __init_subclass__(
193
+ cls,
194
+ *,
195
+ # Generic info for model card
196
+ repo_url: Optional[str] = None,
197
+ docs_url: Optional[str] = None,
198
+ # Model card template
199
+ model_card_template: str = DEFAULT_MODEL_CARD,
200
+ # Model card metadata
201
+ language: Optional[List[str]] = None,
202
+ library_name: Optional[str] = None,
203
+ license: Optional[str] = None,
204
+ license_name: Optional[str] = None,
205
+ license_link: Optional[str] = None,
206
+ pipeline_tag: Optional[str] = None,
207
+ tags: Optional[List[str]] = None,
208
+ # How to encode/decode arguments with custom type into a JSON config?
209
+ coders: Optional[
210
+ Dict[Type, CODER_T]
211
+ # Key is a type.
212
+ # Value is a tuple (encoder, decoder).
213
+ # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
214
+ ] = None,
215
+ ) -> None:
216
+ """Inspect __init__ signature only once when subclassing + handle modelcard."""
217
+ super().__init_subclass__()
218
+
219
+ # Will be reused when creating modelcard
220
+ tags = tags or []
221
+ tags.append("model_hub_mixin")
222
+
223
+ # Initialize MixinInfo if not existent
224
+ info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())
225
+
226
+ # If parent class has a MixinInfo, inherit from it as a copy
227
+ if hasattr(cls, "_hub_mixin_info"):
228
+ # Inherit model card template from parent class if not explicitly set
229
+ if model_card_template == DEFAULT_MODEL_CARD:
230
+ info.model_card_template = cls._hub_mixin_info.model_card_template
231
+
232
+ # Inherit from parent model card data
233
+ info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())
234
+
235
+ # Inherit other info
236
+ info.docs_url = cls._hub_mixin_info.docs_url
237
+ info.repo_url = cls._hub_mixin_info.repo_url
238
+ cls._hub_mixin_info = info
239
+
240
+ # Update MixinInfo with metadata
241
+ if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
242
+ info.model_card_template = model_card_template
243
+ if repo_url is not None:
244
+ info.repo_url = repo_url
245
+ if docs_url is not None:
246
+ info.docs_url = docs_url
247
+ if language is not None:
248
+ info.model_card_data.language = language
249
+ if library_name is not None:
250
+ info.model_card_data.library_name = library_name
251
+ if license is not None:
252
+ info.model_card_data.license = license
253
+ if license_name is not None:
254
+ info.model_card_data.license_name = license_name
255
+ if license_link is not None:
256
+ info.model_card_data.license_link = license_link
257
+ if pipeline_tag is not None:
258
+ info.model_card_data.pipeline_tag = pipeline_tag
259
+ if tags is not None:
260
+ if info.model_card_data.tags is not None:
261
+ info.model_card_data.tags.extend(tags)
262
+ else:
263
+ info.model_card_data.tags = tags
264
+
265
+ info.model_card_data.tags = sorted(set(info.model_card_data.tags))
266
+
267
+ # Handle encoders/decoders for args
268
+ cls._hub_mixin_coders = coders or {}
269
+ cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
270
+
271
+ # Inspect __init__ signature to handle config
272
+ cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
273
+ cls._hub_mixin_jsonable_default_values = {
274
+ param.name: cls._encode_arg(param.default)
275
+ for param in cls._hub_mixin_init_parameters.values()
276
+ if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
277
+ }
278
+ cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
279
+
280
+ def __new__(cls: Type[T], *args, **kwargs) -> T:
281
+ """Create a new instance of the class and handle config.
282
+
283
+ 3 cases:
284
+ - If `self._hub_mixin_config` is already set, do nothing.
285
+ - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
286
+ - Otherwise, build `self._hub_mixin_config` from default values and passed values.
287
+ """
288
+ instance = super().__new__(cls)
289
+
290
+ # If `config` is already set, return early
291
+ if instance._hub_mixin_config is not None:
292
+ return instance
293
+
294
+ # Infer passed values
295
+ passed_values = {
296
+ **{
297
+ key: value
298
+ for key, value in zip(
299
+ # [1:] to skip `self` parameter
300
+ list(cls._hub_mixin_init_parameters)[1:],
301
+ args,
302
+ )
303
+ },
304
+ **kwargs,
305
+ }
306
+
307
+ # If config passed as dataclass => set it and return early
308
+ if is_dataclass(passed_values.get("config")):
309
+ instance._hub_mixin_config = passed_values["config"]
310
+ return instance
311
+
312
+ # Otherwise, build config from default + passed values
313
+ init_config = {
314
+ # default values
315
+ **cls._hub_mixin_jsonable_default_values,
316
+ # passed values
317
+ **{
318
+ key: cls._encode_arg(value) # Encode custom types as jsonable value
319
+ for key, value in passed_values.items()
320
+ if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
321
+ },
322
+ }
323
+ passed_config = init_config.pop("config", {})
324
+
325
+ # Populate `init_config` with provided config
326
+ if isinstance(passed_config, dict):
327
+ init_config.update(passed_config)
328
+
329
+ # Set `config` attribute and return
330
+ if init_config != {}:
331
+ instance._hub_mixin_config = init_config
332
+ return instance
333
+
334
+ @classmethod
335
+ def _is_jsonable(cls, value: Any) -> bool:
336
+ """Check if a value is JSON serializable."""
337
+ if isinstance(value, cls._hub_mixin_jsonable_custom_types):
338
+ return True
339
+ return is_jsonable(value)
340
+
341
+ @classmethod
342
+ def _encode_arg(cls, arg: Any) -> Any:
343
+ """Encode an argument into a JSON serializable format."""
344
+ for type_, (encoder, _) in cls._hub_mixin_coders.items():
345
+ if isinstance(arg, type_):
346
+ if arg is None:
347
+ return None
348
+ return encoder(arg)
349
+ return arg
350
+
351
+ @classmethod
352
+ def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
353
+ """Decode a JSON serializable value into an argument."""
354
+ if is_simple_optional_type(expected_type):
355
+ if value is None:
356
+ return None
357
+ expected_type = unwrap_simple_optional_type(expected_type)
358
+ # Dataclass => handle it
359
+ if is_dataclass(expected_type):
360
+ return _load_dataclass(expected_type, value) # type: ignore[return-value]
361
+ # Otherwise => check custom decoders
362
+ for type_, (_, decoder) in cls._hub_mixin_coders.items():
363
+ if inspect.isclass(expected_type) and issubclass(expected_type, type_):
364
+ return decoder(value)
365
+ # Otherwise => don't decode
366
+ return value
367
+
368
+ def save_pretrained(
369
+ self,
370
+ save_directory: Union[str, Path],
371
+ *,
372
+ config: Optional[Union[dict, DataclassInstance]] = None,
373
+ repo_id: Optional[str] = None,
374
+ push_to_hub: bool = False,
375
+ model_card_kwargs: Optional[Dict[str, Any]] = None,
376
+ **push_to_hub_kwargs,
377
+ ) -> Optional[str]:
378
+ """
379
+ Save weights in local directory.
380
+
381
+ Args:
382
+ save_directory (`str` or `Path`):
383
+ Path to directory in which the model weights and configuration will be saved.
384
+ config (`dict` or `DataclassInstance`, *optional*):
385
+ Model configuration specified as a key/value dictionary or a dataclass instance.
386
+ push_to_hub (`bool`, *optional*, defaults to `False`):
387
+ Whether or not to push your model to the Huggingface Hub after saving it.
388
+ repo_id (`str`, *optional*):
389
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
390
+ not provided.
391
+ model_card_kwargs (`Dict[str, Any]`, *optional*):
392
+ Additional arguments passed to the model card template to customize the model card.
393
+ push_to_hub_kwargs:
394
+ Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
395
+ Returns:
396
+ `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
397
+ """
398
+ save_directory = Path(save_directory)
399
+ save_directory.mkdir(parents=True, exist_ok=True)
400
+
401
+ # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
402
+ # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
403
+ # an existing config.json if it was not saved by `_save_pretrained`.
404
+ config_path = save_directory / constants.CONFIG_NAME
405
+ config_path.unlink(missing_ok=True)
406
+
407
+ # save model weights/files (framework-specific)
408
+ self._save_pretrained(save_directory)
409
+
410
+ # save config (if provided and if not serialized yet in `_save_pretrained`)
411
+ if config is None:
412
+ config = self._hub_mixin_config
413
+ if config is not None:
414
+ if is_dataclass(config):
415
+ config = asdict(config) # type: ignore[arg-type]
416
+ if not config_path.exists():
417
+ config_str = json.dumps(config, sort_keys=True, indent=2)
418
+ config_path.write_text(config_str)
419
+
420
+ # save model card
421
+ model_card_path = save_directory / "README.md"
422
+ model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
423
+ if not model_card_path.exists(): # do not overwrite if already exists
424
+ self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")
425
+
426
+ # push to the Hub if required
427
+ if push_to_hub:
428
+ kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
429
+ if config is not None: # kwarg for `push_to_hub`
430
+ kwargs["config"] = config
431
+ if repo_id is None:
432
+ repo_id = save_directory.name # Defaults to `save_directory` name
433
+ return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
434
+ return None
435
+
436
+ def _save_pretrained(self, save_directory: Path) -> None:
437
+ """
438
+ Overwrite this method in subclass to define how to save your model.
439
+ Check out our [integration guide](../guides/integrations) for instructions.
440
+
441
+ Args:
442
+ save_directory (`str` or `Path`):
443
+ Path to directory in which the model weights and configuration will be saved.
444
+ """
445
+ raise NotImplementedError
446
+
447
+ @classmethod
448
+ @validate_hf_hub_args
449
+ def from_pretrained(
450
+ cls: Type[T],
451
+ pretrained_model_name_or_path: Union[str, Path],
452
+ *,
453
+ force_download: bool = False,
454
+ resume_download: Optional[bool] = None,
455
+ proxies: Optional[Dict] = None,
456
+ token: Optional[Union[str, bool]] = None,
457
+ cache_dir: Optional[Union[str, Path]] = None,
458
+ local_files_only: bool = False,
459
+ revision: Optional[str] = None,
460
+ **model_kwargs,
461
+ ) -> T:
462
+ """
463
+ Download a model from the Huggingface Hub and instantiate it.
464
+
465
+ Args:
466
+ pretrained_model_name_or_path (`str`, `Path`):
467
+ - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
468
+ - Or a path to a `directory` containing model weights saved using
469
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
470
+ revision (`str`, *optional*):
471
+ Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
472
+ Defaults to the latest commit on `main` branch.
473
+ force_download (`bool`, *optional*, defaults to `False`):
474
+ Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
475
+ the existing cache.
476
+ proxies (`Dict[str, str]`, *optional*):
477
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
478
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
479
+ token (`str` or `bool`, *optional*):
480
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
481
+ cached when running `huggingface-cli login`.
482
+ cache_dir (`str`, `Path`, *optional*):
483
+ Path to the folder where cached files are stored.
484
+ local_files_only (`bool`, *optional*, defaults to `False`):
485
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
486
+ model_kwargs (`Dict`, *optional*):
487
+ Additional kwargs to pass to the model during initialization.
488
+ """
489
+ model_id = str(pretrained_model_name_or_path)
490
+ config_file: Optional[str] = None
491
+ if os.path.isdir(model_id):
492
+ if constants.CONFIG_NAME in os.listdir(model_id):
493
+ config_file = os.path.join(model_id, constants.CONFIG_NAME)
494
+ else:
495
+ logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}")
496
+ else:
497
+ try:
498
+ config_file = hf_hub_download(
499
+ repo_id=model_id,
500
+ filename=constants.CONFIG_NAME,
501
+ revision=revision,
502
+ cache_dir=cache_dir,
503
+ force_download=force_download,
504
+ proxies=proxies,
505
+ resume_download=resume_download,
506
+ token=token,
507
+ local_files_only=local_files_only,
508
+ )
509
+ except HfHubHTTPError as e:
510
+ logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
511
+
512
+ # Read config
513
+ config = None
514
+ if config_file is not None:
515
+ with open(config_file, "r", encoding="utf-8") as f:
516
+ config = json.load(f)
517
+
518
+ # Decode custom types in config
519
+ for key, value in config.items():
520
+ if key in cls._hub_mixin_init_parameters:
521
+ expected_type = cls._hub_mixin_init_parameters[key].annotation
522
+ if expected_type is not inspect.Parameter.empty:
523
+ config[key] = cls._decode_arg(expected_type, value)
524
+
525
+ # Populate model_kwargs from config
526
+ for param in cls._hub_mixin_init_parameters.values():
527
+ if param.name not in model_kwargs and param.name in config:
528
+ model_kwargs[param.name] = config[param.name]
529
+
530
+ # Check if `config` argument was passed at init
531
+ if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
532
+ # Decode `config` argument if it was passed
533
+ config_annotation = cls._hub_mixin_init_parameters["config"].annotation
534
+ config = cls._decode_arg(config_annotation, config)
535
+
536
+ # Forward config to model initialization
537
+ model_kwargs["config"] = config
538
+
539
+ # Inject config if `**kwargs` are expected
540
+ if is_dataclass(cls):
541
+ for key in cls.__dataclass_fields__:
542
+ if key not in model_kwargs and key in config:
543
+ model_kwargs[key] = config[key]
544
+ elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
545
+ for key, value in config.items():
546
+ if key not in model_kwargs:
547
+ model_kwargs[key] = value
548
+
549
+ # Finally, also inject if `_from_pretrained` expects it
550
+ if cls._hub_mixin_inject_config and "config" not in model_kwargs:
551
+ model_kwargs["config"] = config
552
+
553
+ instance = cls._from_pretrained(
554
+ model_id=str(model_id),
555
+ revision=revision,
556
+ cache_dir=cache_dir,
557
+ force_download=force_download,
558
+ proxies=proxies,
559
+ resume_download=resume_download,
560
+ local_files_only=local_files_only,
561
+ token=token,
562
+ **model_kwargs,
563
+ )
564
+
565
+ # Implicitly set the config as instance attribute if not already set by the class
566
+ # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
567
+ if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
568
+ instance._hub_mixin_config = config
569
+
570
+ return instance
571
+
572
+ @classmethod
573
+ def _from_pretrained(
574
+ cls: Type[T],
575
+ *,
576
+ model_id: str,
577
+ revision: Optional[str],
578
+ cache_dir: Optional[Union[str, Path]],
579
+ force_download: bool,
580
+ proxies: Optional[Dict],
581
+ resume_download: Optional[bool],
582
+ local_files_only: bool,
583
+ token: Optional[Union[str, bool]],
584
+ **model_kwargs,
585
+ ) -> T:
586
+ """Overwrite this method in subclass to define how to load your model from pretrained.
587
+
588
+ Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
589
+ args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
590
+ method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
591
+ parameter to set on which device the model should be loaded.
592
+
593
+ Check out our [integration guide](../guides/integrations) for more instructions.
594
+
595
+ Args:
596
+ model_id (`str`):
597
+ ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
598
+ revision (`str`, *optional*):
599
+ Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
600
+ latest commit on `main` branch.
601
+ force_download (`bool`, *optional*, defaults to `False`):
602
+ Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
603
+ the existing cache.
604
+ proxies (`Dict[str, str]`, *optional*):
605
+ A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
606
+ 'http://hostname': 'foo.bar:4012'}`).
607
+ token (`str` or `bool`, *optional*):
608
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
609
+ cached when running `huggingface-cli login`.
610
+ cache_dir (`str`, `Path`, *optional*):
611
+ Path to the folder where cached files are stored.
612
+ local_files_only (`bool`, *optional*, defaults to `False`):
613
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
614
+ model_kwargs:
615
+ Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
616
+ """
617
+ raise NotImplementedError
618
+
619
+ @validate_hf_hub_args
620
+ def push_to_hub(
621
+ self,
622
+ repo_id: str,
623
+ *,
624
+ config: Optional[Union[dict, DataclassInstance]] = None,
625
+ commit_message: str = "Push model using huggingface_hub.",
626
+ private: Optional[bool] = None,
627
+ token: Optional[str] = None,
628
+ branch: Optional[str] = None,
629
+ create_pr: Optional[bool] = None,
630
+ allow_patterns: Optional[Union[List[str], str]] = None,
631
+ ignore_patterns: Optional[Union[List[str], str]] = None,
632
+ delete_patterns: Optional[Union[List[str], str]] = None,
633
+ model_card_kwargs: Optional[Dict[str, Any]] = None,
634
+ ) -> str:
635
+ """
636
+ Upload model checkpoint to the Hub.
637
+
638
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
639
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
640
+ details.
641
+
642
+ Args:
643
+ repo_id (`str`):
644
+ ID of the repository to push to (example: `"username/my-model"`).
645
+ config (`dict` or `DataclassInstance`, *optional*):
646
+ Model configuration specified as a key/value dictionary or a dataclass instance.
647
+ commit_message (`str`, *optional*):
648
+ Message to commit while pushing.
649
+ private (`bool`, *optional*):
650
+ Whether the repository created should be private.
651
+ If `None` (default), the repo will be public unless the organization's default is private.
652
+ token (`str`, *optional*):
653
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
654
+ cached when running `huggingface-cli login`.
655
+ branch (`str`, *optional*):
656
+ The git branch on which to push the model. This defaults to `"main"`.
657
+ create_pr (`boolean`, *optional*):
658
+ Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
659
+ allow_patterns (`List[str]` or `str`, *optional*):
660
+ If provided, only files matching at least one pattern are pushed.
661
+ ignore_patterns (`List[str]` or `str`, *optional*):
662
+ If provided, files matching any of the patterns are not pushed.
663
+ delete_patterns (`List[str]` or `str`, *optional*):
664
+ If provided, remote files matching any of the patterns will be deleted from the repo.
665
+ model_card_kwargs (`Dict[str, Any]`, *optional*):
666
+ Additional arguments passed to the model card template to customize the model card.
667
+
668
+ Returns:
669
+ The url of the commit of your model in the given repository.
670
+ """
671
+ api = HfApi(token=token)
672
+ repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
673
+
674
+ # Push the files to the repo in a single commit
675
+ with SoftTemporaryDirectory() as tmp:
676
+ saved_path = Path(tmp) / repo_id
677
+ self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
678
+ return api.upload_folder(
679
+ repo_id=repo_id,
680
+ repo_type="model",
681
+ folder_path=saved_path,
682
+ commit_message=commit_message,
683
+ revision=branch,
684
+ create_pr=create_pr,
685
+ allow_patterns=allow_patterns,
686
+ ignore_patterns=ignore_patterns,
687
+ delete_patterns=delete_patterns,
688
+ )
689
+
690
+ def generate_model_card(self, *args, **kwargs) -> ModelCard:
691
+ card = ModelCard.from_template(
692
+ card_data=self._hub_mixin_info.model_card_data,
693
+ template_str=self._hub_mixin_info.model_card_template,
694
+ repo_url=self._hub_mixin_info.repo_url,
695
+ docs_url=self._hub_mixin_info.docs_url,
696
+ **kwargs,
697
+ )
698
+ return card
699
+
700
+
701
+ class PyTorchModelHubMixin(ModelHubMixin):
702
+ """
703
+ Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
704
+ is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
705
+ you should first set it back in training mode with `model.train()`.
706
+
707
+ See [`ModelHubMixin`] for more details on how to use the mixin.
708
+
709
+ Example:
710
+
711
+ ```python
712
+ >>> import torch
713
+ >>> import torch.nn as nn
714
+ >>> from huggingface_hub import PyTorchModelHubMixin
715
+
716
+ >>> class MyModel(
717
+ ... nn.Module,
718
+ ... PyTorchModelHubMixin,
719
+ ... library_name="keras-nlp",
720
+ ... repo_url="https://github.com/keras-team/keras-nlp",
721
+ ... docs_url="https://keras.io/keras_nlp/",
722
+ ... # ^ optional metadata to generate model card
723
+ ... ):
724
+ ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
725
+ ... super().__init__()
726
+ ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
727
+ ... self.linear = nn.Linear(output_size, vocab_size)
728
+
729
+ ... def forward(self, x):
730
+ ... return self.linear(x + self.param)
731
+ >>> model = MyModel(hidden_size=256)
732
+
733
+ # Save model weights to local directory
734
+ >>> model.save_pretrained("my-awesome-model")
735
+
736
+ # Push model weights to the Hub
737
+ >>> model.push_to_hub("my-awesome-model")
738
+
739
+ # Download and initialize weights from the Hub
740
+ >>> model = MyModel.from_pretrained("username/my-awesome-model")
741
+ >>> model.hidden_size
742
+ 256
743
+ ```
744
+ """
745
+
746
+ def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
747
+ tags = tags or []
748
+ tags.append("pytorch_model_hub_mixin")
749
+ kwargs["tags"] = tags
750
+ return super().__init_subclass__(*args, **kwargs)
751
+
752
+ def _save_pretrained(self, save_directory: Path) -> None:
753
+ """Save weights from a Pytorch model to a local directory."""
754
+ model_to_save = self.module if hasattr(self, "module") else self # type: ignore
755
+ save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE))
756
+
757
+ @classmethod
758
+ def _from_pretrained(
759
+ cls,
760
+ *,
761
+ model_id: str,
762
+ revision: Optional[str],
763
+ cache_dir: Optional[Union[str, Path]],
764
+ force_download: bool,
765
+ proxies: Optional[Dict],
766
+ resume_download: Optional[bool],
767
+ local_files_only: bool,
768
+ token: Union[str, bool, None],
769
+ map_location: str = "cpu",
770
+ strict: bool = False,
771
+ **model_kwargs,
772
+ ):
773
+ """Load Pytorch pretrained weights and return the loaded model."""
774
+ model = cls(**model_kwargs)
775
+ if os.path.isdir(model_id):
776
+ print("Loading weights from local directory")
777
+ model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE)
778
+ return cls._load_as_safetensor(model, model_file, map_location, strict)
779
+ else:
780
+ try:
781
+ model_file = hf_hub_download(
782
+ repo_id=model_id,
783
+ filename=constants.SAFETENSORS_SINGLE_FILE,
784
+ revision=revision,
785
+ cache_dir=cache_dir,
786
+ force_download=force_download,
787
+ proxies=proxies,
788
+ resume_download=resume_download,
789
+ token=token,
790
+ local_files_only=local_files_only,
791
+ )
792
+ return cls._load_as_safetensor(model, model_file, map_location, strict)
793
+ except EntryNotFoundError:
794
+ model_file = hf_hub_download(
795
+ repo_id=model_id,
796
+ filename=constants.PYTORCH_WEIGHTS_NAME,
797
+ revision=revision,
798
+ cache_dir=cache_dir,
799
+ force_download=force_download,
800
+ proxies=proxies,
801
+ resume_download=resume_download,
802
+ token=token,
803
+ local_files_only=local_files_only,
804
+ )
805
+ return cls._load_as_pickle(model, model_file, map_location, strict)
806
+
807
+ @classmethod
808
+ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
809
+ state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True)
810
+ model.load_state_dict(state_dict, strict=strict) # type: ignore
811
+ model.eval() # type: ignore
812
+ return model
813
+
814
+ @classmethod
815
+ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
816
+ if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined]
817
+ load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
818
+ if map_location != "cpu":
819
+ logger.warning(
820
+ "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
821
+ " This means that the model is loaded on 'cpu' first and then copied to the device."
822
+ " This leads to a slower loading time."
823
+ " Please update safetensors to version 0.4.3 or above for improved performance."
824
+ )
825
+ model.to(map_location) # type: ignore [attr-defined]
826
+ else:
827
+ safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type]
828
+ return model
829
+
830
+
831
+ def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
832
+ """Load a dataclass instance from a dictionary.
833
+
834
+ Fields not expected by the dataclass are ignored.
835
+ """
836
+ return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})
.venv/lib/python3.11/site-packages/huggingface_hub/inference/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (198 Bytes). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_client.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_common.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains utilities used by both the sync and async inference clients."""
16
+
17
+ import base64
18
+ import io
19
+ import json
20
+ import logging
21
+ from abc import ABC, abstractmethod
22
+ from contextlib import contextmanager
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+ from typing import (
26
+ TYPE_CHECKING,
27
+ Any,
28
+ AsyncIterable,
29
+ BinaryIO,
30
+ ContextManager,
31
+ Dict,
32
+ Generator,
33
+ Iterable,
34
+ List,
35
+ Literal,
36
+ NoReturn,
37
+ Optional,
38
+ Union,
39
+ overload,
40
+ )
41
+
42
+ from requests import HTTPError
43
+
44
+ from huggingface_hub.errors import (
45
+ GenerationError,
46
+ IncompleteGenerationError,
47
+ OverloadedError,
48
+ TextGenerationError,
49
+ UnknownError,
50
+ ValidationError,
51
+ )
52
+
53
+ from ..utils import (
54
+ get_session,
55
+ is_aiohttp_available,
56
+ is_numpy_available,
57
+ is_pillow_available,
58
+ )
59
+ from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput
60
+
61
+
62
+ if TYPE_CHECKING:
63
+ from aiohttp import ClientResponse, ClientSession
64
+ from PIL.Image import Image
65
+
66
+ # TYPES
67
+ UrlT = str
68
+ PathT = Union[str, Path]
69
+ BinaryT = Union[bytes, BinaryIO]
70
+ ContentT = Union[BinaryT, PathT, UrlT]
71
+
72
+ # Use to set a Accept: image/png header
73
+ TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+
78
+ @dataclass
79
+ class RequestParameters:
80
+ url: str
81
+ task: str
82
+ model: Optional[str]
83
+ json: Optional[Union[str, Dict, List]]
84
+ data: Optional[ContentT]
85
+ headers: Dict[str, Any]
86
+
87
+
88
+ class TaskProviderHelper(ABC):
89
+ """Protocol defining the interface for task-specific provider helpers."""
90
+
91
+ @abstractmethod
92
+ def prepare_request(
93
+ self,
94
+ *,
95
+ inputs: Any,
96
+ parameters: Dict[str, Any],
97
+ headers: Dict,
98
+ model: Optional[str],
99
+ api_key: Optional[str],
100
+ extra_payload: Optional[Dict[str, Any]] = None,
101
+ ) -> RequestParameters: ...
102
+ @abstractmethod
103
+ def get_response(self, response: Union[bytes, Dict]) -> Any: ...
104
+
105
+
106
+ # Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
107
+ @dataclass
108
+ class ModelStatus:
109
+ """
110
+ This Dataclass represents the model status in the Hugging Face Inference API.
111
+
112
+ Args:
113
+ loaded (`bool`):
114
+ If the model is currently loaded into Hugging Face's InferenceAPI. Models
115
+ are loaded on-demand, leading to the user's first request taking longer.
116
+ If a model is loaded, you can be assured that it is in a healthy state.
117
+ state (`str`):
118
+ The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'.
119
+ If a model's state is 'Loadable', it's not too big and has a supported
120
+ backend. Loadable models are automatically loaded when the user first
121
+ requests inference on the endpoint. This means it is transparent for the
122
+ user to load a model, except that the first call takes longer to complete.
123
+ compute_type (`Dict`):
124
+ Information about the compute resource the model is using or will use, such as 'gpu' type and number of
125
+ replicas.
126
+ framework (`str`):
127
+ The name of the framework that the model was built with, such as 'transformers'
128
+ or 'text-generation-inference'.
129
+ """
130
+
131
+ loaded: bool
132
+ state: str
133
+ compute_type: Dict
134
+ framework: str
135
+
136
+
137
+ ## IMPORT UTILS
138
+
139
+
140
+ def _import_aiohttp():
141
+ # Make sure `aiohttp` is installed on the machine.
142
+ if not is_aiohttp_available():
143
+ raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).")
144
+ import aiohttp
145
+
146
+ return aiohttp
147
+
148
+
149
+ def _import_numpy():
150
+ """Make sure `numpy` is installed on the machine."""
151
+ if not is_numpy_available():
152
+ raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).")
153
+ import numpy
154
+
155
+ return numpy
156
+
157
+
158
+ def _import_pil_image():
159
+ """Make sure `PIL` is installed on the machine."""
160
+ if not is_pillow_available():
161
+ raise ImportError(
162
+ "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be"
163
+ " post-processed, use `client.post(...)` and get the raw response from the server."
164
+ )
165
+ from PIL import Image
166
+
167
+ return Image
168
+
169
+
170
+ ## ENCODING / DECODING UTILS
171
+
172
+
173
+ @overload
174
+ def _open_as_binary(
175
+ content: ContentT,
176
+ ) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
177
+
178
+
179
+ @overload
180
+ def _open_as_binary(
181
+ content: Literal[None],
182
+ ) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
183
+
184
+
185
+ @contextmanager # type: ignore
186
+ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
187
+ """Open `content` as a binary file, either from a URL, a local path, or raw bytes.
188
+
189
+ Do nothing if `content` is None,
190
+
191
+ TODO: handle a PIL.Image as input
192
+ TODO: handle base64 as input
193
+ """
194
+ # If content is a string => must be either a URL or a path
195
+ if isinstance(content, str):
196
+ if content.startswith("https://") or content.startswith("http://"):
197
+ logger.debug(f"Downloading content from {content}")
198
+ yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ?
199
+ return
200
+ content = Path(content)
201
+ if not content.exists():
202
+ raise FileNotFoundError(
203
+ f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local"
204
+ " file. To pass raw content, please encode it as bytes first."
205
+ )
206
+
207
+ # If content is a Path => open it
208
+ if isinstance(content, Path):
209
+ logger.debug(f"Opening content from {content}")
210
+ with content.open("rb") as f:
211
+ yield f
212
+ else:
213
+ # Otherwise: already a file-like object or None
214
+ yield content
215
+
216
+
217
+ def _b64_encode(content: ContentT) -> str:
218
+ """Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
219
+ with _open_as_binary(content) as data:
220
+ data_as_bytes = data if isinstance(data, bytes) else data.read()
221
+ return base64.b64encode(data_as_bytes).decode()
222
+
223
+
224
+ def _b64_to_image(encoded_image: str) -> "Image":
225
+ """Parse a base64-encoded string into a PIL Image."""
226
+ Image = _import_pil_image()
227
+ return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
228
+
229
+
230
+ def _bytes_to_list(content: bytes) -> List:
231
+ """Parse bytes from a Response object into a Python list.
232
+
233
+ Expects the response body to be JSON-encoded data.
234
+
235
+ NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
236
+ dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
237
+ """
238
+ return json.loads(content.decode())
239
+
240
+
241
+ def _bytes_to_dict(content: bytes) -> Dict:
242
+ """Parse bytes from a Response object into a Python dictionary.
243
+
244
+ Expects the response body to be JSON-encoded data.
245
+
246
+ NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
247
+ list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
248
+ """
249
+ return json.loads(content.decode())
250
+
251
+
252
+ def _bytes_to_image(content: bytes) -> "Image":
253
+ """Parse bytes from a Response object into a PIL Image.
254
+
255
+ Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead.
256
+ """
257
+ Image = _import_pil_image()
258
+ return Image.open(io.BytesIO(content))
259
+
260
+
261
+ def _as_dict(response: Union[bytes, Dict]) -> Dict:
262
+ return json.loads(response) if isinstance(response, bytes) else response
263
+
264
+
265
+ ## PAYLOAD UTILS
266
+
267
+
268
+ ## STREAMING UTILS
269
+
270
+
271
+ def _stream_text_generation_response(
272
+ bytes_output_as_lines: Iterable[bytes], details: bool
273
+ ) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
274
+ """Used in `InferenceClient.text_generation`."""
275
+ # Parse ServerSentEvents
276
+ for byte_payload in bytes_output_as_lines:
277
+ try:
278
+ output = _format_text_generation_stream_output(byte_payload, details)
279
+ except StopIteration:
280
+ break
281
+ if output is not None:
282
+ yield output
283
+
284
+
285
+ async def _async_stream_text_generation_response(
286
+ bytes_output_as_lines: AsyncIterable[bytes], details: bool
287
+ ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
288
+ """Used in `AsyncInferenceClient.text_generation`."""
289
+ # Parse ServerSentEvents
290
+ async for byte_payload in bytes_output_as_lines:
291
+ try:
292
+ output = _format_text_generation_stream_output(byte_payload, details)
293
+ except StopIteration:
294
+ break
295
+ if output is not None:
296
+ yield output
297
+
298
+
299
+ def _format_text_generation_stream_output(
300
+ byte_payload: bytes, details: bool
301
+ ) -> Optional[Union[str, TextGenerationStreamOutput]]:
302
+ if not byte_payload.startswith(b"data:"):
303
+ return None # empty line
304
+
305
+ if byte_payload.strip() == b"data: [DONE]":
306
+ raise StopIteration("[DONE] signal received.")
307
+
308
+ # Decode payload
309
+ payload = byte_payload.decode("utf-8")
310
+ json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
311
+
312
+ # Either an error as being returned
313
+ if json_payload.get("error") is not None:
314
+ raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
315
+
316
+ # Or parse token payload
317
+ output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload)
318
+ return output.token.text if not details else output
319
+
320
+
321
+ def _stream_chat_completion_response(
322
+ bytes_lines: Iterable[bytes],
323
+ ) -> Iterable[ChatCompletionStreamOutput]:
324
+ """Used in `InferenceClient.chat_completion` if model is served with TGI."""
325
+ for item in bytes_lines:
326
+ try:
327
+ output = _format_chat_completion_stream_output(item)
328
+ except StopIteration:
329
+ break
330
+ if output is not None:
331
+ yield output
332
+
333
+
334
+ async def _async_stream_chat_completion_response(
335
+ bytes_lines: AsyncIterable[bytes],
336
+ ) -> AsyncIterable[ChatCompletionStreamOutput]:
337
+ """Used in `AsyncInferenceClient.chat_completion`."""
338
+ async for item in bytes_lines:
339
+ try:
340
+ output = _format_chat_completion_stream_output(item)
341
+ except StopIteration:
342
+ break
343
+ if output is not None:
344
+ yield output
345
+
346
+
347
+ def _format_chat_completion_stream_output(
348
+ byte_payload: bytes,
349
+ ) -> Optional[ChatCompletionStreamOutput]:
350
+ if not byte_payload.startswith(b"data:"):
351
+ return None # empty line
352
+
353
+ if byte_payload.strip() == b"data: [DONE]":
354
+ raise StopIteration("[DONE] signal received.")
355
+
356
+ # Decode payload
357
+ payload = byte_payload.decode("utf-8")
358
+ json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
359
+
360
+ # Either an error as being returned
361
+ if json_payload.get("error") is not None:
362
+ raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
363
+
364
+ # Or parse token payload
365
+ return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
366
+
367
+
368
+ async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
369
+ async for byte_payload in response.content:
370
+ yield byte_payload.strip()
371
+ await client.close()
372
+
373
+
374
+ # "TGI servers" are servers running with the `text-generation-inference` backend.
375
+ # This backend is the go-to solution to run large language models at scale. However,
376
+ # for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference`
377
+ # solution is still in use.
378
+ #
379
+ # Both approaches have very similar APIs, but not exactly the same. What we do first in
380
+ # the `text_generation` method is to assume the model is served via TGI. If we realize
381
+ # it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the
382
+ # default API with a warning message. When that's the case, We remember the unsupported
383
+ # attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable.
384
+ #
385
+ # In addition, TGI servers have a built-in API route for chat-completion, which is not
386
+ # available on the default API. We use this route to provide a more consistent behavior
387
+ # when available.
388
+ #
389
+ # For more details, see https://github.com/huggingface/text-generation-inference and
390
+ # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
391
+
392
+ _UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {}
393
+
394
+
395
+ def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None:
396
+ _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs)
397
+
398
+
399
+ def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
400
+ return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])
401
+
402
+
403
+ # TEXT GENERATION ERRORS
404
+ # ----------------------
405
+ # Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
406
+ # inference project (https://github.com/huggingface/text-generation-inference).
407
+ # ----------------------
408
+
409
+
410
+ def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
411
+ """
412
+ Try to parse text-generation-inference error message and raise HTTPError in any case.
413
+
414
+ Args:
415
+ error (`HTTPError`):
416
+ The HTTPError that have been raised.
417
+ """
418
+ # Try to parse a Text Generation Inference error
419
+
420
+ try:
421
+ # Hacky way to retrieve payload in case of aiohttp error
422
+ payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
423
+ error = payload.get("error")
424
+ error_type = payload.get("error_type")
425
+ except Exception: # no payload
426
+ raise http_error
427
+
428
+ # If error_type => more information than `hf_raise_for_status`
429
+ if error_type is not None:
430
+ exception = _parse_text_generation_error(error, error_type)
431
+ raise exception from http_error
432
+
433
+ # Otherwise, fallback to default error
434
+ raise http_error
435
+
436
+
437
+ def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
438
+ if error_type == "generation":
439
+ return GenerationError(error) # type: ignore
440
+ if error_type == "incomplete_generation":
441
+ return IncompleteGenerationError(error) # type: ignore
442
+ if error_type == "overloaded":
443
+ return OverloadedError(error) # type: ignore
444
+ if error_type == "validation":
445
+ return ValidationError(error) # type: ignore
446
+ return UnknownError(error) # type: ignore
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/audio_to_audio.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ @dataclass
13
+ class AudioToAudioInput(BaseInferenceType):
14
+ """Inputs for Audio to Audio inference"""
15
+
16
+ inputs: Any
17
+ """The input audio data"""
18
+
19
+
20
+ @dataclass
21
+ class AudioToAudioOutputElement(BaseInferenceType):
22
+ """Outputs of inference for the Audio To Audio task
23
+ A generated audio file with its label.
24
+ """
25
+
26
+ blob: Any
27
+ """The generated audio file."""
28
+ content_type: str
29
+ """The content type of audio file."""
30
+ label: str
31
+ """The label of the audio file."""
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import List, Literal, Optional, Union
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"]
13
+
14
+
15
+ @dataclass
16
+ class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType):
17
+ """Parametrization of the text generation process"""
18
+
19
+ do_sample: Optional[bool] = None
20
+ """Whether to use sampling instead of greedy decoding when generating new tokens."""
21
+ early_stopping: Optional[Union[bool, "AutomaticSpeechRecognitionEarlyStoppingEnum"]] = None
22
+ """Controls the stopping condition for beam-based methods."""
23
+ epsilon_cutoff: Optional[float] = None
24
+ """If set to float strictly between 0 and 1, only tokens with a conditional probability
25
+ greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
26
+ 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
27
+ Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
28
+ """
29
+ eta_cutoff: Optional[float] = None
30
+ """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
31
+ float strictly between 0 and 1, a token is only considered if it is greater than either
32
+ eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
33
+ term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
34
+ the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
35
+ See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
36
+ for more details.
37
+ """
38
+ max_length: Optional[int] = None
39
+ """The maximum length (in tokens) of the generated text, including the input."""
40
+ max_new_tokens: Optional[int] = None
41
+ """The maximum number of tokens to generate. Takes precedence over max_length."""
42
+ min_length: Optional[int] = None
43
+ """The minimum length (in tokens) of the generated text, including the input."""
44
+ min_new_tokens: Optional[int] = None
45
+ """The minimum number of tokens to generate. Takes precedence over min_length."""
46
+ num_beam_groups: Optional[int] = None
47
+ """Number of groups to divide num_beams into in order to ensure diversity among different
48
+ groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
49
+ """
50
+ num_beams: Optional[int] = None
51
+ """Number of beams to use for beam search."""
52
+ penalty_alpha: Optional[float] = None
53
+ """The value balances the model confidence and the degeneration penalty in contrastive
54
+ search decoding.
55
+ """
56
+ temperature: Optional[float] = None
57
+ """The value used to modulate the next token probabilities."""
58
+ top_k: Optional[int] = None
59
+ """The number of highest probability vocabulary tokens to keep for top-k-filtering."""
60
+ top_p: Optional[float] = None
61
+ """If set to float < 1, only the smallest set of most probable tokens with probabilities
62
+ that add up to top_p or higher are kept for generation.
63
+ """
64
+ typical_p: Optional[float] = None
65
+ """Local typicality measures how similar the conditional probability of predicting a target
66
+ token next is to the expected conditional probability of predicting a random token next,
67
+ given the partial text already generated. If set to float < 1, the smallest set of the
68
+ most locally typical tokens with probabilities that add up to typical_p or higher are
69
+ kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
70
+ """
71
+ use_cache: Optional[bool] = None
72
+ """Whether the model should use the past last key/values attentions to speed up decoding"""
73
+
74
+
75
+ @dataclass
76
+ class AutomaticSpeechRecognitionParameters(BaseInferenceType):
77
+ """Additional inference parameters for Automatic Speech Recognition"""
78
+
79
+ return_timestamps: Optional[bool] = None
80
+ """Whether to output corresponding timestamps with the generated text"""
81
+ # Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
82
+ generate_kwargs: Optional[AutomaticSpeechRecognitionGenerationParameters] = None
83
+ """Parametrization of the text generation process"""
84
+
85
+
86
+ @dataclass
87
+ class AutomaticSpeechRecognitionInput(BaseInferenceType):
88
+ """Inputs for Automatic Speech Recognition inference"""
89
+
90
+ inputs: str
91
+ """The input audio data as a base64-encoded string. If no `parameters` are provided, you can
92
+ also provide the audio data as a raw bytes payload.
93
+ """
94
+ parameters: Optional[AutomaticSpeechRecognitionParameters] = None
95
+ """Additional inference parameters for Automatic Speech Recognition"""
96
+
97
+
98
+ @dataclass
99
+ class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType):
100
+ text: str
101
+ """A chunk of text identified by the model"""
102
+ timestamps: List[float]
103
+ """The start and end timestamps corresponding with the text"""
104
+
105
+
106
+ @dataclass
107
+ class AutomaticSpeechRecognitionOutput(BaseInferenceType):
108
+ """Outputs of inference for the Automatic Speech Recognition task"""
109
+
110
+ text: str
111
+ """The recognized text."""
112
+ chunks: Optional[List[AutomaticSpeechRecognitionOutputChunk]] = None
113
+ """When returnTimestamps is enabled, chunks contains a list of audio chunks identified by
114
+ the model.
115
+ """
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/depth_estimation.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Optional
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ @dataclass
13
+ class DepthEstimationInput(BaseInferenceType):
14
+ """Inputs for Depth Estimation inference"""
15
+
16
+ inputs: Any
17
+ """The input image data"""
18
+ parameters: Optional[Dict[str, Any]] = None
19
+ """Additional inference parameters for Depth Estimation"""
20
+
21
+
22
+ @dataclass
23
+ class DepthEstimationOutput(BaseInferenceType):
24
+ """Outputs of inference for the Depth Estimation task"""
25
+
26
+ depth: Any
27
+ """The predicted depth as an image"""
28
+ predicted_depth: Any
29
+ """The predicted depth as a tensor"""
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/fill_mask.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Any, List, Optional
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ @dataclass
13
+ class FillMaskParameters(BaseInferenceType):
14
+ """Additional inference parameters for Fill Mask"""
15
+
16
+ targets: Optional[List[str]] = None
17
+ """When passed, the model will limit the scores to the passed targets instead of looking up
18
+ in the whole vocabulary. If the provided targets are not in the model vocab, they will be
19
+ tokenized and the first resulting token will be used (with a warning, and that might be
20
+ slower).
21
+ """
22
+ top_k: Optional[int] = None
23
+ """When passed, overrides the number of predictions to return."""
24
+
25
+
26
+ @dataclass
27
+ class FillMaskInput(BaseInferenceType):
28
+ """Inputs for Fill Mask inference"""
29
+
30
+ inputs: str
31
+ """The text with masked tokens"""
32
+ parameters: Optional[FillMaskParameters] = None
33
+ """Additional inference parameters for Fill Mask"""
34
+
35
+
36
+ @dataclass
37
+ class FillMaskOutputElement(BaseInferenceType):
38
+ """Outputs of inference for the Fill Mask task"""
39
+
40
+ score: float
41
+ """The corresponding probability"""
42
+ sequence: str
43
+ """The corresponding input with the mask token prediction."""
44
+ token: int
45
+ """The predicted token id (to replace the masked one)."""
46
+ token_str: Any
47
+ fill_mask_output_token_str: Optional[str] = None
48
+ """The predicted token (to replace the masked one)."""
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_segmentation.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Literal, Optional
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ ImageSegmentationSubtask = Literal["instance", "panoptic", "semantic"]
13
+
14
+
15
+ @dataclass
16
+ class ImageSegmentationParameters(BaseInferenceType):
17
+ """Additional inference parameters for Image Segmentation"""
18
+
19
+ mask_threshold: Optional[float] = None
20
+ """Threshold to use when turning the predicted masks into binary values."""
21
+ overlap_mask_area_threshold: Optional[float] = None
22
+ """Mask overlap threshold to eliminate small, disconnected segments."""
23
+ subtask: Optional["ImageSegmentationSubtask"] = None
24
+ """Segmentation task to be performed, depending on model capabilities."""
25
+ threshold: Optional[float] = None
26
+ """Probability threshold to filter out predicted masks."""
27
+
28
+
29
+ @dataclass
30
+ class ImageSegmentationInput(BaseInferenceType):
31
+ """Inputs for Image Segmentation inference"""
32
+
33
+ inputs: str
34
+ """The input image data as a base64-encoded string. If no `parameters` are provided, you can
35
+ also provide the image data as a raw bytes payload.
36
+ """
37
+ parameters: Optional[ImageSegmentationParameters] = None
38
+ """Additional inference parameters for Image Segmentation"""
39
+
40
+
41
+ @dataclass
42
+ class ImageSegmentationOutputElement(BaseInferenceType):
43
+ """Outputs of inference for the Image Segmentation task
44
+ A predicted mask / segment
45
+ """
46
+
47
+ label: str
48
+ """The label of the predicted segment."""
49
+ mask: str
50
+ """The corresponding mask as a black-and-white image (base64-encoded)."""
51
+ score: Optional[float] = None
52
+ """The score or confidence degree the model has."""
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_image.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ @dataclass
13
+ class ImageToImageTargetSize(BaseInferenceType):
14
+ """The size in pixel of the output image."""
15
+
16
+ height: int
17
+ width: int
18
+
19
+
20
+ @dataclass
21
+ class ImageToImageParameters(BaseInferenceType):
22
+ """Additional inference parameters for Image To Image"""
23
+
24
+ guidance_scale: Optional[float] = None
25
+ """For diffusion models. A higher guidance scale value encourages the model to generate
26
+ images closely linked to the text prompt at the expense of lower image quality.
27
+ """
28
+ negative_prompt: Optional[str] = None
29
+ """One prompt to guide what NOT to include in image generation."""
30
+ num_inference_steps: Optional[int] = None
31
+ """For diffusion models. The number of denoising steps. More denoising steps usually lead to
32
+ a higher quality image at the expense of slower inference.
33
+ """
34
+ target_size: Optional[ImageToImageTargetSize] = None
35
+ """The size in pixel of the output image."""
36
+
37
+
38
+ @dataclass
39
+ class ImageToImageInput(BaseInferenceType):
40
+ """Inputs for Image To Image inference"""
41
+
42
+ inputs: str
43
+ """The input image data as a base64-encoded string. If no `parameters` are provided, you can
44
+ also provide the image data as a raw bytes payload.
45
+ """
46
+ parameters: Optional[ImageToImageParameters] = None
47
+ """Additional inference parameters for Image To Image"""
48
+
49
+
50
+ @dataclass
51
+ class ImageToImageOutput(BaseInferenceType):
52
+ """Outputs of inference for the Image To Image task"""
53
+
54
+ image: Any
55
+ """The output image returned as raw bytes in the payload."""
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/image_to_text.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Any, Literal, Optional, Union
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ ImageToTextEarlyStoppingEnum = Literal["never"]
13
+
14
+
15
+ @dataclass
16
+ class ImageToTextGenerationParameters(BaseInferenceType):
17
+ """Parametrization of the text generation process"""
18
+
19
+ do_sample: Optional[bool] = None
20
+ """Whether to use sampling instead of greedy decoding when generating new tokens."""
21
+ early_stopping: Optional[Union[bool, "ImageToTextEarlyStoppingEnum"]] = None
22
+ """Controls the stopping condition for beam-based methods."""
23
+ epsilon_cutoff: Optional[float] = None
24
+ """If set to float strictly between 0 and 1, only tokens with a conditional probability
25
+ greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
26
+ 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
27
+ Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
28
+ """
29
+ eta_cutoff: Optional[float] = None
30
+ """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
31
+ float strictly between 0 and 1, a token is only considered if it is greater than either
32
+ eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
33
+ term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
34
+ the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
35
+ See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
36
+ for more details.
37
+ """
38
+ max_length: Optional[int] = None
39
+ """The maximum length (in tokens) of the generated text, including the input."""
40
+ max_new_tokens: Optional[int] = None
41
+ """The maximum number of tokens to generate. Takes precedence over max_length."""
42
+ min_length: Optional[int] = None
43
+ """The minimum length (in tokens) of the generated text, including the input."""
44
+ min_new_tokens: Optional[int] = None
45
+ """The minimum number of tokens to generate. Takes precedence over min_length."""
46
+ num_beam_groups: Optional[int] = None
47
+ """Number of groups to divide num_beams into in order to ensure diversity among different
48
+ groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
49
+ """
50
+ num_beams: Optional[int] = None
51
+ """Number of beams to use for beam search."""
52
+ penalty_alpha: Optional[float] = None
53
+ """The value balances the model confidence and the degeneration penalty in contrastive
54
+ search decoding.
55
+ """
56
+ temperature: Optional[float] = None
57
+ """The value used to modulate the next token probabilities."""
58
+ top_k: Optional[int] = None
59
+ """The number of highest probability vocabulary tokens to keep for top-k-filtering."""
60
+ top_p: Optional[float] = None
61
+ """If set to float < 1, only the smallest set of most probable tokens with probabilities
62
+ that add up to top_p or higher are kept for generation.
63
+ """
64
+ typical_p: Optional[float] = None
65
+ """Local typicality measures how similar the conditional probability of predicting a target
66
+ token next is to the expected conditional probability of predicting a random token next,
67
+ given the partial text already generated. If set to float < 1, the smallest set of the
68
+ most locally typical tokens with probabilities that add up to typical_p or higher are
69
+ kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
70
+ """
71
+ use_cache: Optional[bool] = None
72
+ """Whether the model should use the past last key/values attentions to speed up decoding"""
73
+
74
+
75
+ @dataclass
76
+ class ImageToTextParameters(BaseInferenceType):
77
+ """Additional inference parameters for Image To Text"""
78
+
79
+ max_new_tokens: Optional[int] = None
80
+ """The amount of maximum tokens to generate."""
81
+ # Will be deprecated in the future when the renaming to `generation_parameters` is implemented in transformers
82
+ generate_kwargs: Optional[ImageToTextGenerationParameters] = None
83
+ """Parametrization of the text generation process"""
84
+
85
+
86
+ @dataclass
87
+ class ImageToTextInput(BaseInferenceType):
88
+ """Inputs for Image To Text inference"""
89
+
90
+ inputs: Any
91
+ """The input image data"""
92
+ parameters: Optional[ImageToTextParameters] = None
93
+ """Additional inference parameters for Image To Text"""
94
+
95
+
96
+ @dataclass
97
+ class ImageToTextOutput(BaseInferenceType):
98
+ """Outputs of inference for the Image To Text task"""
99
+
100
+ generated_text: Any
101
+ image_to_text_output_generated_text: Optional[str] = None
102
+ """The generated text."""
.venv/lib/python3.11/site-packages/huggingface_hub/inference/_generated/types/object_detection.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ @dataclass
13
+ class ObjectDetectionParameters(BaseInferenceType):
14
+ """Additional inference parameters for Object Detection"""
15
+
16
+ threshold: Optional[float] = None
17
+ """The probability necessary to make a prediction."""
18
+
19
+
20
+ @dataclass
21
+ class ObjectDetectionInput(BaseInferenceType):
22
+ """Inputs for Object Detection inference"""
23
+
24
+ inputs: str
25
+ """The input image data as a base64-encoded string. If no `parameters` are provided, you can
26
+ also provide the image data as a raw bytes payload.
27
+ """
28
+ parameters: Optional[ObjectDetectionParameters] = None
29
+ """Additional inference parameters for Object Detection"""
30
+
31
+
32
+ @dataclass
33
+ class ObjectDetectionBoundingBox(BaseInferenceType):
34
+ """The predicted bounding box. Coordinates are relative to the top left corner of the input
35
+ image.
36
+ """
37
+
38
+ xmax: int
39
+ """The x-coordinate of the bottom-right corner of the bounding box."""
40
+ xmin: int
41
+ """The x-coordinate of the top-left corner of the bounding box."""
42
+ ymax: int
43
+ """The y-coordinate of the bottom-right corner of the bounding box."""
44
+ ymin: int
45
+ """The y-coordinate of the top-left corner of the bounding box."""
46
+
47
+
48
+ @dataclass
49
+ class ObjectDetectionOutputElement(BaseInferenceType):
50
+ """Outputs of inference for the Object Detection task"""
51
+
52
+ box: ObjectDetectionBoundingBox
53
+ """The predicted bounding box. Coordinates are relative to the top left corner of the input
54
+ image.
55
+ """
56
+ label: str
57
+ """The predicted label for the bounding box."""
58
+ score: float
59
+ """The associated score / probability."""