diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5c3159c1367e70f760e7d4b3bd42c611b6a01da Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/_version.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/_version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef4b0508bb156c58a193de2a920650a28de40ebc Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/_version.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/archive.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/archive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd2a928aad2520c20aac150187814a723adc8630 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/archive.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/asyn.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/asyn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df7b61f65202dfdd54aae1f7030c4602b654db72 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/asyn.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/caching.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/caching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3891c0fa837eb5b616e9e0a3fd03726b93a24744 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/caching.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/callbacks.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/callbacks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfec5f94e50aa019751ea1561e522576d37b89ce Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/callbacks.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/compression.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/compression.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00011ce5bbf3e0e7066eb3639ae840a7ea380102 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/compression.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/config.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de633e5b4d7c2901ad09a85071c53785faef002b Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/config.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/conftest.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/conftest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85ea78498d9568cc002ff94074ab0cfd309fe581 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/conftest.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/core.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..433de20a7746831a355aa4774cac95ec137aa2a7 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/core.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/dircache.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/dircache.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5ff8d7cb791ae5ded1215ccfbabde5e1ab7b66b Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/dircache.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/exceptions.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b52df4fb9eea51089fddeaed4f7081f0da83bd97 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/exceptions.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/fuse.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/fuse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..271f41b455ad7a494803ff32ac340d24d3904109 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/fuse.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/generic.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/generic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d9d0802296f0207d5682edd0d0e623dfe5f7de3 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/generic.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/gui.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/gui.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d23ab430adb50a52a800edf0ffe1f208efba5895 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/gui.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/json.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/json.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d26566e20544937827b74fe8387c70794a4a05 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/json.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/mapping.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/mapping.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c2260ed9e2c70349dc000865f5e993928f472b Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/mapping.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/parquet.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/parquet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30d7a95a85a947b9c06e9ec90744af5665ad617e Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/parquet.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/registry.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e6f715609b3fc665fafb3e4e94cd3b8c817758 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/registry.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/spec.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/spec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ba3fa82b1fc4eee3a9cb7f7e15bd427e0a00316 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/spec.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/transaction.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/transaction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..831238f7670dfe9d75469bbaa9313a70b6bc1f19 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/transaction.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/__pycache__/utils.cpython-310.pyc b/venv/lib/python3.10/site-packages/fsspec/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67d50db26391dca17a07e3906b2485bf4ba96133 Binary files /dev/null and b/venv/lib/python3.10/site-packages/fsspec/__pycache__/utils.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/__init__.py b/venv/lib/python3.10/site-packages/fsspec/implementations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/arrow.py b/venv/lib/python3.10/site-packages/fsspec/implementations/arrow.py new file mode 100644 index 0000000000000000000000000000000000000000..227d50930763f56bc8c01556249eb1306347350b --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/arrow.py @@ -0,0 +1,312 @@ +import errno +import io +import os +import secrets +import shutil +from contextlib import suppress +from functools import cached_property, wraps +from urllib.parse import parse_qs + +from fsspec.spec import AbstractFileSystem +from fsspec.utils import ( + get_package_version_without_import, + infer_storage_options, + mirror_from, + tokenize, +) + + +def wrap_exceptions(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except OSError as exception: + if not exception.args: + raise + + message, *args = exception.args + if isinstance(message, str) and "does not exist" in message: + raise FileNotFoundError(errno.ENOENT, message) from exception + else: + raise + + return wrapper + + +PYARROW_VERSION = None + + +class ArrowFSWrapper(AbstractFileSystem): + """FSSpec-compatible wrapper of pyarrow.fs.FileSystem. + + Parameters + ---------- + fs : pyarrow.fs.FileSystem + + """ + + root_marker = "/" + + def __init__(self, fs, **kwargs): + global PYARROW_VERSION + PYARROW_VERSION = get_package_version_without_import("pyarrow") + self.fs = fs + super().__init__(**kwargs) + + @property + def protocol(self): + return self.fs.type_name + + @cached_property + def fsid(self): + return "hdfs_" + tokenize(self.fs.host, self.fs.port) + + @classmethod + def _strip_protocol(cls, path): + ops = infer_storage_options(path) + path = ops["path"] + if path.startswith("//"): + # special case for "hdfs://path" (without the triple slash) + path = path[1:] + return path + + def ls(self, path, detail=False, **kwargs): + path = self._strip_protocol(path) + from pyarrow.fs import FileSelector + + try: + entries = [ + self._make_entry(entry) + for entry in self.fs.get_file_info(FileSelector(path)) + ] + except (FileNotFoundError, NotADirectoryError): + entries = [self.info(path, **kwargs)] + if detail: + return entries + else: + return [entry["name"] for entry in entries] + + def info(self, path, **kwargs): + path = self._strip_protocol(path) + [info] = self.fs.get_file_info([path]) + return self._make_entry(info) + + def exists(self, path): + path = self._strip_protocol(path) + try: + self.info(path) + except FileNotFoundError: + return False + else: + return True + + def _make_entry(self, info): + from pyarrow.fs import FileType + + if info.type is FileType.Directory: + kind = "directory" + elif info.type is FileType.File: + kind = "file" + elif info.type is FileType.NotFound: + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), info.path) + else: + kind = "other" + + return { + "name": info.path, + "size": info.size, + "type": kind, + "mtime": info.mtime, + } + + @wrap_exceptions + def cp_file(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1).rstrip("/") + path2 = self._strip_protocol(path2).rstrip("/") + + with self._open(path1, "rb") as lstream: + tmp_fname = f"{path2}.tmp.{secrets.token_hex(6)}" + try: + with self.open(tmp_fname, "wb") as rstream: + shutil.copyfileobj(lstream, rstream) + self.fs.move(tmp_fname, path2) + except BaseException: + with suppress(FileNotFoundError): + self.fs.delete_file(tmp_fname) + raise + + @wrap_exceptions + def mv(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1).rstrip("/") + path2 = self._strip_protocol(path2).rstrip("/") + self.fs.move(path1, path2) + + @wrap_exceptions + def rm_file(self, path): + path = self._strip_protocol(path) + self.fs.delete_file(path) + + @wrap_exceptions + def rm(self, path, recursive=False, maxdepth=None): + path = self._strip_protocol(path).rstrip("/") + if self.isdir(path): + if recursive: + self.fs.delete_dir(path) + else: + raise ValueError("Can't delete directories without recursive=False") + else: + self.fs.delete_file(path) + + @wrap_exceptions + def _open(self, path, mode="rb", block_size=None, seekable=True, **kwargs): + if mode == "rb": + if seekable: + method = self.fs.open_input_file + else: + method = self.fs.open_input_stream + elif mode == "wb": + method = self.fs.open_output_stream + elif mode == "ab": + method = self.fs.open_append_stream + else: + raise ValueError(f"unsupported mode for Arrow filesystem: {mode!r}") + + _kwargs = {} + if mode != "rb" or not seekable: + if int(PYARROW_VERSION.split(".")[0]) >= 4: + # disable compression auto-detection + _kwargs["compression"] = None + stream = method(path, **_kwargs) + + return ArrowFile(self, stream, path, mode, block_size, **kwargs) + + @wrap_exceptions + def mkdir(self, path, create_parents=True, **kwargs): + path = self._strip_protocol(path) + if create_parents: + self.makedirs(path, exist_ok=True) + else: + self.fs.create_dir(path, recursive=False) + + @wrap_exceptions + def makedirs(self, path, exist_ok=False): + path = self._strip_protocol(path) + self.fs.create_dir(path, recursive=True) + + @wrap_exceptions + def rmdir(self, path): + path = self._strip_protocol(path) + self.fs.delete_dir(path) + + @wrap_exceptions + def modified(self, path): + path = self._strip_protocol(path) + return self.fs.get_file_info(path).mtime + + def cat_file(self, path, start=None, end=None, **kwargs): + kwargs.setdefault("seekable", start not in [None, 0]) + return super().cat_file(path, start=None, end=None, **kwargs) + + def get_file(self, rpath, lpath, **kwargs): + kwargs.setdefault("seekable", False) + super().get_file(rpath, lpath, **kwargs) + + +@mirror_from( + "stream", + [ + "read", + "seek", + "tell", + "write", + "readable", + "writable", + "close", + "seekable", + ], +) +class ArrowFile(io.IOBase): + def __init__(self, fs, stream, path, mode, block_size=None, **kwargs): + self.path = path + self.mode = mode + + self.fs = fs + self.stream = stream + + self.blocksize = self.block_size = block_size + self.kwargs = kwargs + + def __enter__(self): + return self + + @property + def size(self): + if self.stream.seekable(): + return self.stream.size() + return None + + def __exit__(self, *args): + return self.close() + + +class HadoopFileSystem(ArrowFSWrapper): + """A wrapper on top of the pyarrow.fs.HadoopFileSystem + to connect it's interface with fsspec""" + + protocol = "hdfs" + + def __init__( + self, + host="default", + port=0, + user=None, + kerb_ticket=None, + replication=3, + extra_conf=None, + **kwargs, + ): + """ + + Parameters + ---------- + host: str + Hostname, IP or "default" to try to read from Hadoop config + port: int + Port to connect on, or default from Hadoop config if 0 + user: str or None + If given, connect as this username + kerb_ticket: str or None + If given, use this ticket for authentication + replication: int + set replication factor of file for write operations. default value is 3. + extra_conf: None or dict + Passed on to HadoopFileSystem + """ + from pyarrow.fs import HadoopFileSystem + + fs = HadoopFileSystem( + host=host, + port=port, + user=user, + kerb_ticket=kerb_ticket, + replication=replication, + extra_conf=extra_conf, + ) + super().__init__(fs=fs, **kwargs) + + @staticmethod + def _get_kwargs_from_urls(path): + ops = infer_storage_options(path) + out = {} + if ops.get("host", None): + out["host"] = ops["host"] + if ops.get("username", None): + out["user"] = ops["username"] + if ops.get("port", None): + out["port"] = ops["port"] + if ops.get("url_query", None): + queries = parse_qs(ops["url_query"]) + if queries.get("replication", None): + out["replication"] = int(queries["replication"][0]) + return out diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/asyn_wrapper.py b/venv/lib/python3.10/site-packages/fsspec/implementations/asyn_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..91db5eb48d00e36b46d9deb49504a7d2ad76d690 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/asyn_wrapper.py @@ -0,0 +1,124 @@ +import asyncio +import functools +import inspect + +import fsspec +from fsspec.asyn import AsyncFileSystem, running_async + +from .chained import ChainedFileSystem + + +def async_wrapper(func, obj=None, semaphore=None): + """ + Wraps a synchronous function to make it awaitable. + + Parameters + ---------- + func : callable + The synchronous function to wrap. + obj : object, optional + The instance to bind the function to, if applicable. + semaphore : asyncio.Semaphore, optional + A semaphore to limit concurrent calls. + + Returns + ------- + coroutine + An awaitable version of the function. + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + if semaphore: + async with semaphore: + return await asyncio.to_thread(func, *args, **kwargs) + return await asyncio.to_thread(func, *args, **kwargs) + + return wrapper + + +class AsyncFileSystemWrapper(AsyncFileSystem, ChainedFileSystem): + """ + A wrapper class to convert a synchronous filesystem into an asynchronous one. + + This class takes an existing synchronous filesystem implementation and wraps all + its methods to provide an asynchronous interface. + + Parameters + ---------- + sync_fs : AbstractFileSystem + The synchronous filesystem instance to wrap. + """ + + protocol = "asyncwrapper", "async_wrapper" + cachable = False + + def __init__( + self, + fs=None, + asynchronous=None, + target_protocol=None, + target_options=None, + semaphore=None, + max_concurrent_tasks=None, + **kwargs, + ): + if asynchronous is None: + asynchronous = running_async() + super().__init__(asynchronous=asynchronous, **kwargs) + if fs is not None: + self.sync_fs = fs + else: + self.sync_fs = fsspec.filesystem(target_protocol, **target_options) + self.protocol = self.sync_fs.protocol + self.semaphore = semaphore + self._wrap_all_sync_methods() + + @property + def fsid(self): + return f"async_{self.sync_fs.fsid}" + + def _wrap_all_sync_methods(self): + """ + Wrap all synchronous methods of the underlying filesystem with asynchronous versions. + """ + excluded_methods = {"open"} + for method_name in dir(self.sync_fs): + if method_name.startswith("_") or method_name in excluded_methods: + continue + + attr = inspect.getattr_static(self.sync_fs, method_name) + if isinstance(attr, property): + continue + + method = getattr(self.sync_fs, method_name) + if callable(method) and not inspect.iscoroutinefunction(method): + async_method = async_wrapper(method, obj=self, semaphore=self.semaphore) + setattr(self, f"_{method_name}", async_method) + + @classmethod + def wrap_class(cls, sync_fs_class): + """ + Create a new class that can be used to instantiate an AsyncFileSystemWrapper + with lazy instantiation of the underlying synchronous filesystem. + + Parameters + ---------- + sync_fs_class : type + The class of the synchronous filesystem to wrap. + + Returns + ------- + type + A new class that wraps the provided synchronous filesystem class. + """ + + class GeneratedAsyncFileSystemWrapper(cls): + def __init__(self, *args, **kwargs): + sync_fs = sync_fs_class(*args, **kwargs) + super().__init__(sync_fs) + + GeneratedAsyncFileSystemWrapper.__name__ = ( + f"Async{sync_fs_class.__name__}Wrapper" + ) + return GeneratedAsyncFileSystemWrapper diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/cache_mapper.py b/venv/lib/python3.10/site-packages/fsspec/implementations/cache_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7c7d88afdddf12f77b26bb635bd8bf1e2bd7f1 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/cache_mapper.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import abc +import hashlib + +from fsspec.implementations.local import make_path_posix + + +class AbstractCacheMapper(abc.ABC): + """Abstract super-class for mappers from remote URLs to local cached + basenames. + """ + + @abc.abstractmethod + def __call__(self, path: str) -> str: ... + + def __eq__(self, other: object) -> bool: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return isinstance(other, type(self)) + + def __hash__(self) -> int: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return hash(type(self)) + + +class BasenameCacheMapper(AbstractCacheMapper): + """Cache mapper that uses the basename of the remote URL and a fixed number + of directory levels above this. + + The default is zero directory levels, meaning different paths with the same + basename will have the same cached basename. + """ + + def __init__(self, directory_levels: int = 0): + if directory_levels < 0: + raise ValueError( + "BasenameCacheMapper requires zero or positive directory_levels" + ) + self.directory_levels = directory_levels + + # Separator for directories when encoded as strings. + self._separator = "_@_" + + def __call__(self, path: str) -> str: + path = make_path_posix(path) + prefix, *bits = path.rsplit("/", self.directory_levels + 1) + if bits: + return self._separator.join(bits) + else: + return prefix # No separator found, simple filename + + def __eq__(self, other: object) -> bool: + return super().__eq__(other) and self.directory_levels == other.directory_levels + + def __hash__(self) -> int: + return super().__hash__() ^ hash(self.directory_levels) + + +class HashCacheMapper(AbstractCacheMapper): + """Cache mapper that uses a hash of the remote URL.""" + + def __call__(self, path: str) -> str: + return hashlib.sha256(path.encode()).hexdigest() + + +def create_cache_mapper(same_names: bool) -> AbstractCacheMapper: + """Factory method to create cache mapper for backward compatibility with + ``CachingFileSystem`` constructor using ``same_names`` kwarg. + """ + if same_names: + return BasenameCacheMapper() + else: + return HashCacheMapper() diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/cache_metadata.py b/venv/lib/python3.10/site-packages/fsspec/implementations/cache_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1f7eb7f846186606921ff6a1539442a0899506 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/cache_metadata.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import os +import pickle +import time +from typing import TYPE_CHECKING + +from fsspec.utils import atomic_write + +try: + import ujson as json +except ImportError: + if not TYPE_CHECKING: + import json + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any, Literal, TypeAlias + + from .cached import CachingFileSystem + + Detail: TypeAlias = dict[str, Any] + + +class CacheMetadata: + """Cache metadata. + + All reading and writing of cache metadata is performed by this class, + accessing the cached files and blocks is not. + + Metadata is stored in a single file per storage directory in JSON format. + For backward compatibility, also reads metadata stored in pickle format + which is converted to JSON when next saved. + """ + + def __init__(self, storage: list[str]): + """ + + Parameters + ---------- + storage: list[str] + Directories containing cached files, must be at least one. Metadata + is stored in the last of these directories by convention. + """ + if not storage: + raise ValueError("CacheMetadata expects at least one storage location") + + self._storage = storage + self.cached_files: list[Detail] = [{}] + + # Private attribute to force saving of metadata in pickle format rather than + # JSON for use in tests to confirm can read both pickle and JSON formats. + self._force_save_pickle = False + + def _load(self, fn: str) -> Detail: + """Low-level function to load metadata from specific file""" + try: + with open(fn, "r") as f: + loaded = json.load(f) + except ValueError: + with open(fn, "rb") as f: + loaded = pickle.load(f) + for c in loaded.values(): + if isinstance(c.get("blocks"), list): + c["blocks"] = set(c["blocks"]) + return loaded + + def _save(self, metadata_to_save: Detail, fn: str) -> None: + """Low-level function to save metadata to specific file""" + if self._force_save_pickle: + with atomic_write(fn) as f: + pickle.dump(metadata_to_save, f) + else: + with atomic_write(fn, mode="w") as f: + json.dump(metadata_to_save, f) + + def _scan_locations( + self, writable_only: bool = False + ) -> Iterator[tuple[str, str, bool]]: + """Yield locations (filenames) where metadata is stored, and whether + writable or not. + + Parameters + ---------- + writable: bool + Set to True to only yield writable locations. + + Returns + ------- + Yields (str, str, bool) + """ + n = len(self._storage) + for i, storage in enumerate(self._storage): + writable = i == n - 1 + if writable_only and not writable: + continue + yield os.path.join(storage, "cache"), storage, writable + + def check_file( + self, path: str, cfs: CachingFileSystem | None + ) -> Literal[False] | tuple[Detail, str]: + """If path is in cache return its details, otherwise return ``False``. + + If the optional CachingFileSystem is specified then it is used to + perform extra checks to reject possible matches, such as if they are + too old. + """ + for (fn, base, _), cache in zip(self._scan_locations(), self.cached_files): + if path not in cache: + continue + detail = cache[path].copy() + + if cfs is not None: + if cfs.check_files and detail["uid"] != cfs.fs.ukey(path): + # Wrong file as determined by hash of file properties + continue + if cfs.expiry and time.time() - detail["time"] > cfs.expiry: + # Cached file has expired + continue + + fn = os.path.join(base, detail["fn"]) + if os.path.exists(fn): + return detail, fn + return False + + def clear_expired(self, expiry_time: int) -> tuple[list[str], bool]: + """Remove expired metadata from the cache. + + Returns names of files corresponding to expired metadata and a boolean + flag indicating whether the writable cache is empty. Caller is + responsible for deleting the expired files. + """ + expired_files = [] + for path, detail in self.cached_files[-1].copy().items(): + if time.time() - detail["time"] > expiry_time: + fn = detail.get("fn", "") + if not fn: + raise RuntimeError( + f"Cache metadata does not contain 'fn' for {path}" + ) + fn = os.path.join(self._storage[-1], fn) + expired_files.append(fn) + self.cached_files[-1].pop(path) + + if self.cached_files[-1]: + cache_path = os.path.join(self._storage[-1], "cache") + self._save(self.cached_files[-1], cache_path) + + writable_cache_empty = not self.cached_files[-1] + return expired_files, writable_cache_empty + + def load(self) -> None: + """Load all metadata from disk and store in ``self.cached_files``""" + cached_files = [] + for fn, _, _ in self._scan_locations(): + if os.path.exists(fn): + # TODO: consolidate blocks here + cached_files.append(self._load(fn)) + else: + cached_files.append({}) + self.cached_files = cached_files or [{}] + + def on_close_cached_file(self, f: Any, path: str) -> None: + """Perform side-effect actions on closing a cached file. + + The actual closing of the file is the responsibility of the caller. + """ + # File must be writeble, so in self.cached_files[-1] + c = self.cached_files[-1][path] + if c["blocks"] is not True and len(c["blocks"]) * f.blocksize >= f.size: + c["blocks"] = True + + def pop_file(self, path: str) -> str | None: + """Remove metadata of cached file. + + If path is in the cache, return the filename of the cached file, + otherwise return ``None``. Caller is responsible for deleting the + cached file. + """ + details = self.check_file(path, None) + if not details: + return None + _, fn = details + if fn.startswith(self._storage[-1]): + self.cached_files[-1].pop(path) + self.save() + else: + raise PermissionError( + "Can only delete cached file in last, writable cache location" + ) + return fn + + def save(self) -> None: + """Save metadata to disk""" + for (fn, _, writable), cache in zip(self._scan_locations(), self.cached_files): + if not writable: + continue + + if os.path.exists(fn): + cached_files = self._load(fn) + for k, c in cached_files.items(): + if k in cache: + if c["blocks"] is True or cache[k]["blocks"] is True: + c["blocks"] = True + else: + # self.cached_files[*][*]["blocks"] must continue to + # point to the same set object so that updates + # performed by MMapCache are propagated back to + # self.cached_files. + blocks = cache[k]["blocks"] + blocks.update(c["blocks"]) + c["blocks"] = blocks + c["time"] = max(c["time"], cache[k]["time"]) + c["uid"] = cache[k]["uid"] + + # Files can be added to cache after it was written once + for k, c in cache.items(): + if k not in cached_files: + cached_files[k] = c + else: + cached_files = cache + cache = {k: v.copy() for k, v in cached_files.items()} + for c in cache.values(): + if isinstance(c["blocks"], set): + c["blocks"] = list(c["blocks"]) + self._save(cache, fn) + self.cached_files[-1] = cached_files + + def update_file(self, path: str, detail: Detail) -> None: + """Update metadata for specific file in memory, do not save""" + self.cached_files[-1][path] = detail diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/cached.py b/venv/lib/python3.10/site-packages/fsspec/implementations/cached.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4fc0f662d6989da346a5e672e06466156f5f21 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/cached.py @@ -0,0 +1,1021 @@ +from __future__ import annotations + +import inspect +import logging +import os +import tempfile +import time +import weakref +from collections.abc import Callable +from shutil import rmtree +from typing import TYPE_CHECKING, Any, ClassVar + +from fsspec import filesystem +from fsspec.callbacks import DEFAULT_CALLBACK +from fsspec.compression import compr +from fsspec.core import BaseCache, MMapCache +from fsspec.exceptions import BlocksizeMismatchError +from fsspec.implementations.cache_mapper import create_cache_mapper +from fsspec.implementations.cache_metadata import CacheMetadata +from fsspec.implementations.chained import ChainedFileSystem +from fsspec.implementations.local import LocalFileSystem +from fsspec.spec import AbstractBufferedFile +from fsspec.transaction import Transaction +from fsspec.utils import infer_compression + +if TYPE_CHECKING: + from fsspec.implementations.cache_mapper import AbstractCacheMapper + +logger = logging.getLogger("fsspec.cached") + + +class WriteCachedTransaction(Transaction): + def complete(self, commit=True): + rpaths = [f.path for f in self.files] + lpaths = [f.fn for f in self.files] + if commit: + self.fs.put(lpaths, rpaths) + self.files.clear() + self.fs._intrans = False + self.fs._transaction = None + self.fs = None # break cycle + + +class CachingFileSystem(ChainedFileSystem): + """Locally caching filesystem, layer over any other FS + + This class implements chunk-wise local storage of remote files, for quick + access after the initial download. The files are stored in a given + directory with hashes of URLs for the filenames. If no directory is given, + a temporary one is used, which should be cleaned up by the OS after the + process ends. The files themselves are sparse (as implemented in + :class:`~fsspec.caching.MMapCache`), so only the data which is accessed + takes up space. + + Restrictions: + + - the block-size must be the same for each access of a given file, unless + all blocks of the file have already been read + - caching can only be applied to file-systems which produce files + derived from fsspec.spec.AbstractBufferedFile ; LocalFileSystem is also + allowed, for testing + """ + + protocol: ClassVar[str | tuple[str, ...]] = ("blockcache", "cached") + _strip_tokenize_options = ("fo",) + + def __init__( + self, + target_protocol=None, + cache_storage="TMP", + cache_check=10, + check_files=False, + expiry_time=604800, + target_options=None, + fs=None, + same_names: bool | None = None, + compression=None, + cache_mapper: AbstractCacheMapper | None = None, + **kwargs, + ): + """ + + Parameters + ---------- + target_protocol: str (optional) + Target filesystem protocol. Provide either this or ``fs``. + cache_storage: str or list(str) + Location to store files. If "TMP", this is a temporary directory, + and will be cleaned up by the OS when this process ends (or later). + If a list, each location will be tried in the order given, but + only the last will be considered writable. + cache_check: int + Number of seconds between reload of cache metadata + check_files: bool + Whether to explicitly see if the UID of the remote file matches + the stored one before using. Warning: some file systems such as + HTTP cannot reliably give a unique hash of the contents of some + path, so be sure to set this option to False. + expiry_time: int + The time in seconds after which a local copy is considered useless. + Set to falsy to prevent expiry. The default is equivalent to one + week. + target_options: dict or None + Passed to the instantiation of the FS, if fs is None. + fs: filesystem instance + The target filesystem to run against. Provide this or ``protocol``. + same_names: bool (optional) + By default, target URLs are hashed using a ``HashCacheMapper`` so + that files from different backends with the same basename do not + conflict. If this argument is ``true``, a ``BasenameCacheMapper`` + is used instead. Other cache mapper options are available by using + the ``cache_mapper`` keyword argument. Only one of this and + ``cache_mapper`` should be specified. + compression: str (optional) + To decompress on download. Can be 'infer' (guess from the URL name), + one of the entries in ``fsspec.compression.compr``, or None for no + decompression. + cache_mapper: AbstractCacheMapper (optional) + The object use to map from original filenames to cached filenames. + Only one of this and ``same_names`` should be specified. + """ + super().__init__(**kwargs) + if fs is None and target_protocol is None: + raise ValueError( + "Please provide filesystem instance(fs) or target_protocol" + ) + if not (fs is None) ^ (target_protocol is None): + raise ValueError( + "Both filesystems (fs) and target_protocol may not be both given." + ) + if cache_storage == "TMP": + tempdir = tempfile.mkdtemp() + storage = [tempdir] + weakref.finalize(self, self._remove_tempdir, tempdir) + else: + if isinstance(cache_storage, str): + storage = [cache_storage] + else: + storage = cache_storage + os.makedirs(storage[-1], exist_ok=True) + self.storage = storage + self.kwargs = target_options or {} + self.cache_check = cache_check + self.check_files = check_files + self.expiry = expiry_time + self.compression = compression + + # Size of cache in bytes. If None then the size is unknown and will be + # recalculated the next time cache_size() is called. On writes to the + # cache this is reset to None. + self._cache_size = None + + if same_names is not None and cache_mapper is not None: + raise ValueError( + "Cannot specify both same_names and cache_mapper in " + "CachingFileSystem.__init__" + ) + if cache_mapper is not None: + self._mapper = cache_mapper + else: + self._mapper = create_cache_mapper( + same_names if same_names is not None else False + ) + + self.target_protocol = ( + target_protocol + if isinstance(target_protocol, str) + else (fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]) + ) + self._metadata = CacheMetadata(self.storage) + self.load_cache() + self.fs = fs if fs is not None else filesystem(target_protocol, **self.kwargs) + + def _strip_protocol(path): + # acts as a method, since each instance has a difference target + return self.fs._strip_protocol(type(self)._strip_protocol(path)) + + self._strip_protocol: Callable = _strip_protocol + + @staticmethod + def _remove_tempdir(tempdir): + try: + rmtree(tempdir) + except Exception: + pass + + def _mkcache(self): + os.makedirs(self.storage[-1], exist_ok=True) + + def cache_size(self): + """Return size of cache in bytes. + + If more than one cache directory is in use, only the size of the last + one (the writable cache directory) is returned. + """ + if self._cache_size is None: + cache_dir = self.storage[-1] + self._cache_size = filesystem("file").du(cache_dir, withdirs=True) + return self._cache_size + + def load_cache(self): + """Read set of stored blocks from file""" + self._metadata.load() + self._mkcache() + self.last_cache = time.time() + + def save_cache(self): + """Save set of stored blocks from file""" + self._mkcache() + self._metadata.save() + self.last_cache = time.time() + self._cache_size = None + + def _check_cache(self): + """Reload caches if time elapsed or any disappeared""" + self._mkcache() + if not self.cache_check: + # explicitly told not to bother checking + return + timecond = time.time() - self.last_cache > self.cache_check + existcond = all(os.path.exists(storage) for storage in self.storage) + if timecond or not existcond: + self.load_cache() + + def _check_file(self, path): + """Is path in cache and still valid""" + path = self._strip_protocol(path) + self._check_cache() + return self._metadata.check_file(path, self) + + def clear_cache(self): + """Remove all files and metadata from the cache + + In the case of multiple cache locations, this clears only the last one, + which is assumed to be the read/write one. + """ + rmtree(self.storage[-1]) + self.load_cache() + self._cache_size = None + + def clear_expired_cache(self, expiry_time=None): + """Remove all expired files and metadata from the cache + + In the case of multiple cache locations, this clears only the last one, + which is assumed to be the read/write one. + + Parameters + ---------- + expiry_time: int + The time in seconds after which a local copy is considered useless. + If not defined the default is equivalent to the attribute from the + file caching instantiation. + """ + + if not expiry_time: + expiry_time = self.expiry + + self._check_cache() + + expired_files, writable_cache_empty = self._metadata.clear_expired(expiry_time) + for fn in expired_files: + if os.path.exists(fn): + os.remove(fn) + + if writable_cache_empty: + rmtree(self.storage[-1]) + self.load_cache() + + self._cache_size = None + + def pop_from_cache(self, path): + """Remove cached version of given file + + Deletes local copy of the given (remote) path. If it is found in a cache + location which is not the last, it is assumed to be read-only, and + raises PermissionError + """ + path = self._strip_protocol(path) + fn = self._metadata.pop_file(path) + if fn is not None: + os.remove(fn) + self._cache_size = None + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + """Wrap the target _open + + If the whole file exists in the cache, just open it locally and + return that. + + Otherwise, open the file on the target FS, and make it have a mmap + cache pointing to the location which we determine, in our cache. + The ``blocks`` instance is shared, so as the mmap cache instance + updates, so does the entry in our ``cached_files`` attribute. + We monkey-patch this file, so that when it closes, we call + ``close_and_update`` to save the state of the blocks. + """ + path = self._strip_protocol(path) + + path = self.fs._strip_protocol(path) + if "r" not in mode: + return self.fs._open( + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + **kwargs, + ) + detail = self._check_file(path) + if detail: + # file is in cache + detail, fn = detail + hash, blocks = detail["fn"], detail["blocks"] + if blocks is True: + # stored file is complete + logger.debug("Opening local copy of %s", path) + return open(fn, mode) + # TODO: action where partial file exists in read-only cache + logger.debug("Opening partially cached copy of %s", path) + else: + hash = self._mapper(path) + fn = os.path.join(self.storage[-1], hash) + blocks = set() + detail = { + "original": path, + "fn": hash, + "blocks": blocks, + "time": time.time(), + "uid": self.fs.ukey(path), + } + self._metadata.update_file(path, detail) + logger.debug("Creating local sparse file for %s", path) + + # explicitly submitting the size to the open call will avoid extra + # operations when opening. This is particularly relevant + # for any file that is read over a network, e.g. S3. + size = detail.get("size") + + # call target filesystems open + self._mkcache() + f = self.fs._open( + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + cache_type="none", + size=size, + **kwargs, + ) + + # set size if not already set + if size is None: + detail["size"] = f.size + self._metadata.update_file(path, detail) + + if self.compression: + comp = ( + infer_compression(path) + if self.compression == "infer" + else self.compression + ) + f = compr[comp](f, mode="rb") + if "blocksize" in detail: + if detail["blocksize"] != f.blocksize: + raise BlocksizeMismatchError( + f"Cached file must be reopened with same block" + f" size as original (old: {detail['blocksize']}," + f" new {f.blocksize})" + ) + else: + detail["blocksize"] = f.blocksize + + def _fetch_ranges(ranges): + return self.fs.cat_ranges( + [path] * len(ranges), + [r[0] for r in ranges], + [r[1] for r in ranges], + **kwargs, + ) + + multi_fetcher = None if self.compression else _fetch_ranges + f.cache = MMapCache( + f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher + ) + close = f.close + f.close = lambda: self.close_and_update(f, close) + self.save_cache() + return f + + def _parent(self, path): + return self.fs._parent(path) + + def hash_name(self, path: str, *args: Any) -> str: + # Kept for backward compatibility with downstream libraries. + # Ignores extra arguments, previously same_name boolean. + return self._mapper(path) + + def close_and_update(self, f, close): + """Called when a file is closing, so store the set of blocks""" + if f.closed: + return + path = self._strip_protocol(f.path) + self._metadata.on_close_cached_file(f, path) + try: + logger.debug("going to save") + self.save_cache() + logger.debug("saved") + except OSError: + logger.debug("Cache saving failed while closing file") + except NameError: + logger.debug("Cache save failed due to interpreter shutdown") + close() + f.closed = True + + def ls(self, path, detail=True): + return self.fs.ls(path, detail) + + def __getattribute__(self, item): + if item in { + "load_cache", + "_get_cached_file_before_open", + "_open", + "save_cache", + "close_and_update", + "__init__", + "__getattribute__", + "__reduce__", + "_make_local_details", + "open", + "cat", + "cat_file", + "_cat_file", + "cat_ranges", + "_cat_ranges", + "get", + "read_block", + "tail", + "head", + "info", + "ls", + "exists", + "isfile", + "isdir", + "_check_file", + "_check_cache", + "_mkcache", + "clear_cache", + "clear_expired_cache", + "pop_from_cache", + "local_file", + "_paths_from_path", + "get_mapper", + "open_many", + "commit_many", + "hash_name", + "__hash__", + "__eq__", + "to_json", + "to_dict", + "cache_size", + "pipe_file", + "pipe", + "start_transaction", + "end_transaction", + }: + # all the methods defined in this class. Note `open` here, since + # it calls `_open`, but is actually in superclass + return lambda *args, **kw: getattr(type(self), item).__get__(self)( + *args, **kw + ) + if item in ["__reduce_ex__"]: + raise AttributeError + if item in ["transaction"]: + # property + return type(self).transaction.__get__(self) + if item in {"_cache", "transaction_type", "protocol"}: + # class attributes + return getattr(type(self), item) + if item == "__class__": + return type(self) + d = object.__getattribute__(self, "__dict__") + fs = d.get("fs", None) # fs is not immediately defined + if item in d: + return d[item] + elif fs is not None: + if item in fs.__dict__: + # attribute of instance + return fs.__dict__[item] + # attributed belonging to the target filesystem + cls = type(fs) + m = getattr(cls, item) + if (inspect.isfunction(m) or inspect.isdatadescriptor(m)) and ( + not hasattr(m, "__self__") or m.__self__ is None + ): + # instance method + return m.__get__(fs, cls) + return m # class method or attribute + else: + # attributes of the superclass, while target is being set up + return super().__getattribute__(item) + + def __eq__(self, other): + """Test for equality.""" + if self is other: + return True + if not isinstance(other, type(self)): + return False + return ( + self.storage == other.storage + and self.kwargs == other.kwargs + and self.cache_check == other.cache_check + and self.check_files == other.check_files + and self.expiry == other.expiry + and self.compression == other.compression + and self._mapper == other._mapper + and self.target_protocol == other.target_protocol + ) + + def __hash__(self): + """Calculate hash.""" + return ( + hash(tuple(self.storage)) + ^ hash(str(self.kwargs)) + ^ hash(self.cache_check) + ^ hash(self.check_files) + ^ hash(self.expiry) + ^ hash(self.compression) + ^ hash(self._mapper) + ^ hash(self.target_protocol) + ) + + +class WholeFileCacheFileSystem(CachingFileSystem): + """Caches whole remote files on first access + + This class is intended as a layer over any other file system, and + will make a local copy of each file accessed, so that all subsequent + reads are local. This is similar to ``CachingFileSystem``, but without + the block-wise functionality and so can work even when sparse files + are not allowed. See its docstring for definition of the init + arguments. + + The class still needs access to the remote store for listing files, + and may refresh cached files. + """ + + protocol = "filecache" + local_file = True + + def open_many(self, open_files, **kwargs): + paths = [of.path for of in open_files] + if "r" in open_files.mode: + self._mkcache() + else: + return [ + LocalTempFile( + self.fs, + path, + mode=open_files.mode, + fn=os.path.join(self.storage[-1], self._mapper(path)), + **kwargs, + ) + for path in paths + ] + + if self.compression: + raise NotImplementedError + details = [self._check_file(sp) for sp in paths] + downpath = [p for p, d in zip(paths, details) if not d] + downfn0 = [ + os.path.join(self.storage[-1], self._mapper(p)) + for p, d in zip(paths, details) + ] # keep these path names for opening later + downfn = [fn for fn, d in zip(downfn0, details) if not d] + if downpath: + # skip if all files are already cached and up to date + self.fs.get(downpath, downfn) + + # update metadata - only happens when downloads are successful + newdetail = [ + { + "original": path, + "fn": self._mapper(path), + "blocks": True, + "time": time.time(), + "uid": self.fs.ukey(path), + } + for path in downpath + ] + for path, detail in zip(downpath, newdetail): + self._metadata.update_file(path, detail) + self.save_cache() + + def firstpart(fn): + # helper to adapt both whole-file and simple-cache + return fn[1] if isinstance(fn, tuple) else fn + + return [ + open(firstpart(fn0) if fn0 else fn1, mode=open_files.mode) + for fn0, fn1 in zip(details, downfn0) + ] + + def commit_many(self, open_files): + self.fs.put([f.fn for f in open_files], [f.path for f in open_files]) + [f.close() for f in open_files] + for f in open_files: + # in case autocommit is off, and so close did not already delete + try: + os.remove(f.name) + except FileNotFoundError: + pass + self._cache_size = None + + def _make_local_details(self, path): + hash = self._mapper(path) + fn = os.path.join(self.storage[-1], hash) + detail = { + "original": path, + "fn": hash, + "blocks": True, + "time": time.time(), + "uid": self.fs.ukey(path), + } + self._metadata.update_file(path, detail) + logger.debug("Copying %s to local cache", path) + return fn + + def cat( + self, + path, + recursive=False, + on_error="raise", + callback=DEFAULT_CALLBACK, + **kwargs, + ): + paths = self.expand_path( + path, recursive=recursive, maxdepth=kwargs.get("maxdepth") + ) + getpaths = [] + storepaths = [] + fns = [] + out = {} + for p in paths.copy(): + try: + detail = self._check_file(p) + if not detail: + fn = self._make_local_details(p) + getpaths.append(p) + storepaths.append(fn) + else: + detail, fn = detail if isinstance(detail, tuple) else (None, detail) + fns.append(fn) + except Exception as e: + if on_error == "raise": + raise + if on_error == "return": + out[p] = e + paths.remove(p) + + if getpaths: + self.fs.get(getpaths, storepaths) + self.save_cache() + + callback.set_size(len(paths)) + for p, fn in zip(paths, fns): + with open(fn, "rb") as f: + out[p] = f.read() + callback.relative_update(1) + if isinstance(path, str) and len(paths) == 1 and recursive is False: + out = out[paths[0]] + return out + + def _get_cached_file_before_open(self, path, **kwargs): + fn = self._make_local_details(path) + # call target filesystems open + self._mkcache() + if self.compression: + with self.fs._open(path, mode="rb", **kwargs) as f, open(fn, "wb") as f2: + if isinstance(f, AbstractBufferedFile): + # want no type of caching if just downloading whole thing + f.cache = BaseCache(0, f.cache.fetcher, f.size) + comp = ( + infer_compression(path) + if self.compression == "infer" + else self.compression + ) + f = compr[comp](f, mode="rb") + data = True + while data: + block = getattr(f, "blocksize", 5 * 2**20) + data = f.read(block) + f2.write(data) + else: + self.fs.get_file(path, fn) + self.save_cache() + + def _open(self, path, mode="rb", **kwargs): + path = self._strip_protocol(path) + # For read (or append), (try) download from remote + if "r" in mode or "a" in mode: + if not self._check_file(path): + if self.fs.exists(path): + self._get_cached_file_before_open(path, **kwargs) + elif "r" in mode: + raise FileNotFoundError(path) + + detail, fn = self._check_file(path) + _, blocks = detail["fn"], detail["blocks"] + if blocks is True: + logger.debug("Opening local copy of %s", path) + else: + raise ValueError( + f"Attempt to open partially cached file {path}" + f" as a wholly cached file" + ) + + # Just reading does not need special file handling + if "r" in mode and "+" not in mode: + # In order to support downstream filesystems to be able to + # infer the compression from the original filename, like + # the `TarFileSystem`, let's extend the `io.BufferedReader` + # fileobject protocol by adding a dedicated attribute + # `original`. + f = open(fn, mode) + f.original = detail.get("original") + return f + + hash = self._mapper(path) + fn = os.path.join(self.storage[-1], hash) + user_specified_kwargs = { + k: v + for k, v in kwargs.items() + # those kwargs were added by open(), we don't want them + if k not in ["autocommit", "block_size", "cache_options"] + } + return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs) + + +class SimpleCacheFileSystem(WholeFileCacheFileSystem): + """Caches whole remote files on first access + + This class is intended as a layer over any other file system, and + will make a local copy of each file accessed, so that all subsequent + reads are local. This implementation only copies whole files, and + does not keep any metadata about the download time or file details. + It is therefore safer to use in multi-threaded/concurrent situations. + + This is the only of the caching filesystems that supports write: you will + be given a real local open file, and upon close and commit, it will be + uploaded to the target filesystem; the writability or the target URL is + not checked until that time. + + """ + + protocol = "simplecache" + local_file = True + transaction_type = WriteCachedTransaction + + def __init__(self, **kwargs): + kw = kwargs.copy() + for key in ["cache_check", "expiry_time", "check_files"]: + kw[key] = False + super().__init__(**kw) + for storage in self.storage: + if not os.path.exists(storage): + os.makedirs(storage, exist_ok=True) + + def _check_file(self, path): + self._check_cache() + sha = self._mapper(path) + for storage in self.storage: + fn = os.path.join(storage, sha) + if os.path.exists(fn): + return fn + + def save_cache(self): + pass + + def load_cache(self): + pass + + def pipe_file(self, path, value=None, **kwargs): + if self._intrans: + with self.open(path, "wb") as f: + f.write(value) + else: + super().pipe_file(path, value) + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + details = [] + try: + details = self.fs.ls( + path, detail=True, **kwargs + ).copy() # don't edit original! + except FileNotFoundError as e: + ex = e + else: + ex = None + if self._intrans: + path1 = path.rstrip("/") + "/" + for f in self.transaction.files: + if f.path == path: + details.append( + {"name": path, "size": f.size or f.tell(), "type": "file"} + ) + elif f.path.startswith(path1): + if f.path.count("/") == path1.count("/"): + details.append( + {"name": f.path, "size": f.size or f.tell(), "type": "file"} + ) + else: + dname = "/".join(f.path.split("/")[: path1.count("/") + 1]) + details.append({"name": dname, "size": 0, "type": "directory"}) + if ex is not None and not details: + raise ex + if detail: + return details + return sorted(_["name"] for _ in details) + + def info(self, path, **kwargs): + path = self._strip_protocol(path) + if self._intrans: + f = [_ for _ in self.transaction.files if _.path == path] + if f: + size = os.path.getsize(f[0].fn) if f[0].closed else f[0].tell() + return {"name": path, "size": size, "type": "file"} + f = any(_.path.startswith(path + "/") for _ in self.transaction.files) + if f: + return {"name": path, "size": 0, "type": "directory"} + return self.fs.info(path, **kwargs) + + def pipe(self, path, value=None, **kwargs): + if isinstance(path, str): + self.pipe_file(self._strip_protocol(path), value, **kwargs) + elif isinstance(path, dict): + for k, v in path.items(): + self.pipe_file(self._strip_protocol(k), v, **kwargs) + else: + raise ValueError("path must be str or dict") + + async def _cat_file(self, path, start=None, end=None, **kwargs): + logger.debug("async cat_file %s", path) + path = self._strip_protocol(path) + sha = self._mapper(path) + fn = self._check_file(path) + + if not fn: + fn = os.path.join(self.storage[-1], sha) + await self.fs._get_file(path, fn, **kwargs) + + with open(fn, "rb") as f: # noqa ASYNC230 + if start: + f.seek(start) + size = -1 if end is None else end - f.tell() + return f.read(size) + + async def _cat_ranges( + self, paths, starts, ends, max_gap=None, on_error="return", **kwargs + ): + logger.debug("async cat ranges %s", paths) + lpaths = [] + rset = set() + download = [] + rpaths = [] + for p in paths: + fn = self._check_file(p) + if fn is None and p not in rset: + sha = self._mapper(p) + fn = os.path.join(self.storage[-1], sha) + download.append(fn) + rset.add(p) + rpaths.append(p) + lpaths.append(fn) + if download: + await self.fs._get(rpaths, download, on_error=on_error) + + return LocalFileSystem().cat_ranges( + lpaths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs + ) + + def cat_ranges( + self, paths, starts, ends, max_gap=None, on_error="return", **kwargs + ): + logger.debug("cat ranges %s", paths) + lpaths = [self._check_file(p) for p in paths] + rpaths = [p for l, p in zip(lpaths, paths) if l is False] + lpaths = [l for l, p in zip(lpaths, paths) if l is False] + self.fs.get(rpaths, lpaths) + paths = [self._check_file(p) for p in paths] + return LocalFileSystem().cat_ranges( + paths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs + ) + + def _get_cached_file_before_open(self, path, **kwargs): + sha = self._mapper(path) + fn = os.path.join(self.storage[-1], sha) + logger.debug("Copying %s to local cache", path) + + self._mkcache() + self._cache_size = None + + if self.compression: + with self.fs._open(path, mode="rb", **kwargs) as f, open(fn, "wb") as f2: + if isinstance(f, AbstractBufferedFile): + # want no type of caching if just downloading whole thing + f.cache = BaseCache(0, f.cache.fetcher, f.size) + comp = ( + infer_compression(path) + if self.compression == "infer" + else self.compression + ) + f = compr[comp](f, mode="rb") + data = True + while data: + block = getattr(f, "blocksize", 5 * 2**20) + data = f.read(block) + f2.write(data) + else: + self.fs.get_file(path, fn) + + def _open(self, path, mode="rb", **kwargs): + path = self._strip_protocol(path) + sha = self._mapper(path) + + # For read (or append), (try) download from remote + if "r" in mode or "a" in mode: + if not self._check_file(path): + # append does not require an existing file but read does + if self.fs.exists(path): + self._get_cached_file_before_open(path, **kwargs) + elif "r" in mode: + raise FileNotFoundError(path) + + fn = self._check_file(path) + # Just reading does not need special file handling + if "r" in mode and "+" not in mode: + return open(fn, mode) + + fn = os.path.join(self.storage[-1], sha) + user_specified_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ["autocommit", "block_size", "cache_options"] + } # those were added by open() + return LocalTempFile( + self, + path, + mode=mode, + autocommit=not self._intrans, + fn=fn, + **user_specified_kwargs, + ) + + +class LocalTempFile: + """A temporary local file, which will be uploaded on commit""" + + def __init__(self, fs, path, fn, mode="wb", autocommit=True, seek=0, **kwargs): + self.fn = fn + self.fh = open(fn, mode) + self.mode = mode + if seek: + self.fh.seek(seek) + self.path = path + self.size = None + self.fs = fs + self.closed = False + self.autocommit = autocommit + self.kwargs = kwargs + + def __reduce__(self): + # always open in r+b to allow continuing writing at a location + return ( + LocalTempFile, + (self.fs, self.path, self.fn, "r+b", self.autocommit, self.tell()), + ) + + def __enter__(self): + return self.fh + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + # self.size = self.fh.tell() + if self.closed: + return + self.fh.close() + self.closed = True + if self.autocommit: + self.commit() + + def discard(self): + self.fh.close() + os.remove(self.fn) + + def commit(self): + # calling put() with list arguments avoids path expansion and additional operations + # like isdir() + self.fs.put([self.fn], [self.path], **self.kwargs) + # we do not delete the local copy, it's still in the cache. + + @property + def name(self): + return self.fn + + def __repr__(self) -> str: + return f"LocalTempFile: {self.path}" + + def __getattr__(self, item): + return getattr(self.fh, item) diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/chained.py b/venv/lib/python3.10/site-packages/fsspec/implementations/chained.py new file mode 100644 index 0000000000000000000000000000000000000000..bfce64334e8db0272eefa96b4428b23524b059f0 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/chained.py @@ -0,0 +1,23 @@ +from typing import ClassVar + +from fsspec import AbstractFileSystem + +__all__ = ("ChainedFileSystem",) + + +class ChainedFileSystem(AbstractFileSystem): + """Chained filesystem base class. + + A chained filesystem is designed to be layered over another FS. + This is useful to implement things like caching. + + This base class does very little on its own, but is used as a marker + that the class is designed for chaining. + + Right now this is only used in `url_to_fs` to provide the path argument + (`fo`) to the chained filesystem from the underlying filesystem. + + Additional functionality may be added in the future. + """ + + protocol: ClassVar[str] = "chained" diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/dask.py b/venv/lib/python3.10/site-packages/fsspec/implementations/dask.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1276463db6866665e6a0fe114efc247971b57e --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/dask.py @@ -0,0 +1,152 @@ +import dask +from distributed.client import Client, _get_global_client +from distributed.worker import Worker + +from fsspec import filesystem +from fsspec.spec import AbstractBufferedFile, AbstractFileSystem +from fsspec.utils import infer_storage_options + + +def _get_client(client): + if client is None: + return _get_global_client() + elif isinstance(client, Client): + return client + else: + # e.g., connection string + return Client(client) + + +def _in_worker(): + return bool(Worker._instances) + + +class DaskWorkerFileSystem(AbstractFileSystem): + """View files accessible to a worker as any other remote file-system + + When instances are run on the worker, uses the real filesystem. When + run on the client, they call the worker to provide information or data. + + **Warning** this implementation is experimental, and read-only for now. + """ + + def __init__( + self, target_protocol=None, target_options=None, fs=None, client=None, **kwargs + ): + super().__init__(**kwargs) + if not (fs is None) ^ (target_protocol is None): + raise ValueError( + "Please provide one of filesystem instance (fs) or" + " target_protocol, not both" + ) + self.target_protocol = target_protocol + self.target_options = target_options + self.worker = None + self.client = client + self.fs = fs + self._determine_worker() + + @staticmethod + def _get_kwargs_from_urls(path): + so = infer_storage_options(path) + if "host" in so and "port" in so: + return {"client": f"{so['host']}:{so['port']}"} + else: + return {} + + def _determine_worker(self): + if _in_worker(): + self.worker = True + if self.fs is None: + self.fs = filesystem( + self.target_protocol, **(self.target_options or {}) + ) + else: + self.worker = False + self.client = _get_client(self.client) + self.rfs = dask.delayed(self) + + def mkdir(self, *args, **kwargs): + if self.worker: + self.fs.mkdir(*args, **kwargs) + else: + self.rfs.mkdir(*args, **kwargs).compute() + + def rm(self, *args, **kwargs): + if self.worker: + self.fs.rm(*args, **kwargs) + else: + self.rfs.rm(*args, **kwargs).compute() + + def copy(self, *args, **kwargs): + if self.worker: + self.fs.copy(*args, **kwargs) + else: + self.rfs.copy(*args, **kwargs).compute() + + def mv(self, *args, **kwargs): + if self.worker: + self.fs.mv(*args, **kwargs) + else: + self.rfs.mv(*args, **kwargs).compute() + + def ls(self, *args, **kwargs): + if self.worker: + return self.fs.ls(*args, **kwargs) + else: + return self.rfs.ls(*args, **kwargs).compute() + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + if self.worker: + return self.fs._open( + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + **kwargs, + ) + else: + return DaskFile( + fs=self, + path=path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + **kwargs, + ) + + def fetch_range(self, path, mode, start, end): + if self.worker: + with self._open(path, mode) as f: + f.seek(start) + return f.read(end - start) + else: + return self.rfs.fetch_range(path, mode, start, end).compute() + + +class DaskFile(AbstractBufferedFile): + def __init__(self, mode="rb", **kwargs): + if mode != "rb": + raise ValueError('Remote dask files can only be opened in "rb" mode') + super().__init__(**kwargs) + + def _upload_chunk(self, final=False): + pass + + def _initiate_upload(self): + """Create remote file/upload""" + pass + + def _fetch_range(self, start, end): + """Get the specified set of bytes from remote""" + return self.fs.fetch_range(self.path, self.mode, start, end) diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/data.py b/venv/lib/python3.10/site-packages/fsspec/implementations/data.py new file mode 100644 index 0000000000000000000000000000000000000000..f11542b48c98fd53fc367ade7425a00b38487619 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/data.py @@ -0,0 +1,57 @@ +import base64 +import io +from urllib.parse import unquote + +from fsspec import AbstractFileSystem + + +class DataFileSystem(AbstractFileSystem): + """A handy decoder for data-URLs + + Example + ------- + >>> with fsspec.open("data:,Hello%2C%20World%21") as f: + ... print(f.read()) + b"Hello, World!" + + See https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs + """ + + protocol = "data" + + def __init__(self, **kwargs): + """No parameters for this filesystem""" + super().__init__(**kwargs) + + def cat_file(self, path, start=None, end=None, **kwargs): + pref, data = path.split(",", 1) + if pref.endswith("base64"): + return base64.b64decode(data)[start:end] + return unquote(data).encode()[start:end] + + def info(self, path, **kwargs): + pref, name = path.split(",", 1) + data = self.cat_file(path) + mime = pref.split(":", 1)[1].split(";", 1)[0] + return {"name": name, "size": len(data), "type": "file", "mimetype": mime} + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + if "r" not in mode: + raise ValueError("Read only filesystem") + return io.BytesIO(self.cat_file(path)) + + @staticmethod + def encode(data: bytes, mime: str | None = None): + """Format the given data into data-URL syntax + + This version always base64 encodes, even when the data is ascii/url-safe. + """ + return f"data:{mime or ''};base64,{base64.b64encode(data).decode()}" diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/dbfs.py b/venv/lib/python3.10/site-packages/fsspec/implementations/dbfs.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7fc93d7389c894ecb5fc6267ce20abe4087068 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/dbfs.py @@ -0,0 +1,496 @@ +from __future__ import annotations + +import base64 +import urllib + +import requests +from requests.adapters import HTTPAdapter, Retry +from typing_extensions import override + +from fsspec import AbstractFileSystem +from fsspec.spec import AbstractBufferedFile + + +class DatabricksException(Exception): + """ + Helper class for exceptions raised in this module. + """ + + def __init__(self, error_code, message, details=None): + """Create a new DatabricksException""" + super().__init__(message) + + self.error_code = error_code + self.message = message + self.details = details + + +class DatabricksFileSystem(AbstractFileSystem): + """ + Get access to the Databricks filesystem implementation over HTTP. + Can be used inside and outside of a databricks cluster. + """ + + def __init__(self, instance, token, **kwargs): + """ + Create a new DatabricksFileSystem. + + Parameters + ---------- + instance: str + The instance URL of the databricks cluster. + For example for an Azure databricks cluster, this + has the form adb-..azuredatabricks.net. + token: str + Your personal token. Find out more + here: https://docs.databricks.com/dev-tools/api/latest/authentication.html + """ + self.instance = instance + self.token = token + self.session = requests.Session() + self.retries = Retry( + total=10, + backoff_factor=0.05, + status_forcelist=[408, 429, 500, 502, 503, 504], + ) + + self.session.mount("https://", HTTPAdapter(max_retries=self.retries)) + self.session.headers.update({"Authorization": f"Bearer {self.token}"}) + + super().__init__(**kwargs) + + @override + def _ls_from_cache(self, path) -> list[dict[str, str | int]] | None: + """Check cache for listing + + Returns listing, if found (may be empty list for a directory that + exists but contains nothing), None if not in cache. + """ + self.dircache.pop(path.rstrip("/"), None) + + parent = self._parent(path) + if parent in self.dircache: + for entry in self.dircache[parent]: + if entry["name"] == path.rstrip("/"): + if entry["type"] != "directory": + return [entry] + return [] + raise FileNotFoundError(path) + + def ls(self, path, detail=True, **kwargs): + """ + List the contents of the given path. + + Parameters + ---------- + path: str + Absolute path + detail: bool + Return not only the list of filenames, + but also additional information on file sizes + and types. + """ + try: + out = self._ls_from_cache(path) + except FileNotFoundError: + # This happens if the `path`'s parent was cached, but `path` is not + # there. This suggests that `path` is new since the parent was + # cached. Attempt to invalidate parent's cache before continuing. + self.dircache.pop(self._parent(path), None) + out = None + + if not out: + try: + r = self._send_to_api( + method="get", endpoint="list", json={"path": path} + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + + raise + files = r.get("files", []) + out = [ + { + "name": o["path"], + "type": "directory" if o["is_dir"] else "file", + "size": o["file_size"], + } + for o in files + ] + self.dircache[path] = out + + if detail: + return out + return [o["name"] for o in out] + + def makedirs(self, path, exist_ok=True): + """ + Create a given absolute path and all of its parents. + + Parameters + ---------- + path: str + Absolute path to create + exist_ok: bool + If false, checks if the folder + exists before creating it (and raises an + Exception if this is the case) + """ + if not exist_ok: + try: + # If the following succeeds, the path is already present + self._send_to_api( + method="get", endpoint="get-status", json={"path": path} + ) + raise FileExistsError(f"Path {path} already exists") + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + pass + + try: + self._send_to_api(method="post", endpoint="mkdirs", json={"path": path}) + except DatabricksException as e: + if e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) from e + + raise + self.invalidate_cache(self._parent(path)) + + def mkdir(self, path, create_parents=True, **kwargs): + """ + Create a given absolute path and all of its parents. + + Parameters + ---------- + path: str + Absolute path to create + create_parents: bool + Whether to create all parents or not. + "False" is not implemented so far. + """ + if not create_parents: + raise NotImplementedError + + self.mkdirs(path, **kwargs) + + def rm(self, path, recursive=False, **kwargs): + """ + Remove the file or folder at the given absolute path. + + Parameters + ---------- + path: str + Absolute path what to remove + recursive: bool + Recursively delete all files in a folder. + """ + try: + self._send_to_api( + method="post", + endpoint="delete", + json={"path": path, "recursive": recursive}, + ) + except DatabricksException as e: + # This is not really an exception, it just means + # not everything was deleted so far + if e.error_code == "PARTIAL_DELETE": + self.rm(path=path, recursive=recursive) + elif e.error_code == "IO_ERROR": + # Using the same exception as the os module would use here + raise OSError(e.message) from e + + raise + self.invalidate_cache(self._parent(path)) + + def mv( + self, source_path, destination_path, recursive=False, maxdepth=None, **kwargs + ): + """ + Move a source to a destination path. + + A note from the original [databricks API manual] + (https://docs.databricks.com/dev-tools/api/latest/dbfs.html#move). + + When moving a large number of files the API call will time out after + approximately 60s, potentially resulting in partially moved data. + Therefore, for operations that move more than 10k files, we strongly + discourage using the DBFS REST API. + + Parameters + ---------- + source_path: str + From where to move (absolute path) + destination_path: str + To where to move (absolute path) + recursive: bool + Not implemented to far. + maxdepth: + Not implemented to far. + """ + if recursive: + raise NotImplementedError + if maxdepth: + raise NotImplementedError + + try: + self._send_to_api( + method="post", + endpoint="move", + json={"source_path": source_path, "destination_path": destination_path}, + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + elif e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) from e + + raise + self.invalidate_cache(self._parent(source_path)) + self.invalidate_cache(self._parent(destination_path)) + + def _open(self, path, mode="rb", block_size="default", **kwargs): + """ + Overwrite the base class method to make sure to create a DBFile. + All arguments are copied from the base method. + + Only the default blocksize is allowed. + """ + return DatabricksFile(self, path, mode=mode, block_size=block_size, **kwargs) + + def _send_to_api(self, method, endpoint, json): + """ + Send the given json to the DBFS API + using a get or post request (specified by the argument `method`). + + Parameters + ---------- + method: str + Which http method to use for communication; "get" or "post". + endpoint: str + Where to send the request to (last part of the API URL) + json: dict + Dictionary of information to send + """ + if method == "post": + session_call = self.session.post + elif method == "get": + session_call = self.session.get + else: + raise ValueError(f"Do not understand method {method}") + + url = urllib.parse.urljoin(f"https://{self.instance}/api/2.0/dbfs/", endpoint) + + r = session_call(url, json=json) + + # The DBFS API will return a json, also in case of an exception. + # We want to preserve this information as good as possible. + try: + r.raise_for_status() + except requests.HTTPError as e: + # try to extract json error message + # if that fails, fall back to the original exception + try: + exception_json = e.response.json() + except Exception: + raise e from None + + raise DatabricksException(**exception_json) from e + + return r.json() + + def _create_handle(self, path, overwrite=True): + """ + Internal function to create a handle, which can be used to + write blocks of a file to DBFS. + A handle has a unique identifier which needs to be passed + whenever written during this transaction. + The handle is active for 10 minutes - after that a new + write transaction needs to be created. + Make sure to close the handle after you are finished. + + Parameters + ---------- + path: str + Absolute path for this file. + overwrite: bool + If a file already exist at this location, either overwrite + it or raise an exception. + """ + try: + r = self._send_to_api( + method="post", + endpoint="create", + json={"path": path, "overwrite": overwrite}, + ) + return r["handle"] + except DatabricksException as e: + if e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) from e + + raise + + def _close_handle(self, handle): + """ + Close a handle, which was opened by :func:`_create_handle`. + + Parameters + ---------- + handle: str + Which handle to close. + """ + try: + self._send_to_api(method="post", endpoint="close", json={"handle": handle}) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + + raise + + def _add_data(self, handle, data): + """ + Upload data to an already opened file handle + (opened by :func:`_create_handle`). + The maximal allowed data size is 1MB after + conversion to base64. + Remember to close the handle when you are finished. + + Parameters + ---------- + handle: str + Which handle to upload data to. + data: bytes + Block of data to add to the handle. + """ + data = base64.b64encode(data).decode() + try: + self._send_to_api( + method="post", + endpoint="add-block", + json={"handle": handle, "data": data}, + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + elif e.error_code == "MAX_BLOCK_SIZE_EXCEEDED": + raise ValueError(e.message) from e + + raise + + def _get_data(self, path, start, end): + """ + Download data in bytes from a given absolute path in a block + from [start, start+length]. + The maximum number of allowed bytes to read is 1MB. + + Parameters + ---------- + path: str + Absolute path to download data from + start: int + Start position of the block + end: int + End position of the block + """ + try: + r = self._send_to_api( + method="get", + endpoint="read", + json={"path": path, "offset": start, "length": end - start}, + ) + return base64.b64decode(r["data"]) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + elif e.error_code in ["INVALID_PARAMETER_VALUE", "MAX_READ_SIZE_EXCEEDED"]: + raise ValueError(e.message) from e + + raise + + def invalidate_cache(self, path=None): + if path is None: + self.dircache.clear() + else: + self.dircache.pop(path, None) + super().invalidate_cache(path) + + +class DatabricksFile(AbstractBufferedFile): + """ + Helper class for files referenced in the DatabricksFileSystem. + """ + + DEFAULT_BLOCK_SIZE = 1 * 2**20 # only allowed block size + + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + **kwargs, + ): + """ + Create a new instance of the DatabricksFile. + + The blocksize needs to be the default one. + """ + if block_size is None or block_size == "default": + block_size = self.DEFAULT_BLOCK_SIZE + + assert block_size == self.DEFAULT_BLOCK_SIZE, ( + f"Only the default block size is allowed, not {block_size}" + ) + + super().__init__( + fs, + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_type=cache_type, + cache_options=cache_options or {}, + **kwargs, + ) + + def _initiate_upload(self): + """Internal function to start a file upload""" + self.handle = self.fs._create_handle(self.path) + + def _upload_chunk(self, final=False): + """Internal function to add a chunk of data to a started upload""" + self.buffer.seek(0) + data = self.buffer.getvalue() + + data_chunks = [ + data[start:end] for start, end in self._to_sized_blocks(len(data)) + ] + + for data_chunk in data_chunks: + self.fs._add_data(handle=self.handle, data=data_chunk) + + if final: + self.fs._close_handle(handle=self.handle) + return True + + def _fetch_range(self, start, end): + """Internal function to download a block of data""" + return_buffer = b"" + length = end - start + for chunk_start, chunk_end in self._to_sized_blocks(length, start): + return_buffer += self.fs._get_data( + path=self.path, start=chunk_start, end=chunk_end + ) + + return return_buffer + + def _to_sized_blocks(self, length, start=0): + """Helper function to split a range from 0 to total_length into blocksizes""" + end = start + length + for data_chunk in range(start, end, self.blocksize): + data_start = data_chunk + data_end = min(end, data_chunk + self.blocksize) + yield data_start, data_end diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/dirfs.py b/venv/lib/python3.10/site-packages/fsspec/implementations/dirfs.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3dd3cf4c2f421292ba5d9fab8b733a60550496 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/dirfs.py @@ -0,0 +1,389 @@ +from .. import filesystem +from ..asyn import AsyncFileSystem +from .chained import ChainedFileSystem + + +class DirFileSystem(AsyncFileSystem, ChainedFileSystem): + """Directory prefix filesystem + + The DirFileSystem is a filesystem-wrapper. It assumes every path it is dealing with + is relative to the `path`. After performing the necessary paths operation it + delegates everything to the wrapped filesystem. + """ + + protocol = "dir" + + def __init__( + self, + path=None, + fs=None, + fo=None, + target_protocol=None, + target_options=None, + **storage_options, + ): + """ + Parameters + ---------- + path: str + Path to the directory. + fs: AbstractFileSystem + An instantiated filesystem to wrap. + target_protocol, target_options: + if fs is none, construct it from these + fo: str + Alternate for path; do not provide both + """ + super().__init__(**storage_options) + if fs is None: + fs = filesystem(protocol=target_protocol, **(target_options or {})) + path = path or fo + + if self.asynchronous and not fs.async_impl: + raise ValueError("can't use asynchronous with non-async fs") + + if fs.async_impl and self.asynchronous != fs.asynchronous: + raise ValueError("both dirfs and fs should be in the same sync/async mode") + + self.path = fs._strip_protocol(path) + self.fs = fs + + def _join(self, path): + if isinstance(path, str): + if not self.path: + return path + if not path: + return self.path + return self.fs.sep.join((self.path, self._strip_protocol(path))) + if isinstance(path, dict): + return {self._join(_path): value for _path, value in path.items()} + return [self._join(_path) for _path in path] + + def _relpath(self, path): + if isinstance(path, str): + if not self.path: + return path + # We need to account for S3FileSystem returning paths that do not + # start with a '/' + if path == self.path or ( + self.path.startswith(self.fs.sep) and path == self.path[1:] + ): + return "" + prefix = self.path + self.fs.sep + if self.path.startswith(self.fs.sep) and not path.startswith(self.fs.sep): + prefix = prefix[1:] + assert path.startswith(prefix) + return path[len(prefix) :] + return [self._relpath(_path) for _path in path] + + # Wrappers below + + @property + def sep(self): + return self.fs.sep + + async def set_session(self, *args, **kwargs): + return await self.fs.set_session(*args, **kwargs) + + async def _rm_file(self, path, **kwargs): + return await self.fs._rm_file(self._join(path), **kwargs) + + def rm_file(self, path, **kwargs): + return self.fs.rm_file(self._join(path), **kwargs) + + async def _rm(self, path, *args, **kwargs): + return await self.fs._rm(self._join(path), *args, **kwargs) + + def rm(self, path, *args, **kwargs): + return self.fs.rm(self._join(path), *args, **kwargs) + + async def _cp_file(self, path1, path2, **kwargs): + return await self.fs._cp_file(self._join(path1), self._join(path2), **kwargs) + + def cp_file(self, path1, path2, **kwargs): + return self.fs.cp_file(self._join(path1), self._join(path2), **kwargs) + + async def _copy( + self, + path1, + path2, + *args, + **kwargs, + ): + return await self.fs._copy( + self._join(path1), + self._join(path2), + *args, + **kwargs, + ) + + def copy(self, path1, path2, *args, **kwargs): + return self.fs.copy( + self._join(path1), + self._join(path2), + *args, + **kwargs, + ) + + async def _pipe(self, path, *args, **kwargs): + return await self.fs._pipe(self._join(path), *args, **kwargs) + + def pipe(self, path, *args, **kwargs): + return self.fs.pipe(self._join(path), *args, **kwargs) + + async def _pipe_file(self, path, *args, **kwargs): + return await self.fs._pipe_file(self._join(path), *args, **kwargs) + + def pipe_file(self, path, *args, **kwargs): + return self.fs.pipe_file(self._join(path), *args, **kwargs) + + async def _cat_file(self, path, *args, **kwargs): + return await self.fs._cat_file(self._join(path), *args, **kwargs) + + def cat_file(self, path, *args, **kwargs): + return self.fs.cat_file(self._join(path), *args, **kwargs) + + async def _cat(self, path, *args, **kwargs): + ret = await self.fs._cat( + self._join(path), + *args, + **kwargs, + ) + + if isinstance(ret, dict): + return {self._relpath(key): value for key, value in ret.items()} + + return ret + + def cat(self, path, *args, **kwargs): + ret = self.fs.cat( + self._join(path), + *args, + **kwargs, + ) + + if isinstance(ret, dict): + return {self._relpath(key): value for key, value in ret.items()} + + return ret + + async def _put_file(self, lpath, rpath, **kwargs): + return await self.fs._put_file(lpath, self._join(rpath), **kwargs) + + def put_file(self, lpath, rpath, **kwargs): + return self.fs.put_file(lpath, self._join(rpath), **kwargs) + + async def _put( + self, + lpath, + rpath, + *args, + **kwargs, + ): + return await self.fs._put( + lpath, + self._join(rpath), + *args, + **kwargs, + ) + + def put(self, lpath, rpath, *args, **kwargs): + return self.fs.put( + lpath, + self._join(rpath), + *args, + **kwargs, + ) + + async def _get_file(self, rpath, lpath, **kwargs): + return await self.fs._get_file(self._join(rpath), lpath, **kwargs) + + def get_file(self, rpath, lpath, **kwargs): + return self.fs.get_file(self._join(rpath), lpath, **kwargs) + + async def _get(self, rpath, *args, **kwargs): + return await self.fs._get(self._join(rpath), *args, **kwargs) + + def get(self, rpath, *args, **kwargs): + return self.fs.get(self._join(rpath), *args, **kwargs) + + async def _isfile(self, path): + return await self.fs._isfile(self._join(path)) + + def isfile(self, path): + return self.fs.isfile(self._join(path)) + + async def _isdir(self, path): + return await self.fs._isdir(self._join(path)) + + def isdir(self, path): + return self.fs.isdir(self._join(path)) + + async def _size(self, path): + return await self.fs._size(self._join(path)) + + def size(self, path): + return self.fs.size(self._join(path)) + + async def _exists(self, path): + return await self.fs._exists(self._join(path)) + + def exists(self, path): + return self.fs.exists(self._join(path)) + + async def _info(self, path, **kwargs): + info = await self.fs._info(self._join(path), **kwargs) + info = info.copy() + info["name"] = self._relpath(info["name"]) + return info + + def info(self, path, **kwargs): + info = self.fs.info(self._join(path), **kwargs) + info = info.copy() + info["name"] = self._relpath(info["name"]) + return info + + async def _ls(self, path, detail=True, **kwargs): + ret = (await self.fs._ls(self._join(path), detail=detail, **kwargs)).copy() + if detail: + out = [] + for entry in ret: + entry = entry.copy() + entry["name"] = self._relpath(entry["name"]) + out.append(entry) + return out + + return self._relpath(ret) + + def ls(self, path, detail=True, **kwargs): + ret = self.fs.ls(self._join(path), detail=detail, **kwargs).copy() + if detail: + out = [] + for entry in ret: + entry = entry.copy() + entry["name"] = self._relpath(entry["name"]) + out.append(entry) + return out + + return self._relpath(ret) + + async def _walk(self, path, *args, **kwargs): + async for root, dirs, files in self.fs._walk(self._join(path), *args, **kwargs): + yield self._relpath(root), dirs, files + + def walk(self, path, *args, **kwargs): + for root, dirs, files in self.fs.walk(self._join(path), *args, **kwargs): + yield self._relpath(root), dirs, files + + async def _glob(self, path, **kwargs): + detail = kwargs.get("detail", False) + ret = await self.fs._glob(self._join(path), **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + def glob(self, path, **kwargs): + detail = kwargs.get("detail", False) + ret = self.fs.glob(self._join(path), **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + async def _du(self, path, *args, **kwargs): + total = kwargs.get("total", True) + ret = await self.fs._du(self._join(path), *args, **kwargs) + if total: + return ret + + return {self._relpath(path): size for path, size in ret.items()} + + def du(self, path, *args, **kwargs): + total = kwargs.get("total", True) + ret = self.fs.du(self._join(path), *args, **kwargs) + if total: + return ret + + return {self._relpath(path): size for path, size in ret.items()} + + async def _find(self, path, *args, **kwargs): + detail = kwargs.get("detail", False) + ret = await self.fs._find(self._join(path), *args, **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + def find(self, path, *args, **kwargs): + detail = kwargs.get("detail", False) + ret = self.fs.find(self._join(path), *args, **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + async def _expand_path(self, path, *args, **kwargs): + return self._relpath( + await self.fs._expand_path(self._join(path), *args, **kwargs) + ) + + def expand_path(self, path, *args, **kwargs): + return self._relpath(self.fs.expand_path(self._join(path), *args, **kwargs)) + + async def _mkdir(self, path, *args, **kwargs): + return await self.fs._mkdir(self._join(path), *args, **kwargs) + + def mkdir(self, path, *args, **kwargs): + return self.fs.mkdir(self._join(path), *args, **kwargs) + + async def _makedirs(self, path, *args, **kwargs): + return await self.fs._makedirs(self._join(path), *args, **kwargs) + + def makedirs(self, path, *args, **kwargs): + return self.fs.makedirs(self._join(path), *args, **kwargs) + + def rmdir(self, path): + return self.fs.rmdir(self._join(path)) + + def mv(self, path1, path2, **kwargs): + return self.fs.mv( + self._join(path1), + self._join(path2), + **kwargs, + ) + + def touch(self, path, **kwargs): + return self.fs.touch(self._join(path), **kwargs) + + def created(self, path): + return self.fs.created(self._join(path)) + + def modified(self, path): + return self.fs.modified(self._join(path)) + + def sign(self, path, *args, **kwargs): + return self.fs.sign(self._join(path), *args, **kwargs) + + def __repr__(self): + return f"{self.__class__.__qualname__}(path='{self.path}', fs={self.fs})" + + def open( + self, + path, + *args, + **kwargs, + ): + return self.fs.open( + self._join(path), + *args, + **kwargs, + ) + + async def open_async( + self, + path, + *args, + **kwargs, + ): + return await self.fs.open_async( + self._join(path), + *args, + **kwargs, + ) diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/ftp.py b/venv/lib/python3.10/site-packages/fsspec/implementations/ftp.py new file mode 100644 index 0000000000000000000000000000000000000000..ca151c6ea3cc2bb514701cd95fd5258bed3c9899 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/ftp.py @@ -0,0 +1,437 @@ +import os +import ssl +import uuid +from ftplib import FTP, FTP_TLS, Error, error_perm +from typing import Any + +from ..spec import AbstractBufferedFile, AbstractFileSystem +from ..utils import infer_storage_options, isfilelike + +SECURITY_PROTOCOL_MAP = { + "tls": ssl.PROTOCOL_TLS, + "tlsv1": ssl.PROTOCOL_TLSv1, + "tlsv1_1": ssl.PROTOCOL_TLSv1_1, + "tlsv1_2": ssl.PROTOCOL_TLSv1_2, + "sslv23": ssl.PROTOCOL_SSLv23, +} + + +class ImplicitFTPTLS(FTP_TLS): + """ + FTP_TLS subclass that automatically wraps sockets in SSL + to support implicit FTPS. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._sock = None + + @property + def sock(self): + """Return the socket.""" + return self._sock + + @sock.setter + def sock(self, value): + """When modifying the socket, ensure that it is ssl wrapped.""" + if value is not None and not isinstance(value, ssl.SSLSocket): + value = self.context.wrap_socket(value) + self._sock = value + + +class FTPFileSystem(AbstractFileSystem): + """A filesystem over classic FTP""" + + root_marker = "/" + cachable = False + protocol = "ftp" + + def __init__( + self, + host, + port=21, + username=None, + password=None, + acct=None, + block_size=None, + tempdir=None, + timeout=30, + encoding="utf-8", + tls=False, + **kwargs, + ): + """ + You can use _get_kwargs_from_urls to get some kwargs from + a reasonable FTP url. + + Authentication will be anonymous if username/password are not + given. + + Parameters + ---------- + host: str + The remote server name/ip to connect to + port: int + Port to connect with + username: str or None + If authenticating, the user's identifier + password: str of None + User's password on the server, if using + acct: str or None + Some servers also need an "account" string for auth + block_size: int or None + If given, the read-ahead or write buffer size. + tempdir: str + Directory on remote to put temporary files when in a transaction + timeout: int + Timeout of the ftp connection in seconds + encoding: str + Encoding to use for directories and filenames in FTP connection + tls: bool or str + Enable FTP-TLS for secure connections: + - False: Plain FTP (default) + - True: Explicit TLS (FTPS with AUTH TLS command) + - "tls": Auto-negotiate highest protocol + - "tlsv1": TLS v1.0 + - "tlsv1_1": TLS v1.1 + - "tlsv1_2": TLS v1.2 + """ + super().__init__(**kwargs) + self.host = host + self.port = port + self.tempdir = tempdir or "/tmp" + self.cred = username or "", password or "", acct or "" + self.timeout = timeout + self.encoding = encoding + if block_size is not None: + self.blocksize = block_size + else: + self.blocksize = 2**16 + self.tls = tls + self._connect() + if isinstance(self.tls, bool) and self.tls: + self.ftp.prot_p() + + def _connect(self): + security = None + if self.tls: + if isinstance(self.tls, str): + ftp_cls = ImplicitFTPTLS + security = SECURITY_PROTOCOL_MAP.get( + self.tls, + f"Not supported {self.tls} protocol", + ) + if isinstance(security, str): + raise ValueError(security) + else: + ftp_cls = FTP_TLS + else: + ftp_cls = FTP + self.ftp = ftp_cls(timeout=self.timeout, encoding=self.encoding) + if security: + self.ftp.ssl_version = security + self.ftp.connect(self.host, self.port) + self.ftp.login(*self.cred) + + @classmethod + def _strip_protocol(cls, path): + return "/" + infer_storage_options(path)["path"].lstrip("/").rstrip("/") + + @staticmethod + def _get_kwargs_from_urls(urlpath): + out = infer_storage_options(urlpath) + out.pop("path", None) + out.pop("protocol", None) + return out + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + out = [] + if path not in self.dircache: + try: + try: + out = [ + (fn, details) + for (fn, details) in self.ftp.mlsd(path) + if fn not in [".", ".."] + and details["type"] not in ["pdir", "cdir"] + ] + except error_perm: + out = _mlsd2(self.ftp, path) # Not platform independent + for fn, details in out: + details["name"] = "/".join( + ["" if path == "/" else path, fn.lstrip("/")] + ) + if details["type"] == "file": + details["size"] = int(details["size"]) + else: + details["size"] = 0 + if details["type"] == "dir": + details["type"] = "directory" + self.dircache[path] = out + except Error: + try: + info = self.info(path) + if info["type"] == "file": + out = [(path, info)] + except (Error, IndexError) as exc: + raise FileNotFoundError(path) from exc + files = self.dircache.get(path, out) + if not detail: + return sorted([fn for fn, details in files]) + return [details for fn, details in files] + + def info(self, path, **kwargs): + # implement with direct method + path = self._strip_protocol(path) + if path == "/": + # special case, since this dir has no real entry + return {"name": "/", "size": 0, "type": "directory"} + files = self.ls(self._parent(path).lstrip("/"), True) + try: + out = next(f for f in files if f["name"] == path) + except StopIteration as exc: + raise FileNotFoundError(path) from exc + return out + + def get_file(self, rpath, lpath, **kwargs): + if self.isdir(rpath): + if not os.path.exists(lpath): + os.mkdir(lpath) + return + if isfilelike(lpath): + outfile = lpath + else: + outfile = open(lpath, "wb") + + def cb(x): + outfile.write(x) + + self.ftp.retrbinary( + f"RETR {rpath}", + blocksize=self.blocksize, + callback=cb, + ) + if not isfilelike(lpath): + outfile.close() + + def cat_file(self, path, start=None, end=None, **kwargs): + if end is not None: + return super().cat_file(path, start, end, **kwargs) + out = [] + + def cb(x): + out.append(x) + + try: + self.ftp.retrbinary( + f"RETR {path}", + blocksize=self.blocksize, + rest=start, + callback=cb, + ) + except (Error, error_perm) as orig_exc: + raise FileNotFoundError(path) from orig_exc + return b"".join(out) + + def _open( + self, + path, + mode="rb", + block_size=None, + cache_options=None, + autocommit=True, + **kwargs, + ): + path = self._strip_protocol(path) + block_size = block_size or self.blocksize + return FTPFile( + self, + path, + mode=mode, + block_size=block_size, + tempdir=self.tempdir, + autocommit=autocommit, + cache_options=cache_options, + ) + + def _rm(self, path): + path = self._strip_protocol(path) + self.ftp.delete(path) + self.invalidate_cache(self._parent(path)) + + def rm(self, path, recursive=False, maxdepth=None): + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(paths): + if self.isfile(p): + self.rm_file(p) + else: + self.rmdir(p) + + def mkdir(self, path: str, create_parents: bool = True, **kwargs: Any) -> None: + path = self._strip_protocol(path) + parent = self._parent(path) + if parent != self.root_marker and not self.exists(parent) and create_parents: + self.mkdir(parent, create_parents=create_parents) + + self.ftp.mkd(path) + self.invalidate_cache(self._parent(path)) + + def makedirs(self, path: str, exist_ok: bool = False) -> None: + path = self._strip_protocol(path) + if self.exists(path): + # NB: "/" does not "exist" as it has no directory entry + if not exist_ok: + raise FileExistsError(f"{path} exists without `exist_ok`") + # exists_ok=True -> no-op + else: + self.mkdir(path, create_parents=True) + + def rmdir(self, path): + path = self._strip_protocol(path) + self.ftp.rmd(path) + self.invalidate_cache(self._parent(path)) + + def mv(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1) + path2 = self._strip_protocol(path2) + self.ftp.rename(path1, path2) + self.invalidate_cache(self._parent(path1)) + self.invalidate_cache(self._parent(path2)) + + def __del__(self): + self.ftp.close() + + def invalidate_cache(self, path=None): + if path is None: + self.dircache.clear() + else: + self.dircache.pop(path, None) + super().invalidate_cache(path) + + +class TransferDone(Exception): + """Internal exception to break out of transfer""" + + pass + + +class FTPFile(AbstractBufferedFile): + """Interact with a remote FTP file with read/write buffering""" + + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + **kwargs, + ): + super().__init__( + fs, + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + if not autocommit: + self.target = self.path + self.path = "/".join([kwargs["tempdir"], str(uuid.uuid4())]) + + def commit(self): + self.fs.mv(self.path, self.target) + + def discard(self): + self.fs.rm(self.path) + + def _fetch_range(self, start, end): + """Get bytes between given byte limits + + Implemented by raising an exception in the fetch callback when the + number of bytes received reaches the requested amount. + + Will fail if the server does not respect the REST command on + retrieve requests. + """ + out = [] + total = [0] + + def callback(x): + total[0] += len(x) + if total[0] > end - start: + out.append(x[: (end - start) - total[0]]) + if end < self.size: + raise TransferDone + else: + out.append(x) + + if total[0] == end - start and end < self.size: + raise TransferDone + + try: + self.fs.ftp.retrbinary( + f"RETR {self.path}", + blocksize=self.blocksize, + rest=start, + callback=callback, + ) + except TransferDone: + try: + # stop transfer, we got enough bytes for this block + self.fs.ftp.abort() + self.fs.ftp.getmultiline() + except Error: + self.fs._connect() + + return b"".join(out) + + def _upload_chunk(self, final=False): + self.buffer.seek(0) + self.fs.ftp.storbinary( + f"STOR {self.path}", self.buffer, blocksize=self.blocksize, rest=self.offset + ) + return True + + +def _mlsd2(ftp, path="."): + """ + Fall back to using `dir` instead of `mlsd` if not supported. + + This parses a Linux style `ls -l` response to `dir`, but the response may + be platform dependent. + + Parameters + ---------- + ftp: ftplib.FTP + path: str + Expects to be given path, but defaults to ".". + """ + lines = [] + minfo = [] + ftp.dir(path, lines.append) + for line in lines: + split_line = line.split() + if len(split_line) < 9: + continue + this = ( + split_line[-1], + { + "modify": " ".join(split_line[5:8]), + "unix.owner": split_line[2], + "unix.group": split_line[3], + "unix.mode": split_line[0], + "size": split_line[4], + }, + ) + if this[1]["unix.mode"][0] == "d": + this[1]["type"] = "dir" + else: + this[1]["type"] = "file" + minfo.append(this) + return minfo diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/gist.py b/venv/lib/python3.10/site-packages/fsspec/implementations/gist.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9ac0b6a1cdbcfba6188e2cdeab2350bb9aad0a --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/gist.py @@ -0,0 +1,241 @@ +import requests + +from ..spec import AbstractFileSystem +from ..utils import infer_storage_options +from .memory import MemoryFile + + +class GistFileSystem(AbstractFileSystem): + """ + Interface to files in a single GitHub Gist. + + Provides read-only access to a gist's files. Gists do not contain + subdirectories, so file listing is straightforward. + + Parameters + ---------- + gist_id: str + The ID of the gist you want to access (the long hex value from the URL). + filenames: list[str] (optional) + If provided, only make a file system representing these files, and do not fetch + the list of all files for this gist. + sha: str (optional) + If provided, fetch a particular revision of the gist. If omitted, + the latest revision is used. + username: str (optional) + GitHub username for authentication. + token: str (optional) + GitHub personal access token (required if username is given), or. + timeout: (float, float) or float, optional + Connect and read timeouts for requests (default 60s each). + kwargs: dict + Stored on `self.request_kw` and passed to `requests.get` when fetching Gist + metadata or reading ("opening") a file. + """ + + protocol = "gist" + gist_url = "https://api.github.com/gists/{gist_id}" + gist_rev_url = "https://api.github.com/gists/{gist_id}/{sha}" + + def __init__( + self, + gist_id, + filenames=None, + sha=None, + username=None, + token=None, + timeout=None, + **kwargs, + ): + super().__init__() + self.gist_id = gist_id + self.filenames = filenames + self.sha = sha # revision of the gist (optional) + if username is not None and token is None: + raise ValueError("User auth requires a token") + self.username = username + self.token = token + self.request_kw = kwargs + # Default timeouts to 60s connect/read if none provided + self.timeout = timeout if timeout is not None else (60, 60) + + # We use a single-level "directory" cache, because a gist is essentially flat + self.dircache[""] = self._fetch_file_list() + + @property + def kw(self): + """Auth parameters passed to 'requests' if we have username/token.""" + kw = { + "headers": { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + } + kw.update(self.request_kw) + if self.username and self.token: + kw["auth"] = (self.username, self.token) + elif self.token: + kw["headers"]["Authorization"] = f"Bearer {self.token}" + return kw + + def _fetch_gist_metadata(self): + """ + Fetch the JSON metadata for this gist (possibly for a specific revision). + """ + if self.sha: + url = self.gist_rev_url.format(gist_id=self.gist_id, sha=self.sha) + else: + url = self.gist_url.format(gist_id=self.gist_id) + + r = requests.get(url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError( + f"Gist not found: {self.gist_id}@{self.sha or 'latest'}" + ) + r.raise_for_status() + return r.json() + + def _fetch_file_list(self): + """ + Returns a list of dicts describing each file in the gist. These get stored + in self.dircache[""]. + """ + meta = self._fetch_gist_metadata() + if self.filenames: + available_files = meta.get("files", {}) + files = {} + for fn in self.filenames: + if fn not in available_files: + raise FileNotFoundError(fn) + files[fn] = available_files[fn] + else: + files = meta.get("files", {}) + + out = [] + for fname, finfo in files.items(): + if finfo is None: + # Occasionally GitHub returns a file entry with null if it was deleted + continue + # Build a directory entry + out.append( + { + "name": fname, # file's name + "type": "file", # gists have no subdirectories + "size": finfo.get("size", 0), # file size in bytes + "raw_url": finfo.get("raw_url"), + } + ) + return out + + @classmethod + def _strip_protocol(cls, path): + """ + Remove 'gist://' from the path, if present. + """ + # The default infer_storage_options can handle gist://username:token@id/file + # or gist://id/file, but let's ensure we handle a normal usage too. + # We'll just strip the protocol prefix if it exists. + path = infer_storage_options(path).get("path", path) + return path.lstrip("/") + + @staticmethod + def _get_kwargs_from_urls(path): + """ + Parse 'gist://' style URLs into GistFileSystem constructor kwargs. + For example: + gist://:TOKEN@/file.txt + gist://username:TOKEN@/file.txt + """ + so = infer_storage_options(path) + out = {} + if "username" in so and so["username"]: + out["username"] = so["username"] + if "password" in so and so["password"]: + out["token"] = so["password"] + if "host" in so and so["host"]: + # We interpret 'host' as the gist ID + out["gist_id"] = so["host"] + + # Extract SHA and filename from path + if "path" in so and so["path"]: + path_parts = so["path"].rsplit("/", 2)[-2:] + if len(path_parts) == 2: + if path_parts[0]: # SHA present + out["sha"] = path_parts[0] + if path_parts[1]: # filename also present + out["filenames"] = [path_parts[1]] + + return out + + def ls(self, path="", detail=False, **kwargs): + """ + List files in the gist. Gists are single-level, so any 'path' is basically + the filename, or empty for all files. + + Parameters + ---------- + path : str, optional + The filename to list. If empty, returns all files in the gist. + detail : bool, default False + If True, return a list of dicts; if False, return a list of filenames. + """ + path = self._strip_protocol(path or "") + # If path is empty, return all + if path == "": + results = self.dircache[""] + else: + # We want just the single file with this name + all_files = self.dircache[""] + results = [f for f in all_files if f["name"] == path] + if not results: + raise FileNotFoundError(path) + if detail: + return results + else: + return sorted(f["name"] for f in results) + + def _open(self, path, mode="rb", block_size=None, **kwargs): + """ + Read a single file from the gist. + """ + if mode != "rb": + raise NotImplementedError("GitHub Gist FS is read-only (no write).") + + path = self._strip_protocol(path) + # Find the file entry in our dircache + matches = [f for f in self.dircache[""] if f["name"] == path] + if not matches: + raise FileNotFoundError(path) + finfo = matches[0] + + raw_url = finfo.get("raw_url") + if not raw_url: + raise FileNotFoundError(f"No raw_url for file: {path}") + + r = requests.get(raw_url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + return MemoryFile(path, None, r.content) + + def cat(self, path, recursive=False, on_error="raise", **kwargs): + """ + Return {path: contents} for the given file or files. If 'recursive' is True, + and path is empty, returns all files in the gist. + """ + paths = self.expand_path(path, recursive=recursive) + out = {} + for p in paths: + try: + with self.open(p, "rb") as f: + out[p] = f.read() + except FileNotFoundError as e: + if on_error == "raise": + raise e + elif on_error == "omit": + pass # skip + else: + out[p] = e + if len(paths) == 1 and paths[0] == path: + return out[path] + return out diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/git.py b/venv/lib/python3.10/site-packages/fsspec/implementations/git.py new file mode 100644 index 0000000000000000000000000000000000000000..808d293a1c991ea87d19a2129f3e56d9b813daaa --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/git.py @@ -0,0 +1,114 @@ +import os + +import pygit2 + +from fsspec.spec import AbstractFileSystem + +from .memory import MemoryFile + + +class GitFileSystem(AbstractFileSystem): + """Browse the files of a local git repo at any hash/tag/branch + + (experimental backend) + """ + + root_marker = "" + cachable = True + + def __init__(self, path=None, fo=None, ref=None, **kwargs): + """ + + Parameters + ---------- + path: str (optional) + Local location of the repo (uses current directory if not given). + May be deprecated in favour of ``fo``. When used with a higher + level function such as fsspec.open(), may be of the form + "git://[path-to-repo[:]][ref@]path/to/file" (but the actual + file path should not contain "@" or ":"). + fo: str (optional) + Same as ``path``, but passed as part of a chained URL. This one + takes precedence if both are given. + ref: str (optional) + Reference to work with, could be a hash, tag or branch name. Defaults + to current working tree. Note that ``ls`` and ``open`` also take hash, + so this becomes the default for those operations + kwargs + """ + super().__init__(**kwargs) + self.repo = pygit2.Repository(fo or path or os.getcwd()) + self.ref = ref or "master" + + @classmethod + def _strip_protocol(cls, path): + path = super()._strip_protocol(path).lstrip("/") + if ":" in path: + path = path.split(":", 1)[1] + if "@" in path: + path = path.split("@", 1)[1] + return path.lstrip("/") + + def _path_to_object(self, path, ref): + comm, ref = self.repo.resolve_refish(ref or self.ref) + parts = path.split("/") + tree = comm.tree + for part in parts: + if part and isinstance(tree, pygit2.Tree): + if part not in tree: + raise FileNotFoundError(path) + tree = tree[part] + return tree + + @staticmethod + def _get_kwargs_from_urls(path): + path = path.removeprefix("git://") + out = {} + if ":" in path: + out["path"], path = path.split(":", 1) + if "@" in path: + out["ref"], path = path.split("@", 1) + return out + + @staticmethod + def _object_to_info(obj, path=None): + # obj.name and obj.filemode are None for the root tree! + is_dir = isinstance(obj, pygit2.Tree) + return { + "type": "directory" if is_dir else "file", + "name": ( + "/".join([path, obj.name or ""]).lstrip("/") if path else obj.name + ), + "hex": str(obj.id), + "mode": "100644" if obj.filemode is None else f"{obj.filemode:o}", + "size": 0 if is_dir else obj.size, + } + + def ls(self, path, detail=True, ref=None, **kwargs): + tree = self._path_to_object(self._strip_protocol(path), ref) + return [ + GitFileSystem._object_to_info(obj, path) + if detail + else GitFileSystem._object_to_info(obj, path)["name"] + for obj in (tree if isinstance(tree, pygit2.Tree) else [tree]) + ] + + def info(self, path, ref=None, **kwargs): + tree = self._path_to_object(self._strip_protocol(path), ref) + return GitFileSystem._object_to_info(tree, path) + + def ukey(self, path, ref=None): + return self.info(path, ref=ref)["hex"] + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + ref=None, + **kwargs, + ): + obj = self._path_to_object(path, ref or self.ref) + return MemoryFile(data=obj.data) diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/github.py b/venv/lib/python3.10/site-packages/fsspec/implementations/github.py new file mode 100644 index 0000000000000000000000000000000000000000..3630f6db54413e2c396f6cc1b6b10cd379200043 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/github.py @@ -0,0 +1,333 @@ +import base64 +import re + +import requests + +from ..spec import AbstractFileSystem +from ..utils import infer_storage_options +from .memory import MemoryFile + + +class GithubFileSystem(AbstractFileSystem): + """Interface to files in github + + An instance of this class provides the files residing within a remote github + repository. You may specify a point in the repos history, by SHA, branch + or tag (default is current master). + + For files less than 1 MB in size, file content is returned directly in a + MemoryFile. For larger files, or for files tracked by git-lfs, file content + is returned as an HTTPFile wrapping the ``download_url`` provided by the + GitHub API. + + When using fsspec.open, allows URIs of the form: + + - "github://path/file", in which case you must specify org, repo and + may specify sha in the extra args + - 'github://org:repo@/precip/catalog.yml', where the org and repo are + part of the URI + - 'github://org:repo@sha/precip/catalog.yml', where the sha is also included + + ``sha`` can be the full or abbreviated hex of the commit you want to fetch + from, or a branch or tag name (so long as it doesn't contain special characters + like "/", "?", which would have to be HTTP-encoded). + + For authorised access, you must provide username and token, which can be made + at https://github.com/settings/tokens + """ + + url = "https://api.github.com/repos/{org}/{repo}/git/trees/{sha}" + content_url = "https://api.github.com/repos/{org}/{repo}/contents/{path}?ref={sha}" + protocol = "github" + timeout = (60, 60) # connect, read timeouts + + def __init__( + self, org, repo, sha=None, username=None, token=None, timeout=None, **kwargs + ): + super().__init__(**kwargs) + self.org = org + self.repo = repo + if (username is None) ^ (token is None): + raise ValueError("Auth required both username and token") + self.username = username + self.token = token + if timeout is not None: + self.timeout = timeout + if sha is None: + # look up default branch (not necessarily "master") + u = "https://api.github.com/repos/{org}/{repo}" + r = requests.get( + u.format(org=org, repo=repo), timeout=self.timeout, **self.kw + ) + r.raise_for_status() + sha = r.json()["default_branch"] + + self.root = sha + self.ls("") + try: + from .http import HTTPFileSystem + + self.http_fs = HTTPFileSystem(**kwargs) + except ImportError: + self.http_fs = None + + @property + def kw(self): + if self.username: + return {"auth": (self.username, self.token)} + return {} + + @classmethod + def repos(cls, org_or_user, is_org=True): + """List repo names for given org or user + + This may become the top level of the FS + + Parameters + ---------- + org_or_user: str + Name of the github org or user to query + is_org: bool (default True) + Whether the name is an organisation (True) or user (False) + + Returns + ------- + List of string + """ + r = requests.get( + f"https://api.github.com/{['users', 'orgs'][is_org]}/{org_or_user}/repos", + timeout=cls.timeout, + ) + r.raise_for_status() + return [repo["name"] for repo in r.json()] + + @property + def tags(self): + """Names of tags in the repo""" + r = requests.get( + f"https://api.github.com/repos/{self.org}/{self.repo}/tags", + timeout=self.timeout, + **self.kw, + ) + r.raise_for_status() + return [t["name"] for t in r.json()] + + @property + def branches(self): + """Names of branches in the repo""" + r = requests.get( + f"https://api.github.com/repos/{self.org}/{self.repo}/branches", + timeout=self.timeout, + **self.kw, + ) + r.raise_for_status() + return [t["name"] for t in r.json()] + + @property + def refs(self): + """Named references, tags and branches""" + return {"tags": self.tags, "branches": self.branches} + + def ls(self, path, detail=False, sha=None, _sha=None, **kwargs): + """List files at given path + + Parameters + ---------- + path: str + Location to list, relative to repo root + detail: bool + If True, returns list of dicts, one per file; if False, returns + list of full filenames only + sha: str (optional) + List at the given point in the repo history, branch or tag name or commit + SHA + _sha: str (optional) + List this specific tree object (used internally to descend into trees) + """ + path = self._strip_protocol(path) + if path == "": + _sha = sha or self.root + if _sha is None: + parts = path.rstrip("/").split("/") + so_far = "" + _sha = sha or self.root + for part in parts: + out = self.ls(so_far, True, sha=sha, _sha=_sha) + so_far += "/" + part if so_far else part + out = [o for o in out if o["name"] == so_far] + if not out: + raise FileNotFoundError(path) + out = out[0] + if out["type"] == "file": + if detail: + return [out] + else: + return path + _sha = out["sha"] + if path not in self.dircache or sha not in [self.root, None]: + r = requests.get( + self.url.format(org=self.org, repo=self.repo, sha=_sha), + timeout=self.timeout, + **self.kw, + ) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + types = {"blob": "file", "tree": "directory"} + out = [ + { + "name": path + "/" + f["path"] if path else f["path"], + "mode": f["mode"], + "type": types[f["type"]], + "size": f.get("size", 0), + "sha": f["sha"], + } + for f in r.json()["tree"] + if f["type"] in types + ] + if sha in [self.root, None]: + self.dircache[path] = out + else: + out = self.dircache[path] + if detail: + return out + else: + return sorted([f["name"] for f in out]) + + def invalidate_cache(self, path=None): + self.dircache.clear() + + @classmethod + def _strip_protocol(cls, path): + opts = infer_storage_options(path) + if "username" not in opts: + return super()._strip_protocol(path) + return opts["path"].lstrip("/") + + @staticmethod + def _get_kwargs_from_urls(path): + opts = infer_storage_options(path) + if "username" not in opts: + return {} + out = {"org": opts["username"], "repo": opts["password"]} + if opts["host"]: + out["sha"] = opts["host"] + return out + + def _open( + self, + path, + mode="rb", + block_size=None, + cache_options=None, + sha=None, + **kwargs, + ): + if mode != "rb": + raise NotImplementedError + + # construct a url to hit the GitHub API's repo contents API + url = self.content_url.format( + org=self.org, repo=self.repo, path=path, sha=sha or self.root + ) + + # make a request to this API, and parse the response as JSON + r = requests.get(url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + content_json = r.json() + + # if the response's content key is not empty, try to parse it as base64 + if content_json["content"]: + content = base64.b64decode(content_json["content"]) + + # as long as the content does not start with the string + # "version https://git-lfs.github.com/" + # then it is probably not a git-lfs pointer and we can just return + # the content directly + if not content.startswith(b"version https://git-lfs.github.com/"): + return MemoryFile(None, None, content) + + # we land here if the content was not present in the first response + # (regular file over 1MB or git-lfs tracked file) + # in this case, we get let the HTTPFileSystem handle the download + if self.http_fs is None: + raise ImportError( + "Please install fsspec[http] to access github files >1 MB " + "or git-lfs tracked files." + ) + return self.http_fs.open( + content_json["download_url"], + mode=mode, + block_size=block_size, + cache_options=cache_options, + **kwargs, + ) + + def rm(self, path, recursive=False, maxdepth=None, message=None): + path = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(path): + self.rm_file(p, message=message) + + def rm_file(self, path, message=None, **kwargs): + """ + Remove a file from a specified branch using a given commit message. + + Since Github DELETE operation requires a branch name, and we can't reliably + determine whether the provided SHA refers to a branch, tag, or commit, we + assume it's a branch. If it's not, the user will encounter an error when + attempting to retrieve the file SHA or delete the file. + + Parameters + ---------- + path: str + The file's location relative to the repository root. + message: str, optional + The commit message for the deletion. + """ + + if not self.username: + raise ValueError("Authentication required") + + path = self._strip_protocol(path) + + # Attempt to get SHA from cache or Github API + sha = self._get_sha_from_cache(path) + if not sha: + url = self.content_url.format( + org=self.org, repo=self.repo, path=path.lstrip("/"), sha=self.root + ) + r = requests.get(url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + sha = r.json()["sha"] + + # Delete the file + delete_url = self.content_url.format( + org=self.org, repo=self.repo, path=path, sha=self.root + ) + branch = self.root + data = { + "message": message or f"Delete {path}", + "sha": sha, + **({"branch": branch} if branch else {}), + } + + r = requests.delete(delete_url, json=data, timeout=self.timeout, **self.kw) + error_message = r.json().get("message", "") + if re.search(r"Branch .+ not found", error_message): + error = "Remove only works when the filesystem is initialised from a branch or default (None)" + raise ValueError(error) + r.raise_for_status() + + self.invalidate_cache(path) + + def _get_sha_from_cache(self, path): + for entries in self.dircache.values(): + for entry in entries: + entry_path = entry.get("name") + if entry_path and entry_path == path and "sha" in entry: + return entry["sha"] + return None diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/http.py b/venv/lib/python3.10/site-packages/fsspec/implementations/http.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb1bc36074ff4c85463133387601ae16ae1280e --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/http.py @@ -0,0 +1,897 @@ +import asyncio +import io +import logging +import re +import weakref +from copy import copy +from urllib.parse import urlparse + +import aiohttp +import yarl + +from fsspec.asyn import AbstractAsyncStreamedFile, AsyncFileSystem, sync, sync_wrapper +from fsspec.callbacks import DEFAULT_CALLBACK +from fsspec.exceptions import FSTimeoutError +from fsspec.spec import AbstractBufferedFile +from fsspec.utils import ( + DEFAULT_BLOCK_SIZE, + glob_translate, + isfilelike, + nullcontext, + tokenize, +) + +from ..caching import AllBytes + +# https://stackoverflow.com/a/15926317/3821154 +ex = re.compile(r"""<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P[^"']+)""") +ex2 = re.compile(r"""(?Phttp[s]?://[-a-zA-Z0-9@:%_+.~#?&/=]+)""") +logger = logging.getLogger("fsspec.http") + + +async def get_client(**kwargs): + return aiohttp.ClientSession(**kwargs) + + +class HTTPFileSystem(AsyncFileSystem): + """ + Simple File-System for fetching data via HTTP(S) + + ``ls()`` is implemented by loading the parent page and doing a regex + match on the result. If simple_link=True, anything of the form + "http(s)://server.com/stuff?thing=other"; otherwise only links within + HTML href tags will be used. + """ + + protocol = ("http", "https") + sep = "/" + + def __init__( + self, + simple_links=True, + block_size=None, + same_scheme=True, + size_policy=None, + cache_type="bytes", + cache_options=None, + asynchronous=False, + loop=None, + client_kwargs=None, + get_client=get_client, + encoded=False, + **storage_options, + ): + """ + NB: if this is called async, you must await set_client + + Parameters + ---------- + block_size: int + Blocks to read bytes; if 0, will default to raw requests file-like + objects instead of HTTPFile instances + simple_links: bool + If True, will consider both HTML tags and anything that looks + like a URL; if False, will consider only the former. + same_scheme: True + When doing ls/glob, if this is True, only consider paths that have + http/https matching the input URLs. + size_policy: this argument is deprecated + client_kwargs: dict + Passed to aiohttp.ClientSession, see + https://docs.aiohttp.org/en/stable/client_reference.html + For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}`` + get_client: Callable[..., aiohttp.ClientSession] + A callable, which takes keyword arguments and constructs + an aiohttp.ClientSession. Its state will be managed by + the HTTPFileSystem class. + storage_options: key-value + Any other parameters passed on to requests + cache_type, cache_options: defaults used in open() + """ + super().__init__(self, asynchronous=asynchronous, loop=loop, **storage_options) + self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE + self.simple_links = simple_links + self.same_schema = same_scheme + self.cache_type = cache_type + self.cache_options = cache_options + self.client_kwargs = client_kwargs or {} + self.get_client = get_client + self.encoded = encoded + self.kwargs = storage_options + self._session = None + + # Clean caching-related parameters from `storage_options` + # before propagating them as `request_options` through `self.kwargs`. + # TODO: Maybe rename `self.kwargs` to `self.request_options` to make + # it clearer. + request_options = copy(storage_options) + self.use_listings_cache = request_options.pop("use_listings_cache", False) + request_options.pop("listings_expiry_time", None) + request_options.pop("max_paths", None) + request_options.pop("skip_instance_cache", None) + self.kwargs = request_options + + @property + def fsid(self): + return "http" + + def encode_url(self, url): + return yarl.URL(url, encoded=self.encoded) + + @staticmethod + def close_session(loop, session): + if loop is not None and loop.is_running(): + try: + sync(loop, session.close, timeout=0.1) + return + except (TimeoutError, FSTimeoutError, NotImplementedError): + pass + connector = getattr(session, "_connector", None) + if connector is not None: + # close after loop is dead + connector._close() + + async def set_session(self): + if self._session is None: + self._session = await self.get_client(loop=self.loop, **self.client_kwargs) + if not self.asynchronous: + weakref.finalize(self, self.close_session, self.loop, self._session) + return self._session + + @classmethod + def _strip_protocol(cls, path): + """For HTTP, we always want to keep the full URL""" + return path + + @classmethod + def _parent(cls, path): + # override, since _strip_protocol is different for URLs + par = super()._parent(path) + if len(par) > 7: # "http://..." + return par + return "" + + async def _ls_real(self, url, detail=True, **kwargs): + # ignoring URL-encoded arguments + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + session = await self.set_session() + async with session.get(self.encode_url(url), **self.kwargs) as r: + self._raise_not_found_for_status(r, url) + + if "Content-Type" in r.headers: + mimetype = r.headers["Content-Type"].partition(";")[0] + else: + mimetype = None + + if mimetype in ("text/html", None): + try: + text = await r.text(errors="ignore") + if self.simple_links: + links = ex2.findall(text) + [u[2] for u in ex.findall(text)] + else: + links = [u[2] for u in ex.findall(text)] + except UnicodeDecodeError: + links = [] # binary, not HTML + else: + links = [] + + out = set() + parts = urlparse(url) + for l in links: + if isinstance(l, tuple): + l = l[1] + if l.startswith("/") and len(l) > 1: + # absolute URL on this server + l = f"{parts.scheme}://{parts.netloc}{l}" + if l.startswith("http"): + if self.same_schema and l.startswith(url.rstrip("/") + "/"): + out.add(l) + elif l.replace("https", "http").startswith( + url.replace("https", "http").rstrip("/") + "/" + ): + # allowed to cross http <-> https + out.add(l) + else: + if l not in ["..", "../"]: + # Ignore FTP-like "parent" + out.add("/".join([url.rstrip("/"), l.lstrip("/")])) + if not out and url.endswith("/"): + out = await self._ls_real(url.rstrip("/"), detail=False) + if detail: + return [ + { + "name": u, + "size": None, + "type": "directory" if u.endswith("/") else "file", + } + for u in out + ] + else: + return sorted(out) + + async def _ls(self, url, detail=True, **kwargs): + if self.use_listings_cache and url in self.dircache: + out = self.dircache[url] + else: + out = await self._ls_real(url, detail=detail, **kwargs) + self.dircache[url] = out + return out + + ls = sync_wrapper(_ls) + + def _raise_not_found_for_status(self, response, url): + """ + Raises FileNotFoundError for 404s, otherwise uses raise_for_status. + """ + if response.status == 404: + raise FileNotFoundError(url) + response.raise_for_status() + + async def _cat_file(self, url, start=None, end=None, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + + if start is not None or end is not None: + if start == end: + return b"" + headers = kw.pop("headers", {}).copy() + + headers["Range"] = await self._process_limits(url, start, end) + kw["headers"] = headers + session = await self.set_session() + async with session.get(self.encode_url(url), **kw) as r: + out = await r.read() + self._raise_not_found_for_status(r, url) + return out + + async def _get_file( + self, rpath, lpath, chunk_size=5 * 2**20, callback=DEFAULT_CALLBACK, **kwargs + ): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(rpath) + session = await self.set_session() + async with session.get(self.encode_url(rpath), **kw) as r: + try: + size = int(r.headers["content-length"]) + except (ValueError, KeyError): + size = None + + callback.set_size(size) + self._raise_not_found_for_status(r, rpath) + if isfilelike(lpath): + outfile = lpath + else: + outfile = open(lpath, "wb") # noqa: ASYNC230 + + try: + chunk = True + while chunk: + chunk = await r.content.read(chunk_size) + outfile.write(chunk) + callback.relative_update(len(chunk)) + finally: + if not isfilelike(lpath): + outfile.close() + + async def _put_file( + self, + lpath, + rpath, + chunk_size=5 * 2**20, + callback=DEFAULT_CALLBACK, + method="post", + mode="overwrite", + **kwargs, + ): + if mode != "overwrite": + raise NotImplementedError("Exclusive write") + + async def gen_chunks(): + # Support passing arbitrary file-like objects + # and use them instead of streams. + if isinstance(lpath, io.IOBase): + context = nullcontext(lpath) + use_seek = False # might not support seeking + else: + context = open(lpath, "rb") # noqa: ASYNC230 + use_seek = True + + with context as f: + if use_seek: + callback.set_size(f.seek(0, 2)) + f.seek(0) + else: + callback.set_size(getattr(f, "size", None)) + + chunk = f.read(chunk_size) + while chunk: + yield chunk + callback.relative_update(len(chunk)) + chunk = f.read(chunk_size) + + kw = self.kwargs.copy() + kw.update(kwargs) + session = await self.set_session() + + method = method.lower() + if method not in ("post", "put"): + raise ValueError( + f"method has to be either 'post' or 'put', not: {method!r}" + ) + + meth = getattr(session, method) + async with meth(self.encode_url(rpath), data=gen_chunks(), **kw) as resp: + self._raise_not_found_for_status(resp, rpath) + + async def _exists(self, path, strict=False, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + try: + logger.debug(path) + session = await self.set_session() + r = await session.get(self.encode_url(path), **kw) + async with r: + if strict: + self._raise_not_found_for_status(r, path) + return r.status < 400 + except FileNotFoundError: + return False + except aiohttp.ClientError: + if strict: + raise + return False + + async def _isfile(self, path, **kwargs): + return await self._exists(path, **kwargs) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=None, # XXX: This differs from the base class. + cache_type=None, + cache_options=None, + size=None, + **kwargs, + ): + """Make a file-like object + + Parameters + ---------- + path: str + Full URL with protocol + mode: string + must be "rb" + block_size: int or None + Bytes to download in one request; use instance value if None. If + zero, will return a streaming Requests file-like instance. + kwargs: key-value + Any other parameters, passed to requests calls + """ + if mode != "rb": + raise NotImplementedError + block_size = block_size if block_size is not None else self.block_size + kw = self.kwargs.copy() + kw["asynchronous"] = self.asynchronous + kw.update(kwargs) + info = {} + size = size or info.update(self.info(path, **kwargs)) or info["size"] + session = sync(self.loop, self.set_session) + if block_size and size and info.get("partial", True): + return HTTPFile( + self, + path, + session=session, + block_size=block_size, + mode=mode, + size=size, + cache_type=cache_type or self.cache_type, + cache_options=cache_options or self.cache_options, + loop=self.loop, + **kw, + ) + else: + return HTTPStreamFile( + self, + path, + mode=mode, + loop=self.loop, + session=session, + **kw, + ) + + async def open_async(self, path, mode="rb", size=None, **kwargs): + session = await self.set_session() + if size is None: + try: + size = (await self._info(path, **kwargs))["size"] + except FileNotFoundError: + pass + return AsyncStreamFile( + self, + path, + loop=self.loop, + session=session, + size=size, + **kwargs, + ) + + def ukey(self, url): + """Unique identifier; assume HTTP files are static, unchanging""" + return tokenize(url, self.kwargs, self.protocol) + + async def _info(self, url, **kwargs): + """Get info of URL + + Tries to access location via HEAD, and then GET methods, but does + not fetch the data. + + It is possible that the server does not supply any size information, in + which case size will be given as None (and certain operations on the + corresponding file will not work). + """ + info = {} + session = await self.set_session() + + for policy in ["head", "get"]: + try: + info.update( + await _file_info( + self.encode_url(url), + size_policy=policy, + session=session, + **self.kwargs, + **kwargs, + ) + ) + if info.get("size") is not None: + break + except Exception as exc: + if policy == "get": + # If get failed, then raise a FileNotFoundError + raise FileNotFoundError(url) from exc + logger.debug("", exc_info=exc) + + return {"name": url, "size": None, **info, "type": "file"} + + async def _glob(self, path, maxdepth=None, **kwargs): + """ + Find files by glob-matching. + + This implementation is idntical to the one in AbstractFileSystem, + but "?" is not considered as a character for globbing, because it is + so common in URLs, often identifying the "query" part. + """ + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + import re + + ends_with_slash = path.endswith("/") # _strip_protocol strips trailing slash + path = self._strip_protocol(path) + append_slash_to_dirname = ends_with_slash or path.endswith(("/**", "/*")) + idx_star = path.find("*") if path.find("*") >= 0 else len(path) + idx_brace = path.find("[") if path.find("[") >= 0 else len(path) + + min_idx = min(idx_star, idx_brace) + + detail = kwargs.pop("detail", False) + + if not has_magic(path): + if await self._exists(path, **kwargs): + if not detail: + return [path] + else: + return {path: await self._info(path, **kwargs)} + else: + if not detail: + return [] # glob of non-existent returns empty + else: + return {} + elif "/" in path[:min_idx]: + min_idx = path[:min_idx].rindex("/") + root = path[: min_idx + 1] + depth = path[min_idx + 1 :].count("/") + 1 + else: + root = "" + depth = path[min_idx + 1 :].count("/") + 1 + + if "**" in path: + if maxdepth is not None: + idx_double_stars = path.find("**") + depth_double_stars = path[idx_double_stars:].count("/") + 1 + depth = depth - depth_double_stars + maxdepth + else: + depth = None + + allpaths = await self._find( + root, maxdepth=depth, withdirs=True, detail=True, **kwargs + ) + + pattern = glob_translate(path + ("/" if ends_with_slash else "")) + pattern = re.compile(pattern) + + out = { + ( + p.rstrip("/") + if not append_slash_to_dirname + and info["type"] == "directory" + and p.endswith("/") + else p + ): info + for p, info in sorted(allpaths.items()) + if pattern.match(p.rstrip("/")) + } + + if detail: + return out + else: + return list(out) + + async def _isdir(self, path): + # override, since all URLs are (also) files + try: + return bool(await self._ls(path)) + except (FileNotFoundError, ValueError): + return False + + async def _pipe_file(self, path, value, mode="overwrite", **kwargs): + """ + Write bytes to a remote file over HTTP. + + Parameters + ---------- + path : str + Target URL where the data should be written + value : bytes + Data to be written + mode : str + How to write to the file - 'overwrite' or 'append' + **kwargs : dict + Additional parameters to pass to the HTTP request + """ + url = self._strip_protocol(path) + headers = kwargs.pop("headers", {}) + headers["Content-Length"] = str(len(value)) + + session = await self.set_session() + + async with session.put(url, data=value, headers=headers, **kwargs) as r: + r.raise_for_status() + + +class HTTPFile(AbstractBufferedFile): + """ + A file-like object pointing to a remote HTTP(S) resource + + Supports only reading, with read-ahead of a predetermined block-size. + + In the case that the server does not supply the filesize, only reading of + the complete file in one go is supported. + + Parameters + ---------- + url: str + Full URL of the remote resource, including the protocol + session: aiohttp.ClientSession or None + All calls will be made within this session, to avoid restarting + connections where the server allows this + block_size: int or None + The amount of read-ahead to do, in bytes. Default is 5MB, or the value + configured for the FileSystem creating this file + size: None or int + If given, this is the size of the file in bytes, and we don't attempt + to call the server to find the value. + kwargs: all other key-values are passed to requests calls. + """ + + def __init__( + self, + fs, + url, + session=None, + block_size=None, + mode="rb", + cache_type="bytes", + cache_options=None, + size=None, + loop=None, + asynchronous=False, + **kwargs, + ): + if mode != "rb": + raise NotImplementedError("File mode not supported") + self.asynchronous = asynchronous + self.loop = loop + self.url = url + self.session = session + self.details = {"name": url, "size": size, "type": "file"} + super().__init__( + fs=fs, + path=url, + mode=mode, + block_size=block_size, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + + def read(self, length=-1): + """Read bytes from file + + Parameters + ---------- + length: int + Read up to this many bytes. If negative, read all content to end of + file. If the server has not supplied the filesize, attempting to + read only part of the data will raise a ValueError. + """ + if ( + (length < 0 and self.loc == 0) # explicit read all + # but not when the size is known and fits into a block anyways + and not (self.size is not None and self.size <= self.blocksize) + ): + self._fetch_all() + if self.size is None: + if length < 0: + self._fetch_all() + else: + length = min(self.size - self.loc, length) + return super().read(length) + + async def async_fetch_all(self): + """Read whole file in one shot, without caching + + This is only called when position is still at zero, + and read() is called without a byte-count. + """ + logger.debug(f"Fetch all for {self}") + if not isinstance(self.cache, AllBytes): + r = await self.session.get(self.fs.encode_url(self.url), **self.kwargs) + async with r: + r.raise_for_status() + out = await r.read() + self.cache = AllBytes( + size=len(out), fetcher=None, blocksize=None, data=out + ) + self.size = len(out) + + _fetch_all = sync_wrapper(async_fetch_all) + + def _parse_content_range(self, headers): + """Parse the Content-Range header""" + s = headers.get("Content-Range", "") + m = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", s) + if not m: + return None, None, None + + if m[1] == "*": + start = end = None + else: + start, end = [int(x) for x in m[1].split("-")] + total = None if m[2] == "*" else int(m[2]) + return start, end, total + + async def async_fetch_range(self, start, end): + """Download a block of data + + The expectation is that the server returns only the requested bytes, + with HTTP code 206. If this is not the case, we first check the headers, + and then stream the output - if the data size is bigger than we + requested, an exception is raised. + """ + logger.debug(f"Fetch range for {self}: {start}-{end}") + kwargs = self.kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + logger.debug(f"{self.url} : {headers['Range']}") + r = await self.session.get( + self.fs.encode_url(self.url), headers=headers, **kwargs + ) + async with r: + if r.status == 416: + # range request outside file + return b"" + r.raise_for_status() + + # If the server has handled the range request, it should reply + # with status 206 (partial content). But we'll guess that a suitable + # Content-Range header or a Content-Length no more than the + # requested range also mean we have got the desired range. + response_is_range = ( + r.status == 206 + or self._parse_content_range(r.headers)[0] == start + or int(r.headers.get("Content-Length", end + 1)) <= end - start + ) + + if response_is_range: + # partial content, as expected + out = await r.read() + elif start > 0: + raise ValueError( + "The HTTP server doesn't appear to support range requests. " + "Only reading this file from the beginning is supported. " + "Open with block_size=0 for a streaming file interface." + ) + else: + # Response is not a range, but we want the start of the file, + # so we can read the required amount anyway. + cl = 0 + out = [] + while True: + chunk = await r.content.read(2**20) + # data size unknown, let's read until we have enough + if chunk: + out.append(chunk) + cl += len(chunk) + if cl > end - start: + break + else: + break + out = b"".join(out)[: end - start] + return out + + _fetch_range = sync_wrapper(async_fetch_range) + + +magic_check = re.compile("([*[])") + + +def has_magic(s): + match = magic_check.search(s) + return match is not None + + +class HTTPStreamFile(AbstractBufferedFile): + def __init__(self, fs, url, mode="rb", loop=None, session=None, **kwargs): + self.asynchronous = kwargs.pop("asynchronous", False) + self.url = url + self.loop = loop + self.session = session + if mode != "rb": + raise ValueError + self.details = {"name": url, "size": None} + super().__init__(fs=fs, path=url, mode=mode, cache_type="none", **kwargs) + + async def cor(): + r = await self.session.get(self.fs.encode_url(url), **kwargs).__aenter__() + self.fs._raise_not_found_for_status(r, url) + return r + + self.r = sync(self.loop, cor) + self.loop = fs.loop + + def seek(self, loc, whence=0): + if loc == 0 and whence == 1: + return + if loc == self.loc and whence == 0: + return + raise ValueError("Cannot seek streaming HTTP file") + + async def _read(self, num=-1): + out = await self.r.content.read(num) + self.loc += len(out) + return out + + read = sync_wrapper(_read) + + async def _close(self): + self.r.close() + + def close(self): + asyncio.run_coroutine_threadsafe(self._close(), self.loop) + super().close() + + +class AsyncStreamFile(AbstractAsyncStreamedFile): + def __init__( + self, fs, url, mode="rb", loop=None, session=None, size=None, **kwargs + ): + self.url = url + self.session = session + self.r = None + if mode != "rb": + raise ValueError + self.details = {"name": url, "size": None} + self.kwargs = kwargs + super().__init__(fs=fs, path=url, mode=mode, cache_type="none") + self.size = size + + async def read(self, num=-1): + if self.r is None: + r = await self.session.get( + self.fs.encode_url(self.url), **self.kwargs + ).__aenter__() + self.fs._raise_not_found_for_status(r, self.url) + self.r = r + out = await self.r.content.read(num) + self.loc += len(out) + return out + + async def close(self): + if self.r is not None: + self.r.close() + self.r = None + await super().close() + + +async def get_range(session, url, start, end, file=None, **kwargs): + # explicit get a range when we know it must be safe + kwargs = kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + r = await session.get(url, headers=headers, **kwargs) + r.raise_for_status() + async with r: + out = await r.read() + if file: + with open(file, "r+b") as f: # noqa: ASYNC230 + f.seek(start) + f.write(out) + else: + return out + + +async def _file_info(url, session, size_policy="head", **kwargs): + """Call HEAD on the server to get details about the file (size/checksum etc.) + + Default operation is to explicitly allow redirects and use encoding + 'identity' (no compression) to get the true size of the target. + """ + logger.debug("Retrieve file size for %s", url) + kwargs = kwargs.copy() + ar = kwargs.pop("allow_redirects", True) + head = kwargs.get("headers", {}).copy() + head["Accept-Encoding"] = "identity" + kwargs["headers"] = head + + info = {} + if size_policy == "head": + r = await session.head(url, allow_redirects=ar, **kwargs) + elif size_policy == "get": + r = await session.get(url, allow_redirects=ar, **kwargs) + else: + raise TypeError(f'size_policy must be "head" or "get", got {size_policy}') + async with r: + r.raise_for_status() + + if "Content-Length" in r.headers: + # Some servers may choose to ignore Accept-Encoding and return + # compressed content, in which case the returned size is unreliable. + if "Content-Encoding" not in r.headers or r.headers["Content-Encoding"] in [ + "identity", + "", + ]: + info["size"] = int(r.headers["Content-Length"]) + elif "Content-Range" in r.headers: + info["size"] = int(r.headers["Content-Range"].split("/")[1]) + + if "Content-Type" in r.headers: + info["mimetype"] = r.headers["Content-Type"].partition(";")[0] + + if r.headers.get("Accept-Ranges") == "none": + # Some servers may explicitly discourage partial content requests, but + # the lack of "Accept-Ranges" does not always indicate they would fail + info["partial"] = False + + info["url"] = str(r.url) + + for checksum_field in ["ETag", "Content-MD5", "Digest", "Last-Modified"]: + if r.headers.get(checksum_field): + info[checksum_field] = r.headers[checksum_field] + + return info + + +async def _file_size(url, session=None, *args, **kwargs): + if session is None: + session = await get_client() + info = await _file_info(url, session=session, *args, **kwargs) + return info.get("size") + + +file_size = sync_wrapper(_file_size) diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/http_sync.py b/venv/lib/python3.10/site-packages/fsspec/implementations/http_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..a67ea3ea5fee9e6b51f7f3f66773e8cf65735e52 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/http_sync.py @@ -0,0 +1,937 @@ +"""This file is largely copied from http.py""" + +import io +import logging +import re +import urllib.error +import urllib.parse +from copy import copy +from json import dumps, loads +from urllib.parse import urlparse + +try: + import yarl +except (ImportError, ModuleNotFoundError, OSError): + yarl = False + +from fsspec.callbacks import _DEFAULT_CALLBACK +from fsspec.registry import register_implementation +from fsspec.spec import AbstractBufferedFile, AbstractFileSystem +from fsspec.utils import DEFAULT_BLOCK_SIZE, isfilelike, nullcontext, tokenize + +from ..caching import AllBytes + +# https://stackoverflow.com/a/15926317/3821154 +ex = re.compile(r"""<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P[^"']+)""") +ex2 = re.compile(r"""(?Phttp[s]?://[-a-zA-Z0-9@:%_+.~#?&/=]+)""") +logger = logging.getLogger("fsspec.http") + + +class JsHttpException(urllib.error.HTTPError): ... + + +class StreamIO(io.BytesIO): + # fake class, so you can set attributes on it + # will eventually actually stream + ... + + +class ResponseProxy: + """Looks like a requests response""" + + def __init__(self, req, stream=False): + self.request = req + self.stream = stream + self._data = None + self._headers = None + + @property + def raw(self): + if self._data is None: + b = self.request.response.to_bytes() + if self.stream: + self._data = StreamIO(b) + else: + self._data = b + return self._data + + def close(self): + if hasattr(self, "_data"): + del self._data + + @property + def headers(self): + if self._headers is None: + self._headers = dict( + [ + _.split(": ") + for _ in self.request.getAllResponseHeaders().strip().split("\r\n") + ] + ) + return self._headers + + @property + def status_code(self): + return int(self.request.status) + + def raise_for_status(self): + if not self.ok: + raise JsHttpException( + self.url, self.status_code, self.reason, self.headers, None + ) + + def iter_content(self, chunksize, *_, **__): + while True: + out = self.raw.read(chunksize) + if out: + yield out + else: + break + + @property + def reason(self): + return self.request.statusText + + @property + def ok(self): + return self.status_code < 400 + + @property + def url(self): + return self.request.response.responseURL + + @property + def text(self): + # TODO: encoding from headers + return self.content.decode() + + @property + def content(self): + self.stream = False + return self.raw + + def json(self): + return loads(self.text) + + +class RequestsSessionShim: + def __init__(self): + self.headers = {} + + def request( + self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=None, + allow_redirects=None, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None, + ): + from js import Blob, XMLHttpRequest + + logger.debug("JS request: %s %s", method, url) + + if cert or verify or proxies or files or cookies or hooks: + raise NotImplementedError + if data and json: + raise ValueError("Use json= or data=, not both") + req = XMLHttpRequest.new() + extra = auth if auth else () + if params: + url = f"{url}?{urllib.parse.urlencode(params)}" + req.open(method, url, False, *extra) + if timeout: + req.timeout = timeout + if headers: + for k, v in headers.items(): + req.setRequestHeader(k, v) + + req.setRequestHeader("Accept", "application/octet-stream") + req.responseType = "arraybuffer" + if json: + blob = Blob.new([dumps(data)], {type: "application/json"}) + req.send(blob) + elif data: + if isinstance(data, io.IOBase): + data = data.read() + blob = Blob.new([data], {type: "application/octet-stream"}) + req.send(blob) + else: + req.send(None) + return ResponseProxy(req, stream=stream) + + def get(self, url, **kwargs): + return self.request("GET", url, **kwargs) + + def head(self, url, **kwargs): + return self.request("HEAD", url, **kwargs) + + def post(self, url, **kwargs): + return self.request("POST}", url, **kwargs) + + def put(self, url, **kwargs): + return self.request("PUT", url, **kwargs) + + def patch(self, url, **kwargs): + return self.request("PATCH", url, **kwargs) + + def delete(self, url, **kwargs): + return self.request("DELETE", url, **kwargs) + + +class HTTPFileSystem(AbstractFileSystem): + """ + Simple File-System for fetching data via HTTP(S) + + This is the BLOCKING version of the normal HTTPFileSystem. It uses + requests in normal python and the JS runtime in pyodide. + + ***This implementation is extremely experimental, do not use unless + you are testing pyodide/pyscript integration*** + """ + + protocol = ("http", "https", "sync-http", "sync-https") + sep = "/" + + def __init__( + self, + simple_links=True, + block_size=None, + same_scheme=True, + cache_type="readahead", + cache_options=None, + client_kwargs=None, + encoded=False, + **storage_options, + ): + """ + + Parameters + ---------- + block_size: int + Blocks to read bytes; if 0, will default to raw requests file-like + objects instead of HTTPFile instances + simple_links: bool + If True, will consider both HTML tags and anything that looks + like a URL; if False, will consider only the former. + same_scheme: True + When doing ls/glob, if this is True, only consider paths that have + http/https matching the input URLs. + size_policy: this argument is deprecated + client_kwargs: dict + Passed to aiohttp.ClientSession, see + https://docs.aiohttp.org/en/stable/client_reference.html + For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}`` + storage_options: key-value + Any other parameters passed on to requests + cache_type, cache_options: defaults used in open + """ + super().__init__(self, **storage_options) + self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE + self.simple_links = simple_links + self.same_schema = same_scheme + self.cache_type = cache_type + self.cache_options = cache_options + self.client_kwargs = client_kwargs or {} + self.encoded = encoded + self.kwargs = storage_options + + try: + import js # noqa: F401 + + logger.debug("Starting JS session") + self.session = RequestsSessionShim() + self.js = True + except Exception as e: + import requests + + logger.debug("Starting cpython session because of: %s", e) + self.session = requests.Session(**(client_kwargs or {})) + self.js = False + + request_options = copy(storage_options) + self.use_listings_cache = request_options.pop("use_listings_cache", False) + request_options.pop("listings_expiry_time", None) + request_options.pop("max_paths", None) + request_options.pop("skip_instance_cache", None) + self.kwargs = request_options + + @property + def fsid(self): + return "sync-http" + + def encode_url(self, url): + if yarl: + return yarl.URL(url, encoded=self.encoded) + return url + + @classmethod + def _strip_protocol(cls, path: str) -> str: + """For HTTP, we always want to keep the full URL""" + path = path.replace("sync-http://", "http://").replace( + "sync-https://", "https://" + ) + return path + + @classmethod + def _parent(cls, path): + # override, since _strip_protocol is different for URLs + par = super()._parent(path) + if len(par) > 7: # "http://..." + return par + return "" + + def _ls_real(self, url, detail=True, **kwargs): + # ignoring URL-encoded arguments + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + r = self.session.get(self.encode_url(url), **self.kwargs) + self._raise_not_found_for_status(r, url) + text = r.text + if self.simple_links: + links = ex2.findall(text) + [u[2] for u in ex.findall(text)] + else: + links = [u[2] for u in ex.findall(text)] + out = set() + parts = urlparse(url) + for l in links: + if isinstance(l, tuple): + l = l[1] + if l.startswith("/") and len(l) > 1: + # absolute URL on this server + l = parts.scheme + "://" + parts.netloc + l + if l.startswith("http"): + if self.same_schema and l.startswith(url.rstrip("/") + "/"): + out.add(l) + elif l.replace("https", "http").startswith( + url.replace("https", "http").rstrip("/") + "/" + ): + # allowed to cross http <-> https + out.add(l) + else: + if l not in ["..", "../"]: + # Ignore FTP-like "parent" + out.add("/".join([url.rstrip("/"), l.lstrip("/")])) + if not out and url.endswith("/"): + out = self._ls_real(url.rstrip("/"), detail=False) + if detail: + return [ + { + "name": u, + "size": None, + "type": "directory" if u.endswith("/") else "file", + } + for u in out + ] + else: + return sorted(out) + + def ls(self, url, detail=True, **kwargs): + if self.use_listings_cache and url in self.dircache: + out = self.dircache[url] + else: + out = self._ls_real(url, detail=detail, **kwargs) + self.dircache[url] = out + return out + + def _raise_not_found_for_status(self, response, url): + """ + Raises FileNotFoundError for 404s, otherwise uses raise_for_status. + """ + if response.status_code == 404: + raise FileNotFoundError(url) + response.raise_for_status() + + def cat_file(self, url, start=None, end=None, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + + if start is not None or end is not None: + if start == end: + return b"" + headers = kw.pop("headers", {}).copy() + + headers["Range"] = self._process_limits(url, start, end) + kw["headers"] = headers + r = self.session.get(self.encode_url(url), **kw) + self._raise_not_found_for_status(r, url) + return r.content + + def get_file( + self, rpath, lpath, chunk_size=5 * 2**20, callback=_DEFAULT_CALLBACK, **kwargs + ): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(rpath) + r = self.session.get(self.encode_url(rpath), **kw) + try: + size = int( + r.headers.get("content-length", None) + or r.headers.get("Content-Length", None) + ) + except (ValueError, KeyError, TypeError): + size = None + + callback.set_size(size) + self._raise_not_found_for_status(r, rpath) + if not isfilelike(lpath): + lpath = open(lpath, "wb") + for chunk in r.iter_content(chunk_size, decode_unicode=False): + lpath.write(chunk) + callback.relative_update(len(chunk)) + + def put_file( + self, + lpath, + rpath, + chunk_size=5 * 2**20, + callback=_DEFAULT_CALLBACK, + method="post", + **kwargs, + ): + def gen_chunks(): + # Support passing arbitrary file-like objects + # and use them instead of streams. + if isinstance(lpath, io.IOBase): + context = nullcontext(lpath) + use_seek = False # might not support seeking + else: + context = open(lpath, "rb") + use_seek = True + + with context as f: + if use_seek: + callback.set_size(f.seek(0, 2)) + f.seek(0) + else: + callback.set_size(getattr(f, "size", None)) + + chunk = f.read(chunk_size) + while chunk: + yield chunk + callback.relative_update(len(chunk)) + chunk = f.read(chunk_size) + + kw = self.kwargs.copy() + kw.update(kwargs) + + method = method.lower() + if method not in ("post", "put"): + raise ValueError( + f"method has to be either 'post' or 'put', not: {method!r}" + ) + + meth = getattr(self.session, method) + resp = meth(rpath, data=gen_chunks(), **kw) + self._raise_not_found_for_status(resp, rpath) + + def _process_limits(self, url, start, end): + """Helper for "Range"-based _cat_file""" + size = None + suff = False + if start is not None and start < 0: + # if start is negative and end None, end is the "suffix length" + if end is None: + end = -start + start = "" + suff = True + else: + size = size or self.info(url)["size"] + start = size + start + elif start is None: + start = 0 + if not suff: + if end is not None and end < 0: + if start is not None: + size = size or self.info(url)["size"] + end = size + end + elif end is None: + end = "" + if isinstance(end, int): + end -= 1 # bytes range is inclusive + return f"bytes={start}-{end}" + + def exists(self, path, strict=False, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + try: + logger.debug(path) + r = self.session.get(self.encode_url(path), **kw) + if strict: + self._raise_not_found_for_status(r, path) + return r.status_code < 400 + except FileNotFoundError: + return False + except Exception: + if strict: + raise + return False + + def isfile(self, path, **kwargs): + return self.exists(path, **kwargs) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=None, # XXX: This differs from the base class. + cache_type=None, + cache_options=None, + size=None, + **kwargs, + ): + """Make a file-like object + + Parameters + ---------- + path: str + Full URL with protocol + mode: string + must be "rb" + block_size: int or None + Bytes to download in one request; use instance value if None. If + zero, will return a streaming Requests file-like instance. + kwargs: key-value + Any other parameters, passed to requests calls + """ + if mode != "rb": + raise NotImplementedError + block_size = block_size if block_size is not None else self.block_size + kw = self.kwargs.copy() + kw.update(kwargs) + size = size or self.info(path, **kwargs)["size"] + if block_size and size: + return HTTPFile( + self, + path, + session=self.session, + block_size=block_size, + mode=mode, + size=size, + cache_type=cache_type or self.cache_type, + cache_options=cache_options or self.cache_options, + **kw, + ) + else: + return HTTPStreamFile( + self, + path, + mode=mode, + session=self.session, + **kw, + ) + + def ukey(self, url): + """Unique identifier; assume HTTP files are static, unchanging""" + return tokenize(url, self.kwargs, self.protocol) + + def info(self, url, **kwargs): + """Get info of URL + + Tries to access location via HEAD, and then GET methods, but does + not fetch the data. + + It is possible that the server does not supply any size information, in + which case size will be given as None (and certain operations on the + corresponding file will not work). + """ + info = {} + for policy in ["head", "get"]: + try: + info.update( + _file_info( + self.encode_url(url), + size_policy=policy, + session=self.session, + **self.kwargs, + **kwargs, + ) + ) + if info.get("size") is not None: + break + except Exception as exc: + if policy == "get": + # If get failed, then raise a FileNotFoundError + raise FileNotFoundError(url) from exc + logger.debug(str(exc)) + + return {"name": url, "size": None, **info, "type": "file"} + + def glob(self, path, maxdepth=None, **kwargs): + """ + Find files by glob-matching. + + This implementation is idntical to the one in AbstractFileSystem, + but "?" is not considered as a character for globbing, because it is + so common in URLs, often identifying the "query" part. + """ + import re + + ends = path.endswith("/") + path = self._strip_protocol(path) + indstar = path.find("*") if path.find("*") >= 0 else len(path) + indbrace = path.find("[") if path.find("[") >= 0 else len(path) + + ind = min(indstar, indbrace) + + detail = kwargs.pop("detail", False) + + if not has_magic(path): + root = path + depth = 1 + if ends: + path += "/*" + elif self.exists(path): + if not detail: + return [path] + else: + return {path: self.info(path)} + else: + if not detail: + return [] # glob of non-existent returns empty + else: + return {} + elif "/" in path[:ind]: + ind2 = path[:ind].rindex("/") + root = path[: ind2 + 1] + depth = None if "**" in path else path[ind2 + 1 :].count("/") + 1 + else: + root = "" + depth = None if "**" in path else path[ind + 1 :].count("/") + 1 + + allpaths = self.find( + root, maxdepth=maxdepth or depth, withdirs=True, detail=True, **kwargs + ) + # Escape characters special to python regex, leaving our supported + # special characters in place. + # See https://www.gnu.org/software/bash/manual/html_node/Pattern-Matching.html + # for shell globbing details. + pattern = ( + "^" + + ( + path.replace("\\", r"\\") + .replace(".", r"\.") + .replace("+", r"\+") + .replace("//", "/") + .replace("(", r"\(") + .replace(")", r"\)") + .replace("|", r"\|") + .replace("^", r"\^") + .replace("$", r"\$") + .replace("{", r"\{") + .replace("}", r"\}") + .rstrip("/") + ) + + "$" + ) + pattern = re.sub("[*]{2}", "=PLACEHOLDER=", pattern) + pattern = re.sub("[*]", "[^/]*", pattern) + pattern = re.compile(pattern.replace("=PLACEHOLDER=", ".*")) + out = { + p: allpaths[p] + for p in sorted(allpaths) + if pattern.match(p.replace("//", "/").rstrip("/")) + } + if detail: + return out + else: + return list(out) + + def isdir(self, path): + # override, since all URLs are (also) files + try: + return bool(self.ls(path)) + except (FileNotFoundError, ValueError): + return False + + +class HTTPFile(AbstractBufferedFile): + """ + A file-like object pointing to a remove HTTP(S) resource + + Supports only reading, with read-ahead of a predermined block-size. + + In the case that the server does not supply the filesize, only reading of + the complete file in one go is supported. + + Parameters + ---------- + url: str + Full URL of the remote resource, including the protocol + session: requests.Session or None + All calls will be made within this session, to avoid restarting + connections where the server allows this + block_size: int or None + The amount of read-ahead to do, in bytes. Default is 5MB, or the value + configured for the FileSystem creating this file + size: None or int + If given, this is the size of the file in bytes, and we don't attempt + to call the server to find the value. + kwargs: all other key-values are passed to requests calls. + """ + + def __init__( + self, + fs, + url, + session=None, + block_size=None, + mode="rb", + cache_type="bytes", + cache_options=None, + size=None, + **kwargs, + ): + if mode != "rb": + raise NotImplementedError("File mode not supported") + self.url = url + self.session = session + self.details = {"name": url, "size": size, "type": "file"} + super().__init__( + fs=fs, + path=url, + mode=mode, + block_size=block_size, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + + def read(self, length=-1): + """Read bytes from file + + Parameters + ---------- + length: int + Read up to this many bytes. If negative, read all content to end of + file. If the server has not supplied the filesize, attempting to + read only part of the data will raise a ValueError. + """ + if ( + (length < 0 and self.loc == 0) # explicit read all + # but not when the size is known and fits into a block anyways + and not (self.size is not None and self.size <= self.blocksize) + ): + self._fetch_all() + if self.size is None: + if length < 0: + self._fetch_all() + else: + length = min(self.size - self.loc, length) + return super().read(length) + + def _fetch_all(self): + """Read whole file in one shot, without caching + + This is only called when position is still at zero, + and read() is called without a byte-count. + """ + logger.debug(f"Fetch all for {self}") + if not isinstance(self.cache, AllBytes): + r = self.session.get(self.fs.encode_url(self.url), **self.kwargs) + r.raise_for_status() + out = r.content + self.cache = AllBytes(size=len(out), fetcher=None, blocksize=None, data=out) + self.size = len(out) + + def _parse_content_range(self, headers): + """Parse the Content-Range header""" + s = headers.get("Content-Range", "") + m = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", s) + if not m: + return None, None, None + + if m[1] == "*": + start = end = None + else: + start, end = [int(x) for x in m[1].split("-")] + total = None if m[2] == "*" else int(m[2]) + return start, end, total + + def _fetch_range(self, start, end): + """Download a block of data + + The expectation is that the server returns only the requested bytes, + with HTTP code 206. If this is not the case, we first check the headers, + and then stream the output - if the data size is bigger than we + requested, an exception is raised. + """ + logger.debug(f"Fetch range for {self}: {start}-{end}") + kwargs = self.kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + logger.debug("%s : %s", self.url, headers["Range"]) + r = self.session.get(self.fs.encode_url(self.url), headers=headers, **kwargs) + if r.status_code == 416: + # range request outside file + return b"" + r.raise_for_status() + + # If the server has handled the range request, it should reply + # with status 206 (partial content). But we'll guess that a suitable + # Content-Range header or a Content-Length no more than the + # requested range also mean we have got the desired range. + cl = r.headers.get("Content-Length", r.headers.get("content-length", end + 1)) + response_is_range = ( + r.status_code == 206 + or self._parse_content_range(r.headers)[0] == start + or int(cl) <= end - start + ) + + if response_is_range: + # partial content, as expected + out = r.content + elif start > 0: + raise ValueError( + "The HTTP server doesn't appear to support range requests. " + "Only reading this file from the beginning is supported. " + "Open with block_size=0 for a streaming file interface." + ) + else: + # Response is not a range, but we want the start of the file, + # so we can read the required amount anyway. + cl = 0 + out = [] + for chunk in r.iter_content(2**20, False): + out.append(chunk) + cl += len(chunk) + out = b"".join(out)[: end - start] + return out + + +magic_check = re.compile("([*[])") + + +def has_magic(s): + match = magic_check.search(s) + return match is not None + + +class HTTPStreamFile(AbstractBufferedFile): + def __init__(self, fs, url, mode="rb", session=None, **kwargs): + self.url = url + self.session = session + if mode != "rb": + raise ValueError + self.details = {"name": url, "size": None} + super().__init__(fs=fs, path=url, mode=mode, cache_type="readahead", **kwargs) + + r = self.session.get(self.fs.encode_url(url), stream=True, **kwargs) + self.fs._raise_not_found_for_status(r, url) + self.it = r.iter_content(1024, False) + self.leftover = b"" + + self.r = r + + def seek(self, *args, **kwargs): + raise ValueError("Cannot seek streaming HTTP file") + + def read(self, num=-1): + bufs = [self.leftover] + leng = len(self.leftover) + while leng < num or num < 0: + try: + out = self.it.__next__() + except StopIteration: + break + if out: + bufs.append(out) + else: + break + leng += len(out) + out = b"".join(bufs) + if num >= 0: + self.leftover = out[num:] + out = out[:num] + else: + self.leftover = b"" + self.loc += len(out) + return out + + def close(self): + self.r.close() + self.closed = True + + +def get_range(session, url, start, end, **kwargs): + # explicit get a range when we know it must be safe + kwargs = kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + r = session.get(url, headers=headers, **kwargs) + r.raise_for_status() + return r.content + + +def _file_info(url, session, size_policy="head", **kwargs): + """Call HEAD on the server to get details about the file (size/checksum etc.) + + Default operation is to explicitly allow redirects and use encoding + 'identity' (no compression) to get the true size of the target. + """ + logger.debug("Retrieve file size for %s", url) + kwargs = kwargs.copy() + ar = kwargs.pop("allow_redirects", True) + head = kwargs.get("headers", {}).copy() + # TODO: not allowed in JS + # head["Accept-Encoding"] = "identity" + kwargs["headers"] = head + + info = {} + if size_policy == "head": + r = session.head(url, allow_redirects=ar, **kwargs) + elif size_policy == "get": + r = session.get(url, allow_redirects=ar, **kwargs) + else: + raise TypeError(f'size_policy must be "head" or "get", got {size_policy}') + r.raise_for_status() + + # TODO: + # recognise lack of 'Accept-Ranges', + # or 'Accept-Ranges': 'none' (not 'bytes') + # to mean streaming only, no random access => return None + if "Content-Length" in r.headers: + info["size"] = int(r.headers["Content-Length"]) + elif "Content-Range" in r.headers: + info["size"] = int(r.headers["Content-Range"].split("/")[1]) + elif "content-length" in r.headers: + info["size"] = int(r.headers["content-length"]) + elif "content-range" in r.headers: + info["size"] = int(r.headers["content-range"].split("/")[1]) + + for checksum_field in ["ETag", "Content-MD5", "Digest"]: + if r.headers.get(checksum_field): + info[checksum_field] = r.headers[checksum_field] + + return info + + +# importing this is enough to register it +def register(): + register_implementation("http", HTTPFileSystem, clobber=True) + register_implementation("https", HTTPFileSystem, clobber=True) + register_implementation("sync-http", HTTPFileSystem, clobber=True) + register_implementation("sync-https", HTTPFileSystem, clobber=True) + + +register() + + +def unregister(): + from fsspec.implementations.http import HTTPFileSystem + + register_implementation("http", HTTPFileSystem, clobber=True) + register_implementation("https", HTTPFileSystem, clobber=True) diff --git a/venv/lib/python3.10/site-packages/fsspec/implementations/libarchive.py b/venv/lib/python3.10/site-packages/fsspec/implementations/libarchive.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8e750002df72865d611b48022e6634f9572614 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/implementations/libarchive.py @@ -0,0 +1,213 @@ +from contextlib import contextmanager +from ctypes import ( + CFUNCTYPE, + POINTER, + c_int, + c_longlong, + c_void_p, + cast, + create_string_buffer, +) + +import libarchive +import libarchive.ffi as ffi + +from fsspec import open_files +from fsspec.archive import AbstractArchiveFileSystem +from fsspec.implementations.memory import MemoryFile +from fsspec.utils import DEFAULT_BLOCK_SIZE + +# Libarchive requires seekable files or memory only for certain archive +# types. However, since we read the directory first to cache the contents +# and also allow random access to any file, the file-like object needs +# to be seekable no matter what. + +# Seek call-backs (not provided in the libarchive python wrapper) +SEEK_CALLBACK = CFUNCTYPE(c_longlong, c_int, c_void_p, c_longlong, c_int) +read_set_seek_callback = ffi.ffi( + "read_set_seek_callback", [ffi.c_archive_p, SEEK_CALLBACK], c_int, ffi.check_int +) +new_api = hasattr(ffi, "NO_OPEN_CB") + + +@contextmanager +def custom_reader(file, format_name="all", filter_name="all", block_size=ffi.page_size): + """Read an archive from a seekable file-like object. + + The `file` object must support the standard `readinto` and 'seek' methods. + """ + buf = create_string_buffer(block_size) + buf_p = cast(buf, c_void_p) + + def read_func(archive_p, context, ptrptr): + # readinto the buffer, returns number of bytes read + length = file.readinto(buf) + # write the address of the buffer into the pointer + ptrptr = cast(ptrptr, POINTER(c_void_p)) + ptrptr[0] = buf_p + # tell libarchive how much data was written into the buffer + return length + + def seek_func(archive_p, context, offset, whence): + file.seek(offset, whence) + # tell libarchvie the current position + return file.tell() + + read_cb = ffi.READ_CALLBACK(read_func) + seek_cb = SEEK_CALLBACK(seek_func) + + if new_api: + open_cb = ffi.NO_OPEN_CB + close_cb = ffi.NO_CLOSE_CB + else: + open_cb = libarchive.read.OPEN_CALLBACK(ffi.VOID_CB) + close_cb = libarchive.read.CLOSE_CALLBACK(ffi.VOID_CB) + + with libarchive.read.new_archive_read(format_name, filter_name) as archive_p: + read_set_seek_callback(archive_p, seek_cb) + ffi.read_open(archive_p, None, open_cb, read_cb, close_cb) + yield libarchive.read.ArchiveRead(archive_p) + + +class LibArchiveFileSystem(AbstractArchiveFileSystem): + """Compressed archives as a file-system (read-only) + + Supports the following formats: + tar, pax , cpio, ISO9660, zip, mtree, shar, ar, raw, xar, lha/lzh, rar + Microsoft CAB, 7-Zip, WARC + + See the libarchive documentation for further restrictions. + https://www.libarchive.org/ + + Keeps file object open while instance lives. It only works in seekable + file-like objects. In case the filesystem does not support this kind of + file object, it is recommended to cache locally. + + This class is pickleable, but not necessarily thread-safe (depends on the + platform). See libarchive documentation for details. + """ + + root_marker = "" + protocol = "libarchive" + cachable = False + + def __init__( + self, + fo="", + mode="r", + target_protocol=None, + target_options=None, + block_size=DEFAULT_BLOCK_SIZE, + **kwargs, + ): + """ + Parameters + ---------- + fo: str or file-like + Contains ZIP, and must exist. If a str, will fetch file using + :meth:`~fsspec.open_files`, which must return one file exactly. + mode: str + Currently, only 'r' accepted + target_protocol: str (optional) + If ``fo`` is a string, this value can be used to override the + FS protocol inferred from a URL + target_options: dict (optional) + Kwargs passed when instantiating the target FS, if ``fo`` is + a string. + """ + super().__init__(self, **kwargs) + if mode != "r": + raise ValueError("Only read from archive files accepted") + if isinstance(fo, str): + files = open_files(fo, protocol=target_protocol, **(target_options or {})) + if len(files) != 1: + raise ValueError( + f'Path "{fo}" did not resolve to exactly one file: "{files}"' + ) + fo = files[0] + self.of = fo + self.fo = fo.__enter__() # the whole instance is a context + self.block_size = block_size + self.dir_cache = None + + @contextmanager + def _open_archive(self): + self.fo.seek(0) + with custom_reader(self.fo, block_size=self.block_size) as arc: + yield arc + + @classmethod + def _strip_protocol(cls, path): + # file paths are always relative to the archive root + return super()._strip_protocol(path).lstrip("/") + + def _get_dirs(self): + fields = { + "name": "pathname", + "size": "size", + "created": "ctime", + "mode": "mode", + "uid": "uid", + "gid": "gid", + "mtime": "mtime", + } + + if self.dir_cache is not None: + return + + self.dir_cache = {} + list_names = [] + with self._open_archive() as arc: + for entry in arc: + if not entry.isdir and not entry.isfile: + # Skip symbolic links, fifo entries, etc. + continue + self.dir_cache.update( + { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(set(entry.name)) + } + ) + f = {key: getattr(entry, fields[key]) for key in fields} + f["type"] = "directory" if entry.isdir else "file" + list_names.append(entry.name) + + self.dir_cache[f["name"]] = f + # libarchive does not seem to return an entry for the directories (at least + # not in all formats), so get the directories names from the files names + self.dir_cache.update( + { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(list_names) + } + ) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if mode != "rb": + raise NotImplementedError + + data = b"" + with self._open_archive() as arc: + for entry in arc: + if entry.pathname != path: + continue + + if entry.size == 0: + # empty file, so there are no blocks + break + + for block in entry.get_blocks(entry.size): + data = block + break + else: + raise ValueError + return MemoryFile(fs=self, path=path, data=data) diff --git a/venv/lib/python3.10/site-packages/fsspec/parquet.py b/venv/lib/python3.10/site-packages/fsspec/parquet.py new file mode 100644 index 0000000000000000000000000000000000000000..25f1b702d36a047f15a2095c151ecf86e4fe99ab --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/parquet.py @@ -0,0 +1,572 @@ +import io +import json +import warnings + +import fsspec + +from .core import url_to_fs +from .spec import AbstractBufferedFile +from .utils import merge_offset_ranges + +# Parquet-Specific Utilities for fsspec +# +# Most of the functions defined in this module are NOT +# intended for public consumption. The only exception +# to this is `open_parquet_file`, which should be used +# place of `fs.open()` to open parquet-formatted files +# on remote file systems. + + +class AlreadyBufferedFile(AbstractBufferedFile): + def _fetch_range(self, start, end): + raise NotImplementedError + + +def open_parquet_files( + path: list[str], + fs: None | fsspec.AbstractFileSystem = None, + metadata=None, + columns: None | list[str] = None, + row_groups: None | list[int] = None, + storage_options: None | dict = None, + engine: str = "auto", + max_gap: int = 64_000, + max_block: int = 256_000_000, + footer_sample_size: int = 1_000_000, + filters: None | list[list[list[str]]] = None, + **kwargs, +): + """ + Return a file-like object for a single Parquet file. + + The specified parquet `engine` will be used to parse the + footer metadata, and determine the required byte ranges + from the file. The target path will then be opened with + the "parts" (`KnownPartsOfAFile`) caching strategy. + + Note that this method is intended for usage with remote + file systems, and is unlikely to improve parquet-read + performance on local file systems. + + Parameters + ---------- + path: str + Target file path. + metadata: Any, optional + Parquet metadata object. Object type must be supported + by the backend parquet engine. For now, only the "fastparquet" + engine supports an explicit `ParquetFile` metadata object. + If a metadata object is supplied, the remote footer metadata + will not need to be transferred into local memory. + fs: AbstractFileSystem, optional + Filesystem object to use for opening the file. If nothing is + specified, an `AbstractFileSystem` object will be inferred. + engine : str, default "auto" + Parquet engine to use for metadata parsing. Allowed options + include "fastparquet", "pyarrow", and "auto". The specified + engine must be installed in the current environment. If + "auto" is specified, and both engines are installed, + "fastparquet" will take precedence over "pyarrow". + columns: list, optional + List of all column names that may be read from the file. + row_groups : list, optional + List of all row-groups that may be read from the file. This + may be a list of row-group indices (integers), or it may be + a list of `RowGroup` metadata objects (if the "fastparquet" + engine is used). + storage_options : dict, optional + Used to generate an `AbstractFileSystem` object if `fs` was + not specified. + max_gap : int, optional + Neighboring byte ranges will only be merged when their + inter-range gap is <= `max_gap`. Default is 64KB. + max_block : int, optional + Neighboring byte ranges will only be merged when the size of + the aggregated range is <= `max_block`. Default is 256MB. + footer_sample_size : int, optional + Number of bytes to read from the end of the path to look + for the footer metadata. If the sampled bytes do not contain + the footer, a second read request will be required, and + performance will suffer. Default is 1MB. + filters : list[list], optional + List of filters to apply to prevent reading row groups, of the + same format as accepted by the loading engines. Ignored if + ``row_groups`` is specified. + **kwargs : + Optional key-word arguments to pass to `fs.open` + """ + + # Make sure we have an `AbstractFileSystem` object + # to work with + if fs is None: + path0 = path + if isinstance(path, (list, tuple)): + path = path[0] + fs, path = url_to_fs(path, **(storage_options or {})) + else: + path0 = path + + # For now, `columns == []` not supported, is the same + # as all columns + if columns is not None and len(columns) == 0: + columns = None + + # Set the engine + engine = _set_engine(engine) + + if isinstance(path0, (list, tuple)): + paths = path0 + elif "*" in path: + paths = fs.glob(path) + elif path0.endswith("/"): # or fs.isdir(path): + paths = [ + _ + for _ in fs.find(path, withdirs=False, detail=False) + if _.endswith((".parquet", ".parq")) + ] + else: + paths = [path] + + data = _get_parquet_byte_ranges( + paths, + fs, + metadata=metadata, + columns=columns, + row_groups=row_groups, + engine=engine, + max_gap=max_gap, + max_block=max_block, + footer_sample_size=footer_sample_size, + filters=filters, + ) + + # Call self.open with "parts" caching + options = kwargs.pop("cache_options", {}).copy() + return [ + AlreadyBufferedFile( + fs=None, + path=fn, + mode="rb", + cache_type="parts", + cache_options={ + **options, + "data": ranges, + }, + size=max(_[1] for _ in ranges), + **kwargs, + ) + for fn, ranges in data.items() + ] + + +def open_parquet_file(*args, **kwargs): + """Create files tailed to reading specific parts of parquet files + + Please see ``open_parquet_files`` for details of the arguments. The + difference is, this function always returns a single ``AleadyBufferedFile``, + whereas `open_parquet_files`` always returns a list of files, even if + there are one or zero matching parquet files. + """ + return open_parquet_files(*args, **kwargs)[0] + + +def _get_parquet_byte_ranges( + paths, + fs, + metadata=None, + columns=None, + row_groups=None, + max_gap=64_000, + max_block=256_000_000, + footer_sample_size=1_000_000, + engine="auto", + filters=None, +): + """Get a dictionary of the known byte ranges needed + to read a specific column/row-group selection from a + Parquet dataset. Each value in the output dictionary + is intended for use as the `data` argument for the + `KnownPartsOfAFile` caching strategy of a single path. + """ + + # Set engine if necessary + if isinstance(engine, str): + engine = _set_engine(engine) + + # Pass to a specialized function if metadata is defined + if metadata is not None: + # Use the provided parquet metadata object + # to avoid transferring/parsing footer metadata + return _get_parquet_byte_ranges_from_metadata( + metadata, + fs, + engine, + columns=columns, + row_groups=row_groups, + max_gap=max_gap, + max_block=max_block, + filters=filters, + ) + + # Populate global paths, starts, & ends + if columns is None and row_groups is None and filters is None: + # We are NOT selecting specific columns or row-groups. + # + # We can avoid sampling the footers, and just transfer + # all file data with cat_ranges + result = {path: {(0, len(data)): data} for path, data in fs.cat(paths).items()} + else: + # We ARE selecting specific columns or row-groups. + # + # Get file sizes asynchronously + file_sizes = fs.sizes(paths) + data_paths = [] + data_starts = [] + data_ends = [] + # Gather file footers. + # We just take the last `footer_sample_size` bytes of each + # file (or the entire file if it is smaller than that) + footer_starts = [ + max(0, file_size - footer_sample_size) for file_size in file_sizes + ] + footer_samples = fs.cat_ranges(paths, footer_starts, file_sizes) + + # Check our footer samples and re-sample if necessary. + large_footer = [] + for i, path in enumerate(paths): + footer_size = int.from_bytes(footer_samples[i][-8:-4], "little") + real_footer_start = file_sizes[i] - (footer_size + 8) + if real_footer_start < footer_starts[i]: + large_footer.append((i, real_footer_start)) + if large_footer: + warnings.warn( + f"Not enough data was used to sample the parquet footer. " + f"Try setting footer_sample_size >= {large_footer}." + ) + path0 = [paths[i] for i, _ in large_footer] + starts = [_[1] for _ in large_footer] + ends = [file_sizes[i] - footer_sample_size for i, _ in large_footer] + data = fs.cat_ranges(path0, starts, ends) + for i, (path, start, block) in enumerate(zip(path0, starts, data)): + footer_samples[i] = block + footer_samples[i] + footer_starts[i] = start + result = { + path: {(start, size): data} + for path, start, size, data in zip( + paths, footer_starts, file_sizes, footer_samples + ) + } + + # Calculate required byte ranges for each path + for i, path in enumerate(paths): + # Use "engine" to collect data byte ranges + path_data_starts, path_data_ends = engine._parquet_byte_ranges( + columns, + row_groups=row_groups, + footer=footer_samples[i], + footer_start=footer_starts[i], + filters=filters, + ) + + data_paths += [path] * len(path_data_starts) + data_starts += path_data_starts + data_ends += path_data_ends + + # Merge adjacent offset ranges + data_paths, data_starts, data_ends = merge_offset_ranges( + data_paths, + data_starts, + data_ends, + max_gap=max_gap, + max_block=max_block, + sort=True, + ) + + # Transfer the data byte-ranges into local memory + _transfer_ranges(fs, result, data_paths, data_starts, data_ends) + + # Add b"PAR1" to headers + _add_header_magic(result) + + return result + + +def _get_parquet_byte_ranges_from_metadata( + metadata, + fs, + engine, + columns=None, + row_groups=None, + max_gap=64_000, + max_block=256_000_000, + filters=None, +): + """Simplified version of `_get_parquet_byte_ranges` for + the case that an engine-specific `metadata` object is + provided, and the remote footer metadata does not need to + be transferred before calculating the required byte ranges. + """ + + # Use "engine" to collect data byte ranges + data_paths, data_starts, data_ends = engine._parquet_byte_ranges( + columns, row_groups=row_groups, metadata=metadata, filters=filters + ) + + # Merge adjacent offset ranges + data_paths, data_starts, data_ends = merge_offset_ranges( + data_paths, + data_starts, + data_ends, + max_gap=max_gap, + max_block=max_block, + sort=False, # Should be sorted + ) + + # Transfer the data byte-ranges into local memory + result = {fn: {} for fn in list(set(data_paths))} + _transfer_ranges(fs, result, data_paths, data_starts, data_ends) + + # Add b"PAR1" to header + _add_header_magic(result) + + return result + + +def _transfer_ranges(fs, blocks, paths, starts, ends): + # Use cat_ranges to gather the data byte_ranges + ranges = (paths, starts, ends) + for path, start, stop, data in zip(*ranges, fs.cat_ranges(*ranges)): + blocks[path][(start, stop)] = data + + +def _add_header_magic(data): + # Add b"PAR1" to file headers + for path in list(data): + add_magic = True + for k in data[path]: + if k[0] == 0 and k[1] >= 4: + add_magic = False + break + if add_magic: + data[path][(0, 4)] = b"PAR1" + + +def _set_engine(engine_str): + # Define a list of parquet engines to try + if engine_str == "auto": + try_engines = ("fastparquet", "pyarrow") + elif not isinstance(engine_str, str): + raise ValueError( + "Failed to set parquet engine! " + "Please pass 'fastparquet', 'pyarrow', or 'auto'" + ) + elif engine_str not in ("fastparquet", "pyarrow"): + raise ValueError(f"{engine_str} engine not supported by `fsspec.parquet`") + else: + try_engines = [engine_str] + + # Try importing the engines in `try_engines`, + # and choose the first one that succeeds + for engine in try_engines: + try: + if engine == "fastparquet": + return FastparquetEngine() + elif engine == "pyarrow": + return PyarrowEngine() + except ImportError: + pass + + # Raise an error if a supported parquet engine + # was not found + raise ImportError( + f"The following parquet engines are not installed " + f"in your python environment: {try_engines}." + f"Please install 'fastparquert' or 'pyarrow' to " + f"utilize the `fsspec.parquet` module." + ) + + +class FastparquetEngine: + # The purpose of the FastparquetEngine class is + # to check if fastparquet can be imported (on initialization) + # and to define a `_parquet_byte_ranges` method. In the + # future, this class may also be used to define other + # methods/logic that are specific to fastparquet. + + def __init__(self): + import fastparquet as fp + + self.fp = fp + + def _parquet_byte_ranges( + self, + columns, + row_groups=None, + metadata=None, + footer=None, + footer_start=None, + filters=None, + ): + # Initialize offset ranges and define ParqetFile metadata + pf = metadata + data_paths, data_starts, data_ends = [], [], [] + if filters and row_groups: + raise ValueError("filters and row_groups cannot be used together") + if pf is None: + pf = self.fp.ParquetFile(io.BytesIO(footer)) + + # Convert columns to a set and add any index columns + # specified in the pandas metadata (just in case) + column_set = None if columns is None else {c.split(".", 1)[0] for c in columns} + if column_set is not None and hasattr(pf, "pandas_metadata"): + md_index = [ + ind + for ind in pf.pandas_metadata.get("index_columns", []) + # Ignore RangeIndex information + if not isinstance(ind, dict) + ] + column_set |= set(md_index) + + # Check if row_groups is a list of integers + # or a list of row-group metadata + if filters: + from fastparquet.api import filter_row_groups + + row_group_indices = None + row_groups = filter_row_groups(pf, filters) + elif row_groups and not isinstance(row_groups[0], int): + # Input row_groups contains row-group metadata + row_group_indices = None + else: + # Input row_groups contains row-group indices + row_group_indices = row_groups + row_groups = pf.row_groups + if column_set is not None: + column_set = [ + _ if isinstance(_, list) else _.split(".") for _ in column_set + ] + + # Loop through column chunks to add required byte ranges + for r, row_group in enumerate(row_groups): + # Skip this row-group if we are targeting + # specific row-groups + if row_group_indices is None or r in row_group_indices: + # Find the target parquet-file path for `row_group` + fn = pf.row_group_filename(row_group) + + for column in row_group.columns: + name = column.meta_data.path_in_schema + # Skip this column if we are targeting specific columns + if column_set is None or _cmp(name, column_set): + file_offset0 = column.meta_data.dictionary_page_offset + if file_offset0 is None: + file_offset0 = column.meta_data.data_page_offset + num_bytes = column.meta_data.total_compressed_size + if footer_start is None or file_offset0 < footer_start: + data_paths.append(fn) + data_starts.append(file_offset0) + data_ends.append( + min( + file_offset0 + num_bytes, + footer_start or (file_offset0 + num_bytes), + ) + ) + + if metadata: + # The metadata in this call may map to multiple + # file paths. Need to include `data_paths` + return data_paths, data_starts, data_ends + return data_starts, data_ends + + +class PyarrowEngine: + # The purpose of the PyarrowEngine class is + # to check if pyarrow can be imported (on initialization) + # and to define a `_parquet_byte_ranges` method. In the + # future, this class may also be used to define other + # methods/logic that are specific to pyarrow. + + def __init__(self): + import pyarrow.parquet as pq + + self.pq = pq + + def _parquet_byte_ranges( + self, + columns, + row_groups=None, + metadata=None, + footer=None, + footer_start=None, + filters=None, + ): + if metadata is not None: + raise ValueError("metadata input not supported for PyarrowEngine") + if filters: + # there must be a way! + raise NotImplementedError + + data_starts, data_ends = [], [] + md = self.pq.ParquetFile(io.BytesIO(footer)).metadata + + # Convert columns to a set and add any index columns + # specified in the pandas metadata (just in case) + column_set = None if columns is None else set(columns) + if column_set is not None: + schema = md.schema.to_arrow_schema() + has_pandas_metadata = ( + schema.metadata is not None and b"pandas" in schema.metadata + ) + if has_pandas_metadata: + md_index = [ + ind + for ind in json.loads( + schema.metadata[b"pandas"].decode("utf8") + ).get("index_columns", []) + # Ignore RangeIndex information + if not isinstance(ind, dict) + ] + column_set |= set(md_index) + if column_set is not None: + column_set = [ + _[:1] if isinstance(_, list) else _.split(".")[:1] for _ in column_set + ] + + # Loop through column chunks to add required byte ranges + for r in range(md.num_row_groups): + # Skip this row-group if we are targeting + # specific row-groups + if row_groups is None or r in row_groups: + row_group = md.row_group(r) + for c in range(row_group.num_columns): + column = row_group.column(c) + name = column.path_in_schema.split(".") + # Skip this column if we are targeting specific columns + if column_set is None or _cmp(name, column_set): + meta = column.to_dict() + # Any offset could be the first one + file_offset0 = min( + _ + for _ in [ + meta.get("dictionary_page_offset"), + meta.get("data_page_offset"), + meta.get("index_page_offset"), + ] + if _ is not None + ) + if file_offset0 < footer_start: + data_starts.append(file_offset0) + data_ends.append( + min( + meta["total_compressed_size"] + file_offset0, + footer_start, + ) + ) + + data_starts.append(footer_start) + data_ends.append(footer_start + len(footer)) + return data_starts, data_ends + + +def _cmp(name, column_set): + return any(all(a == b for a, b in zip(name, _)) for _ in column_set) diff --git a/venv/lib/python3.10/site-packages/fsspec/registry.py b/venv/lib/python3.10/site-packages/fsspec/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..597c6ba57a7cb70c65fd6daeb97ce7ab703a5121 --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/registry.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +import importlib +import types +import warnings + +__all__ = ["registry", "get_filesystem_class", "default"] + +# internal, mutable +_registry: dict[str, type] = {} + +# external, immutable +registry = types.MappingProxyType(_registry) +default = "file" + + +def register_implementation(name, cls, clobber=False, errtxt=None): + """Add implementation class to the registry + + Parameters + ---------- + name: str + Protocol name to associate with the class + cls: class or str + if a class: fsspec-compliant implementation class (normally inherits from + ``fsspec.AbstractFileSystem``, gets added straight to the registry. If a + str, the full path to an implementation class like package.module.class, + which gets added to known_implementations, + so the import is deferred until the filesystem is actually used. + clobber: bool (optional) + Whether to overwrite a protocol with the same name; if False, will raise + instead. + errtxt: str (optional) + If given, then a failure to import the given class will result in this + text being given. + """ + if isinstance(cls, str): + if name in known_implementations and clobber is False: + if cls != known_implementations[name]["class"]: + raise ValueError( + f"Name ({name}) already in the known_implementations and clobber " + f"is False" + ) + else: + known_implementations[name] = { + "class": cls, + "err": errtxt or f"{cls} import failed for protocol {name}", + } + + else: + if name in registry and clobber is False: + if _registry[name] is not cls: + raise ValueError( + f"Name ({name}) already in the registry and clobber is False" + ) + else: + _registry[name] = cls + + +# protocols mapped to the class which implements them. This dict can be +# updated with register_implementation +known_implementations = { + "abfs": { + "class": "adlfs.AzureBlobFileSystem", + "err": "Install adlfs to access Azure Datalake Gen2 and Azure Blob Storage", + }, + "adl": { + "class": "adlfs.AzureDatalakeFileSystem", + "err": "Install adlfs to access Azure Datalake Gen1", + }, + "arrow_hdfs": { + "class": "fsspec.implementations.arrow.HadoopFileSystem", + "err": "pyarrow and local java libraries required for HDFS", + }, + "async_wrapper": { + "class": "fsspec.implementations.asyn_wrapper.AsyncFileSystemWrapper", + }, + "asynclocal": { + "class": "morefs.asyn_local.AsyncLocalFileSystem", + "err": "Install 'morefs[asynclocalfs]' to use AsyncLocalFileSystem", + }, + "asyncwrapper": { + "class": "fsspec.implementations.asyn_wrapper.AsyncFileSystemWrapper", + }, + "az": { + "class": "adlfs.AzureBlobFileSystem", + "err": "Install adlfs to access Azure Datalake Gen2 and Azure Blob Storage", + }, + "blockcache": {"class": "fsspec.implementations.cached.CachingFileSystem"}, + "box": { + "class": "boxfs.BoxFileSystem", + "err": "Please install boxfs to access BoxFileSystem", + }, + "cached": {"class": "fsspec.implementations.cached.CachingFileSystem"}, + "dask": { + "class": "fsspec.implementations.dask.DaskWorkerFileSystem", + "err": "Install dask distributed to access worker file system", + }, + "data": {"class": "fsspec.implementations.data.DataFileSystem"}, + "dbfs": { + "class": "fsspec.implementations.dbfs.DatabricksFileSystem", + "err": "Install the requests package to use the DatabricksFileSystem", + }, + "dir": {"class": "fsspec.implementations.dirfs.DirFileSystem"}, + "dropbox": { + "class": "dropboxdrivefs.DropboxDriveFileSystem", + "err": ( + 'DropboxFileSystem requires "dropboxdrivefs","requests" and "' + '"dropbox" to be installed' + ), + }, + "dvc": { + "class": "dvc.api.DVCFileSystem", + "err": "Install dvc to access DVCFileSystem", + }, + "file": {"class": "fsspec.implementations.local.LocalFileSystem"}, + "filecache": {"class": "fsspec.implementations.cached.WholeFileCacheFileSystem"}, + "ftp": {"class": "fsspec.implementations.ftp.FTPFileSystem"}, + "gcs": { + "class": "gcsfs.GCSFileSystem", + "err": "Please install gcsfs to access Google Storage", + }, + "gdrive": { + "class": "gdrive_fsspec.GoogleDriveFileSystem", + "err": "Please install gdrive_fs for access to Google Drive", + }, + "generic": {"class": "fsspec.generic.GenericFileSystem"}, + "gist": { + "class": "fsspec.implementations.gist.GistFileSystem", + "err": "Install the requests package to use the gist FS", + }, + "git": { + "class": "fsspec.implementations.git.GitFileSystem", + "err": "Install pygit2 to browse local git repos", + }, + "github": { + "class": "fsspec.implementations.github.GithubFileSystem", + "err": "Install the requests package to use the github FS", + }, + "gs": { + "class": "gcsfs.GCSFileSystem", + "err": "Please install gcsfs to access Google Storage", + }, + "hdfs": { + "class": "fsspec.implementations.arrow.HadoopFileSystem", + "err": "pyarrow and local java libraries required for HDFS", + }, + "hf": { + "class": "huggingface_hub.HfFileSystem", + "err": "Install huggingface_hub to access HfFileSystem", + }, + "http": { + "class": "fsspec.implementations.http.HTTPFileSystem", + "err": 'HTTPFileSystem requires "requests" and "aiohttp" to be installed', + }, + "https": { + "class": "fsspec.implementations.http.HTTPFileSystem", + "err": 'HTTPFileSystem requires "requests" and "aiohttp" to be installed', + }, + "jlab": { + "class": "fsspec.implementations.jupyter.JupyterFileSystem", + "err": "Jupyter FS requires requests to be installed", + }, + "jupyter": { + "class": "fsspec.implementations.jupyter.JupyterFileSystem", + "err": "Jupyter FS requires requests to be installed", + }, + "lakefs": { + "class": "lakefs_spec.LakeFSFileSystem", + "err": "Please install lakefs-spec to access LakeFSFileSystem", + }, + "libarchive": { + "class": "fsspec.implementations.libarchive.LibArchiveFileSystem", + "err": "LibArchive requires to be installed", + }, + "local": {"class": "fsspec.implementations.local.LocalFileSystem"}, + "memory": {"class": "fsspec.implementations.memory.MemoryFileSystem"}, + "oci": { + "class": "ocifs.OCIFileSystem", + "err": "Install ocifs to access OCI Object Storage", + }, + "ocilake": { + "class": "ocifs.OCIFileSystem", + "err": "Install ocifs to access OCI Data Lake", + }, + "oss": { + "class": "ossfs.OSSFileSystem", + "err": "Install ossfs to access Alibaba Object Storage System", + }, + "pyscript": { + "class": "pyscript_fsspec_client.client.PyscriptFileSystem", + "err": "This only runs in a pyscript context", + }, + "reference": {"class": "fsspec.implementations.reference.ReferenceFileSystem"}, + "root": { + "class": "fsspec_xrootd.XRootDFileSystem", + "err": ( + "Install fsspec-xrootd to access xrootd storage system. " + "Note: 'root' is the protocol name for xrootd storage systems, " + "not referring to root directories" + ), + }, + "s3": {"class": "s3fs.S3FileSystem", "err": "Install s3fs to access S3"}, + "s3a": {"class": "s3fs.S3FileSystem", "err": "Install s3fs to access S3"}, + "sftp": { + "class": "fsspec.implementations.sftp.SFTPFileSystem", + "err": 'SFTPFileSystem requires "paramiko" to be installed', + }, + "simplecache": {"class": "fsspec.implementations.cached.SimpleCacheFileSystem"}, + "smb": { + "class": "fsspec.implementations.smb.SMBFileSystem", + "err": 'SMB requires "smbprotocol" or "smbprotocol[kerberos]" installed', + }, + "ssh": { + "class": "fsspec.implementations.sftp.SFTPFileSystem", + "err": 'SFTPFileSystem requires "paramiko" to be installed', + }, + "tar": {"class": "fsspec.implementations.tar.TarFileSystem"}, + "tos": { + "class": "tosfs.TosFileSystem", + "err": "Install tosfs to access ByteDance volcano engine Tinder Object Storage", + }, + "tosfs": { + "class": "tosfs.TosFileSystem", + "err": "Install tosfs to access ByteDance volcano engine Tinder Object Storage", + }, + "wandb": {"class": "wandbfs.WandbFS", "err": "Install wandbfs to access wandb"}, + "webdav": { + "class": "webdav4.fsspec.WebdavFileSystem", + "err": "Install webdav4 to access WebDAV", + }, + "webhdfs": { + "class": "fsspec.implementations.webhdfs.WebHDFS", + "err": 'webHDFS access requires "requests" to be installed', + }, + "zip": {"class": "fsspec.implementations.zip.ZipFileSystem"}, +} + +assert list(known_implementations) == sorted(known_implementations), ( + "Not in alphabetical order" +) + + +def get_filesystem_class(protocol): + """Fetch named protocol implementation from the registry + + The dict ``known_implementations`` maps protocol names to the locations + of classes implementing the corresponding file-system. When used for the + first time, appropriate imports will happen and the class will be placed in + the registry. All subsequent calls will fetch directly from the registry. + + Some protocol implementations require additional dependencies, and so the + import may fail. In this case, the string in the "err" field of the + ``known_implementations`` will be given as the error message. + """ + if not protocol: + protocol = default + + if protocol not in registry: + if protocol not in known_implementations: + raise ValueError(f"Protocol not known: {protocol}") + bit = known_implementations[protocol] + try: + register_implementation(protocol, _import_class(bit["class"])) + except ImportError as e: + raise ImportError(bit.get("err")) from e + cls = registry[protocol] + if getattr(cls, "protocol", None) in ("abstract", None): + cls.protocol = protocol + + return cls + + +s3_msg = """Your installed version of s3fs is very old and known to cause +severe performance issues, see also https://github.com/dask/dask/issues/10276 + +To fix, you should specify a lower version bound on s3fs, or +update the current installation. +""" + + +def _import_class(fqp: str): + """Take a fully-qualified path and return the imported class or identifier. + + ``fqp`` is of the form "package.module.klass" or + "package.module:subobject.klass". + + Warnings + -------- + This can import arbitrary modules. Make sure you haven't installed any modules + that may execute malicious code at import time. + """ + if ":" in fqp: + mod, name = fqp.rsplit(":", 1) + else: + mod, name = fqp.rsplit(".", 1) + + is_s3 = mod == "s3fs" + mod = importlib.import_module(mod) + if is_s3 and mod.__version__.split(".") < ["0", "5"]: + warnings.warn(s3_msg) + for part in name.split("."): + mod = getattr(mod, part) + + if not isinstance(mod, type): + raise TypeError(f"{fqp} is not a class") + + return mod + + +def filesystem(protocol, **storage_options): + """Instantiate filesystems for given protocol and arguments + + ``storage_options`` are specific to the protocol being chosen, and are + passed directly to the class. + """ + if protocol == "arrow_hdfs": + warnings.warn( + "The 'arrow_hdfs' protocol has been deprecated and will be " + "removed in the future. Specify it as 'hdfs'.", + DeprecationWarning, + ) + + cls = get_filesystem_class(protocol) + return cls(**storage_options) + + +def available_protocols(): + """Return a list of the implemented protocols. + + Note that any given protocol may require extra packages to be importable. + """ + return list(known_implementations) diff --git a/venv/lib/python3.10/site-packages/fsspec/spec.py b/venv/lib/python3.10/site-packages/fsspec/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..b67d5c16fcdf09ce6f9e1354727196042cde3c4c --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/spec.py @@ -0,0 +1,2281 @@ +from __future__ import annotations + +import io +import json +import logging +import os +import threading +import warnings +import weakref +from errno import ESPIPE +from glob import has_magic +from hashlib import sha256 +from typing import Any, ClassVar + +from .callbacks import DEFAULT_CALLBACK +from .config import apply_config, conf +from .dircache import DirCache +from .transaction import Transaction +from .utils import ( + _unstrip_protocol, + glob_translate, + isfilelike, + other_paths, + read_block, + stringify_path, + tokenize, +) + +logger = logging.getLogger("fsspec") + + +def make_instance(cls, args, kwargs): + return cls(*args, **kwargs) + + +class _Cached(type): + """ + Metaclass for caching file system instances. + + Notes + ----- + Instances are cached according to + + * The values of the class attributes listed in `_extra_tokenize_attributes` + * The arguments passed to ``__init__``. + + This creates an additional reference to the filesystem, which prevents the + filesystem from being garbage collected when all *user* references go away. + A call to the :meth:`AbstractFileSystem.clear_instance_cache` must *also* + be made for a filesystem instance to be garbage collected. + """ + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + # Note: we intentionally create a reference here, to avoid garbage + # collecting instances when all other references are gone. To really + # delete a FileSystem, the cache must be cleared. + if conf.get("weakref_instance_cache"): # pragma: no cover + # debug option for analysing fork/spawn conditions + cls._cache = weakref.WeakValueDictionary() + else: + cls._cache = {} + cls._pid = os.getpid() + + def __call__(cls, *args, **kwargs): + kwargs = apply_config(cls, kwargs) + extra_tokens = tuple( + getattr(cls, attr, None) for attr in cls._extra_tokenize_attributes + ) + strip_tokenize_options = { + k: kwargs.pop(k) for k in cls._strip_tokenize_options if k in kwargs + } + token = tokenize( + cls, cls._pid, threading.get_ident(), *args, *extra_tokens, **kwargs + ) + skip = kwargs.pop("skip_instance_cache", False) + if os.getpid() != cls._pid: + cls._cache.clear() + cls._pid = os.getpid() + if not skip and cls.cachable and token in cls._cache: + cls._latest = token + return cls._cache[token] + else: + obj = super().__call__(*args, **kwargs, **strip_tokenize_options) + # Setting _fs_token here causes some static linters to complain. + obj._fs_token_ = token + obj.storage_args = args + obj.storage_options = kwargs + if obj.async_impl and obj.mirror_sync_methods: + from .asyn import mirror_sync_methods + + mirror_sync_methods(obj) + + if cls.cachable and not skip: + cls._latest = token + cls._cache[token] = obj + return obj + + +class AbstractFileSystem(metaclass=_Cached): + """ + An abstract super-class for pythonic file-systems + + Implementations are expected to be compatible with or, better, subclass + from here. + """ + + cachable = True # this class can be cached, instances reused + _cached = False + blocksize = 2**22 + sep = "/" + protocol: ClassVar[str | tuple[str, ...]] = "abstract" + _latest = None + async_impl = False + mirror_sync_methods = False + root_marker = "" # For some FSs, may require leading '/' or other character + transaction_type = Transaction + + #: Extra *class attributes* that should be considered when hashing. + _extra_tokenize_attributes = () + #: *storage options* that should not be considered when hashing. + _strip_tokenize_options = () + + # Set by _Cached metaclass + storage_args: tuple[Any, ...] + storage_options: dict[str, Any] + + def __init__(self, *args, **storage_options): + """Create and configure file-system instance + + Instances may be cachable, so if similar enough arguments are seen + a new instance is not required. The token attribute exists to allow + implementations to cache instances if they wish. + + A reasonable default should be provided if there are no arguments. + + Subclasses should call this method. + + Parameters + ---------- + use_listings_cache, listings_expiry_time, max_paths: + passed to ``DirCache``, if the implementation supports + directory listing caching. Pass use_listings_cache=False + to disable such caching. + skip_instance_cache: bool + If this is a cachable implementation, pass True here to force + creating a new instance even if a matching instance exists, and prevent + storing this instance. + asynchronous: bool + loop: asyncio-compatible IOLoop or None + """ + if self._cached: + # reusing instance, don't change + return + self._cached = True + self._intrans = False + self._transaction = None + self._invalidated_caches_in_transaction = [] + self.dircache = DirCache(**storage_options) + + if storage_options.pop("add_docs", None): + warnings.warn("add_docs is no longer supported.", FutureWarning) + + if storage_options.pop("add_aliases", None): + warnings.warn("add_aliases has been removed.", FutureWarning) + # This is set in _Cached + self._fs_token_ = None + + @property + def fsid(self): + """Persistent filesystem id that can be used to compare filesystems + across sessions. + """ + raise NotImplementedError + + @property + def _fs_token(self): + return self._fs_token_ + + def __dask_tokenize__(self): + return self._fs_token + + def __hash__(self): + return int(self._fs_token, 16) + + def __eq__(self, other): + return isinstance(other, type(self)) and self._fs_token == other._fs_token + + def __reduce__(self): + return make_instance, (type(self), self.storage_args, self.storage_options) + + @classmethod + def _strip_protocol(cls, path): + """Turn path from fully-qualified to file-system-specific + + May require FS-specific handling, e.g., for relative paths or links. + """ + if isinstance(path, list): + return [cls._strip_protocol(p) for p in path] + path = stringify_path(path) + protos = (cls.protocol,) if isinstance(cls.protocol, str) else cls.protocol + for protocol in protos: + if path.startswith(protocol + "://"): + path = path[len(protocol) + 3 :] + elif path.startswith(protocol + "::"): + path = path[len(protocol) + 2 :] + path = path.rstrip("/") + # use of root_marker to make minimum required path, e.g., "/" + return path or cls.root_marker + + def unstrip_protocol(self, name: str) -> str: + """Format FS-specific path to generic, including protocol""" + protos = (self.protocol,) if isinstance(self.protocol, str) else self.protocol + for protocol in protos: + if name.startswith(f"{protocol}://"): + return name + return f"{protos[0]}://{name}" + + @staticmethod + def _get_kwargs_from_urls(path): + """If kwargs can be encoded in the paths, extract them here + + This should happen before instantiation of the class; incoming paths + then should be amended to strip the options in methods. + + Examples may look like an sftp path "sftp://user@host:/my/path", where + the user and host should become kwargs and later get stripped. + """ + # by default, nothing happens + return {} + + @classmethod + def current(cls): + """Return the most recently instantiated FileSystem + + If no instance has been created, then create one with defaults + """ + if cls._latest in cls._cache: + return cls._cache[cls._latest] + return cls() + + @property + def transaction(self): + """A context within which files are committed together upon exit + + Requires the file class to implement `.commit()` and `.discard()` + for the normal and exception cases. + """ + if self._transaction is None: + self._transaction = self.transaction_type(self) + return self._transaction + + def start_transaction(self): + """Begin write transaction for deferring files, non-context version""" + self._intrans = True + self._transaction = self.transaction_type(self) + return self.transaction + + def end_transaction(self): + """Finish write transaction, non-context version""" + self.transaction.complete() + self._transaction = None + # The invalid cache must be cleared after the transaction is completed. + for path in self._invalidated_caches_in_transaction: + self.invalidate_cache(path) + self._invalidated_caches_in_transaction.clear() + + def invalidate_cache(self, path=None): + """ + Discard any cached directory information + + Parameters + ---------- + path: string or None + If None, clear all listings cached else listings at or under given + path. + """ + # Not necessary to implement invalidation mechanism, may have no cache. + # But if have, you should call this method of parent class from your + # subclass to ensure expiring caches after transacations correctly. + # See the implementation of FTPFileSystem in ftp.py + if self._intrans: + self._invalidated_caches_in_transaction.append(path) + + def mkdir(self, path, create_parents=True, **kwargs): + """ + Create directory entry at path + + For systems that don't have true directories, may create an for + this instance only and not touch the real filesystem + + Parameters + ---------- + path: str + location + create_parents: bool + if True, this is equivalent to ``makedirs`` + kwargs: + may be permissions, etc. + """ + pass # not necessary to implement, may not have directories + + def makedirs(self, path, exist_ok=False): + """Recursively make directories + + Creates directory at path and any intervening required directories. + Raises exception if, for instance, the path already exists but is a + file. + + Parameters + ---------- + path: str + leaf directory name + exist_ok: bool (False) + If False, will error if the target already exists + """ + pass # not necessary to implement, may not have directories + + def rmdir(self, path): + """Remove a directory, if empty""" + pass # not necessary to implement, may not have directories + + def ls(self, path, detail=True, **kwargs): + """List objects at path. + + This should include subdirectories and files at that location. The + difference between a file and a directory must be clear when details + are requested. + + The specific keys, or perhaps a FileInfo class, or similar, is TBD, + but must be consistent across implementations. + Must include: + + - full path to the entry (without protocol) + - size of the entry, in bytes. If the value cannot be determined, will + be ``None``. + - type of entry, "file", "directory" or other + + Additional information + may be present, appropriate to the file-system, e.g., generation, + checksum, etc. + + May use refresh=True|False to allow use of self._ls_from_cache to + check for a saved listing and avoid calling the backend. This would be + common where listing may be expensive. + + Parameters + ---------- + path: str + detail: bool + if True, gives a list of dictionaries, where each is the same as + the result of ``info(path)``. If False, gives a list of paths + (str). + kwargs: may have additional backend-specific options, such as version + information + + Returns + ------- + List of strings if detail is False, or list of directory information + dicts if detail is True. + """ + raise NotImplementedError + + def _ls_from_cache(self, path): + """Check cache for listing + + Returns listing, if found (may be empty list for a directly that exists + but contains nothing), None if not in cache. + """ + parent = self._parent(path) + try: + return self.dircache[path.rstrip("/")] + except KeyError: + pass + try: + files = [ + f + for f in self.dircache[parent] + if f["name"] == path + or (f["name"] == path.rstrip("/") and f["type"] == "directory") + ] + if len(files) == 0: + # parent dir was listed but did not contain this file + raise FileNotFoundError(path) + return files + except KeyError: + pass + + def walk(self, path, maxdepth=None, topdown=True, on_error="omit", **kwargs): + """Return all files under the given path. + + List all files, recursing into subdirectories; output is iterator-style, + like ``os.walk()``. For a simple list of files, ``find()`` is available. + + When topdown is True, the caller can modify the dirnames list in-place (perhaps + using del or slice assignment), and walk() will + only recurse into the subdirectories whose names remain in dirnames; + this can be used to prune the search, impose a specific order of visiting, + or even to inform walk() about directories the caller creates or renames before + it resumes walk() again. + Modifying dirnames when topdown is False has no effect. (see os.walk) + + Note that the "files" outputted will include anything that is not + a directory, such as links. + + Parameters + ---------- + path: str + Root to recurse into + maxdepth: int + Maximum recursion depth. None means limitless, but not recommended + on link-based file-systems. + topdown: bool (True) + Whether to walk the directory tree from the top downwards or from + the bottom upwards. + on_error: "omit", "raise", a callable + if omit (default), path with exception will simply be empty; + If raise, an underlying exception will be raised; + if callable, it will be called with a single OSError instance as argument + kwargs: passed to ``ls`` + """ + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + + path = self._strip_protocol(path) + full_dirs = {} + dirs = {} + files = {} + + detail = kwargs.pop("detail", False) + try: + listing = self.ls(path, detail=True, **kwargs) + except (FileNotFoundError, OSError) as e: + if on_error == "raise": + raise + if callable(on_error): + on_error(e) + return + + for info in listing: + # each info name must be at least [path]/part , but here + # we check also for names like [path]/part/ + pathname = info["name"].rstrip("/") + name = pathname.rsplit("/", 1)[-1] + if info["type"] == "directory" and pathname != path: + # do not include "self" path + full_dirs[name] = pathname + dirs[name] = info + elif pathname == path: + # file-like with same name as give path + files[""] = info + else: + files[name] = info + + if not detail: + dirs = list(dirs) + files = list(files) + + if topdown: + # Yield before recursion if walking top down + yield path, dirs, files + + if maxdepth is not None: + maxdepth -= 1 + if maxdepth < 1: + if not topdown: + yield path, dirs, files + return + + for d in dirs: + yield from self.walk( + full_dirs[d], + maxdepth=maxdepth, + detail=detail, + topdown=topdown, + **kwargs, + ) + + if not topdown: + # Yield after recursion if walking bottom up + yield path, dirs, files + + def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs): + """List all files below path. + + Like posix ``find`` command without conditions + + Parameters + ---------- + path : str + maxdepth: int or None + If not None, the maximum number of levels to descend + withdirs: bool + Whether to include directory paths in the output. This is True + when used by glob, but users usually only want files. + kwargs are passed to ``ls``. + """ + # TODO: allow equivalent of -name parameter + path = self._strip_protocol(path) + out = {} + + # Add the root directory if withdirs is requested + # This is needed for posix glob compliance + if withdirs and path != "" and self.isdir(path): + out[path] = self.info(path) + + for _, dirs, files in self.walk(path, maxdepth, detail=True, **kwargs): + if withdirs: + files.update(dirs) + out.update({info["name"]: info for name, info in files.items()}) + if not out and self.isfile(path): + # walk works on directories, but find should also return [path] + # when path happens to be a file + out[path] = {} + names = sorted(out) + if not detail: + return names + else: + return {name: out[name] for name in names} + + def du(self, path, total=True, maxdepth=None, withdirs=False, **kwargs): + """Space used by files and optionally directories within a path + + Directory size does not include the size of its contents. + + Parameters + ---------- + path: str + total: bool + Whether to sum all the file sizes + maxdepth: int or None + Maximum number of directory levels to descend, None for unlimited. + withdirs: bool + Whether to include directory paths in the output. + kwargs: passed to ``find`` + + Returns + ------- + Dict of {path: size} if total=False, or int otherwise, where numbers + refer to bytes used. + """ + sizes = {} + if withdirs and self.isdir(path): + # Include top-level directory in output + info = self.info(path) + sizes[info["name"]] = info["size"] + for f in self.find(path, maxdepth=maxdepth, withdirs=withdirs, **kwargs): + info = self.info(f) + sizes[info["name"]] = info["size"] + if total: + return sum(sizes.values()) + else: + return sizes + + def glob(self, path, maxdepth=None, **kwargs): + """Find files by glob-matching. + + Pattern matching capabilities for finding files that match the given pattern. + + Parameters + ---------- + path: str + The glob pattern to match against + maxdepth: int or None + Maximum depth for ``'**'`` patterns. Applied on the first ``'**'`` found. + Must be at least 1 if provided. + kwargs: + Additional arguments passed to ``find`` (e.g., detail=True) + + Returns + ------- + List of matched paths, or dict of paths and their info if detail=True + + Notes + ----- + Supported patterns: + - '*': Matches any sequence of characters within a single directory level + - ``'**'``: Matches any number of directory levels (must be an entire path component) + - '?': Matches exactly one character + - '[abc]': Matches any character in the set + - '[a-z]': Matches any character in the range + - '[!abc]': Matches any character NOT in the set + + Special behaviors: + - If the path ends with '/', only folders are returned + - Consecutive '*' characters are compressed into a single '*' + - Empty brackets '[]' never match anything + - Negated empty brackets '[!]' match any single character + - Special characters in character classes are escaped properly + + Limitations: + - ``'**'`` must be a complete path component (e.g., ``'a/**/b'``, not ``'a**b'``) + - No brace expansion ('{a,b}.txt') + - No extended glob patterns ('+(pattern)', '!(pattern)') + """ + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + + import re + + seps = (os.path.sep, os.path.altsep) if os.path.altsep else (os.path.sep,) + ends_with_sep = path.endswith(seps) # _strip_protocol strips trailing slash + path = self._strip_protocol(path) + append_slash_to_dirname = ends_with_sep or path.endswith( + tuple(sep + "**" for sep in seps) + ) + idx_star = path.find("*") if path.find("*") >= 0 else len(path) + idx_qmark = path.find("?") if path.find("?") >= 0 else len(path) + idx_brace = path.find("[") if path.find("[") >= 0 else len(path) + + min_idx = min(idx_star, idx_qmark, idx_brace) + + detail = kwargs.pop("detail", False) + + if not has_magic(path): + if self.exists(path, **kwargs): + if not detail: + return [path] + else: + return {path: self.info(path, **kwargs)} + else: + if not detail: + return [] # glob of non-existent returns empty + else: + return {} + elif "/" in path[:min_idx]: + min_idx = path[:min_idx].rindex("/") + root = path[: min_idx + 1] + depth = path[min_idx + 1 :].count("/") + 1 + else: + root = "" + depth = path[min_idx + 1 :].count("/") + 1 + + if "**" in path: + if maxdepth is not None: + idx_double_stars = path.find("**") + depth_double_stars = path[idx_double_stars:].count("/") + 1 + depth = depth - depth_double_stars + maxdepth + else: + depth = None + + allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs) + + pattern = glob_translate(path + ("/" if ends_with_sep else "")) + pattern = re.compile(pattern) + + out = { + p: info + for p, info in sorted(allpaths.items()) + if pattern.match( + p + "/" + if append_slash_to_dirname and info["type"] == "directory" + else p + ) + } + + if detail: + return out + else: + return list(out) + + def exists(self, path, **kwargs): + """Is there a file at the given path""" + try: + self.info(path, **kwargs) + return True + except: # noqa: E722 + # any exception allowed bar FileNotFoundError? + return False + + def lexists(self, path, **kwargs): + """If there is a file at the given path (including + broken links)""" + return self.exists(path) + + def info(self, path, **kwargs): + """Give details of entry at path + + Returns a single dictionary, with exactly the same information as ``ls`` + would with ``detail=True``. + + The default implementation calls ls and could be overridden by a + shortcut. kwargs are passed on to ```ls()``. + + Some file systems might not be able to measure the file's size, in + which case, the returned dict will include ``'size': None``. + + Returns + ------- + dict with keys: name (full path in the FS), size (in bytes), type (file, + directory, or something else) and other FS-specific keys. + """ + path = self._strip_protocol(path) + out = self.ls(self._parent(path), detail=True, **kwargs) + out = [o for o in out if o["name"].rstrip("/") == path] + if out: + return out[0] + out = self.ls(path, detail=True, **kwargs) + path = path.rstrip("/") + out1 = [o for o in out if o["name"].rstrip("/") == path] + if len(out1) == 1: + if "size" not in out1[0]: + out1[0]["size"] = None + return out1[0] + elif len(out1) > 1 or out: + return {"name": path, "size": 0, "type": "directory"} + else: + raise FileNotFoundError(path) + + def checksum(self, path): + """Unique value for current version of file + + If the checksum is the same from one moment to another, the contents + are guaranteed to be the same. If the checksum changes, the contents + *might* have changed. + + This should normally be overridden; default will probably capture + creation/modification timestamp (which would be good) or maybe + access timestamp (which would be bad) + """ + return int(tokenize(self.info(path)), 16) + + def size(self, path): + """Size in bytes of file""" + return self.info(path).get("size", None) + + def sizes(self, paths): + """Size in bytes of each file in a list of paths""" + return [self.size(p) for p in paths] + + def isdir(self, path): + """Is this entry directory-like?""" + try: + return self.info(path)["type"] == "directory" + except OSError: + return False + + def isfile(self, path): + """Is this entry file-like?""" + try: + return self.info(path)["type"] == "file" + except: # noqa: E722 + return False + + def read_text(self, path, encoding=None, errors=None, newline=None, **kwargs): + """Get the contents of the file as a string. + + Parameters + ---------- + path: str + URL of file on this filesystems + encoding, errors, newline: same as `open`. + """ + with self.open( + path, + mode="r", + encoding=encoding, + errors=errors, + newline=newline, + **kwargs, + ) as f: + return f.read() + + def write_text( + self, path, value, encoding=None, errors=None, newline=None, **kwargs + ): + """Write the text to the given file. + + An existing file will be overwritten. + + Parameters + ---------- + path: str + URL of file on this filesystems + value: str + Text to write. + encoding, errors, newline: same as `open`. + """ + with self.open( + path, + mode="w", + encoding=encoding, + errors=errors, + newline=newline, + **kwargs, + ) as f: + return f.write(value) + + def cat_file(self, path, start=None, end=None, **kwargs): + """Get the content of a file + + Parameters + ---------- + path: URL of file on this filesystems + start, end: int + Bytes limits of the read. If negative, backwards from end, + like usual python slices. Either can be None for start or + end of file, respectively + kwargs: passed to ``open()``. + """ + # explicitly set buffering off? + with self.open(path, "rb", **kwargs) as f: + if start is not None: + if start >= 0: + f.seek(start) + else: + f.seek(max(0, f.size + start)) + if end is not None: + if end < 0: + end = f.size + end + return f.read(end - f.tell()) + return f.read() + + def pipe_file(self, path, value, mode="overwrite", **kwargs): + """Set the bytes of given file""" + if mode == "create" and self.exists(path): + # non-atomic but simple way; or could use "xb" in open(), which is likely + # not as well supported + raise FileExistsError + with self.open(path, "wb", **kwargs) as f: + f.write(value) + + def pipe(self, path, value=None, **kwargs): + """Put value into path + + (counterpart to ``cat``) + + Parameters + ---------- + path: string or dict(str, bytes) + If a string, a single remote location to put ``value`` bytes; if a dict, + a mapping of {path: bytesvalue}. + value: bytes, optional + If using a single path, these are the bytes to put there. Ignored if + ``path`` is a dict + """ + if isinstance(path, str): + self.pipe_file(self._strip_protocol(path), value, **kwargs) + elif isinstance(path, dict): + for k, v in path.items(): + self.pipe_file(self._strip_protocol(k), v, **kwargs) + else: + raise ValueError("path must be str or dict") + + def cat_ranges( + self, paths, starts, ends, max_gap=None, on_error="return", **kwargs + ): + """Get the contents of byte ranges from one or more files + + Parameters + ---------- + paths: list + A list of of filepaths on this filesystems + starts, ends: int or list + Bytes limits of the read. If using a single int, the same value will be + used to read all the specified files. + """ + if max_gap is not None: + raise NotImplementedError + if not isinstance(paths, list): + raise TypeError + if not isinstance(starts, list): + starts = [starts] * len(paths) + if not isinstance(ends, list): + ends = [ends] * len(paths) + if len(starts) != len(paths) or len(ends) != len(paths): + raise ValueError + out = [] + for p, s, e in zip(paths, starts, ends): + try: + out.append(self.cat_file(p, s, e)) + except Exception as e: + if on_error == "return": + out.append(e) + else: + raise + return out + + def cat(self, path, recursive=False, on_error="raise", **kwargs): + """Fetch (potentially multiple) paths' contents + + Parameters + ---------- + recursive: bool + If True, assume the path(s) are directories, and get all the + contained files + on_error : "raise", "omit", "return" + If raise, an underlying exception will be raised (converted to KeyError + if the type is in self.missing_exceptions); if omit, keys with exception + will simply not be included in the output; if "return", all keys are + included in the output, but the value will be bytes or an exception + instance. + kwargs: passed to cat_file + + Returns + ------- + dict of {path: contents} if there are multiple paths + or the path has been otherwise expanded + """ + paths = self.expand_path(path, recursive=recursive, **kwargs) + if ( + len(paths) > 1 + or isinstance(path, list) + or paths[0] != self._strip_protocol(path) + ): + out = {} + for path in paths: + try: + out[path] = self.cat_file(path, **kwargs) + except Exception as e: + if on_error == "raise": + raise + if on_error == "return": + out[path] = e + return out + else: + return self.cat_file(paths[0], **kwargs) + + def get_file(self, rpath, lpath, callback=DEFAULT_CALLBACK, outfile=None, **kwargs): + """Copy single remote file to local""" + from .implementations.local import LocalFileSystem + + if isfilelike(lpath): + outfile = lpath + elif self.isdir(rpath): + os.makedirs(lpath, exist_ok=True) + return None + + fs = LocalFileSystem(auto_mkdir=True) + fs.makedirs(fs._parent(lpath), exist_ok=True) + + with self.open(rpath, "rb", **kwargs) as f1: + if outfile is None: + outfile = open(lpath, "wb") + + try: + callback.set_size(getattr(f1, "size", None)) + data = True + while data: + data = f1.read(self.blocksize) + segment_len = outfile.write(data) + if segment_len is None: + segment_len = len(data) + callback.relative_update(segment_len) + finally: + if not isfilelike(lpath): + outfile.close() + + def get( + self, + rpath, + lpath, + recursive=False, + callback=DEFAULT_CALLBACK, + maxdepth=None, + **kwargs, + ): + """Copy file(s) to local. + + Copies a specific file or tree of files (if recursive=True). If lpath + ends with a "/", it will be assumed to be a directory, and target files + will go within. Can submit a list of paths, which may be glob-patterns + and will be expanded. + + Calls get_file for each source. + """ + if isinstance(lpath, list) and isinstance(rpath, list): + # No need to expand paths when both source and destination + # are provided as lists + rpaths = rpath + lpaths = lpath + else: + from .implementations.local import ( + LocalFileSystem, + make_path_posix, + trailing_sep, + ) + + source_is_str = isinstance(rpath, str) + rpaths = self.expand_path( + rpath, recursive=recursive, maxdepth=maxdepth, **kwargs + ) + if source_is_str and (not recursive or maxdepth is not None): + # Non-recursive glob does not copy directories + rpaths = [p for p in rpaths if not (trailing_sep(p) or self.isdir(p))] + if not rpaths: + return + + if isinstance(lpath, str): + lpath = make_path_posix(lpath) + + source_is_file = len(rpaths) == 1 + dest_is_dir = isinstance(lpath, str) and ( + trailing_sep(lpath) or LocalFileSystem().isdir(lpath) + ) + + exists = source_is_str and ( + (has_magic(rpath) and source_is_file) + or (not has_magic(rpath) and dest_is_dir and not trailing_sep(rpath)) + ) + lpaths = other_paths( + rpaths, + lpath, + exists=exists, + flatten=not source_is_str, + ) + + callback.set_size(len(lpaths)) + for lpath, rpath in callback.wrap(zip(lpaths, rpaths)): + with callback.branched(rpath, lpath) as child: + self.get_file(rpath, lpath, callback=child, **kwargs) + + def put_file( + self, lpath, rpath, callback=DEFAULT_CALLBACK, mode="overwrite", **kwargs + ): + """Copy single file to remote""" + if mode == "create" and self.exists(rpath): + raise FileExistsError + if os.path.isdir(lpath): + self.makedirs(rpath, exist_ok=True) + return None + + with open(lpath, "rb") as f1: + size = f1.seek(0, 2) + callback.set_size(size) + f1.seek(0) + + self.mkdirs(self._parent(os.fspath(rpath)), exist_ok=True) + with self.open(rpath, "wb", **kwargs) as f2: + while f1.tell() < size: + data = f1.read(self.blocksize) + segment_len = f2.write(data) + if segment_len is None: + segment_len = len(data) + callback.relative_update(segment_len) + + def put( + self, + lpath, + rpath, + recursive=False, + callback=DEFAULT_CALLBACK, + maxdepth=None, + **kwargs, + ): + """Copy file(s) from local. + + Copies a specific file or tree of files (if recursive=True). If rpath + ends with a "/", it will be assumed to be a directory, and target files + will go within. + + Calls put_file for each source. + """ + if isinstance(lpath, list) and isinstance(rpath, list): + # No need to expand paths when both source and destination + # are provided as lists + rpaths = rpath + lpaths = lpath + else: + from .implementations.local import ( + LocalFileSystem, + make_path_posix, + trailing_sep, + ) + + source_is_str = isinstance(lpath, str) + if source_is_str: + lpath = make_path_posix(lpath) + fs = LocalFileSystem() + lpaths = fs.expand_path( + lpath, recursive=recursive, maxdepth=maxdepth, **kwargs + ) + if source_is_str and (not recursive or maxdepth is not None): + # Non-recursive glob does not copy directories + lpaths = [p for p in lpaths if not (trailing_sep(p) or fs.isdir(p))] + if not lpaths: + return + + source_is_file = len(lpaths) == 1 + dest_is_dir = isinstance(rpath, str) and ( + trailing_sep(rpath) or self.isdir(rpath) + ) + + rpath = ( + self._strip_protocol(rpath) + if isinstance(rpath, str) + else [self._strip_protocol(p) for p in rpath] + ) + exists = source_is_str and ( + (has_magic(lpath) and source_is_file) + or (not has_magic(lpath) and dest_is_dir and not trailing_sep(lpath)) + ) + rpaths = other_paths( + lpaths, + rpath, + exists=exists, + flatten=not source_is_str, + ) + + callback.set_size(len(rpaths)) + for lpath, rpath in callback.wrap(zip(lpaths, rpaths)): + with callback.branched(lpath, rpath) as child: + self.put_file(lpath, rpath, callback=child, **kwargs) + + def head(self, path, size=1024): + """Get the first ``size`` bytes from file""" + with self.open(path, "rb") as f: + return f.read(size) + + def tail(self, path, size=1024): + """Get the last ``size`` bytes from file""" + with self.open(path, "rb") as f: + f.seek(max(-size, -f.size), 2) + return f.read() + + def cp_file(self, path1, path2, **kwargs): + raise NotImplementedError + + def copy( + self, path1, path2, recursive=False, maxdepth=None, on_error=None, **kwargs + ): + """Copy within two locations in the filesystem + + on_error : "raise", "ignore" + If raise, any not-found exceptions will be raised; if ignore any + not-found exceptions will cause the path to be skipped; defaults to + raise unless recursive is true, where the default is ignore + """ + if on_error is None and recursive: + on_error = "ignore" + elif on_error is None: + on_error = "raise" + + if isinstance(path1, list) and isinstance(path2, list): + # No need to expand paths when both source and destination + # are provided as lists + paths1 = path1 + paths2 = path2 + else: + from .implementations.local import trailing_sep + + source_is_str = isinstance(path1, str) + paths1 = self.expand_path( + path1, recursive=recursive, maxdepth=maxdepth, **kwargs + ) + if source_is_str and (not recursive or maxdepth is not None): + # Non-recursive glob does not copy directories + paths1 = [p for p in paths1 if not (trailing_sep(p) or self.isdir(p))] + if not paths1: + return + + source_is_file = len(paths1) == 1 + dest_is_dir = isinstance(path2, str) and ( + trailing_sep(path2) or self.isdir(path2) + ) + + exists = source_is_str and ( + (has_magic(path1) and source_is_file) + or (not has_magic(path1) and dest_is_dir and not trailing_sep(path1)) + ) + paths2 = other_paths( + paths1, + path2, + exists=exists, + flatten=not source_is_str, + ) + + for p1, p2 in zip(paths1, paths2): + try: + self.cp_file(p1, p2, **kwargs) + except FileNotFoundError: + if on_error == "raise": + raise + + def expand_path(self, path, recursive=False, maxdepth=None, **kwargs): + """Turn one or more globs or directories into a list of all matching paths + to files or directories. + + kwargs are passed to ``glob`` or ``find``, which may in turn call ``ls`` + """ + + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + + if isinstance(path, (str, os.PathLike)): + out = self.expand_path([path], recursive, maxdepth, **kwargs) + else: + out = set() + path = [self._strip_protocol(p) for p in path] + for p in path: + if has_magic(p): + bit = set(self.glob(p, maxdepth=maxdepth, **kwargs)) + out |= bit + if recursive: + # glob call above expanded one depth so if maxdepth is defined + # then decrement it in expand_path call below. If it is zero + # after decrementing then avoid expand_path call. + if maxdepth is not None and maxdepth <= 1: + continue + out |= set( + self.expand_path( + list(bit), + recursive=recursive, + maxdepth=maxdepth - 1 if maxdepth is not None else None, + **kwargs, + ) + ) + continue + elif recursive: + rec = set( + self.find( + p, maxdepth=maxdepth, withdirs=True, detail=False, **kwargs + ) + ) + out |= rec + if p not in out and (recursive is False or self.exists(p)): + # should only check once, for the root + out.add(p) + if not out: + raise FileNotFoundError(path) + return sorted(out) + + def mv(self, path1, path2, recursive=False, maxdepth=None, **kwargs): + """Move file(s) from one location to another""" + if path1 == path2: + logger.debug("%s mv: The paths are the same, so no files were moved.", self) + else: + # explicitly raise exception to prevent data corruption + self.copy( + path1, path2, recursive=recursive, maxdepth=maxdepth, onerror="raise" + ) + self.rm(path1, recursive=recursive) + + def rm_file(self, path): + """Delete a file""" + self._rm(path) + + def _rm(self, path): + """Delete one file""" + # this is the old name for the method, prefer rm_file + raise NotImplementedError + + def rm(self, path, recursive=False, maxdepth=None): + """Delete files. + + Parameters + ---------- + path: str or list of str + File(s) to delete. + recursive: bool + If file(s) are directories, recursively delete contents and then + also remove the directory + maxdepth: int or None + Depth to pass to walk for finding files to delete, if recursive. + If None, there will be no limit and infinite recursion may be + possible. + """ + path = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(path): + self.rm_file(p) + + @classmethod + def _parent(cls, path): + path = cls._strip_protocol(path) + if "/" in path: + parent = path.rsplit("/", 1)[0].lstrip(cls.root_marker) + return cls.root_marker + parent + else: + return cls.root_marker + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + """Return raw bytes-mode file-like from the file-system""" + return AbstractBufferedFile( + self, + path, + mode, + block_size, + autocommit, + cache_options=cache_options, + **kwargs, + ) + + def open( + self, + path, + mode="rb", + block_size=None, + cache_options=None, + compression=None, + **kwargs, + ): + """ + Return a file-like object from the filesystem + + The resultant instance must function correctly in a context ``with`` + block. + + Parameters + ---------- + path: str + Target file + mode: str like 'rb', 'w' + See builtin ``open()`` + Mode "x" (exclusive write) may be implemented by the backend. Even if + it is, whether it is checked up front or on commit, and whether it is + atomic is implementation-dependent. + block_size: int + Some indication of buffering - this is a value in bytes + cache_options : dict, optional + Extra arguments to pass through to the cache. + compression: string or None + If given, open file using compression codec. Can either be a compression + name (a key in ``fsspec.compression.compr``) or "infer" to guess the + compression from the filename suffix. + encoding, errors, newline: passed on to TextIOWrapper for text mode + """ + import io + + path = self._strip_protocol(path) + if "b" not in mode: + mode = mode.replace("t", "") + "b" + + text_kwargs = { + k: kwargs.pop(k) + for k in ["encoding", "errors", "newline"] + if k in kwargs + } + return io.TextIOWrapper( + self.open( + path, + mode, + block_size=block_size, + cache_options=cache_options, + compression=compression, + **kwargs, + ), + **text_kwargs, + ) + else: + ac = kwargs.pop("autocommit", not self._intrans) + f = self._open( + path, + mode=mode, + block_size=block_size, + autocommit=ac, + cache_options=cache_options, + **kwargs, + ) + if compression is not None: + from fsspec.compression import compr + from fsspec.core import get_compression + + compression = get_compression(path, compression) + compress = compr[compression] + f = compress(f, mode=mode[0]) + + if not ac and "r" not in mode: + self.transaction.files.append(f) + return f + + def touch(self, path, truncate=True, **kwargs): + """Create empty file, or update timestamp + + Parameters + ---------- + path: str + file location + truncate: bool + If True, always set file size to 0; if False, update timestamp and + leave file unchanged, if backend allows this + """ + if truncate or not self.exists(path): + with self.open(path, "wb", **kwargs): + pass + else: + raise NotImplementedError # update timestamp, if possible + + def ukey(self, path): + """Hash of file properties, to tell if it has changed""" + return sha256(str(self.info(path)).encode()).hexdigest() + + def read_block(self, fn, offset, length, delimiter=None): + """Read a block of bytes from + + Starting at ``offset`` of the file, read ``length`` bytes. If + ``delimiter`` is set then we ensure that the read starts and stops at + delimiter boundaries that follow the locations ``offset`` and ``offset + + length``. If ``offset`` is zero then we start at zero. The + bytestring returned WILL include the end delimiter string. + + If offset+length is beyond the eof, reads to eof. + + Parameters + ---------- + fn: string + Path to filename + offset: int + Byte offset to start read + length: int + Number of bytes to read. If None, read to end. + delimiter: bytes (optional) + Ensure reading starts and stops at delimiter bytestring + + Examples + -------- + >>> fs.read_block('data/file.csv', 0, 13) # doctest: +SKIP + b'Alice, 100\\nBo' + >>> fs.read_block('data/file.csv', 0, 13, delimiter=b'\\n') # doctest: +SKIP + b'Alice, 100\\nBob, 200\\n' + + Use ``length=None`` to read to the end of the file. + >>> fs.read_block('data/file.csv', 0, None, delimiter=b'\\n') # doctest: +SKIP + b'Alice, 100\\nBob, 200\\nCharlie, 300' + + See Also + -------- + :func:`fsspec.utils.read_block` + """ + with self.open(fn, "rb") as f: + size = f.size + if length is None: + length = size + if size is not None and offset + length > size: + length = size - offset + return read_block(f, offset, length, delimiter) + + def to_json(self, *, include_password: bool = True) -> str: + """ + JSON representation of this filesystem instance. + + Parameters + ---------- + include_password: bool, default True + Whether to include the password (if any) in the output. + + Returns + ------- + JSON string with keys ``cls`` (the python location of this class), + protocol (text name of this class's protocol, first one in case of + multiple), ``args`` (positional args, usually empty), and all other + keyword arguments as their own keys. + + Warnings + -------- + Serialized filesystems may contain sensitive information which have been + passed to the constructor, such as passwords and tokens. Make sure you + store and send them in a secure environment! + """ + from .json import FilesystemJSONEncoder + + return json.dumps( + self, + cls=type( + "_FilesystemJSONEncoder", + (FilesystemJSONEncoder,), + {"include_password": include_password}, + ), + ) + + @staticmethod + def from_json(blob: str) -> AbstractFileSystem: + """ + Recreate a filesystem instance from JSON representation. + + See ``.to_json()`` for the expected structure of the input. + + Parameters + ---------- + blob: str + + Returns + ------- + file system instance, not necessarily of this particular class. + + Warnings + -------- + This can import arbitrary modules (as determined by the ``cls`` key). + Make sure you haven't installed any modules that may execute malicious code + at import time. + """ + from .json import FilesystemJSONDecoder + + return json.loads(blob, cls=FilesystemJSONDecoder) + + def to_dict(self, *, include_password: bool = True) -> dict[str, Any]: + """ + JSON-serializable dictionary representation of this filesystem instance. + + Parameters + ---------- + include_password: bool, default True + Whether to include the password (if any) in the output. + + Returns + ------- + Dictionary with keys ``cls`` (the python location of this class), + protocol (text name of this class's protocol, first one in case of + multiple), ``args`` (positional args, usually empty), and all other + keyword arguments as their own keys. + + Warnings + -------- + Serialized filesystems may contain sensitive information which have been + passed to the constructor, such as passwords and tokens. Make sure you + store and send them in a secure environment! + """ + from .json import FilesystemJSONEncoder + + json_encoder = FilesystemJSONEncoder() + + cls = type(self) + proto = self.protocol + + storage_options = dict(self.storage_options) + if not include_password: + storage_options.pop("password", None) + + return dict( + cls=f"{cls.__module__}:{cls.__name__}", + protocol=proto[0] if isinstance(proto, (tuple, list)) else proto, + args=json_encoder.make_serializable(self.storage_args), + **json_encoder.make_serializable(storage_options), + ) + + @staticmethod + def from_dict(dct: dict[str, Any]) -> AbstractFileSystem: + """ + Recreate a filesystem instance from dictionary representation. + + See ``.to_dict()`` for the expected structure of the input. + + Parameters + ---------- + dct: Dict[str, Any] + + Returns + ------- + file system instance, not necessarily of this particular class. + + Warnings + -------- + This can import arbitrary modules (as determined by the ``cls`` key). + Make sure you haven't installed any modules that may execute malicious code + at import time. + """ + from .json import FilesystemJSONDecoder + + json_decoder = FilesystemJSONDecoder() + + dct = dict(dct) # Defensive copy + + cls = FilesystemJSONDecoder.try_resolve_fs_cls(dct) + if cls is None: + raise ValueError("Not a serialized AbstractFileSystem") + + dct.pop("cls", None) + dct.pop("protocol", None) + + return cls( + *json_decoder.unmake_serializable(dct.pop("args", ())), + **json_decoder.unmake_serializable(dct), + ) + + def _get_pyarrow_filesystem(self): + """ + Make a version of the FS instance which will be acceptable to pyarrow + """ + # all instances already also derive from pyarrow + return self + + def get_mapper(self, root="", check=False, create=False, missing_exceptions=None): + """Create key/value store based on this file-system + + Makes a MutableMapping interface to the FS at the given root path. + See ``fsspec.mapping.FSMap`` for further details. + """ + from .mapping import FSMap + + return FSMap( + root, + self, + check=check, + create=create, + missing_exceptions=missing_exceptions, + ) + + @classmethod + def clear_instance_cache(cls): + """ + Clear the cache of filesystem instances. + + Notes + ----- + Unless overridden by setting the ``cachable`` class attribute to False, + the filesystem class stores a reference to newly created instances. This + prevents Python's normal rules around garbage collection from working, + since the instances refcount will not drop to zero until + ``clear_instance_cache`` is called. + """ + cls._cache.clear() + + def created(self, path): + """Return the created timestamp of a file as a datetime.datetime""" + raise NotImplementedError + + def modified(self, path): + """Return the modified timestamp of a file as a datetime.datetime""" + raise NotImplementedError + + def tree( + self, + path: str = "/", + recursion_limit: int = 2, + max_display: int = 25, + display_size: bool = False, + prefix: str = "", + is_last: bool = True, + first: bool = True, + indent_size: int = 4, + ) -> str: + """ + Return a tree-like structure of the filesystem starting from the given path as a string. + + Parameters + ---------- + path: Root path to start traversal from + recursion_limit: Maximum depth of directory traversal + max_display: Maximum number of items to display per directory + display_size: Whether to display file sizes + prefix: Current line prefix for visual tree structure + is_last: Whether current item is last in its level + first: Whether this is the first call (displays root path) + indent_size: Number of spaces by indent + + Returns + ------- + str: A string representing the tree structure. + + Example + ------- + >>> from fsspec import filesystem + + >>> fs = filesystem('ftp', host='test.rebex.net', user='demo', password='password') + >>> tree = fs.tree(display_size=True, recursion_limit=3, indent_size=8, max_display=10) + >>> print(tree) + """ + + def format_bytes(n: int) -> str: + """Format bytes as text.""" + for prefix, k in ( + ("P", 2**50), + ("T", 2**40), + ("G", 2**30), + ("M", 2**20), + ("k", 2**10), + ): + if n >= 0.9 * k: + return f"{n / k:.2f} {prefix}b" + return f"{n}B" + + result = [] + + if first: + result.append(path) + + if recursion_limit: + indent = " " * indent_size + contents = self.ls(path, detail=True) + contents.sort( + key=lambda x: (x.get("type") != "directory", x.get("name", "")) + ) + + if max_display is not None and len(contents) > max_display: + displayed_contents = contents[:max_display] + remaining_count = len(contents) - max_display + else: + displayed_contents = contents + remaining_count = 0 + + for i, item in enumerate(displayed_contents): + is_last_item = (i == len(displayed_contents) - 1) and ( + remaining_count == 0 + ) + + branch = ( + "└" + ("─" * (indent_size - 2)) + if is_last_item + else "├" + ("─" * (indent_size - 2)) + ) + branch += " " + new_prefix = prefix + ( + indent if is_last_item else "│" + " " * (indent_size - 1) + ) + + name = os.path.basename(item.get("name", "")) + + if display_size and item.get("type") == "directory": + sub_contents = self.ls(item.get("name", ""), detail=True) + num_files = sum( + 1 for sub_item in sub_contents if sub_item.get("type") == "file" + ) + num_folders = sum( + 1 + for sub_item in sub_contents + if sub_item.get("type") == "directory" + ) + + if num_files == 0 and num_folders == 0: + size = " (empty folder)" + elif num_files == 0: + size = f" ({num_folders} subfolder{'s' if num_folders > 1 else ''})" + elif num_folders == 0: + size = f" ({num_files} file{'s' if num_files > 1 else ''})" + else: + size = f" ({num_files} file{'s' if num_files > 1 else ''}, {num_folders} subfolder{'s' if num_folders > 1 else ''})" + elif display_size and item.get("type") == "file": + size = f" ({format_bytes(item.get('size', 0))})" + else: + size = "" + + result.append(f"{prefix}{branch}{name}{size}") + + if item.get("type") == "directory" and recursion_limit > 0: + result.append( + self.tree( + path=item.get("name", ""), + recursion_limit=recursion_limit - 1, + max_display=max_display, + display_size=display_size, + prefix=new_prefix, + is_last=is_last_item, + first=False, + indent_size=indent_size, + ) + ) + + if remaining_count > 0: + more_message = f"{remaining_count} more item(s) not displayed." + result.append( + f"{prefix}{'└' + ('─' * (indent_size - 2))} {more_message}" + ) + + return "\n".join(_ for _ in result if _) + + # ------------------------------------------------------------------------ + # Aliases + + def read_bytes(self, path, start=None, end=None, **kwargs): + """Alias of `AbstractFileSystem.cat_file`.""" + return self.cat_file(path, start=start, end=end, **kwargs) + + def write_bytes(self, path, value, **kwargs): + """Alias of `AbstractFileSystem.pipe_file`.""" + self.pipe_file(path, value, **kwargs) + + def makedir(self, path, create_parents=True, **kwargs): + """Alias of `AbstractFileSystem.mkdir`.""" + return self.mkdir(path, create_parents=create_parents, **kwargs) + + def mkdirs(self, path, exist_ok=False): + """Alias of `AbstractFileSystem.makedirs`.""" + return self.makedirs(path, exist_ok=exist_ok) + + def listdir(self, path, detail=True, **kwargs): + """Alias of `AbstractFileSystem.ls`.""" + return self.ls(path, detail=detail, **kwargs) + + def cp(self, path1, path2, **kwargs): + """Alias of `AbstractFileSystem.copy`.""" + return self.copy(path1, path2, **kwargs) + + def move(self, path1, path2, **kwargs): + """Alias of `AbstractFileSystem.mv`.""" + return self.mv(path1, path2, **kwargs) + + def stat(self, path, **kwargs): + """Alias of `AbstractFileSystem.info`.""" + return self.info(path, **kwargs) + + def disk_usage(self, path, total=True, maxdepth=None, **kwargs): + """Alias of `AbstractFileSystem.du`.""" + return self.du(path, total=total, maxdepth=maxdepth, **kwargs) + + def rename(self, path1, path2, **kwargs): + """Alias of `AbstractFileSystem.mv`.""" + return self.mv(path1, path2, **kwargs) + + def delete(self, path, recursive=False, maxdepth=None): + """Alias of `AbstractFileSystem.rm`.""" + return self.rm(path, recursive=recursive, maxdepth=maxdepth) + + def upload(self, lpath, rpath, recursive=False, **kwargs): + """Alias of `AbstractFileSystem.put`.""" + return self.put(lpath, rpath, recursive=recursive, **kwargs) + + def download(self, rpath, lpath, recursive=False, **kwargs): + """Alias of `AbstractFileSystem.get`.""" + return self.get(rpath, lpath, recursive=recursive, **kwargs) + + def sign(self, path, expiration=100, **kwargs): + """Create a signed URL representing the given path + + Some implementations allow temporary URLs to be generated, as a + way of delegating credentials. + + Parameters + ---------- + path : str + The path on the filesystem + expiration : int + Number of seconds to enable the URL for (if supported) + + Returns + ------- + URL : str + The signed URL + + Raises + ------ + NotImplementedError : if method is not implemented for a filesystem + """ + raise NotImplementedError("Sign is not implemented for this filesystem") + + def _isfilestore(self): + # Originally inherited from pyarrow DaskFileSystem. Keeping this + # here for backwards compatibility as long as pyarrow uses its + # legacy fsspec-compatible filesystems and thus accepts fsspec + # filesystems as well + return False + + +class AbstractBufferedFile(io.IOBase): + """Convenient class to derive from to provide buffering + + In the case that the backend does not provide a pythonic file-like object + already, this class contains much of the logic to build one. The only + methods that need to be overridden are ``_upload_chunk``, + ``_initiate_upload`` and ``_fetch_range``. + """ + + DEFAULT_BLOCK_SIZE = 5 * 2**20 + _details = None + + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + size=None, + **kwargs, + ): + """ + Template for files with buffered reading and writing + + Parameters + ---------- + fs: instance of FileSystem + path: str + location in file-system + mode: str + Normal file modes. Currently only 'wb', 'ab' or 'rb'. Some file + systems may be read-only, and some may not support append. + block_size: int + Buffer size for reading or writing, 'default' for class default + autocommit: bool + Whether to write to final destination; may only impact what + happens when file is being closed. + cache_type: {"readahead", "none", "mmap", "bytes"}, default "readahead" + Caching policy in read mode. See the definitions in ``core``. + cache_options : dict + Additional options passed to the constructor for the cache specified + by `cache_type`. + size: int + If given and in read mode, suppressed having to look up the file size + kwargs: + Gets stored as self.kwargs + """ + from .core import caches + + self.path = path + self.fs = fs + self.mode = mode + self.blocksize = ( + self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size + ) + self.loc = 0 + self.autocommit = autocommit + self.end = None + self.start = None + self.closed = False + + if cache_options is None: + cache_options = {} + + if "trim" in kwargs: + warnings.warn( + "Passing 'trim' to control the cache behavior has been deprecated. " + "Specify it within the 'cache_options' argument instead.", + FutureWarning, + ) + cache_options["trim"] = kwargs.pop("trim") + + self.kwargs = kwargs + + if mode not in {"ab", "rb", "wb", "xb"}: + raise NotImplementedError("File mode not supported") + if mode == "rb": + if size is not None: + self.size = size + else: + self.size = self.details["size"] + self.cache = caches[cache_type]( + self.blocksize, self._fetch_range, self.size, **cache_options + ) + else: + self.buffer = io.BytesIO() + self.offset = None + self.forced = False + self.location = None + + @property + def details(self): + if self._details is None: + self._details = self.fs.info(self.path) + return self._details + + @details.setter + def details(self, value): + self._details = value + self.size = value["size"] + + @property + def full_name(self): + return _unstrip_protocol(self.path, self.fs) + + @property + def closed(self): + # get around this attr being read-only in IOBase + # use getattr here, since this can be called during del + return getattr(self, "_closed", True) + + @closed.setter + def closed(self, c): + self._closed = c + + def __hash__(self): + if "w" in self.mode: + return id(self) + else: + return int(tokenize(self.details), 16) + + def __eq__(self, other): + """Files are equal if they have the same checksum, only in read mode""" + if self is other: + return True + return ( + isinstance(other, type(self)) + and self.mode == "rb" + and other.mode == "rb" + and hash(self) == hash(other) + ) + + def commit(self): + """Move from temp to final destination""" + + def discard(self): + """Throw away temporary file""" + + def info(self): + """File information about this path""" + if self.readable(): + return self.details + else: + raise ValueError("Info not available while writing") + + def tell(self): + """Current file location""" + return self.loc + + def seek(self, loc, whence=0): + """Set current file location + + Parameters + ---------- + loc: int + byte location + whence: {0, 1, 2} + from start of file, current location or end of file, resp. + """ + loc = int(loc) + if not self.mode == "rb": + raise OSError(ESPIPE, "Seek only available in read mode") + if whence == 0: + nloc = loc + elif whence == 1: + nloc = self.loc + loc + elif whence == 2: + nloc = self.size + loc + else: + raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)") + if nloc < 0: + raise ValueError("Seek before start of file") + self.loc = nloc + return self.loc + + def write(self, data): + """ + Write data to buffer. + + Buffer only sent on flush() or if buffer is greater than + or equal to blocksize. + + Parameters + ---------- + data: bytes + Set of bytes to be written. + """ + if not self.writable(): + raise ValueError("File not in write mode") + if self.closed: + raise ValueError("I/O operation on closed file.") + if self.forced: + raise ValueError("This file has been force-flushed, can only close") + out = self.buffer.write(data) + self.loc += out + if self.buffer.tell() >= self.blocksize: + self.flush() + return out + + def flush(self, force=False): + """ + Write buffered data to backend store. + + Writes the current buffer, if it is larger than the block-size, or if + the file is being closed. + + Parameters + ---------- + force: bool + When closing, write the last block even if it is smaller than + blocks are allowed to be. Disallows further writing to this file. + """ + + if self.closed: + raise ValueError("Flush on closed file") + if force and self.forced: + raise ValueError("Force flush cannot be called more than once") + if force: + self.forced = True + + if self.readable(): + # no-op to flush on read-mode + return + + if not force and self.buffer.tell() < self.blocksize: + # Defer write on small block + return + + if self.offset is None: + # Initialize a multipart upload + self.offset = 0 + try: + self._initiate_upload() + except: + self.closed = True + raise + + if self._upload_chunk(final=force) is not False: + self.offset += self.buffer.seek(0, 2) + self.buffer = io.BytesIO() + + def _upload_chunk(self, final=False): + """Write one part of a multi-block file upload + + Parameters + ========== + final: bool + This is the last block, so should complete file, if + self.autocommit is True. + """ + # may not yet have been initialized, may need to call _initialize_upload + + def _initiate_upload(self): + """Create remote file/upload""" + pass + + def _fetch_range(self, start, end): + """Get the specified set of bytes from remote""" + return self.fs.cat_file(self.path, start=start, end=end) + + def read(self, length=-1): + """ + Return data from cache, or fetch pieces as necessary + + Parameters + ---------- + length: int (-1) + Number of bytes to read; if <0, all remaining bytes. + """ + length = -1 if length is None else int(length) + if self.mode != "rb": + raise ValueError("File not in read mode") + if length < 0: + length = self.size - self.loc + if self.closed: + raise ValueError("I/O operation on closed file.") + if length == 0: + # don't even bother calling fetch + return b"" + out = self.cache._fetch(self.loc, self.loc + length) + + logger.debug( + "%s read: %i - %i %s", + self, + self.loc, + self.loc + length, + self.cache._log_stats(), + ) + self.loc += len(out) + return out + + def readinto(self, b): + """mirrors builtin file's readinto method + + https://docs.python.org/3/library/io.html#io.RawIOBase.readinto + """ + out = memoryview(b).cast("B") + data = self.read(out.nbytes) + out[: len(data)] = data + return len(data) + + def readuntil(self, char=b"\n", blocks=None): + """Return data between current position and first occurrence of char + + char is included in the output, except if the end of the tile is + encountered first. + + Parameters + ---------- + char: bytes + Thing to find + blocks: None or int + How much to read in each go. Defaults to file blocksize - which may + mean a new read on every call. + """ + out = [] + while True: + start = self.tell() + part = self.read(blocks or self.blocksize) + if len(part) == 0: + break + found = part.find(char) + if found > -1: + out.append(part[: found + len(char)]) + self.seek(start + found + len(char)) + break + out.append(part) + return b"".join(out) + + def readline(self): + """Read until and including the first occurrence of newline character + + Note that, because of character encoding, this is not necessarily a + true line ending. + """ + return self.readuntil(b"\n") + + def __next__(self): + out = self.readline() + if out: + return out + raise StopIteration + + def __iter__(self): + return self + + def readlines(self): + """Return all data, split by the newline character, including the newline character""" + data = self.read() + lines = data.split(b"\n") + out = [l + b"\n" for l in lines[:-1]] + if data.endswith(b"\n"): + return out + else: + return out + [lines[-1]] + # return list(self) ??? + + def readinto1(self, b): + return self.readinto(b) + + def close(self): + """Close file + + Finalizes writes, discards cache + """ + if getattr(self, "_unclosable", False): + return + if self.closed: + return + try: + if self.mode == "rb": + self.cache = None + else: + if not self.forced: + self.flush(force=True) + + if self.fs is not None: + self.fs.invalidate_cache(self.path) + self.fs.invalidate_cache(self.fs._parent(self.path)) + finally: + self.closed = True + + def readable(self): + """Whether opened for reading""" + return "r" in self.mode and not self.closed + + def seekable(self): + """Whether is seekable (only in read mode)""" + return self.readable() + + def writable(self): + """Whether opened for writing""" + return self.mode in {"wb", "ab", "xb"} and not self.closed + + def __reduce__(self): + if self.mode != "rb": + raise RuntimeError("Pickling a writeable file is not supported") + + return reopen, ( + self.fs, + self.path, + self.mode, + self.blocksize, + self.loc, + self.size, + self.autocommit, + self.cache.name if self.cache else "none", + self.kwargs, + ) + + def __del__(self): + if not self.closed: + self.close() + + def __str__(self): + return f"" + + __repr__ = __str__ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +def reopen(fs, path, mode, blocksize, loc, size, autocommit, cache_type, kwargs): + file = fs.open( + path, + mode=mode, + block_size=blocksize, + autocommit=autocommit, + cache_type=cache_type, + size=size, + **kwargs, + ) + if loc > 0: + file.seek(loc) + return file diff --git a/venv/lib/python3.10/site-packages/fsspec/transaction.py b/venv/lib/python3.10/site-packages/fsspec/transaction.py new file mode 100644 index 0000000000000000000000000000000000000000..77293f63ecc5f611e19d849ef236d53e9c258efc --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/transaction.py @@ -0,0 +1,90 @@ +from collections import deque + + +class Transaction: + """Filesystem transaction write context + + Gathers files for deferred commit or discard, so that several write + operations can be finalized semi-atomically. This works by having this + instance as the ``.transaction`` attribute of the given filesystem + """ + + def __init__(self, fs, **kwargs): + """ + Parameters + ---------- + fs: FileSystem instance + """ + self.fs = fs + self.files = deque() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """End transaction and commit, if exit is not due to exception""" + # only commit if there was no exception + self.complete(commit=exc_type is None) + if self.fs: + self.fs._intrans = False + self.fs._transaction = None + self.fs = None + + def start(self): + """Start a transaction on this FileSystem""" + self.files = deque() # clean up after previous failed completions + self.fs._intrans = True + + def complete(self, commit=True): + """Finish transaction: commit or discard all deferred files""" + while self.files: + f = self.files.popleft() + if commit: + f.commit() + else: + f.discard() + self.fs._intrans = False + self.fs._transaction = None + self.fs = None + + +class FileActor: + def __init__(self): + self.files = [] + + def commit(self): + for f in self.files: + f.commit() + self.files.clear() + + def discard(self): + for f in self.files: + f.discard() + self.files.clear() + + def append(self, f): + self.files.append(f) + + +class DaskTransaction(Transaction): + def __init__(self, fs): + """ + Parameters + ---------- + fs: FileSystem instance + """ + import distributed + + super().__init__(fs) + client = distributed.default_client() + self.files = client.submit(FileActor, actor=True).result() + + def complete(self, commit=True): + """Finish transaction: commit or discard all deferred files""" + if commit: + self.files.commit().result() + else: + self.files.discard().result() + self.fs._intrans = False + self.fs = None diff --git a/venv/lib/python3.10/site-packages/fsspec/utils.py b/venv/lib/python3.10/site-packages/fsspec/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b06dd581417eca1e90a19ef25fb14dcd3cf9c9c --- /dev/null +++ b/venv/lib/python3.10/site-packages/fsspec/utils.py @@ -0,0 +1,748 @@ +from __future__ import annotations + +import contextlib +import logging +import math +import os +import re +import sys +import tempfile +from collections.abc import Callable, Iterable, Iterator, Sequence +from functools import partial +from hashlib import md5 +from importlib.metadata import version +from typing import IO, TYPE_CHECKING, Any, TypeVar +from urllib.parse import urlsplit + +if TYPE_CHECKING: + import pathlib + from typing import TypeGuard + + from fsspec.spec import AbstractFileSystem + + +DEFAULT_BLOCK_SIZE = 5 * 2**20 + +T = TypeVar("T") + + +def infer_storage_options( + urlpath: str, inherit_storage_options: dict[str, Any] | None = None +) -> dict[str, Any]: + """Infer storage options from URL path and merge it with existing storage + options. + + Parameters + ---------- + urlpath: str or unicode + Either local absolute file path or URL (hdfs://namenode:8020/file.csv) + inherit_storage_options: dict (optional) + Its contents will get merged with the inferred information from the + given path + + Returns + ------- + Storage options dict. + + Examples + -------- + >>> infer_storage_options('/mnt/datasets/test.csv') # doctest: +SKIP + {"protocol": "file", "path", "/mnt/datasets/test.csv"} + >>> infer_storage_options( + ... 'hdfs://username:pwd@node:123/mnt/datasets/test.csv?q=1', + ... inherit_storage_options={'extra': 'value'}, + ... ) # doctest: +SKIP + {"protocol": "hdfs", "username": "username", "password": "pwd", + "host": "node", "port": 123, "path": "/mnt/datasets/test.csv", + "url_query": "q=1", "extra": "value"} + """ + # Handle Windows paths including disk name in this special case + if ( + re.match(r"^[a-zA-Z]:[\\/]", urlpath) + or re.match(r"^[a-zA-Z0-9]+://", urlpath) is None + ): + return {"protocol": "file", "path": urlpath} + + parsed_path = urlsplit(urlpath) + protocol = parsed_path.scheme or "file" + if parsed_path.fragment: + path = "#".join([parsed_path.path, parsed_path.fragment]) + else: + path = parsed_path.path + if protocol == "file": + # Special case parsing file protocol URL on Windows according to: + # https://msdn.microsoft.com/en-us/library/jj710207.aspx + windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path) + if windows_path: + drive, path = windows_path.groups() + path = f"{drive}:{path}" + + if protocol in ["http", "https"]: + # for HTTP, we don't want to parse, as requests will anyway + return {"protocol": protocol, "path": urlpath} + + options: dict[str, Any] = {"protocol": protocol, "path": path} + + if parsed_path.netloc: + # Parse `hostname` from netloc manually because `parsed_path.hostname` + # lowercases the hostname which is not always desirable (e.g. in S3): + # https://github.com/dask/dask/issues/1417 + options["host"] = parsed_path.netloc.rsplit("@", 1)[-1].rsplit(":", 1)[0] + + if protocol in ("s3", "s3a", "gcs", "gs"): + options["path"] = options["host"] + options["path"] + else: + options["host"] = options["host"] + if parsed_path.port: + options["port"] = parsed_path.port + if parsed_path.username: + options["username"] = parsed_path.username + if parsed_path.password: + options["password"] = parsed_path.password + + if parsed_path.query: + options["url_query"] = parsed_path.query + if parsed_path.fragment: + options["url_fragment"] = parsed_path.fragment + + if inherit_storage_options: + update_storage_options(options, inherit_storage_options) + + return options + + +def update_storage_options( + options: dict[str, Any], inherited: dict[str, Any] | None = None +) -> None: + if not inherited: + inherited = {} + collisions = set(options) & set(inherited) + if collisions: + for collision in collisions: + if options.get(collision) != inherited.get(collision): + raise KeyError( + f"Collision between inferred and specified storage " + f"option:\n{collision}" + ) + options.update(inherited) + + +# Compression extensions registered via fsspec.compression.register_compression +compressions: dict[str, str] = {} + + +def infer_compression(filename: str) -> str | None: + """Infer compression, if available, from filename. + + Infer a named compression type, if registered and available, from filename + extension. This includes builtin (gz, bz2, zip) compressions, as well as + optional compressions. See fsspec.compression.register_compression. + """ + extension = os.path.splitext(filename)[-1].strip(".").lower() + if extension in compressions: + return compressions[extension] + return None + + +def build_name_function(max_int: float) -> Callable[[int], str]: + """Returns a function that receives a single integer + and returns it as a string padded by enough zero characters + to align with maximum possible integer + + >>> name_f = build_name_function(57) + + >>> name_f(7) + '07' + >>> name_f(31) + '31' + >>> build_name_function(1000)(42) + '0042' + >>> build_name_function(999)(42) + '042' + >>> build_name_function(0)(0) + '0' + """ + # handle corner cases max_int is 0 or exact power of 10 + max_int += 1e-8 + + pad_length = int(math.ceil(math.log10(max_int))) + + def name_function(i: int) -> str: + return str(i).zfill(pad_length) + + return name_function + + +def seek_delimiter(file: IO[bytes], delimiter: bytes, blocksize: int) -> bool: + r"""Seek current file to file start, file end, or byte after delimiter seq. + + Seeks file to next chunk delimiter, where chunks are defined on file start, + a delimiting sequence, and file end. Use file.tell() to see location afterwards. + Note that file start is a valid split, so must be at offset > 0 to seek for + delimiter. + + Parameters + ---------- + file: a file + delimiter: bytes + a delimiter like ``b'\n'`` or message sentinel, matching file .read() type + blocksize: int + Number of bytes to read from the file at once. + + + Returns + ------- + Returns True if a delimiter was found, False if at file start or end. + + """ + + if file.tell() == 0: + # beginning-of-file, return without seek + return False + + # Interface is for binary IO, with delimiter as bytes, but initialize last + # with result of file.read to preserve compatibility with text IO. + last: bytes | None = None + while True: + current = file.read(blocksize) + if not current: + # end-of-file without delimiter + return False + full = last + current if last else current + try: + if delimiter in full: + i = full.index(delimiter) + file.seek(file.tell() - (len(full) - i) + len(delimiter)) + return True + elif len(current) < blocksize: + # end-of-file without delimiter + return False + except (OSError, ValueError): + pass + last = full[-len(delimiter) :] + + +def read_block( + f: IO[bytes], + offset: int, + length: int | None, + delimiter: bytes | None = None, + split_before: bool = False, +) -> bytes: + """Read a block of bytes from a file + + Parameters + ---------- + f: File + Open file + offset: int + Byte offset to start read + length: int + Number of bytes to read, read through end of file if None + delimiter: bytes (optional) + Ensure reading starts and stops at delimiter bytestring + split_before: bool (optional) + Start/stop read *before* delimiter bytestring. + + + If using the ``delimiter=`` keyword argument we ensure that the read + starts and stops at delimiter boundaries that follow the locations + ``offset`` and ``offset + length``. If ``offset`` is zero then we + start at zero, regardless of delimiter. The bytestring returned WILL + include the terminating delimiter string. + + Examples + -------- + + >>> from io import BytesIO # doctest: +SKIP + >>> f = BytesIO(b'Alice, 100\\nBob, 200\\nCharlie, 300') # doctest: +SKIP + >>> read_block(f, 0, 13) # doctest: +SKIP + b'Alice, 100\\nBo' + + >>> read_block(f, 0, 13, delimiter=b'\\n') # doctest: +SKIP + b'Alice, 100\\nBob, 200\\n' + + >>> read_block(f, 10, 10, delimiter=b'\\n') # doctest: +SKIP + b'Bob, 200\\nCharlie, 300' + """ + if delimiter: + f.seek(offset) + found_start_delim = seek_delimiter(f, delimiter, 2**16) + if length is None: + return f.read() + start = f.tell() + length -= start - offset + + f.seek(start + length) + found_end_delim = seek_delimiter(f, delimiter, 2**16) + end = f.tell() + + # Adjust split location to before delimiter if seek found the + # delimiter sequence, not start or end of file. + if found_start_delim and split_before: + start -= len(delimiter) + + if found_end_delim and split_before: + end -= len(delimiter) + + offset = start + length = end - start + + f.seek(offset) + + # TODO: allow length to be None and read to the end of the file? + assert length is not None + b = f.read(length) + return b + + +def tokenize(*args: Any, **kwargs: Any) -> str: + """Deterministic token + + (modified from dask.base) + + >>> tokenize([1, 2, '3']) + '9d71491b50023b06fc76928e6eddb952' + + >>> tokenize('Hello') == tokenize('Hello') + True + """ + if kwargs: + args += (kwargs,) + try: + h = md5(str(args).encode()) + except ValueError: + # FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380 + h = md5(str(args).encode(), usedforsecurity=False) + return h.hexdigest() + + +def stringify_path(filepath: str | os.PathLike[str] | pathlib.Path) -> str: + """Attempt to convert a path-like object to a string. + + Parameters + ---------- + filepath: object to be converted + + Returns + ------- + filepath_str: maybe a string version of the object + + Notes + ----- + Objects supporting the fspath protocol are coerced according to its + __fspath__ method. + + For backwards compatibility with older Python version, pathlib.Path + objects are specially coerced. + + Any other object is passed through unchanged, which includes bytes, + strings, buffers, or anything else that's not even path-like. + """ + if isinstance(filepath, str): + return filepath + elif hasattr(filepath, "__fspath__"): + return filepath.__fspath__() + elif hasattr(filepath, "path"): + return filepath.path + else: + return filepath # type: ignore[return-value] + + +def make_instance( + cls: Callable[..., T], args: Sequence[Any], kwargs: dict[str, Any] +) -> T: + inst = cls(*args, **kwargs) + inst._determine_worker() # type: ignore[attr-defined] + return inst + + +def common_prefix(paths: Iterable[str]) -> str: + """For a list of paths, find the shortest prefix common to all""" + parts = [p.split("/") for p in paths] + lmax = min(len(p) for p in parts) + end = 0 + for i in range(lmax): + end = all(p[i] == parts[0][i] for p in parts) + if not end: + break + i += end + return "/".join(parts[0][:i]) + + +def other_paths( + paths: list[str], + path2: str | list[str], + exists: bool = False, + flatten: bool = False, +) -> list[str]: + """In bulk file operations, construct a new file tree from a list of files + + Parameters + ---------- + paths: list of str + The input file tree + path2: str or list of str + Root to construct the new list in. If this is already a list of str, we just + assert it has the right number of elements. + exists: bool (optional) + For a str destination, it is already exists (and is a dir), files should + end up inside. + flatten: bool (optional) + Whether to flatten the input directory tree structure so that the output files + are in the same directory. + + Returns + ------- + list of str + """ + + if isinstance(path2, str): + path2 = path2.rstrip("/") + + if flatten: + path2 = ["/".join((path2, p.split("/")[-1])) for p in paths] + else: + cp = common_prefix(paths) + if exists: + cp = cp.rsplit("/", 1)[0] + if not cp and all(not s.startswith("/") for s in paths): + path2 = ["/".join([path2, p]) for p in paths] + else: + path2 = [p.replace(cp, path2, 1) for p in paths] + else: + assert len(paths) == len(path2) + return path2 + + +def is_exception(obj: Any) -> bool: + return isinstance(obj, BaseException) + + +def isfilelike(f: Any) -> TypeGuard[IO[bytes]]: + return all(hasattr(f, attr) for attr in ["read", "close", "tell"]) + + +def get_protocol(url: str) -> str: + url = stringify_path(url) + parts = re.split(r"(\:\:|\://)", url, maxsplit=1) + if len(parts) > 1: + return parts[0] + return "file" + + +def get_file_extension(url: str) -> str: + url = stringify_path(url) + ext_parts = url.rsplit(".", 1) + if len(ext_parts) > 1: + return ext_parts[-1] + return "" + + +def can_be_local(path: str) -> bool: + """Can the given URL be used with open_local?""" + from fsspec import get_filesystem_class + + try: + return getattr(get_filesystem_class(get_protocol(path)), "local_file", False) + except (ValueError, ImportError): + # not in registry or import failed + return False + + +def get_package_version_without_import(name: str) -> str | None: + """For given package name, try to find the version without importing it + + Import and package.__version__ is still the backup here, so an import + *might* happen. + + Returns either the version string, or None if the package + or the version was not readily found. + """ + if name in sys.modules: + mod = sys.modules[name] + if hasattr(mod, "__version__"): + return mod.__version__ + try: + return version(name) + except: # noqa: E722 + pass + try: + import importlib + + mod = importlib.import_module(name) + return mod.__version__ + except (ImportError, AttributeError): + return None + + +def setup_logging( + logger: logging.Logger | None = None, + logger_name: str | None = None, + level: str = "DEBUG", + clear: bool = True, +) -> logging.Logger: + if logger is None and logger_name is None: + raise ValueError("Provide either logger object or logger name") + logger = logger or logging.getLogger(logger_name) + handle = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s -- %(message)s" + ) + handle.setFormatter(formatter) + if clear: + logger.handlers.clear() + logger.addHandler(handle) + logger.setLevel(level) + return logger + + +def _unstrip_protocol(name: str, fs: AbstractFileSystem) -> str: + return fs.unstrip_protocol(name) + + +def mirror_from( + origin_name: str, methods: Iterable[str] +) -> Callable[[type[T]], type[T]]: + """Mirror attributes and methods from the given + origin_name attribute of the instance to the + decorated class""" + + def origin_getter(method: str, self: Any) -> Any: + origin = getattr(self, origin_name) + return getattr(origin, method) + + def wrapper(cls: type[T]) -> type[T]: + for method in methods: + wrapped_method = partial(origin_getter, method) + setattr(cls, method, property(wrapped_method)) + return cls + + return wrapper + + +@contextlib.contextmanager +def nullcontext(obj: T) -> Iterator[T]: + yield obj + + +def merge_offset_ranges( + paths: list[str], + starts: list[int] | int, + ends: list[int] | int, + max_gap: int = 0, + max_block: int | None = None, + sort: bool = True, +) -> tuple[list[str], list[int], list[int]]: + """Merge adjacent byte-offset ranges when the inter-range + gap is <= `max_gap`, and when the merged byte range does not + exceed `max_block` (if specified). By default, this function + will re-order the input paths and byte ranges to ensure sorted + order. If the user can guarantee that the inputs are already + sorted, passing `sort=False` will skip the re-ordering. + """ + # Check input + if not isinstance(paths, list): + raise TypeError + if not isinstance(starts, list): + starts = [starts] * len(paths) + if not isinstance(ends, list): + ends = [ends] * len(paths) + if len(starts) != len(paths) or len(ends) != len(paths): + raise ValueError + + # Early Return + if len(starts) <= 1: + return paths, starts, ends + + starts = [s or 0 for s in starts] + # Sort by paths and then ranges if `sort=True` + if sort: + paths, starts, ends = ( + list(v) + for v in zip( + *sorted( + zip(paths, starts, ends), + ) + ) + ) + remove = [] + for i, (path, start, end) in enumerate(zip(paths, starts, ends)): + if any( + e is not None and p == path and start >= s and end <= e and i != i2 + for i2, (p, s, e) in enumerate(zip(paths, starts, ends)) + ): + remove.append(i) + paths = [p for i, p in enumerate(paths) if i not in remove] + starts = [s for i, s in enumerate(starts) if i not in remove] + ends = [e for i, e in enumerate(ends) if i not in remove] + + if paths: + # Loop through the coupled `paths`, `starts`, and + # `ends`, and merge adjacent blocks when appropriate + new_paths = paths[:1] + new_starts = starts[:1] + new_ends = ends[:1] + for i in range(1, len(paths)): + if paths[i] == paths[i - 1] and new_ends[-1] is None: + continue + elif ( + paths[i] != paths[i - 1] + or ((starts[i] - new_ends[-1]) > max_gap) + or (max_block is not None and (ends[i] - new_starts[-1]) > max_block) + ): + # Cannot merge with previous block. + # Add new `paths`, `starts`, and `ends` elements + new_paths.append(paths[i]) + new_starts.append(starts[i]) + new_ends.append(ends[i]) + else: + # Merge with the previous block by updating the + # last element of `ends` + new_ends[-1] = ends[i] + return new_paths, new_starts, new_ends + + # `paths` is empty. Just return input lists + return paths, starts, ends + + +def file_size(filelike: IO[bytes]) -> int: + """Find length of any open read-mode file-like""" + pos = filelike.tell() + try: + return filelike.seek(0, 2) + finally: + filelike.seek(pos) + + +@contextlib.contextmanager +def atomic_write(path: str, mode: str = "wb"): + """ + A context manager that opens a temporary file next to `path` and, on exit, + replaces `path` with the temporary file, thereby updating `path` + atomically. + """ + fd, fn = tempfile.mkstemp( + dir=os.path.dirname(path), prefix=os.path.basename(path) + "-" + ) + try: + with open(fd, mode) as fp: + yield fp + except BaseException: + with contextlib.suppress(FileNotFoundError): + os.unlink(fn) + raise + else: + os.replace(fn, path) + + +def _translate(pat, STAR, QUESTION_MARK): + # Copied from: https://github.com/python/cpython/pull/106703. + res: list[str] = [] + add = res.append + i, n = 0, len(pat) + while i < n: + c = pat[i] + i = i + 1 + if c == "*": + # compress consecutive `*` into one + if (not res) or res[-1] is not STAR: + add(STAR) + elif c == "?": + add(QUESTION_MARK) + elif c == "[": + j = i + if j < n and pat[j] == "!": + j = j + 1 + if j < n and pat[j] == "]": + j = j + 1 + while j < n and pat[j] != "]": + j = j + 1 + if j >= n: + add("\\[") + else: + stuff = pat[i:j] + if "-" not in stuff: + stuff = stuff.replace("\\", r"\\") + else: + chunks = [] + k = i + 2 if pat[i] == "!" else i + 1 + while True: + k = pat.find("-", k, j) + if k < 0: + break + chunks.append(pat[i:k]) + i = k + 1 + k = k + 3 + chunk = pat[i:j] + if chunk: + chunks.append(chunk) + else: + chunks[-1] += "-" + # Remove empty ranges -- invalid in RE. + for k in range(len(chunks) - 1, 0, -1): + if chunks[k - 1][-1] > chunks[k][0]: + chunks[k - 1] = chunks[k - 1][:-1] + chunks[k][1:] + del chunks[k] + # Escape backslashes and hyphens for set difference (--). + # Hyphens that create ranges shouldn't be escaped. + stuff = "-".join( + s.replace("\\", r"\\").replace("-", r"\-") for s in chunks + ) + # Escape set operations (&&, ~~ and ||). + stuff = re.sub(r"([&~|])", r"\\\1", stuff) + i = j + 1 + if not stuff: + # Empty range: never match. + add("(?!)") + elif stuff == "!": + # Negated empty range: match any character. + add(".") + else: + if stuff[0] == "!": + stuff = "^" + stuff[1:] + elif stuff[0] in ("^", "["): + stuff = "\\" + stuff + add(f"[{stuff}]") + else: + add(re.escape(c)) + assert i == n + return res + + +def glob_translate(pat): + # Copied from: https://github.com/python/cpython/pull/106703. + # The keyword parameters' values are fixed to: + # recursive=True, include_hidden=True, seps=None + """Translate a pathname with shell wildcards to a regular expression.""" + if os.path.altsep: + seps = os.path.sep + os.path.altsep + else: + seps = os.path.sep + escaped_seps = "".join(map(re.escape, seps)) + any_sep = f"[{escaped_seps}]" if len(seps) > 1 else escaped_seps + not_sep = f"[^{escaped_seps}]" + one_last_segment = f"{not_sep}+" + one_segment = f"{one_last_segment}{any_sep}" + any_segments = f"(?:.+{any_sep})?" + any_last_segments = ".*" + results = [] + parts = re.split(any_sep, pat) + last_part_idx = len(parts) - 1 + for idx, part in enumerate(parts): + if part == "*": + results.append(one_segment if idx < last_part_idx else one_last_segment) + continue + if part == "**": + results.append(any_segments if idx < last_part_idx else any_last_segments) + continue + elif "**" in part: + raise ValueError( + "Invalid pattern: '**' can only be an entire path component" + ) + if part: + results.extend(_translate(part, f"{not_sep}*", not_sep)) + if idx < last_part_idx: + results.append(any_sep) + res = "".join(results) + return rf"(?s:{res})\Z" diff --git a/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/INSTALLER b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/METADATA b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..8056834e6714b089ef49847820064a1ae4b041fd --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/METADATA @@ -0,0 +1,625 @@ +Metadata-Version: 2.4 +Name: httpcore +Version: 1.0.9 +Summary: A minimal low-level HTTP client. +Project-URL: Documentation, https://www.encode.io/httpcore +Project-URL: Homepage, https://www.encode.io/httpcore/ +Project-URL: Source, https://github.com/encode/httpcore +Author-email: Tom Christie +License-Expression: BSD-3-Clause +License-File: LICENSE.md +Classifier: Development Status :: 3 - Alpha +Classifier: Environment :: Web Environment +Classifier: Framework :: AsyncIO +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.8 +Requires-Dist: certifi +Requires-Dist: h11>=0.16 +Provides-Extra: asyncio +Requires-Dist: anyio<5.0,>=4.0; extra == 'asyncio' +Provides-Extra: http2 +Requires-Dist: h2<5,>=3; extra == 'http2' +Provides-Extra: socks +Requires-Dist: socksio==1.*; extra == 'socks' +Provides-Extra: trio +Requires-Dist: trio<1.0,>=0.22.0; extra == 'trio' +Description-Content-Type: text/markdown + +# HTTP Core + +[![Test Suite](https://github.com/encode/httpcore/workflows/Test%20Suite/badge.svg)](https://github.com/encode/httpcore/actions) +[![Package version](https://badge.fury.io/py/httpcore.svg)](https://pypi.org/project/httpcore/) + +> *Do one thing, and do it well.* + +The HTTP Core package provides a minimal low-level HTTP client, which does +one thing only. Sending HTTP requests. + +It does not provide any high level model abstractions over the API, +does not handle redirects, multipart uploads, building authentication headers, +transparent HTTP caching, URL parsing, session cookie handling, +content or charset decoding, handling JSON, environment based configuration +defaults, or any of that Jazz. + +Some things HTTP Core does do: + +* Sending HTTP requests. +* Thread-safe / task-safe connection pooling. +* HTTP(S) proxy & SOCKS proxy support. +* Supports HTTP/1.1 and HTTP/2. +* Provides both sync and async interfaces. +* Async backend support for `asyncio` and `trio`. + +## Requirements + +Python 3.8+ + +## Installation + +For HTTP/1.1 only support, install with: + +```shell +$ pip install httpcore +``` + +There are also a number of optional extras available... + +```shell +$ pip install httpcore['asyncio,trio,http2,socks'] +``` + +## Sending requests + +Send an HTTP request: + +```python +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'\n\n\nExample Domain\n\n\n ...' +``` + +The top-level `httpcore.request()` function is provided for convenience. In practice whenever you're working with `httpcore` you'll want to use the connection pooling functionality that it provides. + +```python +import httpcore + +http = httpcore.ConnectionPool() +response = http.request("GET", "https://www.example.com/") +``` + +Once you're ready to get going, [head over to the documentation](https://www.encode.io/httpcore/). + +## Motivation + +You *probably* don't want to be using HTTP Core directly. It might make sense if +you're writing something like a proxy service in Python, and you just want +something at the lowest possible level, but more typically you'll want to use +a higher level client library, such as `httpx`. + +The motivation for `httpcore` is: + +* To provide a reusable low-level client library, that other packages can then build on top of. +* To provide a *really clear interface split* between the networking code and client logic, + so that each is easier to understand and reason about in isolation. + +## Dependencies + +The `httpcore` package has the following dependencies... + +* `h11` +* `certifi` + +And the following optional extras... + +* `anyio` - Required by `pip install httpcore['asyncio']`. +* `trio` - Required by `pip install httpcore['trio']`. +* `h2` - Required by `pip install httpcore['http2']`. +* `socksio` - Required by `pip install httpcore['socks']`. + +## Versioning + +We use [SEMVER for our versioning policy](https://semver.org/). + +For changes between package versions please see our [project changelog](CHANGELOG.md). + +We recommend pinning your requirements either the most current major version, or a more specific version range: + +```python +pip install 'httpcore==1.*' +``` +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +## Version 1.0.9 (April 24th, 2025) + +- Resolve https://github.com/advisories/GHSA-vqfr-h8mv-ghfj with h11 dependency update. (#1008) + +## Version 1.0.8 (April 11th, 2025) + +- Fix `AttributeError` when importing on Python 3.14. (#1005) + +## Version 1.0.7 (November 15th, 2024) + +- Support `proxy=…` configuration on `ConnectionPool()`. (#974) + +## Version 1.0.6 (October 1st, 2024) + +- Relax `trio` dependency pinning. (#956) +- Handle `trio` raising `NotImplementedError` on unsupported platforms. (#955) +- Handle mapping `ssl.SSLError` to `httpcore.ConnectError`. (#918) + +## 1.0.5 (March 27th, 2024) + +- Handle `EndOfStream` exception for anyio backend. (#899) +- Allow trio `0.25.*` series in package dependancies. (#903) + +## 1.0.4 (February 21st, 2024) + +- Add `target` request extension. (#888) +- Fix support for connection `Upgrade` and `CONNECT` when some data in the stream has been read. (#882) + +## 1.0.3 (February 13th, 2024) + +- Fix support for async cancellations. (#880) +- Fix trace extension when used with socks proxy. (#849) +- Fix SSL context for connections using the "wss" scheme (#869) + +## 1.0.2 (November 10th, 2023) + +- Fix `float("inf")` timeouts in `Event.wait` function. (#846) + +## 1.0.1 (November 3rd, 2023) + +- Fix pool timeout to account for the total time spent retrying. (#823) +- Raise a neater RuntimeError when the correct async deps are not installed. (#826) +- Add support for synchronous TLS-in-TLS streams. (#840) + +## 1.0.0 (October 6th, 2023) + +From version 1.0 our async support is now optional, as the package has minimal dependencies by default. + +For async support use either `pip install 'httpcore[asyncio]'` or `pip install 'httpcore[trio]'`. + +The project versioning policy is now explicitly governed by SEMVER. See https://semver.org/. + +- Async support becomes fully optional. (#809) +- Add support for Python 3.12. (#807) + +## 0.18.0 (September 8th, 2023) + +- Add support for HTTPS proxies. (#745, #786) +- Drop Python 3.7 support. (#727) +- Handle `sni_hostname` extension with SOCKS proxy. (#774) +- Handle HTTP/1.1 half-closed connections gracefully. (#641) +- Change the type of `Extensions` from `Mapping[Str, Any]` to `MutableMapping[Str, Any]`. (#762) + +## 0.17.3 (July 5th, 2023) + +- Support async cancellations, ensuring that the connection pool is left in a clean state when cancellations occur. (#726) +- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699) +- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730) +- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717) +- Drop Python 3.7 support. (#727) + +## 0.17.2 (May 23th, 2023) + +- Add `socket_options` argument to `ConnectionPool` and `HTTProxy` classes. (#668) +- Improve logging with per-module logger names. (#690) +- Add `sni_hostname` request extension. (#696) +- Resolve race condition during import of `anyio` package. (#692) +- Enable TCP_NODELAY for all synchronous sockets. (#651) + +## 0.17.1 (May 17th, 2023) + +- If 'retries' is set, then allow retries if an SSL handshake error occurs. (#669) +- Improve correctness of tracebacks on network exceptions, by raising properly chained exceptions. (#678) +- Prevent connection-hanging behaviour when HTTP/2 connections are closed by a server-sent 'GoAway' frame. (#679) +- Fix edge-case exception when removing requests from the connection pool. (#680) +- Fix pool timeout edge-case. (#688) + +## 0.17.0 (March 16th, 2023) + +- Add DEBUG level logging. (#648) +- Respect HTTP/2 max concurrent streams when settings updates are sent by server. (#652) +- Increase the allowable HTTP header size to 100kB. (#647) +- Add `retries` option to SOCKS proxy classes. (#643) + +## 0.16.3 (December 20th, 2022) + +- Allow `ws` and `wss` schemes. Allows us to properly support websocket upgrade connections. (#625) +- Forwarding HTTP proxies use a connection-per-remote-host. Required by some proxy implementations. (#637) +- Don't raise `RuntimeError` when closing a connection pool with active connections. Removes some error cases when cancellations are used. (#631) +- Lazy import `anyio`, so that it's no longer a hard dependancy, and isn't imported if unused. (#639) + +## 0.16.2 (November 25th, 2022) + +- Revert 'Fix async cancellation behaviour', which introduced race conditions. (#627) +- Raise `RuntimeError` if attempting to us UNIX domain sockets on Windows. (#619) + +## 0.16.1 (November 17th, 2022) + +- Fix HTTP/1.1 interim informational responses, such as "100 Continue". (#605) + +## 0.16.0 (October 11th, 2022) + +- Support HTTP/1.1 informational responses. (#581) +- Fix async cancellation behaviour. (#580) +- Support `h11` 0.14. (#579) + +## 0.15.0 (May 17th, 2022) + +- Drop Python 3.6 support (#535) +- Ensure HTTP proxy CONNECT requests include `timeout` configuration. (#506) +- Switch to explicit `typing.Optional` for type hints. (#513) +- For `trio` map OSError exceptions to `ConnectError`. (#543) + +## 0.14.7 (February 4th, 2022) + +- Requests which raise a PoolTimeout need to be removed from the pool queue. (#502) +- Fix AttributeError that happened when Socks5Connection were terminated. (#501) + +## 0.14.6 (February 1st, 2022) + +- Fix SOCKS support for `http://` URLs. (#492) +- Resolve race condition around exceptions during streaming a response. (#491) + +## 0.14.5 (January 18th, 2022) + +- SOCKS proxy support. (#478) +- Add proxy_auth argument to HTTPProxy. (#481) +- Improve error message on 'RemoteProtocolError' exception when server disconnects without sending a response. (#479) + +## 0.14.4 (January 5th, 2022) + +- Support HTTP/2 on HTTPS tunnelling proxies. (#468) +- Fix proxy headers missing on HTTP forwarding. (#456) +- Only instantiate SSL context if required. (#457) +- More robust HTTP/2 handling. (#253, #439, #440, #441) + +## 0.14.3 (November 17th, 2021) + +- Fix race condition when removing closed connections from the pool. (#437) + +## 0.14.2 (November 16th, 2021) + +- Failed connections no longer remain in the pool. (Pull #433) + +## 0.14.1 (November 12th, 2021) + +- `max_connections` becomes optional. (Pull #429) +- `certifi` is now included in the install dependancies. (Pull #428) +- `h2` is now strictly optional. (Pull #428) + +## 0.14.0 (November 11th, 2021) + +The 0.14 release is a complete reworking of `httpcore`, comprehensively addressing some underlying issues in the connection pooling, as well as substantially redesigning the API to be more user friendly. + +Some of the lower-level API design also makes the components more easily testable in isolation, and the package now has 100% test coverage. + +See [discussion #419](https://github.com/encode/httpcore/discussions/419) for a little more background. + +There's some other neat bits in there too, such as the "trace" extension, which gives a hook into inspecting the internal events that occur during the request/response cycle. This extension is needed for the HTTPX cli, in order to... + +* Log the point at which the connection is established, and the IP/port on which it is made. +* Determine if the outgoing request should log as HTTP/1.1 or HTTP/2, rather than having to assume it's HTTP/2 if the --http2 flag was passed. (Which may not actually be true.) +* Log SSL version info / certificate info. + +Note that `curio` support is not currently available in 0.14.0. If you're using `httpcore` with `curio` please get in touch, so we can assess if we ought to prioritize it as a feature or not. + +## 0.13.7 (September 13th, 2021) + +- Fix broken error messaging when URL scheme is missing, or a non HTTP(S) scheme is used. (Pull #403) + +## 0.13.6 (June 15th, 2021) + +### Fixed + +- Close sockets when read or write timeouts occur. (Pull #365) + +## 0.13.5 (June 14th, 2021) + +### Fixed + +- Resolved niggles with AnyIO EOF behaviours. (Pull #358, #362) + +## 0.13.4 (June 9th, 2021) + +### Added + +- Improved error messaging when URL scheme is missing, or a non HTTP(S) scheme is used. (Pull #354) + +### Fixed + +- Switched to `anyio` as the default backend implementation when running with `asyncio`. Resolves some awkward [TLS timeout issues](https://github.com/encode/httpx/discussions/1511). + +## 0.13.3 (May 6th, 2021) + +### Added + +- Support HTTP/2 prior knowledge, using `httpcore.SyncConnectionPool(http1=False)`. (Pull #333) + +### Fixed + +- Handle cases where environment does not provide `select.poll` support. (Pull #331) + +## 0.13.2 (April 29th, 2021) + +### Added + +- Improve error message for specific case of `RemoteProtocolError` where server disconnects without sending a response. (Pull #313) + +## 0.13.1 (April 28th, 2021) + +### Fixed + +- More resiliant testing for closed connections. (Pull #311) +- Don't raise exceptions on ungraceful connection closes. (Pull #310) + +## 0.13.0 (April 21st, 2021) + +The 0.13 release updates the core API in order to match the HTTPX Transport API, +introduced in HTTPX 0.18 onwards. + +An example of making requests with the new interface is: + +```python +with httpcore.SyncConnectionPool() as http: + status_code, headers, stream, extensions = http.handle_request( + method=b'GET', + url=(b'https', b'example.org', 443, b'/'), + headers=[(b'host', b'example.org'), (b'user-agent', b'httpcore')] + stream=httpcore.ByteStream(b''), + extensions={} + ) + body = stream.read() + print(status_code, body) +``` + +### Changed + +- The `.request()` method is now `handle_request()`. (Pull #296) +- The `.arequest()` method is now `.handle_async_request()`. (Pull #296) +- The `headers` argument is no longer optional. (Pull #296) +- The `stream` argument is no longer optional. (Pull #296) +- The `ext` argument is now named `extensions`, and is no longer optional. (Pull #296) +- The `"reason"` extension keyword is now named `"reason_phrase"`. (Pull #296) +- The `"reason_phrase"` and `"http_version"` extensions now use byte strings for their values. (Pull #296) +- The `httpcore.PlainByteStream()` class becomes `httpcore.ByteStream()`. (Pull #296) + +### Added + +- Streams now support a `.read()` interface. (Pull #296) + +### Fixed + +- Task cancellation no longer leaks connections from the connection pool. (Pull #305) + +## 0.12.3 (December 7th, 2020) + +### Fixed + +- Abort SSL connections on close rather than waiting for remote EOF when using `asyncio`. (Pull #167) +- Fix exception raised in case of connect timeouts when using the `anyio` backend. (Pull #236) +- Fix `Host` header precedence for `:authority` in HTTP/2. (Pull #241, #243) +- Handle extra edge case when detecting for socket readability when using `asyncio`. (Pull #242, #244) +- Fix `asyncio` SSL warning when using proxy tunneling. (Pull #249) + +## 0.12.2 (November 20th, 2020) + +### Fixed + +- Properly wrap connect errors on the asyncio backend. (Pull #235) +- Fix `ImportError` occurring on Python 3.9 when using the HTTP/1.1 sync client in a multithreaded context. (Pull #237) + +## 0.12.1 (November 7th, 2020) + +### Added + +- Add connect retries. (Pull #221) + +### Fixed + +- Tweak detection of dropped connections, resolving an issue with open files limits on Linux. (Pull #185) +- Avoid leaking connections when establishing an HTTP tunnel to a proxy has failed. (Pull #223) +- Properly wrap OS errors when using `trio`. (Pull #225) + +## 0.12.0 (October 6th, 2020) + +### Changed + +- HTTP header casing is now preserved, rather than always sent in lowercase. (#216 and python-hyper/h11#104) + +### Added + +- Add Python 3.9 to officially supported versions. + +### Fixed + +- Gracefully handle a stdlib asyncio bug when a connection is closed while it is in a paused-for-reading state. (#201) + +## 0.11.1 (September 28nd, 2020) + +### Fixed + +- Add await to async semaphore release() coroutine (#197) +- Drop incorrect curio classifier (#192) + +## 0.11.0 (September 22nd, 2020) + +The Transport API with 0.11.0 has a couple of significant changes. + +Firstly we've moved changed the request interface in order to allow extensions, which will later enable us to support features +such as trailing headers, HTTP/2 server push, and CONNECT/Upgrade connections. + +The interface changes from: + +```python +def request(method, url, headers, stream, timeout): + return (http_version, status_code, reason, headers, stream) +``` + +To instead including an optional dictionary of extensions on the request and response: + +```python +def request(method, url, headers, stream, ext): + return (status_code, headers, stream, ext) +``` + +Having an open-ended extensions point will allow us to add later support for various optional features, that wouldn't otherwise be supported without these API changes. + +In particular: + +* Trailing headers support. +* HTTP/2 Server Push +* sendfile. +* Exposing raw connection on CONNECT, Upgrade, HTTP/2 bi-di streaming. +* Exposing debug information out of the API, including template name, template context. + +Currently extensions are limited to: + +* request: `timeout` - Optional. Timeout dictionary. +* response: `http_version` - Optional. Include the HTTP version used on the response. +* response: `reason` - Optional. Include the reason phrase used on the response. Only valid with HTTP/1.*. + +See https://github.com/encode/httpx/issues/1274#issuecomment-694884553 for the history behind this. + +Secondly, the async version of `request` is now namespaced as `arequest`. + +This allows concrete transports to support both sync and async implementations on the same class. + +### Added + +- Add curio support. (Pull #168) +- Add anyio support, with `backend="anyio"`. (Pull #169) + +### Changed + +- Update the Transport API to use 'ext' for optional extensions. (Pull #190) +- Update the Transport API to use `.request` and `.arequest` so implementations can support both sync and async. (Pull #189) + +## 0.10.2 (August 20th, 2020) + +### Added + +- Added Unix Domain Socket support. (Pull #139) + +### Fixed + +- Always include the port on proxy CONNECT requests. (Pull #154) +- Fix `max_keepalive_connections` configuration. (Pull #153) +- Fixes behaviour in HTTP/1.1 where server disconnects can be used to signal the end of the response body. (Pull #164) + +## 0.10.1 (August 7th, 2020) + +- Include `max_keepalive_connections` on `AsyncHTTPProxy`/`SyncHTTPProxy` classes. + +## 0.10.0 (August 7th, 2020) + +The most notable change in the 0.10.0 release is that HTTP/2 support is now fully optional. + +Use either `pip install httpcore` for HTTP/1.1 support only, or `pip install httpcore[http2]` for HTTP/1.1 and HTTP/2 support. + +### Added + +- HTTP/2 support becomes optional. (Pull #121, #130) +- Add `local_address=...` support. (Pull #100, #134) +- Add `PlainByteStream`, `IteratorByteStream`, `AsyncIteratorByteStream`. The `AsyncByteSteam` and `SyncByteStream` classes are now pure interface classes. (#133) +- Add `LocalProtocolError`, `RemoteProtocolError` exceptions. (Pull #129) +- Add `UnsupportedProtocol` exception. (Pull #128) +- Add `.get_connection_info()` method. (Pull #102, #137) +- Add better TRACE logs. (Pull #101) + +### Changed + +- `max_keepalive` is deprecated in favour of `max_keepalive_connections`. (Pull #140) + +### Fixed + +- Improve handling of server disconnects. (Pull #112) + +## 0.9.1 (May 27th, 2020) + +### Fixed + +- Proper host resolution for sync case, including IPv6 support. (Pull #97) +- Close outstanding connections when connection pool is closed. (Pull #98) + +## 0.9.0 (May 21th, 2020) + +### Changed + +- URL port becomes an `Optional[int]` instead of `int`. (Pull #92) + +### Fixed + +- Honor HTTP/2 max concurrent streams settings. (Pull #89, #90) +- Remove incorrect debug log. (Pull #83) + +## 0.8.4 (May 11th, 2020) + +### Added + +- Logging via HTTPCORE_LOG_LEVEL and HTTPX_LOG_LEVEL environment variables +and TRACE level logging. (Pull #79) + +### Fixed + +- Reuse of connections on HTTP/2 in close concurrency situations. (Pull #81) + +## 0.8.3 (May 6rd, 2020) + +### Fixed + +- Include `Host` and `Accept` headers on proxy "CONNECT" requests. +- De-duplicate any headers also contained in proxy_headers. +- HTTP/2 flag not being passed down to proxy connections. + +## 0.8.2 (May 3rd, 2020) + +### Fixed + +- Fix connections using proxy forwarding requests not being added to the +connection pool properly. (Pull #70) + +## 0.8.1 (April 30th, 2020) + +### Changed + +- Allow inherintance of both `httpcore.AsyncByteStream`, `httpcore.SyncByteStream` without type conflicts. + +## 0.8.0 (April 30th, 2020) + +### Fixed + +- Fixed tunnel proxy support. + +### Added + +- New `TimeoutException` base class. + +## 0.7.0 (March 5th, 2020) + +- First integration with HTTPX. diff --git a/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/RECORD b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..aac7c19612e1316861a77cb89aab09c6b1a0f1b2 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/RECORD @@ -0,0 +1,68 @@ +httpcore-1.0.9.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +httpcore-1.0.9.dist-info/METADATA,sha256=_i1P2mGZEol4d54M8n88BFxTGGP83Zh-rMdPOhjUHCE,21529 +httpcore-1.0.9.dist-info/RECORD,, +httpcore-1.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87 +httpcore-1.0.9.dist-info/licenses/LICENSE.md,sha256=_ctZFUx0y6uhahEkL3dAvqnyPW_rVUeRfYxflKgDkqU,1518 +httpcore/__init__.py,sha256=9kT_kqChCCJUTHww24ZmR_ezcdbpRYWksD-gYNzkZP8,3445 +httpcore/__pycache__/__init__.cpython-310.pyc,, +httpcore/__pycache__/_api.cpython-310.pyc,, +httpcore/__pycache__/_exceptions.cpython-310.pyc,, +httpcore/__pycache__/_models.cpython-310.pyc,, +httpcore/__pycache__/_ssl.cpython-310.pyc,, +httpcore/__pycache__/_synchronization.cpython-310.pyc,, +httpcore/__pycache__/_trace.cpython-310.pyc,, +httpcore/__pycache__/_utils.cpython-310.pyc,, +httpcore/_api.py,sha256=unZmeDschBWCGCPCwkS3Wot9euK6bg_kKxLtGTxw214,3146 +httpcore/_async/__init__.py,sha256=EWdl2v4thnAHzJpqjU4h2a8DUiGAvNiWrkii9pfhTf0,1221 +httpcore/_async/__pycache__/__init__.cpython-310.pyc,, +httpcore/_async/__pycache__/connection.cpython-310.pyc,, +httpcore/_async/__pycache__/connection_pool.cpython-310.pyc,, +httpcore/_async/__pycache__/http11.cpython-310.pyc,, +httpcore/_async/__pycache__/http2.cpython-310.pyc,, +httpcore/_async/__pycache__/http_proxy.cpython-310.pyc,, +httpcore/_async/__pycache__/interfaces.cpython-310.pyc,, +httpcore/_async/__pycache__/socks_proxy.cpython-310.pyc,, +httpcore/_async/connection.py,sha256=6OcPXqMEfc0BU38_-iHUNDd1vKSTc2UVT09XqNb_BOk,8449 +httpcore/_async/connection_pool.py,sha256=DOIQ2s2ZCf9qfwxhzMprTPLqCL8OxGXiKF6qRHxvVyY,17307 +httpcore/_async/http11.py,sha256=-qM9bV7PjSQF5vxs37-eUXOIFwbIjPcZbNliuX9TtBw,13880 +httpcore/_async/http2.py,sha256=azX1fcmtXaIwjputFlZ4vd92J8xwjGOa9ax9QIv4394,23936 +httpcore/_async/http_proxy.py,sha256=2zVkrlv-Ds-rWGaqaXlrhEJiAQFPo23BT3Gq_sWoBXU,14701 +httpcore/_async/interfaces.py,sha256=jTiaWL83pgpGC9ziv90ZfwaKNMmHwmOalzaKiuTxATo,4455 +httpcore/_async/socks_proxy.py,sha256=lLKgLlggPfhFlqi0ODeBkOWvt9CghBBUyqsnsU1tx6Q,13841 +httpcore/_backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +httpcore/_backends/__pycache__/__init__.cpython-310.pyc,, +httpcore/_backends/__pycache__/anyio.cpython-310.pyc,, +httpcore/_backends/__pycache__/auto.cpython-310.pyc,, +httpcore/_backends/__pycache__/base.cpython-310.pyc,, +httpcore/_backends/__pycache__/mock.cpython-310.pyc,, +httpcore/_backends/__pycache__/sync.cpython-310.pyc,, +httpcore/_backends/__pycache__/trio.cpython-310.pyc,, +httpcore/_backends/anyio.py,sha256=x8PgEhXRC8bVqsdzk_YJx8Y6d9Tub06CuUSwnbmtqoY,5252 +httpcore/_backends/auto.py,sha256=zO136PKZmsaTDK-HRk84eA-MUg8_2wJf4NvmK432Aio,1662 +httpcore/_backends/base.py,sha256=aShgRdZnMmRhFWHetjumlM73f8Kz1YOAyCUP_4kHslA,3042 +httpcore/_backends/mock.py,sha256=er9T436uSe7NLrfiLa4x6Nuqg5ivQ693CxWYCWsgbH4,4077 +httpcore/_backends/sync.py,sha256=bhE4d9iK9Umxdsdsgm2EfKnXaBms2WggGYU-7jmUujU,7977 +httpcore/_backends/trio.py,sha256=LHu4_Mr5MswQmmT3yE4oLgf9b_JJfeVS4BjDxeJc7Ro,5996 +httpcore/_exceptions.py,sha256=looCKga3_YVYu3s-d3L9RMPRJyhsY7fiuuGxvkOD0c0,1184 +httpcore/_models.py,sha256=IO2CcXcdpovRcLTdGFGB6RyBZdEm2h_TOmoCc4rEKho,17623 +httpcore/_ssl.py,sha256=srqmSNU4iOUvWF-SrJvb8G_YEbHFELOXQOwdDIBTS9c,187 +httpcore/_sync/__init__.py,sha256=JBDIgXt5la1LCJ1sLQeKhjKFpLnpNr8Svs6z2ni3fgg,1141 +httpcore/_sync/__pycache__/__init__.cpython-310.pyc,, +httpcore/_sync/__pycache__/connection.cpython-310.pyc,, +httpcore/_sync/__pycache__/connection_pool.cpython-310.pyc,, +httpcore/_sync/__pycache__/http11.cpython-310.pyc,, +httpcore/_sync/__pycache__/http2.cpython-310.pyc,, +httpcore/_sync/__pycache__/http_proxy.cpython-310.pyc,, +httpcore/_sync/__pycache__/interfaces.cpython-310.pyc,, +httpcore/_sync/__pycache__/socks_proxy.cpython-310.pyc,, +httpcore/_sync/connection.py,sha256=9exGOb3PB-Mp2T1-sckSeL2t-tJ_9-NXomV8ihmWCgU,8238 +httpcore/_sync/connection_pool.py,sha256=a-T8LTsUxc7r0Ww1atfHSDoWPjQ0fA8Ul7S3-F0Mj70,16955 +httpcore/_sync/http11.py,sha256=IFobD1Md5JFlJGKWnh1_Q3epikUryI8qo09v8MiJIEA,13476 +httpcore/_sync/http2.py,sha256=AxU4yhcq68Bn5vqdJYtiXKYUj7nvhYbxz3v4rT4xnvA,23400 +httpcore/_sync/http_proxy.py,sha256=_al_6crKuEZu2wyvu493RZImJdBJnj5oGKNjLOJL2Zo,14463 +httpcore/_sync/interfaces.py,sha256=snXON42vUDHO5JBJvo8D4VWk2Wat44z2OXXHDrjbl94,4344 +httpcore/_sync/socks_proxy.py,sha256=zegZW9Snqj2_992DFJa8_CppOVBkVL4AgwduRkStakQ,13614 +httpcore/_synchronization.py,sha256=zSi13mAColBnknjZBknUC6hKNDQT4C6ijnezZ-r0T2s,9434 +httpcore/_trace.py,sha256=ck6ZoIzYTkdNAIfq5MGeKqBXDtqjOX-qfYwmZFbrGco,3952 +httpcore/_utils.py,sha256=_RLgXYOAYC350ikALV59GZ68IJrdocRZxPs9PjmzdFY,1537 +httpcore/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/WHEEL b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..12228d414b6cfed7c39d3781c85c63256a1d7fb5 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.27.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/licenses/LICENSE.md b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/licenses/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..311b2b56c53f678ab95fc0def708c675d521a807 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore-1.0.9.dist-info/licenses/LICENSE.md @@ -0,0 +1,27 @@ +Copyright © 2020, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8f7f9d5185b5fc047b3419df015a299f0b632d Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/_api.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08644c01ab7f609acf7e3ce7b3f543cbe6e475b1 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_api.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/_exceptions.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c1283371ed56c4f4ac9508e976fe89bbe446ae1 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_exceptions.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/_models.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d76c5440be65487b3d87960b2195776f978473c5 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_models.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/_ssl.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_ssl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fb8e080754dfb1014c33531243b2319f92cc953 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_ssl.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/_synchronization.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_synchronization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bf604efd292523fe150f1e13e90855d8839f6b5 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_synchronization.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/_trace.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_trace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9583ca3612821e2dfb1c61d48a22b9600d118ab Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_trace.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/__pycache__/_utils.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9cf95cd87a93c53b311a9362adce1dbc99931b Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/__pycache__/_utils.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__init__.py b/venv/lib/python3.10/site-packages/httpcore/_async/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88dc7f01e132933728cbcf45c88ce82e85ddf65f --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/__init__.py @@ -0,0 +1,39 @@ +from .connection import AsyncHTTPConnection +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .http_proxy import AsyncHTTPProxy +from .interfaces import AsyncConnectionInterface + +try: + from .http2 import AsyncHTTP2Connection +except ImportError: # pragma: nocover + + class AsyncHTTP2Connection: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use http2 support, but the `h2` package is not " + "installed. Use 'pip install httpcore[http2]'." + ) + + +try: + from .socks_proxy import AsyncSOCKSProxy +except ImportError: # pragma: nocover + + class AsyncSOCKSProxy: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use SOCKS support, but the `socksio` package is not " + "installed. Use 'pip install httpcore[socks]'." + ) + + +__all__ = [ + "AsyncHTTPConnection", + "AsyncConnectionPool", + "AsyncHTTPProxy", + "AsyncHTTP11Connection", + "AsyncHTTP2Connection", + "AsyncConnectionInterface", + "AsyncSOCKSProxy", +] diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb8f3857ee8b3c930a10663073ea7d6536131e5 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/connection.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/connection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a16afa783e6bb6436e35c30d83ad78418215f5b Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/connection.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/connection_pool.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/connection_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb49662e0fa07377f5a320a8e02a54ef3219b16f Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/connection_pool.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http11.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http11.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a70bac2c4e8bceea591ccd327c81726f4b969862 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http11.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http2.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a54516d8efd602432c2f9ce1d56d472c1ab0a4 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http2.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http_proxy.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http_proxy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c78e4f956dae8cf1702f64d7a60e988cbd3593e4 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/http_proxy.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/interfaces.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/interfaces.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194f2f34c5484f55348e15d721152281dd9179e0 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/interfaces.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/socks_proxy.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/socks_proxy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22b7a74671946f397fe6a7eb3aa01cfd8166f8dc Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_async/__pycache__/socks_proxy.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/connection.py b/venv/lib/python3.10/site-packages/httpcore/_async/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..b42581dff8aabf4c2ef80ffda26296e1b368d693 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/connection.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import itertools +import logging +import ssl +import types +import typing + +from .._backends.auto import AutoBackend +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream +from .._exceptions import ConnectError, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + + +logger = logging.getLogger("httpcore.connection") + + +def exponential_backoff(factor: float) -> typing.Iterator[float]: + """ + Generate a geometric sequence that has a ratio of 2 and starts with 0. + + For example: + - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...` + - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...` + """ + yield 0 + for n in itertools.count(): + yield factor * 2**n + + +class AsyncHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + origin: Origin, + ssl_context: ssl.SSLContext | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: str | None = None, + uds: str | None = None, + network_backend: AsyncNetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + self._origin = origin + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend: AsyncNetworkBackend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._connection: AsyncConnectionInterface | None = None + self._connect_failed: bool = False + self._request_lock = AsyncLock() + self._socket_options = socket_options + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) + + try: + async with self._request_lock: + if self._connection is None: + stream = await self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except BaseException as exc: + self._connect_failed = True + raise exc + + return await self._connection.handle_async_request(request) + + async def _connect(self, request: Request) -> AsyncNetworkStream: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + "socket_options": self._socket_options, + } + async with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = await self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + else: + kwargs = { + "path": self._uds, + "timeout": timeout, + "socket_options": self._socket_options, + } + async with Trace( + "connect_unix_socket", logger, request, kwargs + ) as trace: + stream = await self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream + + if self._origin.scheme in (b"https", b"wss"): + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + return stream + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + async with Trace("retry", logger, request, kwargs) as trace: + await self._network_backend.sleep(delay) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + async def aclose(self) -> None: + if self._connection is not None: + async with Trace("close", logger, None, {}): + await self._connection.aclose() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> AsyncHTTPConnection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + await self.aclose() diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/connection_pool.py b/venv/lib/python3.10/site-packages/httpcore/_async/connection_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..96e973d0ce223f6bed9be9e6a6a2f3c01622c611 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/connection_pool.py @@ -0,0 +1,420 @@ +from __future__ import annotations + +import ssl +import sys +import types +import typing + +from .._backends.auto import AutoBackend +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Proxy, Request, Response +from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock +from .connection import AsyncHTTPConnection +from .interfaces import AsyncConnectionInterface, AsyncRequestInterface + + +class AsyncPoolRequest: + def __init__(self, request: Request) -> None: + self.request = request + self.connection: AsyncConnectionInterface | None = None + self._connection_acquired = AsyncEvent() + + def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None: + self.connection = connection + self._connection_acquired.set() + + def clear_connection(self) -> None: + self.connection = None + self._connection_acquired = AsyncEvent() + + async def wait_for_connection( + self, timeout: float | None = None + ) -> AsyncConnectionInterface: + if self.connection is None: + await self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + def is_queued(self) -> bool: + return self.connection is None + + +class AsyncConnectionPool(AsyncRequestInterface): + """ + A connection pool for making HTTP requests. + """ + + def __init__( + self, + ssl_context: ssl.SSLContext | None = None, + proxy: Proxy | None = None, + max_connections: int | None = 10, + max_keepalive_connections: int | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: str | None = None, + uds: str | None = None, + network_backend: AsyncNetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + socket_options: Socket options that have to be included + in the TCP socket when the connection was established. + """ + self._ssl_context = ssl_context + self._proxy = proxy + self._max_connections = ( + sys.maxsize if max_connections is None else max_connections + ) + self._max_keepalive_connections = ( + sys.maxsize + if max_keepalive_connections is None + else max_keepalive_connections + ) + self._max_keepalive_connections = min( + self._max_connections, self._max_keepalive_connections + ) + + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._socket_options = socket_options + + # The mutable state on a connection pool is the queue of incoming requests, + # and the set of connections that are servicing those requests. + self._connections: list[AsyncConnectionInterface] = [] + self._requests: list[AsyncPoolRequest] = [] + + # We only mutate the state of the connection pool within an 'optional_thread_lock' + # context. This holds a threading lock unless we're running in async mode, + # in which case it is a no-op. + self._optional_thread_lock = AsyncThreadLock() + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + if self._proxy is not None: + if self._proxy.url.scheme in (b"socks5", b"socks5h"): + from .socks_proxy import AsyncSocks5Connection + + return AsyncSocks5Connection( + proxy_origin=self._proxy.url.origin, + proxy_auth=self._proxy.auth, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + elif origin.scheme == b"http": + from .http_proxy import AsyncForwardHTTPConnection + + return AsyncForwardHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + ) + from .http_proxy import AsyncTunnelHTTPConnection + + return AsyncTunnelHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + return AsyncHTTPConnection( + origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + retries=self._retries, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, + socket_options=self._socket_options, + ) + + @property + def connections(self) -> list[AsyncConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + , + , + , + ] + ``` + """ + return list(self._connections) + + async def handle_async_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": + raise UnsupportedProtocol( + "Request URL is missing an 'http://' or 'https://' protocol." + ) + if scheme not in ("http", "https", "ws", "wss"): + raise UnsupportedProtocol( + f"Request URL has an unsupported protocol '{scheme}://'." + ) + + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + + with self._optional_thread_lock: + # Add the incoming request to our request queue. + pool_request = AsyncPoolRequest(request) + self._requests.append(pool_request) + + try: + while True: + with self._optional_thread_lock: + # Assign incoming requests to available connections, + # closing or creating new connections as required. + closing = self._assign_requests_to_connections() + await self._close_connections(closing) + + # Wait until this request has an assigned connection. + connection = await pool_request.wait_for_connection(timeout=timeout) + + try: + # Send the request on the assigned connection. + response = await connection.handle_async_request( + pool_request.request + ) + except ConnectionNotAvailable: + # In some cases a connection may initially be available to + # handle a request, but then become unavailable. + # + # In this case we clear the connection and try again. + pool_request.clear_connection() + else: + break # pragma: nocover + + except BaseException as exc: + with self._optional_thread_lock: + # For any exception or cancellation we remove the request from + # the queue, and then re-assign requests to connections. + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() + + await self._close_connections(closing) + raise exc from None + + # Return the response. Note that in this case we still have to manage + # the point at which the response is closed. + assert isinstance(response.stream, typing.AsyncIterable) + return Response( + status=response.status, + headers=response.headers, + content=PoolByteStream( + stream=response.stream, pool_request=pool_request, pool=self + ), + extensions=response.extensions, + ) + + def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]: + """ + Manage the state of the connection pool, assigning incoming + requests to connections as available. + + Called whenever a new request is added or removed from the pool. + + Any closing connections are returned, allowing the I/O for closing + those connections to be handled seperately. + """ + closing_connections = [] + + # First we handle cleaning up any connections that are closed, + # have expired their keep-alive, or surplus idle connections. + for connection in list(self._connections): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): + # log: "closing expired connection" + self._connections.remove(connection) + closing_connections.append(connection) + elif ( + connection.is_idle() + and len([connection.is_idle() for connection in self._connections]) + > self._max_keepalive_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) + + # Assign queued requests to connections. + queued_requests = [request for request in self._requests if request.is_queued()] + for pool_request in queued_requests: + origin = pool_request.request.url.origin + available_connections = [ + connection + for connection in self._connections + if connection.can_handle_request(origin) and connection.is_available() + ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] + + # There are three cases for how we may be able to handle the request: + # + # 1. There is an existing connection that can handle the request. + # 2. We can create a new connection to handle the request. + # 3. We can close an idle connection and then create a new connection + # to handle the request. + if available_connections: + # log: "reusing existing connection" + connection = available_connections[0] + pool_request.assign_to_connection(connection) + elif len(self._connections) < self._max_connections: + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + elif idle_connections: + # log: "closing idle connection" + connection = idle_connections[0] + self._connections.remove(connection) + closing_connections.append(connection) + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + + return closing_connections + + async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None: + # Close connections which have been removed from the pool. + with AsyncShieldCancellation(): + for connection in closing: + await connection.aclose() + + async def aclose(self) -> None: + # Explicitly close the connection pool. + # Clears all existing requests and connections. + with self._optional_thread_lock: + closing_connections = list(self._connections) + self._connections = [] + await self._close_connections(closing_connections) + + async def __aenter__(self) -> AsyncConnectionPool: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + await self.aclose() + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + with self._optional_thread_lock: + request_is_queued = [request.is_queued() for request in self._requests] + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + + num_active_requests = request_is_queued.count(False) + num_queued_requests = request_is_queued.count(True) + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) + + requests_info = ( + f"Requests: {num_active_requests} active, {num_queued_requests} queued" + ) + connection_info = ( + f"Connections: {num_active_connections} active, {num_idle_connections} idle" + ) + + return f"<{class_name} [{requests_info} | {connection_info}]>" + + +class PoolByteStream: + def __init__( + self, + stream: typing.AsyncIterable[bytes], + pool_request: AsyncPoolRequest, + pool: AsyncConnectionPool, + ) -> None: + self._stream = stream + self._pool_request = pool_request + self._pool = pool + self._closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + try: + async for part in self._stream: + yield part + except BaseException as exc: + await self.aclose() + raise exc from None + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + with AsyncShieldCancellation(): + if hasattr(self._stream, "aclose"): + await self._stream.aclose() + + with self._pool._optional_thread_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + + await self._pool._close_connections(closing) diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/http11.py b/venv/lib/python3.10/site-packages/httpcore/_async/http11.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d6d709852b137a862cfe2b3af42dc790fa705d --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/http11.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +import enum +import logging +import ssl +import time +import types +import typing + +import h11 + +from .._backends.base import AsyncNetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + WriteError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncShieldCancellation +from .._trace import Trace +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.http11") + + +# A subset of `h11.Event` types supported by `_send_event` +H11SendEvent = typing.Union[ + h11.Request, + h11.Data, + h11.EndOfMessage, +] + + +class HTTPConnectionState(enum.IntEnum): + NEW = 0 + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class AsyncHTTP11Connection(AsyncConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 + + def __init__( + self, + origin: Origin, + stream: AsyncNetworkStream, + keepalive_expiry: float | None = None, + ) -> None: + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: float | None = keepalive_expiry + self._expire_at: float | None = None + self._state = HTTPConnectionState.NEW + self._state_lock = AsyncLock() + self._request_count = 0 + self._h11_state = h11.Connection( + our_role=h11.CLIENT, + max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, + ) + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + async with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + try: + async with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + await self._send_request_headers(**kwargs) + async with Trace("send_request_body", logger, request, kwargs) as trace: + await self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + trailing_data, + ) = await self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": network_stream, + }, + ) + except BaseException as exc: + with AsyncShieldCancellation(): + async with Trace("response_closed", logger, request) as trace: + await self._response_closed() + raise exc + + # Sending the request... + + async def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + await self._send_event(event, timeout=timeout) + + async def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, typing.AsyncIterable) + async for chunk in request.stream: + event = h11.Data(data=chunk) + await self._send_event(event, timeout=timeout) + + await self._send_event(h11.EndOfMessage(), timeout=timeout) + + async def _send_event(self, event: h11.Event, timeout: float | None = None) -> None: + bytes_to_send = self._h11_state.send(event) + if bytes_to_send is not None: + await self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... + + async def _receive_response_headers( + self, request: Request + ) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = await self._receive_event(timeout=timeout) + if isinstance(event, h11.Response): + break + if ( + isinstance(event, h11.InformationalResponse) + and event.status_code == 101 + ): + break + + http_version = b"HTTP/" + event.http_version + + # h11 version 0.11+ supports a `raw_items` interface to get the + # raw header casing, rather than the enforced lowercase headers. + headers = event.headers.raw_items() + + trailing_data, _ = self._h11_state.trailing_data + + return http_version, event.status_code, event.reason, headers, trailing_data + + async def _receive_response_body( + self, request: Request + ) -> typing.AsyncIterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = await self._receive_event(timeout=timeout) + if isinstance(event, h11.Data): + yield bytes(event.data) + elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + break + + async def _receive_event( + self, timeout: float | None = None + ) -> h11.Event | type[h11.PAUSED]: + while True: + with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + event = self._h11_state.next_event() + + if event is h11.NEED_DATA: + data = await self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) + + # If we feed this case through h11 we'll raise an exception like: + # + # httpcore.RemoteProtocolError: can't handle event type + # ConnectionClosed when role=SERVER and state=SEND_RESPONSE + # + # Which is accurate, but not very informative from an end-user + # perspective. Instead we handle this case distinctly and treat + # it as a ConnectError. + if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: + msg = "Server disconnected without sending a response." + raise RemoteProtocolError(msg) + + self._h11_state.receive_data(data) + else: + # mypy fails to narrow the type in the above if statement above + return event # type: ignore[return-value] + + async def _response_closed(self) -> None: + async with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + await self.aclose() + + # Once the connection is no longer required... + + async def aclose(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # The AsyncConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> AsyncHTTP11Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + await self.aclose() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + kwargs = {"request": self._request} + try: + async with Trace("receive_response_body", logger, self._request, kwargs): + async for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + async with Trace("response_closed", logger, self._request): + await self._connection._response_closed() + + +class AsyncHTTP11UpgradeStream(AsyncNetworkStream): + def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None: + self._stream = stream + self._leading_data = leading_data + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + if self._leading_data: + buffer = self._leading_data[:max_bytes] + self._leading_data = self._leading_data[max_bytes:] + return buffer + else: + return await self._stream.read(max_bytes, timeout) + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + await self._stream.write(buffer, timeout) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> AsyncNetworkStream: + return await self._stream.start_tls(ssl_context, server_hostname, timeout) + + def get_extra_info(self, info: str) -> typing.Any: + return self._stream.get_extra_info(info) diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/http2.py b/venv/lib/python3.10/site-packages/httpcore/_async/http2.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd0beeb4da32d8c0175d412fa442eae8f837723 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/http2.py @@ -0,0 +1,592 @@ +from __future__ import annotations + +import enum +import logging +import time +import types +import typing + +import h2.config +import h2.connection +import h2.events +import h2.exceptions +import h2.settings + +from .._backends.base import AsyncNetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, +) +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation +from .._trace import Trace +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.http2") + + +def has_body_headers(request: Request) -> bool: + return any( + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ) + + +class HTTPConnectionState(enum.IntEnum): + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class AsyncHTTP2Connection(AsyncConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) + + def __init__( + self, + origin: Origin, + stream: AsyncNetworkStream, + keepalive_expiry: float | None = None, + ): + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: float | None = keepalive_expiry + self._h2_state = h2.connection.H2Connection(config=self.CONFIG) + self._state = HTTPConnectionState.IDLE + self._expire_at: float | None = None + self._request_count = 0 + self._init_lock = AsyncLock() + self._state_lock = AsyncLock() + self._read_lock = AsyncLock() + self._write_lock = AsyncLock() + self._sent_connection_init = False + self._used_all_stream_ids = False + self._connection_error = False + + # Mapping from stream ID to response stream events. + self._events: dict[ + int, + list[ + h2.events.ResponseReceived + | h2.events.DataReceived + | h2.events.StreamEnded + | h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: h2.events.ConnectionTerminated | None = None + + self._read_exception: Exception | None = None + self._write_exception: Exception | None = None + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + # This cannot occur in normal operation, since the connection pool + # will only send requests on connections that handle them. + # It's in place simply for resilience as a guard against incorrect + # usage, for anyone working directly with httpcore connections. + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + async with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() + + async with self._init_lock: + if not self._sent_connection_init: + try: + sci_kwargs = {"request": request} + async with Trace( + "send_connection_init", logger, request, sci_kwargs + ): + await self._send_connection_init(**sci_kwargs) + except BaseException as exc: + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + self._sent_connection_init = True + + # Initially start with just 1 until the remote server provides + # its max_concurrent_streams value + self._max_streams = 1 + + local_settings_max_streams = ( + self._h2_state.local_settings.max_concurrent_streams + ) + self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams) + + for _ in range(local_settings_max_streams - self._max_streams): + await self._max_streams_semaphore.acquire() + + await self._max_streams_semaphore.acquire() + + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + async with Trace("send_request_headers", logger, request, kwargs): + await self._send_request_headers(request=request, stream_id=stream_id) + async with Trace("send_request_body", logger, request, kwargs): + await self._send_request_body(request=request, stream_id=stream_id) + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = await self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, + ) + except BaseException as exc: # noqa: PIE786 + with AsyncShieldCancellation(): + kwargs = {"stream_id": stream_id} + async with Trace("response_closed", logger, request, kwargs): + await self._response_closed(stream_id=stream_id) + + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + + raise exc + + async def _send_connection_init(self, request: Request) -> None: + """ + The HTTP/2 connection requires some initial setup before we can start + using individual request/response streams on it. + """ + # Need to set these manually here instead of manipulating via + # __setitem__() otherwise the H2Connection will emit SettingsUpdate + # frames in addition to sending the undesired defaults. + self._h2_state.local_settings = h2.settings.Settings( + client=True, + initial_values={ + # Disable PUSH_PROMISE frames from the server since we don't do anything + # with them for now. Maybe when we support caching? + h2.settings.SettingCodes.ENABLE_PUSH: 0, + # These two are taken from h2 for safe defaults + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + }, + ) + + # Some websites (*cough* Yahoo *cough*) balk at this setting being + # present in the initial handshake since it's not defined in the original + # RFC despite the RFC mandating ignoring settings you don't know about. + del self._h2_state.local_settings[ + h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL + ] + + self._h2_state.initiate_connection() + self._h2_state.increment_flow_control_window(2**24) + await self._write_outgoing_data(request) + + # Sending the request... + + async def _send_request_headers(self, request: Request, stream_id: int) -> None: + """ + Send the request headers to a given stream ID. + """ + end_stream = not has_body_headers(request) + + # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. + # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require + # HTTP/1.1 style headers, and map them appropriately if we end up on + # an HTTP/2 connection. + authority = [v for k, v in request.headers if k.lower() == b"host"][0] + + headers = [ + (b":method", request.method), + (b":authority", authority), + (b":scheme", request.url.scheme), + (b":path", request.url.target), + ] + [ + (k.lower(), v) + for k, v in request.headers + if k.lower() + not in ( + b"host", + b"transfer-encoding", + ) + ] + + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) + await self._write_outgoing_data(request) + + async def _send_request_body(self, request: Request, stream_id: int) -> None: + """ + Iterate over the request body sending it to a given stream ID. + """ + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.AsyncIterable) + async for data in request.stream: + await self._send_stream_data(request, stream_id, data) + await self._send_end_stream(request, stream_id) + + async def _send_stream_data( + self, request: Request, stream_id: int, data: bytes + ) -> None: + """ + Send a single chunk of data in one or more data frames. + """ + while data: + max_flow = await self._wait_for_outgoing_flow(request, stream_id) + chunk_size = min(len(data), max_flow) + chunk, data = data[:chunk_size], data[chunk_size:] + self._h2_state.send_data(stream_id, chunk) + await self._write_outgoing_data(request) + + async def _send_end_stream(self, request: Request, stream_id: int) -> None: + """ + Send an empty data frame on on a given stream ID with the END_STREAM flag set. + """ + self._h2_state.end_stream(stream_id) + await self._write_outgoing_data(request) + + # Receiving the response... + + async def _receive_response( + self, request: Request, stream_id: int + ) -> tuple[int, list[tuple[bytes, bytes]]]: + """ + Return the response status code and headers for a given stream ID. + """ + while True: + event = await self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + assert event.headers is not None + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode("ascii", errors="ignore")) + elif not k.startswith(b":"): + headers.append((k, v)) + + return (status_code, headers) + + async def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.AsyncIterator[bytes]: + """ + Iterator that returns the bytes of the response body for a given stream ID. + """ + while True: + event = await self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.DataReceived): + assert event.flow_controlled_length is not None + assert event.data is not None + amount = event.flow_controlled_length + self._h2_state.acknowledge_received_data(amount, stream_id) + await self._write_outgoing_data(request) + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + async def _receive_stream_event( + self, request: Request, stream_id: int + ) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded: + """ + Return the next available event for a given stream ID. + + Will read more data from the network if required. + """ + while not self._events.get(stream_id): + await self._receive_events(request, stream_id) + event = self._events[stream_id].pop(0) + if isinstance(event, h2.events.StreamReset): + raise RemoteProtocolError(event) + return event + + async def _receive_events( + self, request: Request, stream_id: int | None = None + ) -> None: + """ + Read some data from the network until we see one or more events + for a given stream ID. + """ + async with self._read_lock: + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) + + # This conditional is a bit icky. We don't want to block reading if we've + # actually got an event to return for a given stream. We need to do that + # check *within* the atomic read lock. Though it also need to be optional, + # because when we call it from `_wait_for_outgoing_flow` we *do* want to + # block until we've available flow control, event when we have events + # pending for the stream ID we're attempting to send on. + if stream_id is None or not self._events.get(stream_id): + events = await self._read_incoming_data(request) + for event in events: + if isinstance(event, h2.events.RemoteSettingsChanged): + async with Trace( + "receive_remote_settings", logger, request + ) as trace: + await self._receive_remote_settings_change(event) + trace.return_value = event + + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event + + await self._write_outgoing_data(request) + + async def _receive_remote_settings_change( + self, event: h2.events.RemoteSettingsChanged + ) -> None: + max_concurrent_streams = event.changed_settings.get( + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + ) + if max_concurrent_streams: + new_max_streams = min( + max_concurrent_streams.new_value, + self._h2_state.local_settings.max_concurrent_streams, + ) + if new_max_streams and new_max_streams != self._max_streams: + while new_max_streams > self._max_streams: + await self._max_streams_semaphore.release() + self._max_streams += 1 + while new_max_streams < self._max_streams: + await self._max_streams_semaphore.acquire() + self._max_streams -= 1 + + async def _response_closed(self, stream_id: int) -> None: + await self._max_streams_semaphore.release() + del self._events[stream_id] + async with self._state_lock: + if self._connection_terminated and not self._events: + await self.aclose() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + await self.aclose() + + async def aclose(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._h2_state.close_connection() + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # Wrappers around network read/write operations... + + async def _read_incoming_data(self, request: Request) -> list[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + + events: list[h2.events.Event] = self._h2_state.receive_data(data) + + return events + + async def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + async with self._write_lock: + data_to_send = self._h2_state.data_to_send() + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + await self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc + + # Flow control... + + async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow: int = self._h2_state.local_flow_control_window(stream_id) + max_frame_size: int = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + await self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> AsyncHTTP2Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + await self.aclose() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: AsyncHTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + self._closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + try: + async with Trace("receive_response_body", logger, self._request, kwargs): + async for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + kwargs = {"stream_id": self._stream_id} + async with Trace("response_closed", logger, self._request, kwargs): + await self._connection._response_closed(stream_id=self._stream_id) diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/http_proxy.py b/venv/lib/python3.10/site-packages/httpcore/_async/http_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9d92066e1680576846e46ccdf645a2b1dd5718 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/http_proxy.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import base64 +import logging +import ssl +import typing + +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend +from .._exceptions import ProxyError +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, +) +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .connection import AsyncHTTPConnection +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +ByteOrStr = typing.Union[bytes, str] +HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]] +HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr] + + +logger = logging.getLogger("httpcore.proxy") + + +def merge_headers( + default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, + override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, +) -> list[tuple[bytes, bytes]]: + """ + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. + """ + default_headers = [] if default_headers is None else list(default_headers) + override_headers = [] if override_headers is None else list(override_headers) + has_override = set(key.lower() for key, value in override_headers) + default_headers = [ + (key, value) + for key, value in default_headers + if key.lower() not in has_override + ] + return default_headers + override_headers + + +class AsyncHTTPProxy(AsyncConnectionPool): # pragma: nocover + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: URL | bytes | str, + proxy_auth: tuple[bytes | str, bytes | str] | None = None, + proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None, + ssl_context: ssl.SSLContext | None = None, + proxy_ssl_context: ssl.SSLContext | None = None, + max_connections: int | None = 10, + max_keepalive_connections: int | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: str | None = None, + uds: str | None = None, + network_backend: AsyncNetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_auth: Any proxy authentication as a two-tuple of + (username, password). May be either bytes or ascii-only str. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic :"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, + socket_options=socket_options, + ) + + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if ( + self._proxy_url.scheme == b"http" and proxy_ssl_context is not None + ): # pragma: no cover + raise RuntimeError( + "The `proxy_ssl_context` argument is not allowed for the http scheme" + ) + + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + if proxy_auth is not None: + username = enforce_bytes(proxy_auth[0], name="proxy_auth") + password = enforce_bytes(proxy_auth[1], name="proxy_auth") + userpass = username + b":" + password + authorization = b"Basic " + base64.b64encode(userpass) + self._proxy_headers = [ + (b"Proxy-Authorization", authorization) + ] + self._proxy_headers + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + if origin.scheme == b"http": + return AsyncForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + proxy_ssl_context=self._proxy_ssl_context, + ) + return AsyncTunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + ssl_context=self._ssl_context, + proxy_ssl_context=self._proxy_ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class AsyncForwardHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None, + keepalive_expiry: float | None = None, + network_backend: AsyncNetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + proxy_ssl_context: ssl.SSLContext | None = None, + ) -> None: + self._connection = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._remote_origin = remote_origin + + async def handle_async_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), + ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return await self._connection.handle_async_request(proxy_request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + await self._connection.aclose() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class AsyncTunnelHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: ssl.SSLContext | None = None, + proxy_ssl_context: ssl.SSLContext | None = None, + proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + network_backend: AsyncNetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + self._connection: AsyncConnectionInterface = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._connect_lock = AsyncLock() + self._connected = False + + async def handle_async_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + async with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = await self._connection.handle_async_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + await self._connection.aclose() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + return await self._connection.handle_async_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + await self._connection.aclose() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/interfaces.py b/venv/lib/python3.10/site-packages/httpcore/_async/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..361583bede6b2b84088b38054d5d8116ef9f1597 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/interfaces.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import contextlib +import typing + +from .._models import ( + URL, + Extensions, + HeaderTypes, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class AsyncRequestInterface: + async def request( + self, + method: bytes | str, + url: URL | bytes | str, + *, + headers: HeaderTypes = None, + content: bytes | typing.AsyncIterator[bytes] | None = None, + extensions: Extensions | None = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + await response.aread() + finally: + await response.aclose() + return response + + @contextlib.asynccontextmanager + async def stream( + self, + method: bytes | str, + url: URL | bytes | str, + *, + headers: HeaderTypes = None, + content: bytes | typing.AsyncIterator[bytes] | None = None, + extensions: Extensions | None = None, + ) -> typing.AsyncIterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + yield response + finally: + await response.aclose() + + async def handle_async_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class AsyncConnectionInterface(AsyncRequestInterface): + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/venv/lib/python3.10/site-packages/httpcore/_async/socks_proxy.py b/venv/lib/python3.10/site-packages/httpcore/_async/socks_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..b363f55a0b071de6c5f377726be82dc2110e373c --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_async/socks_proxy.py @@ -0,0 +1,341 @@ +from __future__ import annotations + +import logging +import ssl + +import socksio + +from .._backends.auto import AutoBackend +from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream +from .._exceptions import ConnectionNotAvailable, ProxyError +from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.socks") + + +AUTH_METHODS = { + b"\x00": "NO AUTHENTICATION REQUIRED", + b"\x01": "GSSAPI", + b"\x02": "USERNAME/PASSWORD", + b"\xff": "NO ACCEPTABLE METHODS", +} + +REPLY_CODES = { + b"\x00": "Succeeded", + b"\x01": "General SOCKS server failure", + b"\x02": "Connection not allowed by ruleset", + b"\x03": "Network unreachable", + b"\x04": "Host unreachable", + b"\x05": "Connection refused", + b"\x06": "TTL expired", + b"\x07": "Command not supported", + b"\x08": "Address type not supported", +} + + +async def _init_socks5_connection( + stream: AsyncNetworkStream, + *, + host: bytes, + port: int, + auth: tuple[bytes, bytes] | None = None, +) -> None: + conn = socksio.socks5.SOCKS5Connection() + + # Auth method request + auth_method = ( + socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED + if auth is None + else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD + ) + conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method])) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Auth method response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socksio.socks5.SOCKS5AuthReply) + if response.method != auth_method: + requested = AUTH_METHODS.get(auth_method, "UNKNOWN") + responded = AUTH_METHODS.get(response.method, "UNKNOWN") + raise ProxyError( + f"Requested {requested} from proxy server, but got {responded}." + ) + + if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD: + # Username/password request + assert auth is not None + username, password = auth + conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password)) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Username/password response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply) + if not response.success: + raise ProxyError("Invalid username/password") + + # Connect request + conn.send( + socksio.socks5.SOCKS5CommandRequest.from_address( + socksio.socks5.SOCKS5Command.CONNECT, (host, port) + ) + ) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Connect response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socksio.socks5.SOCKS5Reply) + if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED: + reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN") + raise ProxyError(f"Proxy Server could not connect: {reply_code}.") + + +class AsyncSOCKSProxy(AsyncConnectionPool): # pragma: nocover + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: URL | bytes | str, + proxy_auth: tuple[bytes | str, bytes | str] | None = None, + ssl_context: ssl.SSLContext | None = None, + max_connections: int | None = 10, + max_keepalive_connections: int | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + network_backend: AsyncNetworkBackend | None = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + ) + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if proxy_auth is not None: + username, password = proxy_auth + username_bytes = enforce_bytes(username, name="proxy_auth") + password_bytes = enforce_bytes(password, name="proxy_auth") + self._proxy_auth: tuple[bytes, bytes] | None = ( + username_bytes, + password_bytes, + ) + else: + self._proxy_auth = None + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + return AsyncSocks5Connection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + proxy_auth=self._proxy_auth, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class AsyncSocks5Connection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_auth: tuple[bytes, bytes] | None = None, + ssl_context: ssl.SSLContext | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + network_backend: AsyncNetworkBackend | None = None, + ) -> None: + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._proxy_auth = proxy_auth + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + + self._network_backend: AsyncNetworkBackend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._connect_lock = AsyncLock() + self._connection: AsyncConnectionInterface | None = None + self._connect_failed = False + + async def handle_async_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + async with self._connect_lock: + if self._connection is None: + try: + # Connect to the proxy + kwargs = { + "host": self._proxy_origin.host.decode("ascii"), + "port": self._proxy_origin.port, + "timeout": timeout, + } + async with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = await self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + + # Connect to the remote host using socks5 + kwargs = { + "stream": stream, + "host": self._remote_origin.host.decode("ascii"), + "port": self._remote_origin.port, + "auth": self._proxy_auth, + } + async with Trace( + "setup_socks5_connection", logger, request, kwargs + ) as trace: + await _init_socks5_connection(**kwargs) + trace.return_value = stream + + # Upgrade the stream to SSL + if self._remote_origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ( + ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ) + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or ( + self._http2 and not self._http1 + ): # pragma: nocover + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): # pragma: nocover + raise ConnectionNotAvailable() + + return await self._connection.handle_async_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + if self._connection is not None: + await self._connection.aclose() + + def is_available(self) -> bool: + if self._connection is None: # pragma: nocover + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._remote_origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: # pragma: nocover + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__init__.py b/venv/lib/python3.10/site-packages/httpcore/_backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff0f42573f89424c9d4d39eaa53765855a8aa475 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/anyio.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/anyio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53aec97ee5378c804f4f2bce11eba73099b7c9f5 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/anyio.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/auto.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/auto.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..670563658cf6d71fb43cd830c73586c69a83e6b4 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/auto.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/base.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b4aae61338758ae0f132a76b86e30e3500fa20d Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/base.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/mock.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/mock.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cd179f0e6c1fecd9ab56c6757e71b95238a372f Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/mock.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/sync.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/sync.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1524cd777cc7e8576bb91880c2b54cd53542ffb4 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/sync.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/trio.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/trio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d19067460c46ecaa360f4169d32e30b836ee24eb Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_backends/__pycache__/trio.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/anyio.py b/venv/lib/python3.10/site-packages/httpcore/_backends/anyio.py new file mode 100644 index 0000000000000000000000000000000000000000..a140095e1b8de022f321a41c0125e0e5febc0749 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_backends/anyio.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import ssl +import typing + +import anyio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AnyIOStream(AsyncNetworkStream): + def __init__(self, stream: anyio.abc.ByteStream) -> None: + self._stream = stream + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + exc_map = { + TimeoutError: ReadTimeout, + anyio.BrokenResourceError: ReadError, + anyio.ClosedResourceError: ReadError, + anyio.EndOfStream: ReadError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + try: + return await self._stream.receive(max_bytes=max_bytes) + except anyio.EndOfStream: # pragma: nocover + return b"" + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + if not buffer: + return + + exc_map = { + TimeoutError: WriteTimeout, + anyio.BrokenResourceError: WriteError, + anyio.ClosedResourceError: WriteError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + await self._stream.send(item=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> AsyncNetworkStream: + exc_map = { + TimeoutError: ConnectTimeout, + anyio.BrokenResourceError: ConnectError, + anyio.EndOfStream: ConnectError, + ssl.SSLError: ConnectError, + } + with map_exceptions(exc_map): + try: + with anyio.fail_after(timeout): + ssl_stream = await anyio.streams.tls.TLSStream.wrap( + self._stream, + ssl_context=ssl_context, + hostname=server_hostname, + standard_compatible=False, + server_side=False, + ) + except Exception as exc: # pragma: nocover + await self.aclose() + raise exc + return AnyIOStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object": + return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None) + if info == "client_addr": + return self._stream.extra(anyio.abc.SocketAttribute.local_address, None) + if info == "server_addr": + return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None) + if info == "socket": + return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + if info == "is_readable": + sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + return is_socket_readable(sock) + return None + + +class AnyIOBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_tcp( + remote_host=host, + remote_port=port, + local_host=local_address, + ) + # By default TCP sockets opened in `asyncio` include TCP_NODELAY. + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return AnyIOStream(stream) + + async def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_unix(path) + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return AnyIOStream(stream) + + async def sleep(self, seconds: float) -> None: + await anyio.sleep(seconds) # pragma: nocover diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/auto.py b/venv/lib/python3.10/site-packages/httpcore/_backends/auto.py new file mode 100644 index 0000000000000000000000000000000000000000..49f0e698c97ad5623f376d8182675352e21c2c3c --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_backends/auto.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import typing + +from .._synchronization import current_async_library +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AutoBackend(AsyncNetworkBackend): + async def _init_backend(self) -> None: + if not (hasattr(self, "_backend")): + backend = current_async_library() + if backend == "trio": + from .trio import TrioBackend + + self._backend: AsyncNetworkBackend = TrioBackend() + else: + from .anyio import AnyIOBackend + + self._backend = AnyIOBackend() + + async def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + await self._init_backend() + return await self._backend.connect_tcp( + host, + port, + timeout=timeout, + local_address=local_address, + socket_options=socket_options, + ) + + async def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: # pragma: nocover + await self._init_backend() + return await self._backend.connect_unix_socket( + path, timeout=timeout, socket_options=socket_options + ) + + async def sleep(self, seconds: float) -> None: # pragma: nocover + await self._init_backend() + return await self._backend.sleep(seconds) diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/base.py b/venv/lib/python3.10/site-packages/httpcore/_backends/base.py new file mode 100644 index 0000000000000000000000000000000000000000..cf55c8b10eb543872550be863206fe2f760d0d8d --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_backends/base.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import ssl +import time +import typing + +SOCKET_OPTION = typing.Union[ + typing.Tuple[int, int, int], + typing.Tuple[int, int, typing.Union[bytes, bytearray]], + typing.Tuple[int, int, None, int], +] + + +class NetworkStream: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise NotImplementedError() # pragma: nocover + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise NotImplementedError() # pragma: nocover + + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class NetworkBackend: + def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def sleep(self, seconds: float) -> None: + time.sleep(seconds) # pragma: nocover + + +class AsyncNetworkStream: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise NotImplementedError() # pragma: nocover + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise NotImplementedError() # pragma: nocover + + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class AsyncNetworkBackend: + async def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def sleep(self, seconds: float) -> None: + raise NotImplementedError() # pragma: nocover diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/mock.py b/venv/lib/python3.10/site-packages/httpcore/_backends/mock.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6edca03d4d4b34f355fd53e49d4b4c699c972c --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_backends/mock.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import ssl +import typing + +from .._exceptions import ReadError +from .base import ( + SOCKET_OPTION, + AsyncNetworkBackend, + AsyncNetworkStream, + NetworkBackend, + NetworkStream, +) + + +class MockSSLObject: + def __init__(self, http2: bool): + self._http2 = http2 + + def selected_alpn_protocol(self) -> str: + return "h2" if self._http2 else "http/1.1" + + +class MockStream(NetworkStream): + def __init__(self, buffer: list[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + self._closed = False + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + if not self._buffer: + return b"" + return self._buffer.pop(0) + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + self._closed = True + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> NetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + def __repr__(self) -> str: + return "" + + +class MockBackend(NetworkBackend): + def __init__(self, buffer: list[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def sleep(self, seconds: float) -> None: + pass + + +class AsyncMockStream(AsyncNetworkStream): + def __init__(self, buffer: list[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + self._closed = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + if not self._buffer: + return b"" + return self._buffer.pop(0) + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + self._closed = True + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> AsyncNetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + def __repr__(self) -> str: + return "" + + +class AsyncMockBackend(AsyncNetworkBackend): + def __init__(self, buffer: list[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + async def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def sleep(self, seconds: float) -> None: + pass diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/sync.py b/venv/lib/python3.10/site-packages/httpcore/_backends/sync.py new file mode 100644 index 0000000000000000000000000000000000000000..4018a09c6fb1e0ef1b03ab8d84b13ebef4031f7c --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_backends/sync.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import functools +import socket +import ssl +import sys +import typing + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ExceptionMapping, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, NetworkBackend, NetworkStream + + +class TLSinTLSStream(NetworkStream): # pragma: no cover + """ + Because the standard `SSLContext.wrap_socket` method does + not work for `SSLSocket` objects, we need this class + to implement TLS stream using an underlying `SSLObject` + instance in order to support TLS on top of TLS. + """ + + # Defined in RFC 8449 + TLS_RECORD_SIZE = 16384 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ): + self._sock = sock + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + + self.ssl_obj = ssl_context.wrap_bio( + incoming=self._incoming, + outgoing=self._outgoing, + server_hostname=server_hostname, + ) + + self._sock.settimeout(timeout) + self._perform_io(self.ssl_obj.do_handshake) + + def _perform_io( + self, + func: typing.Callable[..., typing.Any], + ) -> typing.Any: + ret = None + + while True: + errno = None + try: + ret = func() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: + errno = e.errno + + self._sock.sendall(self._outgoing.read()) + + if errno == ssl.SSL_ERROR_WANT_READ: + buf = self._sock.recv(self.TLS_RECORD_SIZE) + + if buf: + self._incoming.write(buf) + else: + self._incoming.write_eof() + if errno is None: + return ret + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + return typing.cast( + bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes)) + ) + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + while buffer: + nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer)) + buffer = buffer[nsent:] + + def close(self) -> None: + self._sock.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> NetworkStream: + raise NotImplementedError() + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object": + return self.ssl_obj + if info == "client_addr": + return self._sock.getsockname() + if info == "server_addr": + return self._sock.getpeername() + if info == "socket": + return self._sock + if info == "is_readable": + return is_socket_readable(self._sock) + return None + + +class SyncStream(NetworkStream): + def __init__(self, sock: socket.socket) -> None: + self._sock = sock + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + return self._sock.recv(max_bytes) + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + if not buffer: + return + + exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} + with map_exceptions(exc_map): + while buffer: + self._sock.settimeout(timeout) + n = self._sock.send(buffer) + buffer = buffer[n:] + + def close(self) -> None: + self._sock.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> NetworkStream: + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + try: + if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover + # If the underlying socket has already been upgraded + # to the TLS layer (i.e. is an instance of SSLSocket), + # we need some additional smarts to support TLS-in-TLS. + return TLSinTLSStream( + self._sock, ssl_context, server_hostname, timeout + ) + else: + self._sock.settimeout(timeout) + sock = ssl_context.wrap_socket( + self._sock, server_hostname=server_hostname + ) + except Exception as exc: # pragma: nocover + self.close() + raise exc + return SyncStream(sock) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): + return self._sock._sslobj # type: ignore + if info == "client_addr": + return self._sock.getsockname() + if info == "server_addr": + return self._sock.getpeername() + if info == "socket": + return self._sock + if info == "is_readable": + return is_socket_readable(self._sock) + return None + + +class SyncBackend(NetworkBackend): + def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: + # Note that we automatically include `TCP_NODELAY` + # in addition to any other custom socket options. + if socket_options is None: + socket_options = [] # pragma: no cover + address = (host, port) + source_address = None if local_address is None else (local_address, 0) + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + + with map_exceptions(exc_map): + sock = socket.create_connection( + address, + timeout, + source_address=source_address, + ) + for option in socket_options: + sock.setsockopt(*option) # pragma: no cover + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return SyncStream(sock) + + def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: # pragma: nocover + if sys.platform == "win32": + raise RuntimeError( + "Attempted to connect to a UNIX socket on a Windows system." + ) + if socket_options is None: + socket_options = [] + + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + for option in socket_options: + sock.setsockopt(*option) + sock.settimeout(timeout) + sock.connect(path) + return SyncStream(sock) diff --git a/venv/lib/python3.10/site-packages/httpcore/_backends/trio.py b/venv/lib/python3.10/site-packages/httpcore/_backends/trio.py new file mode 100644 index 0000000000000000000000000000000000000000..6f53f5f2a025e01e9949e2530bd9ca6928859251 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_backends/trio.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import ssl +import typing + +import trio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ExceptionMapping, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class TrioStream(AsyncNetworkStream): + def __init__(self, stream: trio.abc.Stream) -> None: + self._stream = stream + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ReadTimeout, + trio.BrokenResourceError: ReadError, + trio.ClosedResourceError: ReadError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + data: bytes = await self._stream.receive_some(max_bytes=max_bytes) + return data + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + if not buffer: + return + + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: WriteTimeout, + trio.BrokenResourceError: WriteError, + trio.ClosedResourceError: WriteError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + await self._stream.send_all(data=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> AsyncNetworkStream: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + } + ssl_stream = trio.SSLStream( + self._stream, + ssl_context=ssl_context, + server_hostname=server_hostname, + https_compatible=True, + server_side=False, + ) + with map_exceptions(exc_map): + try: + with trio.fail_after(timeout_or_inf): + await ssl_stream.do_handshake() + except Exception as exc: # pragma: nocover + await self.aclose() + raise exc + return TrioStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._stream, trio.SSLStream): + # Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__. + # Tracked at https://github.com/python-trio/trio/issues/542 + return self._stream._ssl_object # type: ignore[attr-defined] + if info == "client_addr": + return self._get_socket_stream().socket.getsockname() + if info == "server_addr": + return self._get_socket_stream().socket.getpeername() + if info == "socket": + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream.socket + if info == "is_readable": + socket = self.get_extra_info("socket") + return socket.is_readable() + return None + + def _get_socket_stream(self) -> trio.SocketStream: + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream + + +class TrioBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + # By default for TCP sockets, trio enables TCP_NODELAY. + # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream + if socket_options is None: + socket_options = [] # pragma: no cover + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + OSError: ConnectError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_tcp_stream( + host=host, port=port, local_address=local_address + ) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return TrioStream(stream) + + async def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + OSError: ConnectError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_unix_socket(path) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return TrioStream(stream) + + async def sleep(self, seconds: float) -> None: + await trio.sleep(seconds) # pragma: nocover diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__init__.py b/venv/lib/python3.10/site-packages/httpcore/_sync/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b476d76d9a7ff45de8d18ec22d33d6af2982f92e --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/__init__.py @@ -0,0 +1,39 @@ +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .http_proxy import HTTPProxy +from .interfaces import ConnectionInterface + +try: + from .http2 import HTTP2Connection +except ImportError: # pragma: nocover + + class HTTP2Connection: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use http2 support, but the `h2` package is not " + "installed. Use 'pip install httpcore[http2]'." + ) + + +try: + from .socks_proxy import SOCKSProxy +except ImportError: # pragma: nocover + + class SOCKSProxy: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use SOCKS support, but the `socksio` package is not " + "installed. Use 'pip install httpcore[socks]'." + ) + + +__all__ = [ + "HTTPConnection", + "ConnectionPool", + "HTTPProxy", + "HTTP11Connection", + "HTTP2Connection", + "ConnectionInterface", + "SOCKSProxy", +] diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c3b1626492ce186e1dcf12847362fdc7bf1a4d6 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/connection.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/connection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3920457253e25a53d25ab5cfa148e876b93c241 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/connection.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/connection_pool.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/connection_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edde6413f593ca93ef8e6abca59821234b72fb90 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/connection_pool.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http11.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http11.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e37420b9c48269c3bda2f831e8bac5336d9c3b5a Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http11.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http2.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..555aff9e6389abf15f2cade34d140895000b917d Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http2.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http_proxy.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http_proxy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..432f0bce7553b93e84558fe62a84993d549fd838 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/http_proxy.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/interfaces.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/interfaces.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615ebb81e53ec3be7ef653377dcfdb9238990480 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/interfaces.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/socks_proxy.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/socks_proxy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5728a2c89f25998153c7e50fc600a31c446037db Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpcore/_sync/__pycache__/socks_proxy.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/connection.py b/venv/lib/python3.10/site-packages/httpcore/_sync/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..363f8be819d2576ea65365e625dd1596ea40429a --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/connection.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import itertools +import logging +import ssl +import types +import typing + +from .._backends.sync import SyncBackend +from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream +from .._exceptions import ConnectError, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + + +logger = logging.getLogger("httpcore.connection") + + +def exponential_backoff(factor: float) -> typing.Iterator[float]: + """ + Generate a geometric sequence that has a ratio of 2 and starts with 0. + + For example: + - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...` + - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...` + """ + yield 0 + for n in itertools.count(): + yield factor * 2**n + + +class HTTPConnection(ConnectionInterface): + def __init__( + self, + origin: Origin, + ssl_context: ssl.SSLContext | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: str | None = None, + uds: str | None = None, + network_backend: NetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + self._origin = origin + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend: NetworkBackend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._connection: ConnectionInterface | None = None + self._connect_failed: bool = False + self._request_lock = Lock() + self._socket_options = socket_options + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) + + try: + with self._request_lock: + if self._connection is None: + stream = self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except BaseException as exc: + self._connect_failed = True + raise exc + + return self._connection.handle_request(request) + + def _connect(self, request: Request) -> NetworkStream: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + "socket_options": self._socket_options, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + else: + kwargs = { + "path": self._uds, + "timeout": timeout, + "socket_options": self._socket_options, + } + with Trace( + "connect_unix_socket", logger, request, kwargs + ) as trace: + stream = self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream + + if self._origin.scheme in (b"https", b"wss"): + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + return stream + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + with Trace("retry", logger, request, kwargs) as trace: + self._network_backend.sleep(delay) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def close(self) -> None: + if self._connection is not None: + with Trace("close", logger, None, {}): + self._connection.close() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> HTTPConnection: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + self.close() diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/connection_pool.py b/venv/lib/python3.10/site-packages/httpcore/_sync/connection_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..9ccfa53e597a29ee387f9d16f3af4f695ac0d33a --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/connection_pool.py @@ -0,0 +1,420 @@ +from __future__ import annotations + +import ssl +import sys +import types +import typing + +from .._backends.sync import SyncBackend +from .._backends.base import SOCKET_OPTION, NetworkBackend +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Proxy, Request, Response +from .._synchronization import Event, ShieldCancellation, ThreadLock +from .connection import HTTPConnection +from .interfaces import ConnectionInterface, RequestInterface + + +class PoolRequest: + def __init__(self, request: Request) -> None: + self.request = request + self.connection: ConnectionInterface | None = None + self._connection_acquired = Event() + + def assign_to_connection(self, connection: ConnectionInterface | None) -> None: + self.connection = connection + self._connection_acquired.set() + + def clear_connection(self) -> None: + self.connection = None + self._connection_acquired = Event() + + def wait_for_connection( + self, timeout: float | None = None + ) -> ConnectionInterface: + if self.connection is None: + self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + def is_queued(self) -> bool: + return self.connection is None + + +class ConnectionPool(RequestInterface): + """ + A connection pool for making HTTP requests. + """ + + def __init__( + self, + ssl_context: ssl.SSLContext | None = None, + proxy: Proxy | None = None, + max_connections: int | None = 10, + max_keepalive_connections: int | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: str | None = None, + uds: str | None = None, + network_backend: NetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + socket_options: Socket options that have to be included + in the TCP socket when the connection was established. + """ + self._ssl_context = ssl_context + self._proxy = proxy + self._max_connections = ( + sys.maxsize if max_connections is None else max_connections + ) + self._max_keepalive_connections = ( + sys.maxsize + if max_keepalive_connections is None + else max_keepalive_connections + ) + self._max_keepalive_connections = min( + self._max_connections, self._max_keepalive_connections + ) + + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._socket_options = socket_options + + # The mutable state on a connection pool is the queue of incoming requests, + # and the set of connections that are servicing those requests. + self._connections: list[ConnectionInterface] = [] + self._requests: list[PoolRequest] = [] + + # We only mutate the state of the connection pool within an 'optional_thread_lock' + # context. This holds a threading lock unless we're running in async mode, + # in which case it is a no-op. + self._optional_thread_lock = ThreadLock() + + def create_connection(self, origin: Origin) -> ConnectionInterface: + if self._proxy is not None: + if self._proxy.url.scheme in (b"socks5", b"socks5h"): + from .socks_proxy import Socks5Connection + + return Socks5Connection( + proxy_origin=self._proxy.url.origin, + proxy_auth=self._proxy.auth, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + elif origin.scheme == b"http": + from .http_proxy import ForwardHTTPConnection + + return ForwardHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + ) + from .http_proxy import TunnelHTTPConnection + + return TunnelHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + return HTTPConnection( + origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + retries=self._retries, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, + socket_options=self._socket_options, + ) + + @property + def connections(self) -> list[ConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + , + , + , + ] + ``` + """ + return list(self._connections) + + def handle_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": + raise UnsupportedProtocol( + "Request URL is missing an 'http://' or 'https://' protocol." + ) + if scheme not in ("http", "https", "ws", "wss"): + raise UnsupportedProtocol( + f"Request URL has an unsupported protocol '{scheme}://'." + ) + + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + + with self._optional_thread_lock: + # Add the incoming request to our request queue. + pool_request = PoolRequest(request) + self._requests.append(pool_request) + + try: + while True: + with self._optional_thread_lock: + # Assign incoming requests to available connections, + # closing or creating new connections as required. + closing = self._assign_requests_to_connections() + self._close_connections(closing) + + # Wait until this request has an assigned connection. + connection = pool_request.wait_for_connection(timeout=timeout) + + try: + # Send the request on the assigned connection. + response = connection.handle_request( + pool_request.request + ) + except ConnectionNotAvailable: + # In some cases a connection may initially be available to + # handle a request, but then become unavailable. + # + # In this case we clear the connection and try again. + pool_request.clear_connection() + else: + break # pragma: nocover + + except BaseException as exc: + with self._optional_thread_lock: + # For any exception or cancellation we remove the request from + # the queue, and then re-assign requests to connections. + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() + + self._close_connections(closing) + raise exc from None + + # Return the response. Note that in this case we still have to manage + # the point at which the response is closed. + assert isinstance(response.stream, typing.Iterable) + return Response( + status=response.status, + headers=response.headers, + content=PoolByteStream( + stream=response.stream, pool_request=pool_request, pool=self + ), + extensions=response.extensions, + ) + + def _assign_requests_to_connections(self) -> list[ConnectionInterface]: + """ + Manage the state of the connection pool, assigning incoming + requests to connections as available. + + Called whenever a new request is added or removed from the pool. + + Any closing connections are returned, allowing the I/O for closing + those connections to be handled seperately. + """ + closing_connections = [] + + # First we handle cleaning up any connections that are closed, + # have expired their keep-alive, or surplus idle connections. + for connection in list(self._connections): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): + # log: "closing expired connection" + self._connections.remove(connection) + closing_connections.append(connection) + elif ( + connection.is_idle() + and len([connection.is_idle() for connection in self._connections]) + > self._max_keepalive_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) + + # Assign queued requests to connections. + queued_requests = [request for request in self._requests if request.is_queued()] + for pool_request in queued_requests: + origin = pool_request.request.url.origin + available_connections = [ + connection + for connection in self._connections + if connection.can_handle_request(origin) and connection.is_available() + ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] + + # There are three cases for how we may be able to handle the request: + # + # 1. There is an existing connection that can handle the request. + # 2. We can create a new connection to handle the request. + # 3. We can close an idle connection and then create a new connection + # to handle the request. + if available_connections: + # log: "reusing existing connection" + connection = available_connections[0] + pool_request.assign_to_connection(connection) + elif len(self._connections) < self._max_connections: + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + elif idle_connections: + # log: "closing idle connection" + connection = idle_connections[0] + self._connections.remove(connection) + closing_connections.append(connection) + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + + return closing_connections + + def _close_connections(self, closing: list[ConnectionInterface]) -> None: + # Close connections which have been removed from the pool. + with ShieldCancellation(): + for connection in closing: + connection.close() + + def close(self) -> None: + # Explicitly close the connection pool. + # Clears all existing requests and connections. + with self._optional_thread_lock: + closing_connections = list(self._connections) + self._connections = [] + self._close_connections(closing_connections) + + def __enter__(self) -> ConnectionPool: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + self.close() + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + with self._optional_thread_lock: + request_is_queued = [request.is_queued() for request in self._requests] + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + + num_active_requests = request_is_queued.count(False) + num_queued_requests = request_is_queued.count(True) + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) + + requests_info = ( + f"Requests: {num_active_requests} active, {num_queued_requests} queued" + ) + connection_info = ( + f"Connections: {num_active_connections} active, {num_idle_connections} idle" + ) + + return f"<{class_name} [{requests_info} | {connection_info}]>" + + +class PoolByteStream: + def __init__( + self, + stream: typing.Iterable[bytes], + pool_request: PoolRequest, + pool: ConnectionPool, + ) -> None: + self._stream = stream + self._pool_request = pool_request + self._pool = pool + self._closed = False + + def __iter__(self) -> typing.Iterator[bytes]: + try: + for part in self._stream: + yield part + except BaseException as exc: + self.close() + raise exc from None + + def close(self) -> None: + if not self._closed: + self._closed = True + with ShieldCancellation(): + if hasattr(self._stream, "close"): + self._stream.close() + + with self._pool._optional_thread_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + + self._pool._close_connections(closing) diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/http11.py b/venv/lib/python3.10/site-packages/httpcore/_sync/http11.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd3a97480c720d418acb1285a7b75da19b62c8c --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/http11.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +import enum +import logging +import ssl +import time +import types +import typing + +import h11 + +from .._backends.base import NetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + WriteError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import Lock, ShieldCancellation +from .._trace import Trace +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.http11") + + +# A subset of `h11.Event` types supported by `_send_event` +H11SendEvent = typing.Union[ + h11.Request, + h11.Data, + h11.EndOfMessage, +] + + +class HTTPConnectionState(enum.IntEnum): + NEW = 0 + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class HTTP11Connection(ConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 + + def __init__( + self, + origin: Origin, + stream: NetworkStream, + keepalive_expiry: float | None = None, + ) -> None: + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: float | None = keepalive_expiry + self._expire_at: float | None = None + self._state = HTTPConnectionState.NEW + self._state_lock = Lock() + self._request_count = 0 + self._h11_state = h11.Connection( + our_role=h11.CLIENT, + max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, + ) + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + try: + with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + self._send_request_headers(**kwargs) + with Trace("send_request_body", logger, request, kwargs) as trace: + self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + trailing_data, + ) = self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = HTTP11UpgradeStream(network_stream, trailing_data) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": network_stream, + }, + ) + except BaseException as exc: + with ShieldCancellation(): + with Trace("response_closed", logger, request) as trace: + self._response_closed() + raise exc + + # Sending the request... + + def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + self._send_event(event, timeout=timeout) + + def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, typing.Iterable) + for chunk in request.stream: + event = h11.Data(data=chunk) + self._send_event(event, timeout=timeout) + + self._send_event(h11.EndOfMessage(), timeout=timeout) + + def _send_event(self, event: h11.Event, timeout: float | None = None) -> None: + bytes_to_send = self._h11_state.send(event) + if bytes_to_send is not None: + self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... + + def _receive_response_headers( + self, request: Request + ) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = self._receive_event(timeout=timeout) + if isinstance(event, h11.Response): + break + if ( + isinstance(event, h11.InformationalResponse) + and event.status_code == 101 + ): + break + + http_version = b"HTTP/" + event.http_version + + # h11 version 0.11+ supports a `raw_items` interface to get the + # raw header casing, rather than the enforced lowercase headers. + headers = event.headers.raw_items() + + trailing_data, _ = self._h11_state.trailing_data + + return http_version, event.status_code, event.reason, headers, trailing_data + + def _receive_response_body( + self, request: Request + ) -> typing.Iterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = self._receive_event(timeout=timeout) + if isinstance(event, h11.Data): + yield bytes(event.data) + elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + break + + def _receive_event( + self, timeout: float | None = None + ) -> h11.Event | type[h11.PAUSED]: + while True: + with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + event = self._h11_state.next_event() + + if event is h11.NEED_DATA: + data = self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) + + # If we feed this case through h11 we'll raise an exception like: + # + # httpcore.RemoteProtocolError: can't handle event type + # ConnectionClosed when role=SERVER and state=SEND_RESPONSE + # + # Which is accurate, but not very informative from an end-user + # perspective. Instead we handle this case distinctly and treat + # it as a ConnectError. + if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: + msg = "Server disconnected without sending a response." + raise RemoteProtocolError(msg) + + self._h11_state.receive_data(data) + else: + # mypy fails to narrow the type in the above if statement above + return event # type: ignore[return-value] + + def _response_closed(self) -> None: + with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self.close() + + # Once the connection is no longer required... + + def close(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # The ConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> HTTP11Connection: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + self.close() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: HTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + def __iter__(self) -> typing.Iterator[bytes]: + kwargs = {"request": self._request} + try: + with Trace("receive_response_body", logger, self._request, kwargs): + for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with ShieldCancellation(): + self.close() + raise exc + + def close(self) -> None: + if not self._closed: + self._closed = True + with Trace("response_closed", logger, self._request): + self._connection._response_closed() + + +class HTTP11UpgradeStream(NetworkStream): + def __init__(self, stream: NetworkStream, leading_data: bytes) -> None: + self._stream = stream + self._leading_data = leading_data + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + if self._leading_data: + buffer = self._leading_data[:max_bytes] + self._leading_data = self._leading_data[max_bytes:] + return buffer + else: + return self._stream.read(max_bytes, timeout) + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + self._stream.write(buffer, timeout) + + def close(self) -> None: + self._stream.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + timeout: float | None = None, + ) -> NetworkStream: + return self._stream.start_tls(ssl_context, server_hostname, timeout) + + def get_extra_info(self, info: str) -> typing.Any: + return self._stream.get_extra_info(info) diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/http2.py b/venv/lib/python3.10/site-packages/httpcore/_sync/http2.py new file mode 100644 index 0000000000000000000000000000000000000000..ddcc189001c50c37c6a03810dc21d955df919f10 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/http2.py @@ -0,0 +1,592 @@ +from __future__ import annotations + +import enum +import logging +import time +import types +import typing + +import h2.config +import h2.connection +import h2.events +import h2.exceptions +import h2.settings + +from .._backends.base import NetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, +) +from .._models import Origin, Request, Response +from .._synchronization import Lock, Semaphore, ShieldCancellation +from .._trace import Trace +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.http2") + + +def has_body_headers(request: Request) -> bool: + return any( + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ) + + +class HTTPConnectionState(enum.IntEnum): + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class HTTP2Connection(ConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) + + def __init__( + self, + origin: Origin, + stream: NetworkStream, + keepalive_expiry: float | None = None, + ): + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: float | None = keepalive_expiry + self._h2_state = h2.connection.H2Connection(config=self.CONFIG) + self._state = HTTPConnectionState.IDLE + self._expire_at: float | None = None + self._request_count = 0 + self._init_lock = Lock() + self._state_lock = Lock() + self._read_lock = Lock() + self._write_lock = Lock() + self._sent_connection_init = False + self._used_all_stream_ids = False + self._connection_error = False + + # Mapping from stream ID to response stream events. + self._events: dict[ + int, + list[ + h2.events.ResponseReceived + | h2.events.DataReceived + | h2.events.StreamEnded + | h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: h2.events.ConnectionTerminated | None = None + + self._read_exception: Exception | None = None + self._write_exception: Exception | None = None + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + # This cannot occur in normal operation, since the connection pool + # will only send requests on connections that handle them. + # It's in place simply for resilience as a guard against incorrect + # usage, for anyone working directly with httpcore connections. + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() + + with self._init_lock: + if not self._sent_connection_init: + try: + sci_kwargs = {"request": request} + with Trace( + "send_connection_init", logger, request, sci_kwargs + ): + self._send_connection_init(**sci_kwargs) + except BaseException as exc: + with ShieldCancellation(): + self.close() + raise exc + + self._sent_connection_init = True + + # Initially start with just 1 until the remote server provides + # its max_concurrent_streams value + self._max_streams = 1 + + local_settings_max_streams = ( + self._h2_state.local_settings.max_concurrent_streams + ) + self._max_streams_semaphore = Semaphore(local_settings_max_streams) + + for _ in range(local_settings_max_streams - self._max_streams): + self._max_streams_semaphore.acquire() + + self._max_streams_semaphore.acquire() + + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + with Trace("send_request_headers", logger, request, kwargs): + self._send_request_headers(request=request, stream_id=stream_id) + with Trace("send_request_body", logger, request, kwargs): + self._send_request_body(request=request, stream_id=stream_id) + with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, + ) + except BaseException as exc: # noqa: PIE786 + with ShieldCancellation(): + kwargs = {"stream_id": stream_id} + with Trace("response_closed", logger, request, kwargs): + self._response_closed(stream_id=stream_id) + + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + + raise exc + + def _send_connection_init(self, request: Request) -> None: + """ + The HTTP/2 connection requires some initial setup before we can start + using individual request/response streams on it. + """ + # Need to set these manually here instead of manipulating via + # __setitem__() otherwise the H2Connection will emit SettingsUpdate + # frames in addition to sending the undesired defaults. + self._h2_state.local_settings = h2.settings.Settings( + client=True, + initial_values={ + # Disable PUSH_PROMISE frames from the server since we don't do anything + # with them for now. Maybe when we support caching? + h2.settings.SettingCodes.ENABLE_PUSH: 0, + # These two are taken from h2 for safe defaults + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + }, + ) + + # Some websites (*cough* Yahoo *cough*) balk at this setting being + # present in the initial handshake since it's not defined in the original + # RFC despite the RFC mandating ignoring settings you don't know about. + del self._h2_state.local_settings[ + h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL + ] + + self._h2_state.initiate_connection() + self._h2_state.increment_flow_control_window(2**24) + self._write_outgoing_data(request) + + # Sending the request... + + def _send_request_headers(self, request: Request, stream_id: int) -> None: + """ + Send the request headers to a given stream ID. + """ + end_stream = not has_body_headers(request) + + # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. + # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require + # HTTP/1.1 style headers, and map them appropriately if we end up on + # an HTTP/2 connection. + authority = [v for k, v in request.headers if k.lower() == b"host"][0] + + headers = [ + (b":method", request.method), + (b":authority", authority), + (b":scheme", request.url.scheme), + (b":path", request.url.target), + ] + [ + (k.lower(), v) + for k, v in request.headers + if k.lower() + not in ( + b"host", + b"transfer-encoding", + ) + ] + + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) + self._write_outgoing_data(request) + + def _send_request_body(self, request: Request, stream_id: int) -> None: + """ + Iterate over the request body sending it to a given stream ID. + """ + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.Iterable) + for data in request.stream: + self._send_stream_data(request, stream_id, data) + self._send_end_stream(request, stream_id) + + def _send_stream_data( + self, request: Request, stream_id: int, data: bytes + ) -> None: + """ + Send a single chunk of data in one or more data frames. + """ + while data: + max_flow = self._wait_for_outgoing_flow(request, stream_id) + chunk_size = min(len(data), max_flow) + chunk, data = data[:chunk_size], data[chunk_size:] + self._h2_state.send_data(stream_id, chunk) + self._write_outgoing_data(request) + + def _send_end_stream(self, request: Request, stream_id: int) -> None: + """ + Send an empty data frame on on a given stream ID with the END_STREAM flag set. + """ + self._h2_state.end_stream(stream_id) + self._write_outgoing_data(request) + + # Receiving the response... + + def _receive_response( + self, request: Request, stream_id: int + ) -> tuple[int, list[tuple[bytes, bytes]]]: + """ + Return the response status code and headers for a given stream ID. + """ + while True: + event = self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + assert event.headers is not None + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode("ascii", errors="ignore")) + elif not k.startswith(b":"): + headers.append((k, v)) + + return (status_code, headers) + + def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.Iterator[bytes]: + """ + Iterator that returns the bytes of the response body for a given stream ID. + """ + while True: + event = self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.DataReceived): + assert event.flow_controlled_length is not None + assert event.data is not None + amount = event.flow_controlled_length + self._h2_state.acknowledge_received_data(amount, stream_id) + self._write_outgoing_data(request) + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + def _receive_stream_event( + self, request: Request, stream_id: int + ) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded: + """ + Return the next available event for a given stream ID. + + Will read more data from the network if required. + """ + while not self._events.get(stream_id): + self._receive_events(request, stream_id) + event = self._events[stream_id].pop(0) + if isinstance(event, h2.events.StreamReset): + raise RemoteProtocolError(event) + return event + + def _receive_events( + self, request: Request, stream_id: int | None = None + ) -> None: + """ + Read some data from the network until we see one or more events + for a given stream ID. + """ + with self._read_lock: + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) + + # This conditional is a bit icky. We don't want to block reading if we've + # actually got an event to return for a given stream. We need to do that + # check *within* the atomic read lock. Though it also need to be optional, + # because when we call it from `_wait_for_outgoing_flow` we *do* want to + # block until we've available flow control, event when we have events + # pending for the stream ID we're attempting to send on. + if stream_id is None or not self._events.get(stream_id): + events = self._read_incoming_data(request) + for event in events: + if isinstance(event, h2.events.RemoteSettingsChanged): + with Trace( + "receive_remote_settings", logger, request + ) as trace: + self._receive_remote_settings_change(event) + trace.return_value = event + + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event + + self._write_outgoing_data(request) + + def _receive_remote_settings_change( + self, event: h2.events.RemoteSettingsChanged + ) -> None: + max_concurrent_streams = event.changed_settings.get( + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + ) + if max_concurrent_streams: + new_max_streams = min( + max_concurrent_streams.new_value, + self._h2_state.local_settings.max_concurrent_streams, + ) + if new_max_streams and new_max_streams != self._max_streams: + while new_max_streams > self._max_streams: + self._max_streams_semaphore.release() + self._max_streams += 1 + while new_max_streams < self._max_streams: + self._max_streams_semaphore.acquire() + self._max_streams -= 1 + + def _response_closed(self, stream_id: int) -> None: + self._max_streams_semaphore.release() + del self._events[stream_id] + with self._state_lock: + if self._connection_terminated and not self._events: + self.close() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + self.close() + + def close(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._h2_state.close_connection() + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # Wrappers around network read/write operations... + + def _read_incoming_data(self, request: Request) -> list[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + + events: list[h2.events.Event] = self._h2_state.receive_data(data) + + return events + + def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with self._write_lock: + data_to_send = self._h2_state.data_to_send() + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc + + # Flow control... + + def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow: int = self._h2_state.local_flow_control_window(stream_id) + max_frame_size: int = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> HTTP2Connection: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + self.close() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: HTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + self._closed = False + + def __iter__(self) -> typing.Iterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + try: + with Trace("receive_response_body", logger, self._request, kwargs): + for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with ShieldCancellation(): + self.close() + raise exc + + def close(self) -> None: + if not self._closed: + self._closed = True + kwargs = {"stream_id": self._stream_id} + with Trace("response_closed", logger, self._request, kwargs): + self._connection._response_closed(stream_id=self._stream_id) diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/http_proxy.py b/venv/lib/python3.10/site-packages/httpcore/_sync/http_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..ecca88f7dc93b78f2aa26f16cf29d17a8a83ae27 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/http_proxy.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import base64 +import logging +import ssl +import typing + +from .._backends.base import SOCKET_OPTION, NetworkBackend +from .._exceptions import ProxyError +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, +) +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +ByteOrStr = typing.Union[bytes, str] +HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]] +HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr] + + +logger = logging.getLogger("httpcore.proxy") + + +def merge_headers( + default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, + override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, +) -> list[tuple[bytes, bytes]]: + """ + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. + """ + default_headers = [] if default_headers is None else list(default_headers) + override_headers = [] if override_headers is None else list(override_headers) + has_override = set(key.lower() for key, value in override_headers) + default_headers = [ + (key, value) + for key, value in default_headers + if key.lower() not in has_override + ] + return default_headers + override_headers + + +class HTTPProxy(ConnectionPool): # pragma: nocover + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: URL | bytes | str, + proxy_auth: tuple[bytes | str, bytes | str] | None = None, + proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None, + ssl_context: ssl.SSLContext | None = None, + proxy_ssl_context: ssl.SSLContext | None = None, + max_connections: int | None = 10, + max_keepalive_connections: int | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: str | None = None, + uds: str | None = None, + network_backend: NetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_auth: Any proxy authentication as a two-tuple of + (username, password). May be either bytes or ascii-only str. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic :"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, + socket_options=socket_options, + ) + + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if ( + self._proxy_url.scheme == b"http" and proxy_ssl_context is not None + ): # pragma: no cover + raise RuntimeError( + "The `proxy_ssl_context` argument is not allowed for the http scheme" + ) + + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + if proxy_auth is not None: + username = enforce_bytes(proxy_auth[0], name="proxy_auth") + password = enforce_bytes(proxy_auth[1], name="proxy_auth") + userpass = username + b":" + password + authorization = b"Basic " + base64.b64encode(userpass) + self._proxy_headers = [ + (b"Proxy-Authorization", authorization) + ] + self._proxy_headers + + def create_connection(self, origin: Origin) -> ConnectionInterface: + if origin.scheme == b"http": + return ForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + proxy_ssl_context=self._proxy_ssl_context, + ) + return TunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + ssl_context=self._ssl_context, + proxy_ssl_context=self._proxy_ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class ForwardHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None, + keepalive_expiry: float | None = None, + network_backend: NetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + proxy_ssl_context: ssl.SSLContext | None = None, + ) -> None: + self._connection = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._remote_origin = remote_origin + + def handle_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), + ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return self._connection.handle_request(proxy_request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + self._connection.close() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class TunnelHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: ssl.SSLContext | None = None, + proxy_ssl_context: ssl.SSLContext | None = None, + proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + network_backend: NetworkBackend | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + self._connection: ConnectionInterface = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._connect_lock = Lock() + self._connected = False + + def handle_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = self._connection.handle_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + self._connection.close() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + return self._connection.handle_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + self._connection.close() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/interfaces.py b/venv/lib/python3.10/site-packages/httpcore/_sync/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..e673d4cc1b1dd7e7ecdbde91fd6ada386c3de03f --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/interfaces.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import contextlib +import typing + +from .._models import ( + URL, + Extensions, + HeaderTypes, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class RequestInterface: + def request( + self, + method: bytes | str, + url: URL | bytes | str, + *, + headers: HeaderTypes = None, + content: bytes | typing.Iterator[bytes] | None = None, + extensions: Extensions | None = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + response.read() + finally: + response.close() + return response + + @contextlib.contextmanager + def stream( + self, + method: bytes | str, + url: URL | bytes | str, + *, + headers: HeaderTypes = None, + content: bytes | typing.Iterator[bytes] | None = None, + extensions: Extensions | None = None, + ) -> typing.Iterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + yield response + finally: + response.close() + + def handle_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class ConnectionInterface(RequestInterface): + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/venv/lib/python3.10/site-packages/httpcore/_sync/socks_proxy.py b/venv/lib/python3.10/site-packages/httpcore/_sync/socks_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca96ddfb580b19413797f41e79f7abcecdd9d79 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_sync/socks_proxy.py @@ -0,0 +1,341 @@ +from __future__ import annotations + +import logging +import ssl + +import socksio + +from .._backends.sync import SyncBackend +from .._backends.base import NetworkBackend, NetworkStream +from .._exceptions import ConnectionNotAvailable, ProxyError +from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.socks") + + +AUTH_METHODS = { + b"\x00": "NO AUTHENTICATION REQUIRED", + b"\x01": "GSSAPI", + b"\x02": "USERNAME/PASSWORD", + b"\xff": "NO ACCEPTABLE METHODS", +} + +REPLY_CODES = { + b"\x00": "Succeeded", + b"\x01": "General SOCKS server failure", + b"\x02": "Connection not allowed by ruleset", + b"\x03": "Network unreachable", + b"\x04": "Host unreachable", + b"\x05": "Connection refused", + b"\x06": "TTL expired", + b"\x07": "Command not supported", + b"\x08": "Address type not supported", +} + + +def _init_socks5_connection( + stream: NetworkStream, + *, + host: bytes, + port: int, + auth: tuple[bytes, bytes] | None = None, +) -> None: + conn = socksio.socks5.SOCKS5Connection() + + # Auth method request + auth_method = ( + socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED + if auth is None + else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD + ) + conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method])) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Auth method response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socksio.socks5.SOCKS5AuthReply) + if response.method != auth_method: + requested = AUTH_METHODS.get(auth_method, "UNKNOWN") + responded = AUTH_METHODS.get(response.method, "UNKNOWN") + raise ProxyError( + f"Requested {requested} from proxy server, but got {responded}." + ) + + if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD: + # Username/password request + assert auth is not None + username, password = auth + conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password)) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Username/password response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply) + if not response.success: + raise ProxyError("Invalid username/password") + + # Connect request + conn.send( + socksio.socks5.SOCKS5CommandRequest.from_address( + socksio.socks5.SOCKS5Command.CONNECT, (host, port) + ) + ) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Connect response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socksio.socks5.SOCKS5Reply) + if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED: + reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN") + raise ProxyError(f"Proxy Server could not connect: {reply_code}.") + + +class SOCKSProxy(ConnectionPool): # pragma: nocover + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: URL | bytes | str, + proxy_auth: tuple[bytes | str, bytes | str] | None = None, + ssl_context: ssl.SSLContext | None = None, + max_connections: int | None = 10, + max_keepalive_connections: int | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + network_backend: NetworkBackend | None = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + ) + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if proxy_auth is not None: + username, password = proxy_auth + username_bytes = enforce_bytes(username, name="proxy_auth") + password_bytes = enforce_bytes(password, name="proxy_auth") + self._proxy_auth: tuple[bytes, bytes] | None = ( + username_bytes, + password_bytes, + ) + else: + self._proxy_auth = None + + def create_connection(self, origin: Origin) -> ConnectionInterface: + return Socks5Connection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + proxy_auth=self._proxy_auth, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class Socks5Connection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_auth: tuple[bytes, bytes] | None = None, + ssl_context: ssl.SSLContext | None = None, + keepalive_expiry: float | None = None, + http1: bool = True, + http2: bool = False, + network_backend: NetworkBackend | None = None, + ) -> None: + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._proxy_auth = proxy_auth + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + + self._network_backend: NetworkBackend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._connect_lock = Lock() + self._connection: ConnectionInterface | None = None + self._connect_failed = False + + def handle_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + with self._connect_lock: + if self._connection is None: + try: + # Connect to the proxy + kwargs = { + "host": self._proxy_origin.host.decode("ascii"), + "port": self._proxy_origin.port, + "timeout": timeout, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + + # Connect to the remote host using socks5 + kwargs = { + "stream": stream, + "host": self._remote_origin.host.decode("ascii"), + "port": self._remote_origin.port, + "auth": self._proxy_auth, + } + with Trace( + "setup_socks5_connection", logger, request, kwargs + ) as trace: + _init_socks5_connection(**kwargs) + trace.return_value = stream + + # Upgrade the stream to SSL + if self._remote_origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ( + ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ) + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or ( + self._http2 and not self._http1 + ): # pragma: nocover + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): # pragma: nocover + raise ConnectionNotAvailable() + + return self._connection.handle_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + if self._connection is not None: + self._connection.close() + + def is_available(self) -> bool: + if self._connection is None: # pragma: nocover + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._remote_origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: # pragma: nocover + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/venv/lib/python3.10/site-packages/httpcore/_trace.py b/venv/lib/python3.10/site-packages/httpcore/_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1cd7c47829ce17dbcf651ab56b4ffdce04a485 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_trace.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import inspect +import logging +import types +import typing + +from ._models import Request + + +class Trace: + def __init__( + self, + name: str, + logger: logging.Logger, + request: Request | None = None, + kwargs: dict[str, typing.Any] | None = None, + ) -> None: + self.name = name + self.logger = logger + self.trace_extension = ( + None if request is None else request.extensions.get("trace") + ) + self.debug = self.logger.isEnabledFor(logging.DEBUG) + self.kwargs = kwargs or {} + self.return_value: typing.Any = None + self.should_trace = self.debug or self.trace_extension is not None + self.prefix = self.logger.name.split(".")[-1] + + def trace(self, name: str, info: dict[str, typing.Any]) -> None: + if self.trace_extension is not None: + prefix_and_name = f"{self.prefix}.{name}" + ret = self.trace_extension(prefix_and_name, info) + if inspect.iscoroutine(ret): # pragma: no cover + raise TypeError( + "If you are using a synchronous interface, " + "the callback of the `trace` extension should " + "be a normal function instead of an asynchronous function." + ) + + if self.debug: + if not info or "return_value" in info and info["return_value"] is None: + message = name + else: + args = " ".join([f"{key}={value!r}" for key, value in info.items()]) + message = f"{name} {args}" + self.logger.debug(message) + + def __enter__(self) -> Trace: + if self.should_trace: + info = self.kwargs + self.trace(f"{self.name}.started", info) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + if self.should_trace: + if exc_value is None: + info = {"return_value": self.return_value} + self.trace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + self.trace(f"{self.name}.failed", info) + + async def atrace(self, name: str, info: dict[str, typing.Any]) -> None: + if self.trace_extension is not None: + prefix_and_name = f"{self.prefix}.{name}" + coro = self.trace_extension(prefix_and_name, info) + if not inspect.iscoroutine(coro): # pragma: no cover + raise TypeError( + "If you're using an asynchronous interface, " + "the callback of the `trace` extension should " + "be an asynchronous function rather than a normal function." + ) + await coro + + if self.debug: + if not info or "return_value" in info and info["return_value"] is None: + message = name + else: + args = " ".join([f"{key}={value!r}" for key, value in info.items()]) + message = f"{name} {args}" + self.logger.debug(message) + + async def __aenter__(self) -> Trace: + if self.should_trace: + info = self.kwargs + await self.atrace(f"{self.name}.started", info) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + if self.should_trace: + if exc_value is None: + info = {"return_value": self.return_value} + await self.atrace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + await self.atrace(f"{self.name}.failed", info) diff --git a/venv/lib/python3.10/site-packages/httpcore/_utils.py b/venv/lib/python3.10/site-packages/httpcore/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c44ff93cb2f572afc6e679308024b744b65c3b0a --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpcore/_utils.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import select +import socket +import sys + + +def is_socket_readable(sock: socket.socket | None) -> bool: + """ + Return whether a socket, as identifed by its file descriptor, is readable. + "A socket is readable" means that the read buffer isn't empty, i.e. that calling + .recv() on it would immediately return some data. + """ + # NOTE: we want check for readability without actually attempting to read, because + # we don't want to block forever if it's not readable. + + # In the case that the socket no longer exists, or cannot return a file + # descriptor, we treat it as being readable, as if it the next read operation + # on it is ready to return the terminating `b""`. + sock_fd = None if sock is None else sock.fileno() + if sock_fd is None or sock_fd < 0: # pragma: nocover + return True + + # The implementation below was stolen from: + # https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471-L478 + # See also: https://github.com/encode/httpcore/pull/193#issuecomment-703129316 + + # Use select.select on Windows, and when poll is unavailable and select.poll + # everywhere else. (E.g. When eventlet is in use. See #327) + if ( + sys.platform == "win32" or getattr(select, "poll", None) is None + ): # pragma: nocover + rready, _, _ = select.select([sock_fd], [], [], 0) + return bool(rready) + p = select.poll() + p.register(sock_fd, select.POLLIN) + return bool(p.poll(0)) diff --git a/venv/lib/python3.10/site-packages/httpcore/py.typed b/venv/lib/python3.10/site-packages/httpcore/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/INSTALLER b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/METADATA b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..b0d2b196385e98259971519793447c1fd7a9a643 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/METADATA @@ -0,0 +1,203 @@ +Metadata-Version: 2.3 +Name: httpx +Version: 0.28.1 +Summary: The next generation HTTP client. +Project-URL: Changelog, https://github.com/encode/httpx/blob/master/CHANGELOG.md +Project-URL: Documentation, https://www.python-httpx.org +Project-URL: Homepage, https://github.com/encode/httpx +Project-URL: Source, https://github.com/encode/httpx +Author-email: Tom Christie +License: BSD-3-Clause +Classifier: Development Status :: 4 - Beta +Classifier: Environment :: Web Environment +Classifier: Framework :: AsyncIO +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.8 +Requires-Dist: anyio +Requires-Dist: certifi +Requires-Dist: httpcore==1.* +Requires-Dist: idna +Provides-Extra: brotli +Requires-Dist: brotli; (platform_python_implementation == 'CPython') and extra == 'brotli' +Requires-Dist: brotlicffi; (platform_python_implementation != 'CPython') and extra == 'brotli' +Provides-Extra: cli +Requires-Dist: click==8.*; extra == 'cli' +Requires-Dist: pygments==2.*; extra == 'cli' +Requires-Dist: rich<14,>=10; extra == 'cli' +Provides-Extra: http2 +Requires-Dist: h2<5,>=3; extra == 'http2' +Provides-Extra: socks +Requires-Dist: socksio==1.*; extra == 'socks' +Provides-Extra: zstd +Requires-Dist: zstandard>=0.18.0; extra == 'zstd' +Description-Content-Type: text/markdown + +

+ HTTPX +

+ +

HTTPX - A next-generation HTTP client for Python.

+ +

+ + Test Suite + + + Package version + +

+ +HTTPX is a fully featured HTTP client library for Python 3. It includes **an integrated command line client**, has support for both **HTTP/1.1 and HTTP/2**, and provides both **sync and async APIs**. + +--- + +Install HTTPX using pip: + +```shell +$ pip install httpx +``` + +Now, let's get started: + +```pycon +>>> import httpx +>>> r = httpx.get('https://www.example.org/') +>>> r + +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'\n\n\nExample Domain...' +``` + +Or, using the command-line client. + +```shell +$ pip install 'httpx[cli]' # The command line client is an optional dependency. +``` + +Which now allows us to use HTTPX directly from the command-line... + +

+ httpx --help +

+ +Sending a request... + +

+ httpx http://httpbin.org/json +

+ +## Features + +HTTPX builds on the well-established usability of `requests`, and gives you: + +* A broadly [requests-compatible API](https://www.python-httpx.org/compatibility/). +* An integrated command-line client. +* HTTP/1.1 [and HTTP/2 support](https://www.python-httpx.org/http2/). +* Standard synchronous interface, but with [async support if you need it](https://www.python-httpx.org/async/). +* Ability to make requests directly to [WSGI applications](https://www.python-httpx.org/advanced/transports/#wsgi-transport) or [ASGI applications](https://www.python-httpx.org/advanced/transports/#asgi-transport). +* Strict timeouts everywhere. +* Fully type annotated. +* 100% test coverage. + +Plus all the standard features of `requests`... + +* International Domains and URLs +* Keep-Alive & Connection Pooling +* Sessions with Cookie Persistence +* Browser-style SSL Verification +* Basic/Digest Authentication +* Elegant Key/Value Cookies +* Automatic Decompression +* Automatic Content Decoding +* Unicode Response Bodies +* Multipart File Uploads +* HTTP(S) Proxy Support +* Connection Timeouts +* Streaming Downloads +* .netrc Support +* Chunked Requests + +## Installation + +Install with pip: + +```shell +$ pip install httpx +``` + +Or, to include the optional HTTP/2 support, use: + +```shell +$ pip install httpx[http2] +``` + +HTTPX requires Python 3.8+. + +## Documentation + +Project documentation is available at [https://www.python-httpx.org/](https://www.python-httpx.org/). + +For a run-through of all the basics, head over to the [QuickStart](https://www.python-httpx.org/quickstart/). + +For more advanced topics, see the [Advanced Usage](https://www.python-httpx.org/advanced/) section, the [async support](https://www.python-httpx.org/async/) section, or the [HTTP/2](https://www.python-httpx.org/http2/) section. + +The [Developer Interface](https://www.python-httpx.org/api/) provides a comprehensive API reference. + +To find out about tools that integrate with HTTPX, see [Third Party Packages](https://www.python-httpx.org/third_party_packages/). + +## Contribute + +If you want to contribute with HTTPX check out the [Contributing Guide](https://www.python-httpx.org/contributing/) to learn how to start. + +## Dependencies + +The HTTPX project relies on these excellent libraries: + +* `httpcore` - The underlying transport implementation for `httpx`. + * `h11` - HTTP/1.1 support. +* `certifi` - SSL certificates. +* `idna` - Internationalized domain name support. +* `sniffio` - Async library autodetection. + +As well as these optional installs: + +* `h2` - HTTP/2 support. *(Optional, with `httpx[http2]`)* +* `socksio` - SOCKS proxy support. *(Optional, with `httpx[socks]`)* +* `rich` - Rich terminal support. *(Optional, with `httpx[cli]`)* +* `click` - Command line client support. *(Optional, with `httpx[cli]`)* +* `brotli` or `brotlicffi` - Decoding for "brotli" compressed responses. *(Optional, with `httpx[brotli]`)* +* `zstandard` - Decoding for "zstd" compressed responses. *(Optional, with `httpx[zstd]`)* + +A huge amount of credit is due to `requests` for the API layout that +much of this work follows, as well as to `urllib3` for plenty of design +inspiration around the lower-level networking details. + +--- + +

HTTPX is BSD licensed code.
Designed & crafted with care.

— 🦋 —

+ +## Release Information + +### Fixed + +* Reintroduced supposedly-private `URLTypes` shortcut. (#2673) + + +--- + +[Full changelog](https://github.com/encode/httpx/blob/master/CHANGELOG.md) diff --git a/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/RECORD b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..a2299ed5e3b59a7b106cae9540e39d6bffc65c82 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/RECORD @@ -0,0 +1,54 @@ +../../../bin/httpx,sha256=umGhlm6aL6JBms9rRafCiKdVzD-g-yz1wkVRwJaWvrk,369 +httpx-0.28.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +httpx-0.28.1.dist-info/METADATA,sha256=_rubD48-gNV8gZnDBPNcQzboWB0dGNeYPJJ2a4J5OyU,7052 +httpx-0.28.1.dist-info/RECORD,, +httpx-0.28.1.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87 +httpx-0.28.1.dist-info/entry_points.txt,sha256=2lVkdQmxLA1pNMgSN2eV89o90HCZezhmNwsy6ryKDSA,37 +httpx-0.28.1.dist-info/licenses/LICENSE.md,sha256=TsWdVE8StfU5o6cW_TIaxYzNgDC0ZSIfLIgCAM3yjY0,1508 +httpx/__init__.py,sha256=CsaZe6yZj0rHg6322AWKWHGTMVr9txgEfD5P3_Rrz60,2171 +httpx/__pycache__/__init__.cpython-310.pyc,, +httpx/__pycache__/__version__.cpython-310.pyc,, +httpx/__pycache__/_api.cpython-310.pyc,, +httpx/__pycache__/_auth.cpython-310.pyc,, +httpx/__pycache__/_client.cpython-310.pyc,, +httpx/__pycache__/_config.cpython-310.pyc,, +httpx/__pycache__/_content.cpython-310.pyc,, +httpx/__pycache__/_decoders.cpython-310.pyc,, +httpx/__pycache__/_exceptions.cpython-310.pyc,, +httpx/__pycache__/_main.cpython-310.pyc,, +httpx/__pycache__/_models.cpython-310.pyc,, +httpx/__pycache__/_multipart.cpython-310.pyc,, +httpx/__pycache__/_status_codes.cpython-310.pyc,, +httpx/__pycache__/_types.cpython-310.pyc,, +httpx/__pycache__/_urlparse.cpython-310.pyc,, +httpx/__pycache__/_urls.cpython-310.pyc,, +httpx/__pycache__/_utils.cpython-310.pyc,, +httpx/__version__.py,sha256=LoUyYeOXTieGzuP_64UL0wxdtxjuu_QbOvE7NOg-IqU,108 +httpx/_api.py,sha256=r_Zgs4jIpcPJLqK5dbbSayqo_iVMKFaxZCd-oOHxLEs,11743 +httpx/_auth.py,sha256=Yr3QwaUSK17rGYx-7j-FdicFIzz4Y9FFV-1F4-7RXX4,11891 +httpx/_client.py,sha256=xD-UG67-WMkeltAAOeGGj-cZ2RRTAm19sWRxlFY7_40,65714 +httpx/_config.py,sha256=pPp2U-wicfcKsF-KYRE1LYdt3e6ERGeIoXZ8Gjo3LWc,8547 +httpx/_content.py,sha256=LGGzrJTR3OvN4Mb1GVVNLXkXJH-6oKlwAttO9p5w_yg,8161 +httpx/_decoders.py,sha256=p0dX8I0NEHexs3UGp4SsZutiMhsXrrWl6-GnqVb0iKM,12041 +httpx/_exceptions.py,sha256=bxW7fxzgVMAdNTbwT0Vnq04gJDW1_gI_GFiQPuMyjL0,8527 +httpx/_main.py,sha256=Cg9GMabiTT_swaDfUgIRitSwxLRMSwUDOm7LdSGqlA4,15626 +httpx/_models.py,sha256=4__Guyv1gLxuZChwim8kfQNiIOcJ9acreFOSurvZfms,44700 +httpx/_multipart.py,sha256=KOHEZZl6oohg9mPaKyyu345qq1rJLg35TUG3YAzXB3Y,9843 +httpx/_status_codes.py,sha256=DYn-2ufBgMeXy5s8x3_TB7wjAuAAMewTakPrm5rXEsc,5639 +httpx/_transports/__init__.py,sha256=GbUoBSAOp7z-l-9j5YhMhR3DMIcn6FVLhj072O3Nnno,275 +httpx/_transports/__pycache__/__init__.cpython-310.pyc,, +httpx/_transports/__pycache__/asgi.cpython-310.pyc,, +httpx/_transports/__pycache__/base.cpython-310.pyc,, +httpx/_transports/__pycache__/default.cpython-310.pyc,, +httpx/_transports/__pycache__/mock.cpython-310.pyc,, +httpx/_transports/__pycache__/wsgi.cpython-310.pyc,, +httpx/_transports/asgi.py,sha256=HRfiDYMPt4wQH2gFgHZg4c-i3sblo6bL5GTqcET-xz8,5501 +httpx/_transports/base.py,sha256=kZS_VMbViYfF570pogUCJ1bulz-ybfL51Pqs9yktebU,2523 +httpx/_transports/default.py,sha256=AzeaRUyVwCccTyyNJexDf0n1dFfzzydpdIQgvw7PLnk,13983 +httpx/_transports/mock.py,sha256=PTo0d567RITXxGrki6kN7_67wwAxfwiMDcuXJiZCjEo,1232 +httpx/_transports/wsgi.py,sha256=NcPX3Xap_EwCFZWO_OaSyQNuInCYx1QMNbO8GAei6jY,4825 +httpx/_types.py,sha256=Jyh41GQq7AOev8IOWKDAg7zCbvHAfufmW5g_PiTtErY,2965 +httpx/_urlparse.py,sha256=ZAmH47ONfkxrrj-PPYhGeiHjb6AjKCS-ANWIN4OL_KY,18546 +httpx/_urls.py,sha256=dX99VR1DSOHpgo9Aq7PzYO4FKdxqKjwyNp8grf8dHN0,21550 +httpx/_utils.py,sha256=_TVeqAKvxJkKHdz7dFeb4s0LZqQXgeFkXSgfiHBK_1o,8285 +httpx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/WHEEL b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..21aaa72961a8af71c17d2cb3b76d5f7f567100e4 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.26.3 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/entry_points.txt b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..8ae96007f7d725813fd02dc1d06d3834ee1939e4 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +httpx = httpx:main diff --git a/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/licenses/LICENSE.md b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/licenses/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..ab79d16a3f4c6c894c028d1f7431811e8711b42b --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx-0.28.1.dist-info/licenses/LICENSE.md @@ -0,0 +1,12 @@ +Copyright © 2019, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/venv/lib/python3.10/site-packages/httpx/__init__.py b/venv/lib/python3.10/site-packages/httpx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9addde071f81758baf350c4ab6bde2556340131 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/__init__.py @@ -0,0 +1,105 @@ +from .__version__ import __description__, __title__, __version__ +from ._api import * +from ._auth import * +from ._client import * +from ._config import * +from ._content import * +from ._exceptions import * +from ._models import * +from ._status_codes import * +from ._transports import * +from ._types import * +from ._urls import * + +try: + from ._main import main +except ImportError: # pragma: no cover + + def main() -> None: # type: ignore + import sys + + print( + "The httpx command line client could not run because the required " + "dependencies were not installed.\nMake sure you've installed " + "everything with: pip install 'httpx[cli]'" + ) + sys.exit(1) + + +__all__ = [ + "__description__", + "__title__", + "__version__", + "ASGITransport", + "AsyncBaseTransport", + "AsyncByteStream", + "AsyncClient", + "AsyncHTTPTransport", + "Auth", + "BaseTransport", + "BasicAuth", + "ByteStream", + "Client", + "CloseError", + "codes", + "ConnectError", + "ConnectTimeout", + "CookieConflict", + "Cookies", + "create_ssl_context", + "DecodingError", + "delete", + "DigestAuth", + "get", + "head", + "Headers", + "HTTPError", + "HTTPStatusError", + "HTTPTransport", + "InvalidURL", + "Limits", + "LocalProtocolError", + "main", + "MockTransport", + "NetRCAuth", + "NetworkError", + "options", + "patch", + "PoolTimeout", + "post", + "ProtocolError", + "Proxy", + "ProxyError", + "put", + "QueryParams", + "ReadError", + "ReadTimeout", + "RemoteProtocolError", + "request", + "Request", + "RequestError", + "RequestNotRead", + "Response", + "ResponseNotRead", + "stream", + "StreamClosed", + "StreamConsumed", + "StreamError", + "SyncByteStream", + "Timeout", + "TimeoutException", + "TooManyRedirects", + "TransportError", + "UnsupportedProtocol", + "URL", + "USE_CLIENT_DEFAULT", + "WriteError", + "WriteTimeout", + "WSGITransport", +] + + +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + setattr(__locals[__name], "__module__", "httpx") # noqa diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67837552176a313faf4e610f66eddd5b2b59aa43 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/__version__.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/__version__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86bb23ae28ced92f8c6c4891b31d6d4083a79e21 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/__version__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_api.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4c52b6fb444094bda579c01a0bcb92972f106c7 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_api.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_auth.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_auth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ec8e02906c73f1382bda3f4bbc09894a5bdc2ca Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_auth.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_client.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bf128d2e5203cc0510d317ce00e22ff89010d10 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_client.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_config.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2b74330baa40b0c872720be5da8869d1a9f2d0 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_config.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_content.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_content.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc45bc9d2459a71e41ab107a2cfdc761f2e40b22 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_content.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_decoders.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_decoders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00904d1aec3d85d3d66715ed419cc39cfa91f9a9 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_decoders.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_exceptions.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..487ee14117280524cd090a8dcb07ff6abe84ed9a Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_exceptions.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_main.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..589dd2f31aeec893b9c18dc3f82717ef502cfc8b Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_main.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_models.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a385d7c7a48cf981b9220c02a63e6b8bc75551ff Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_models.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_multipart.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_multipart.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..292a69a13d993544a93f3a6289be8ce328731c3a Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_multipart.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_status_codes.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_status_codes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51bd17d70ec780efcddccac01dab4e7336df7489 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_status_codes.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_types.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60daeec437bd62aa7cab827be758d71d65ceb62d Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_types.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_urlparse.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_urlparse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2ebc1be8a6f4bc3ec56f7be6630cb852b9a9185 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_urlparse.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_urls.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_urls.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fa01bb68557600bebde924539fb2f9ed1150451 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_urls.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__pycache__/_utils.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/__pycache__/_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..744d7c92ee816ac937da0e091893425f8ef5524d Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/__pycache__/_utils.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/__version__.py b/venv/lib/python3.10/site-packages/httpx/__version__.py new file mode 100644 index 0000000000000000000000000000000000000000..801bfacf671017cfbebf1ac26ec385daa02ed260 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/__version__.py @@ -0,0 +1,3 @@ +__title__ = "httpx" +__description__ = "A next generation HTTP client, for Python 3." +__version__ = "0.28.1" diff --git a/venv/lib/python3.10/site-packages/httpx/_api.py b/venv/lib/python3.10/site-packages/httpx/_api.py new file mode 100644 index 0000000000000000000000000000000000000000..c3cda1ecda8629edbdca2e3bc04bc51dba5e1430 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_api.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import typing +from contextlib import contextmanager + +from ._client import Client +from ._config import DEFAULT_TIMEOUT_CONFIG +from ._models import Response +from ._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + ProxyTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, +) +from ._urls import URL + +if typing.TYPE_CHECKING: + import ssl # pragma: no cover + + +__all__ = [ + "delete", + "get", + "head", + "options", + "patch", + "post", + "put", + "request", + "stream", +] + + +def request( + method: str, + url: URL | str, + *, + params: QueryParamTypes | None = None, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + trust_env: bool = True, +) -> Response: + """ + Sends an HTTP request. + + **Parameters:** + + * **method** - HTTP method for the new `Request` object: `GET`, `OPTIONS`, + `HEAD`, `POST`, `PUT`, `PATCH`, or `DELETE`. + * **url** - URL for the new `Request` object. + * **params** - *(optional)* Query parameters to include in the URL, as a + string, dictionary, or sequence of two-tuples. + * **content** - *(optional)* Binary content to include in the body of the + request, as bytes or a byte iterator. + * **data** - *(optional)* Form data to include in the body of the request, + as a dictionary. + * **files** - *(optional)* A dictionary of upload files to include in the + body of the request. + * **json** - *(optional)* A JSON serializable object to include in the body + of the request. + * **headers** - *(optional)* Dictionary of HTTP headers to include in the + request. + * **cookies** - *(optional)* Dictionary of Cookie items to include in the + request. + * **auth** - *(optional)* An authentication class to use when sending the + request. + * **proxy** - *(optional)* A proxy URL where all the traffic should be routed. + * **timeout** - *(optional)* The timeout configuration to use when sending + the request. + * **follow_redirects** - *(optional)* Enables or disables HTTP redirects. + * **verify** - *(optional)* Either `True` to use an SSL context with the + default CA bundle, `False` to disable verification, or an instance of + `ssl.SSLContext` to use a custom context. + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + + **Returns:** `Response` + + Usage: + + ``` + >>> import httpx + >>> response = httpx.request('GET', 'https://httpbin.org/get') + >>> response + + ``` + """ + with Client( + cookies=cookies, + proxy=proxy, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + return client.request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + ) + + +@contextmanager +def stream( + method: str, + url: URL | str, + *, + params: QueryParamTypes | None = None, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + trust_env: bool = True, +) -> typing.Iterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + with Client( + cookies=cookies, + proxy=proxy, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + with client.stream( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + ) as response: + yield response + + +def get( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `GET` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `GET` requests should not include a request body. + """ + return request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxy=proxy, + follow_redirects=follow_redirects, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def options( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `OPTIONS` requests should not include a request body. + """ + return request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxy=proxy, + follow_redirects=follow_redirects, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def head( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `HEAD` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `HEAD` requests should not include a request body. + """ + return request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxy=proxy, + follow_redirects=follow_redirects, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def post( + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxy=proxy, + follow_redirects=follow_redirects, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def put( + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxy=proxy, + follow_redirects=follow_redirects, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def patch( + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxy=proxy, + follow_redirects=follow_redirects, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def delete( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + verify: ssl.SSLContext | str | bool = True, + trust_env: bool = True, +) -> Response: + """ + Sends a `DELETE` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `DELETE` requests should not include a request body. + """ + return request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxy=proxy, + follow_redirects=follow_redirects, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) diff --git a/venv/lib/python3.10/site-packages/httpx/_auth.py b/venv/lib/python3.10/site-packages/httpx/_auth.py new file mode 100644 index 0000000000000000000000000000000000000000..b03971ab4b311d60790dc22ca24d9966426ec0a4 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_auth.py @@ -0,0 +1,348 @@ +from __future__ import annotations + +import hashlib +import os +import re +import time +import typing +from base64 import b64encode +from urllib.request import parse_http_list + +from ._exceptions import ProtocolError +from ._models import Cookies, Request, Response +from ._utils import to_bytes, to_str, unquote + +if typing.TYPE_CHECKING: # pragma: no cover + from hashlib import _Hash + + +__all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"] + + +class Auth: + """ + Base class for all authentication schemes. + + To implement a custom authentication scheme, subclass `Auth` and override + the `.auth_flow()` method. + + If the authentication scheme does I/O such as disk access or network calls, or uses + synchronization primitives such as locks, you should override `.sync_auth_flow()` + and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized + implementations that will be used by `Client` and `AsyncClient` respectively. + """ + + requires_request_body = False + requires_response_body = False + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow. + + To dispatch a request, `yield` it: + + ``` + yield request + ``` + + The client will `.send()` the response back into the flow generator. You can + access it like so: + + ``` + response = yield request + ``` + + A `return` (or reaching the end of the generator) will result in the + client returning the last response obtained from the server. + + You can dispatch as many requests as is necessary. + """ + yield request + + def sync_auth_flow( + self, request: Request + ) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow synchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + request.read() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + response.read() + + try: + request = flow.send(response) + except StopIteration: + break + + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + await request.aread() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + await response.aread() + + try: + request = flow.send(response) + except StopIteration: + break + + +class FunctionAuth(Auth): + """ + Allows the 'auth' argument to be passed as a simple callable function, + that takes the request, and returns a new, modified request. + """ + + def __init__(self, func: typing.Callable[[Request], Request]) -> None: + self._func = func + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + yield self._func(request) + + +class BasicAuth(Auth): + """ + Allows the 'auth' argument to be passed as a (username, password) pair, + and uses HTTP Basic authentication. + """ + + def __init__(self, username: str | bytes, password: str | bytes) -> None: + self._auth_header = self._build_auth_header(username, password) + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + request.headers["Authorization"] = self._auth_header + yield request + + def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str: + userpass = b":".join((to_bytes(username), to_bytes(password))) + token = b64encode(userpass).decode() + return f"Basic {token}" + + +class NetRCAuth(Auth): + """ + Use a 'netrc' file to lookup basic auth credentials based on the url host. + """ + + def __init__(self, file: str | None = None) -> None: + # Lazily import 'netrc'. + # There's no need for us to load this module unless 'NetRCAuth' is being used. + import netrc + + self._netrc_info = netrc.netrc(file) + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + auth_info = self._netrc_info.authenticators(request.url.host) + if auth_info is None or not auth_info[2]: + # The netrc file did not have authentication credentials for this host. + yield request + else: + # Build a basic auth header with credentials from the netrc file. + request.headers["Authorization"] = self._build_auth_header( + username=auth_info[0], password=auth_info[2] + ) + yield request + + def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str: + userpass = b":".join((to_bytes(username), to_bytes(password))) + token = b64encode(userpass).decode() + return f"Basic {token}" + + +class DigestAuth(Auth): + _ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = { + "MD5": hashlib.md5, + "MD5-SESS": hashlib.md5, + "SHA": hashlib.sha1, + "SHA-SESS": hashlib.sha1, + "SHA-256": hashlib.sha256, + "SHA-256-SESS": hashlib.sha256, + "SHA-512": hashlib.sha512, + "SHA-512-SESS": hashlib.sha512, + } + + def __init__(self, username: str | bytes, password: str | bytes) -> None: + self._username = to_bytes(username) + self._password = to_bytes(password) + self._last_challenge: _DigestAuthChallenge | None = None + self._nonce_count = 1 + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + if self._last_challenge: + request.headers["Authorization"] = self._build_auth_header( + request, self._last_challenge + ) + + response = yield request + + if response.status_code != 401 or "www-authenticate" not in response.headers: + # If the response is not a 401 then we don't + # need to build an authenticated request. + return + + for auth_header in response.headers.get_list("www-authenticate"): + if auth_header.lower().startswith("digest "): + break + else: + # If the response does not include a 'WWW-Authenticate: Digest ...' + # header, then we don't need to build an authenticated request. + return + + self._last_challenge = self._parse_challenge(request, response, auth_header) + self._nonce_count = 1 + + request.headers["Authorization"] = self._build_auth_header( + request, self._last_challenge + ) + if response.cookies: + Cookies(response.cookies).set_cookie_header(request=request) + yield request + + def _parse_challenge( + self, request: Request, response: Response, auth_header: str + ) -> _DigestAuthChallenge: + """ + Returns a challenge from a Digest WWW-Authenticate header. + These take the form of: + `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"` + """ + scheme, _, fields = auth_header.partition(" ") + + # This method should only ever have been called with a Digest auth header. + assert scheme.lower() == "digest" + + header_dict: dict[str, str] = {} + for field in parse_http_list(fields): + key, value = field.strip().split("=", 1) + header_dict[key] = unquote(value) + + try: + realm = header_dict["realm"].encode() + nonce = header_dict["nonce"].encode() + algorithm = header_dict.get("algorithm", "MD5") + opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None + qop = header_dict["qop"].encode() if "qop" in header_dict else None + return _DigestAuthChallenge( + realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop + ) + except KeyError as exc: + message = "Malformed Digest WWW-Authenticate header" + raise ProtocolError(message, request=request) from exc + + def _build_auth_header( + self, request: Request, challenge: _DigestAuthChallenge + ) -> str: + hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()] + + def digest(data: bytes) -> bytes: + return hash_func(data).hexdigest().encode() + + A1 = b":".join((self._username, challenge.realm, self._password)) + + path = request.url.raw_path + A2 = b":".join((request.method.encode(), path)) + # TODO: implement auth-int + HA2 = digest(A2) + + nc_value = b"%08x" % self._nonce_count + cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce) + self._nonce_count += 1 + + HA1 = digest(A1) + if challenge.algorithm.lower().endswith("-sess"): + HA1 = digest(b":".join((HA1, challenge.nonce, cnonce))) + + qop = self._resolve_qop(challenge.qop, request=request) + if qop is None: + # Following RFC 2069 + digest_data = [HA1, challenge.nonce, HA2] + else: + # Following RFC 2617/7616 + digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2] + + format_args = { + "username": self._username, + "realm": challenge.realm, + "nonce": challenge.nonce, + "uri": path, + "response": digest(b":".join(digest_data)), + "algorithm": challenge.algorithm.encode(), + } + if challenge.opaque: + format_args["opaque"] = challenge.opaque + if qop: + format_args["qop"] = b"auth" + format_args["nc"] = nc_value + format_args["cnonce"] = cnonce + + return "Digest " + self._get_header_value(format_args) + + def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes: + s = str(nonce_count).encode() + s += nonce + s += time.ctime().encode() + s += os.urandom(8) + + return hashlib.sha1(s).hexdigest()[:16].encode() + + def _get_header_value(self, header_fields: dict[str, bytes]) -> str: + NON_QUOTED_FIELDS = ("algorithm", "qop", "nc") + QUOTED_TEMPLATE = '{}="{}"' + NON_QUOTED_TEMPLATE = "{}={}" + + header_value = "" + for i, (field, value) in enumerate(header_fields.items()): + if i > 0: + header_value += ", " + template = ( + QUOTED_TEMPLATE + if field not in NON_QUOTED_FIELDS + else NON_QUOTED_TEMPLATE + ) + header_value += template.format(field, to_str(value)) + + return header_value + + def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None: + if qop is None: + return None + qops = re.split(b", ?", qop) + if b"auth" in qops: + return b"auth" + + if qops == [b"auth-int"]: + raise NotImplementedError("Digest auth-int support is not yet implemented") + + message = f'Unexpected qop value "{qop!r}" in digest auth' + raise ProtocolError(message, request=request) + + +class _DigestAuthChallenge(typing.NamedTuple): + realm: bytes + nonce: bytes + algorithm: str + opaque: bytes | None + qop: bytes | None diff --git a/venv/lib/python3.10/site-packages/httpx/_client.py b/venv/lib/python3.10/site-packages/httpx/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..2249231f8c3b912c731ff160344d3672e2f11738 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_client.py @@ -0,0 +1,2019 @@ +from __future__ import annotations + +import datetime +import enum +import logging +import time +import typing +import warnings +from contextlib import asynccontextmanager, contextmanager +from types import TracebackType + +from .__version__ import __version__ +from ._auth import Auth, BasicAuth, FunctionAuth +from ._config import ( + DEFAULT_LIMITS, + DEFAULT_MAX_REDIRECTS, + DEFAULT_TIMEOUT_CONFIG, + Limits, + Proxy, + Timeout, +) +from ._decoders import SUPPORTED_DECODERS +from ._exceptions import ( + InvalidURL, + RemoteProtocolError, + TooManyRedirects, + request_context, +) +from ._models import Cookies, Headers, Request, Response +from ._status_codes import codes +from ._transports.base import AsyncBaseTransport, BaseTransport +from ._transports.default import AsyncHTTPTransport, HTTPTransport +from ._types import ( + AsyncByteStream, + AuthTypes, + CertTypes, + CookieTypes, + HeaderTypes, + ProxyTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestExtensions, + RequestFiles, + SyncByteStream, + TimeoutTypes, +) +from ._urls import URL, QueryParams +from ._utils import URLPattern, get_environment_proxies + +if typing.TYPE_CHECKING: + import ssl # pragma: no cover + +__all__ = ["USE_CLIENT_DEFAULT", "AsyncClient", "Client"] + +# The type annotation for @classmethod and context managers here follows PEP 484 +# https://www.python.org/dev/peps/pep-0484/#annotating-instance-and-class-methods +T = typing.TypeVar("T", bound="Client") +U = typing.TypeVar("U", bound="AsyncClient") + + +def _is_https_redirect(url: URL, location: URL) -> bool: + """ + Return 'True' if 'location' is a HTTPS upgrade of 'url' + """ + if url.host != location.host: + return False + + return ( + url.scheme == "http" + and _port_or_default(url) == 80 + and location.scheme == "https" + and _port_or_default(location) == 443 + ) + + +def _port_or_default(url: URL) -> int | None: + if url.port is not None: + return url.port + return {"http": 80, "https": 443}.get(url.scheme) + + +def _same_origin(url: URL, other: URL) -> bool: + """ + Return 'True' if the given URLs share the same origin. + """ + return ( + url.scheme == other.scheme + and url.host == other.host + and _port_or_default(url) == _port_or_default(other) + ) + + +class UseClientDefault: + """ + For some parameters such as `auth=...` and `timeout=...` we need to be able + to indicate the default "unset" state, in a way that is distinctly different + to using `None`. + + The default "unset" state indicates that whatever default is set on the + client should be used. This is different to setting `None`, which + explicitly disables the parameter, possibly overriding a client default. + + For example we use `timeout=USE_CLIENT_DEFAULT` in the `request()` signature. + Omitting the `timeout` parameter will send a request using whatever default + timeout has been configured on the client. Including `timeout=None` will + ensure no timeout is used. + + Note that user code shouldn't need to use the `USE_CLIENT_DEFAULT` constant, + but it is used internally when a parameter is not included. + """ + + +USE_CLIENT_DEFAULT = UseClientDefault() + + +logger = logging.getLogger("httpx") + +USER_AGENT = f"python-httpx/{__version__}" +ACCEPT_ENCODING = ", ".join( + [key for key in SUPPORTED_DECODERS.keys() if key != "identity"] +) + + +class ClientState(enum.Enum): + # UNOPENED: + # The client has been instantiated, but has not been used to send a request, + # or been opened by entering the context of a `with` block. + UNOPENED = 1 + # OPENED: + # The client has either sent a request, or is within a `with` block. + OPENED = 2 + # CLOSED: + # The client has either exited the `with` block, or `close()` has + # been called explicitly. + CLOSED = 3 + + +class BoundSyncStream(SyncByteStream): + """ + A byte stream that is bound to a given response instance, and that + ensures the `response.elapsed` is set once the response is closed. + """ + + def __init__( + self, stream: SyncByteStream, response: Response, start: float + ) -> None: + self._stream = stream + self._response = response + self._start = start + + def __iter__(self) -> typing.Iterator[bytes]: + for chunk in self._stream: + yield chunk + + def close(self) -> None: + elapsed = time.perf_counter() - self._start + self._response.elapsed = datetime.timedelta(seconds=elapsed) + self._stream.close() + + +class BoundAsyncStream(AsyncByteStream): + """ + An async byte stream that is bound to a given response instance, and that + ensures the `response.elapsed` is set once the response is closed. + """ + + def __init__( + self, stream: AsyncByteStream, response: Response, start: float + ) -> None: + self._stream = stream + self._response = response + self._start = start + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async for chunk in self._stream: + yield chunk + + async def aclose(self) -> None: + elapsed = time.perf_counter() - self._start + self._response.elapsed = datetime.timedelta(seconds=elapsed) + await self._stream.aclose() + + +EventHook = typing.Callable[..., typing.Any] + + +class BaseClient: + def __init__( + self, + *, + auth: AuthTypes | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: None | (typing.Mapping[str, list[EventHook]]) = None, + base_url: URL | str = "", + trust_env: bool = True, + default_encoding: str | typing.Callable[[bytes], str] = "utf-8", + ) -> None: + event_hooks = {} if event_hooks is None else event_hooks + + self._base_url = self._enforce_trailing_slash(URL(base_url)) + + self._auth = self._build_auth(auth) + self._params = QueryParams(params) + self.headers = Headers(headers) + self._cookies = Cookies(cookies) + self._timeout = Timeout(timeout) + self.follow_redirects = follow_redirects + self.max_redirects = max_redirects + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + self._trust_env = trust_env + self._default_encoding = default_encoding + self._state = ClientState.UNOPENED + + @property + def is_closed(self) -> bool: + """ + Check if the client being closed + """ + return self._state == ClientState.CLOSED + + @property + def trust_env(self) -> bool: + return self._trust_env + + def _enforce_trailing_slash(self, url: URL) -> URL: + if url.raw_path.endswith(b"/"): + return url + return url.copy_with(raw_path=url.raw_path + b"/") + + def _get_proxy_map( + self, proxy: ProxyTypes | None, allow_env_proxies: bool + ) -> dict[str, Proxy | None]: + if proxy is None: + if allow_env_proxies: + return { + key: None if url is None else Proxy(url=url) + for key, url in get_environment_proxies().items() + } + return {} + else: + proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy + return {"all://": proxy} + + @property + def timeout(self) -> Timeout: + return self._timeout + + @timeout.setter + def timeout(self, timeout: TimeoutTypes) -> None: + self._timeout = Timeout(timeout) + + @property + def event_hooks(self) -> dict[str, list[EventHook]]: + return self._event_hooks + + @event_hooks.setter + def event_hooks(self, event_hooks: dict[str, list[EventHook]]) -> None: + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def auth(self) -> Auth | None: + """ + Authentication class used when none is passed at the request-level. + + See also [Authentication][0]. + + [0]: /quickstart/#authentication + """ + return self._auth + + @auth.setter + def auth(self, auth: AuthTypes) -> None: + self._auth = self._build_auth(auth) + + @property + def base_url(self) -> URL: + """ + Base URL to use when sending requests with relative URLs. + """ + return self._base_url + + @base_url.setter + def base_url(self, url: URL | str) -> None: + self._base_url = self._enforce_trailing_slash(URL(url)) + + @property + def headers(self) -> Headers: + """ + HTTP headers to include when sending requests. + """ + return self._headers + + @headers.setter + def headers(self, headers: HeaderTypes) -> None: + client_headers = Headers( + { + b"Accept": b"*/*", + b"Accept-Encoding": ACCEPT_ENCODING.encode("ascii"), + b"Connection": b"keep-alive", + b"User-Agent": USER_AGENT.encode("ascii"), + } + ) + client_headers.update(headers) + self._headers = client_headers + + @property + def cookies(self) -> Cookies: + """ + Cookie values to include when sending requests. + """ + return self._cookies + + @cookies.setter + def cookies(self, cookies: CookieTypes) -> None: + self._cookies = Cookies(cookies) + + @property + def params(self) -> QueryParams: + """ + Query parameters to include in the URL when sending requests. + """ + return self._params + + @params.setter + def params(self, params: QueryParamTypes) -> None: + self._params = QueryParams(params) + + def build_request( + self, + method: str, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Request: + """ + Build and return a request instance. + + * The `params`, `headers` and `cookies` arguments + are merged with any values set on the client. + * The `url` argument is merged with any `base_url` set on the client. + + See also: [Request instances][0] + + [0]: /advanced/clients/#request-instances + """ + url = self._merge_url(url) + headers = self._merge_headers(headers) + cookies = self._merge_cookies(cookies) + params = self._merge_queryparams(params) + extensions = {} if extensions is None else extensions + if "timeout" not in extensions: + timeout = ( + self.timeout + if isinstance(timeout, UseClientDefault) + else Timeout(timeout) + ) + extensions = dict(**extensions, timeout=timeout.as_dict()) + return Request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + extensions=extensions, + ) + + def _merge_url(self, url: URL | str) -> URL: + """ + Merge a URL argument together with any 'base_url' on the client, + to create the URL used for the outgoing request. + """ + merge_url = URL(url) + if merge_url.is_relative_url: + # To merge URLs we always append to the base URL. To get this + # behaviour correct we always ensure the base URL ends in a '/' + # separator, and strip any leading '/' from the merge URL. + # + # So, eg... + # + # >>> client = Client(base_url="https://www.example.com/subpath") + # >>> client.base_url + # URL('https://www.example.com/subpath/') + # >>> client.build_request("GET", "/path").url + # URL('https://www.example.com/subpath/path') + merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/") + return self.base_url.copy_with(raw_path=merge_raw_path) + return merge_url + + def _merge_cookies(self, cookies: CookieTypes | None = None) -> CookieTypes | None: + """ + Merge a cookies argument together with any cookies on the client, + to create the cookies used for the outgoing request. + """ + if cookies or self.cookies: + merged_cookies = Cookies(self.cookies) + merged_cookies.update(cookies) + return merged_cookies + return cookies + + def _merge_headers(self, headers: HeaderTypes | None = None) -> HeaderTypes | None: + """ + Merge a headers argument together with any headers on the client, + to create the headers used for the outgoing request. + """ + merged_headers = Headers(self.headers) + merged_headers.update(headers) + return merged_headers + + def _merge_queryparams( + self, params: QueryParamTypes | None = None + ) -> QueryParamTypes | None: + """ + Merge a queryparams argument together with any queryparams on the client, + to create the queryparams used for the outgoing request. + """ + if params or self.params: + merged_queryparams = QueryParams(self.params) + return merged_queryparams.merge(params) + return params + + def _build_auth(self, auth: AuthTypes | None) -> Auth | None: + if auth is None: + return None + elif isinstance(auth, tuple): + return BasicAuth(username=auth[0], password=auth[1]) + elif isinstance(auth, Auth): + return auth + elif callable(auth): + return FunctionAuth(func=auth) + else: + raise TypeError(f'Invalid "auth" argument: {auth!r}') + + def _build_request_auth( + self, + request: Request, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + ) -> Auth: + auth = ( + self._auth if isinstance(auth, UseClientDefault) else self._build_auth(auth) + ) + + if auth is not None: + return auth + + username, password = request.url.username, request.url.password + if username or password: + return BasicAuth(username=username, password=password) + + return Auth() + + def _build_redirect_request(self, request: Request, response: Response) -> Request: + """ + Given a request and a redirect response, return a new request that + should be used to effect the redirect. + """ + method = self._redirect_method(request, response) + url = self._redirect_url(request, response) + headers = self._redirect_headers(request, url, method) + stream = self._redirect_stream(request, method) + cookies = Cookies(self.cookies) + return Request( + method=method, + url=url, + headers=headers, + cookies=cookies, + stream=stream, + extensions=request.extensions, + ) + + def _redirect_method(self, request: Request, response: Response) -> str: + """ + When being redirected we may want to change the method of the request + based on certain specs or browser behavior. + """ + method = request.method + + # https://tools.ietf.org/html/rfc7231#section-6.4.4 + if response.status_code == codes.SEE_OTHER and method != "HEAD": + method = "GET" + + # Do what the browsers do, despite standards... + # Turn 302s into GETs. + if response.status_code == codes.FOUND and method != "HEAD": + method = "GET" + + # If a POST is responded to with a 301, turn it into a GET. + # This bizarre behaviour is explained in 'requests' issue 1704. + if response.status_code == codes.MOVED_PERMANENTLY and method == "POST": + method = "GET" + + return method + + def _redirect_url(self, request: Request, response: Response) -> URL: + """ + Return the URL for the redirect to follow. + """ + location = response.headers["Location"] + + try: + url = URL(location) + except InvalidURL as exc: + raise RemoteProtocolError( + f"Invalid URL in location header: {exc}.", request=request + ) from None + + # Handle malformed 'Location' headers that are "absolute" form, have no host. + # See: https://github.com/encode/httpx/issues/771 + if url.scheme and not url.host: + url = url.copy_with(host=request.url.host) + + # Facilitate relative 'Location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + if url.is_relative_url: + url = request.url.join(url) + + # Attach previous fragment if needed (RFC 7231 7.1.2) + if request.url.fragment and not url.fragment: + url = url.copy_with(fragment=request.url.fragment) + + return url + + def _redirect_headers(self, request: Request, url: URL, method: str) -> Headers: + """ + Return the headers that should be used for the redirect request. + """ + headers = Headers(request.headers) + + if not _same_origin(url, request.url): + if not _is_https_redirect(request.url, url): + # Strip Authorization headers when responses are redirected + # away from the origin. (Except for direct HTTP to HTTPS redirects.) + headers.pop("Authorization", None) + + # Update the Host header. + headers["Host"] = url.netloc.decode("ascii") + + if method != request.method and method == "GET": + # If we've switch to a 'GET' request, then strip any headers which + # are only relevant to the request body. + headers.pop("Content-Length", None) + headers.pop("Transfer-Encoding", None) + + # We should use the client cookie store to determine any cookie header, + # rather than whatever was on the original outgoing request. + headers.pop("Cookie", None) + + return headers + + def _redirect_stream( + self, request: Request, method: str + ) -> SyncByteStream | AsyncByteStream | None: + """ + Return the body that should be used for the redirect request. + """ + if method != request.method and method == "GET": + return None + + return request.stream + + def _set_timeout(self, request: Request) -> None: + if "timeout" not in request.extensions: + timeout = ( + self.timeout + if isinstance(self.timeout, UseClientDefault) + else Timeout(self.timeout) + ) + request.extensions = dict(**request.extensions, timeout=timeout.as_dict()) + + +class Client(BaseClient): + """ + An HTTP client, with connection pooling, HTTP/2, redirects, cookie persistence, etc. + + It can be shared between threads. + + Usage: + + ```python + >>> client = httpx.Client() + >>> response = client.get('https://example.org') + ``` + + **Parameters:** + + * **auth** - *(optional)* An authentication class to use when sending + requests. + * **params** - *(optional)* Query parameters to include in request URLs, as + a string, dictionary, or sequence of two-tuples. + * **headers** - *(optional)* Dictionary of HTTP headers to include when + sending requests. + * **cookies** - *(optional)* Dictionary of Cookie items to include when + sending requests. + * **verify** - *(optional)* Either `True` to use an SSL context with the + default CA bundle, `False` to disable verification, or an instance of + `ssl.SSLContext` to use a custom context. + * **http2** - *(optional)* A boolean indicating if HTTP/2 support should be + enabled. Defaults to `False`. + * **proxy** - *(optional)* A proxy URL where all the traffic should be routed. + * **timeout** - *(optional)* The timeout configuration to use when sending + requests. + * **limits** - *(optional)* The limits configuration to use. + * **max_redirects** - *(optional)* The maximum number of redirect responses + that should be followed. + * **base_url** - *(optional)* A URL to use as the base when building + request URLs. + * **transport** - *(optional)* A transport class to use for sending requests + over the network. + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + * **default_encoding** - *(optional)* The default encoding to use for decoding + response text, if no charset information is included in a response Content-Type + header. Set to a callable for automatic character set detection. Default: "utf-8". + """ + + def __init__( + self, + *, + auth: AuthTypes | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, + http1: bool = True, + http2: bool = False, + proxy: ProxyTypes | None = None, + mounts: None | (typing.Mapping[str, BaseTransport | None]) = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + limits: Limits = DEFAULT_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: None | (typing.Mapping[str, list[EventHook]]) = None, + base_url: URL | str = "", + transport: BaseTransport | None = None, + default_encoding: str | typing.Callable[[bytes], str] = "utf-8", + ) -> None: + super().__init__( + auth=auth, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + event_hooks=event_hooks, + base_url=base_url, + trust_env=trust_env, + default_encoding=default_encoding, + ) + + if http2: + try: + import h2 # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using http2=True, but the 'h2' package is not installed. " + "Make sure to install httpx using `pip install httpx[http2]`." + ) from None + + allow_env_proxies = trust_env and transport is None + proxy_map = self._get_proxy_map(proxy, allow_env_proxies) + + self._transport = self._init_transport( + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + transport=transport, + ) + self._mounts: dict[URLPattern, BaseTransport | None] = { + URLPattern(key): None + if proxy is None + else self._init_proxy_transport( + proxy, + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + ) + for key, proxy in proxy_map.items() + } + if mounts is not None: + self._mounts.update( + {URLPattern(key): transport for key, transport in mounts.items()} + ) + + self._mounts = dict(sorted(self._mounts.items())) + + def _init_transport( + self, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + transport: BaseTransport | None = None, + ) -> BaseTransport: + if transport is not None: + return transport + + return HTTPTransport( + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + ) + + def _init_proxy_transport( + self, + proxy: Proxy, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + ) -> BaseTransport: + return HTTPTransport( + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + proxy=proxy, + ) + + def _transport_for_url(self, url: URL) -> BaseTransport: + """ + Returns the transport instance that should be used for a given URL. + This will either be the standard connection pool, or a proxy. + """ + for pattern, transport in self._mounts.items(): + if pattern.matches(url): + return self._transport if transport is None else transport + + return self._transport + + def request( + self, + method: str, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Build and send a request. + + Equivalent to: + + ```python + request = client.build_request(...) + response = client.send(request, ...) + ``` + + See `Client.build_request()`, `Client.send()` and + [Merging of configuration][0] for how the various parameters + are merged with client-level configuration. + + [0]: /advanced/clients/#merging-of-configuration + """ + if cookies is not None: + message = ( + "Setting per-request cookies=<...> is being deprecated, because " + "the expected behaviour on cookie persistence is ambiguous. Set " + "cookies directly on the client instance instead." + ) + warnings.warn(message, DeprecationWarning, stacklevel=2) + + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + return self.send(request, auth=auth, follow_redirects=follow_redirects) + + @contextmanager + def stream( + self, + method: str, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> typing.Iterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + response = self.send( + request=request, + auth=auth, + follow_redirects=follow_redirects, + stream=True, + ) + try: + yield response + finally: + response.close() + + def send( + self, + request: Request, + *, + stream: bool = False, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + ) -> Response: + """ + Send a request. + + The request is sent as-is, unmodified. + + Typically you'll want to build one with `Client.build_request()` + so that any client-level configuration is merged into the request, + but passing an explicit `httpx.Request()` is supported as well. + + See also: [Request instances][0] + + [0]: /advanced/clients/#request-instances + """ + if self._state == ClientState.CLOSED: + raise RuntimeError("Cannot send a request, as the client has been closed.") + + self._state = ClientState.OPENED + follow_redirects = ( + self.follow_redirects + if isinstance(follow_redirects, UseClientDefault) + else follow_redirects + ) + + self._set_timeout(request) + + auth = self._build_request_auth(request, auth) + + response = self._send_handling_auth( + request, + auth=auth, + follow_redirects=follow_redirects, + history=[], + ) + try: + if not stream: + response.read() + + return response + + except BaseException as exc: + response.close() + raise exc + + def _send_handling_auth( + self, + request: Request, + auth: Auth, + follow_redirects: bool, + history: list[Response], + ) -> Response: + auth_flow = auth.sync_auth_flow(request) + try: + request = next(auth_flow) + + while True: + response = self._send_handling_redirects( + request, + follow_redirects=follow_redirects, + history=history, + ) + try: + try: + next_request = auth_flow.send(response) + except StopIteration: + return response + + response.history = list(history) + response.read() + request = next_request + history.append(response) + + except BaseException as exc: + response.close() + raise exc + finally: + auth_flow.close() + + def _send_handling_redirects( + self, + request: Request, + follow_redirects: bool, + history: list[Response], + ) -> Response: + while True: + if len(history) > self.max_redirects: + raise TooManyRedirects( + "Exceeded maximum allowed redirects.", request=request + ) + + for hook in self._event_hooks["request"]: + hook(request) + + response = self._send_single_request(request) + try: + for hook in self._event_hooks["response"]: + hook(response) + response.history = list(history) + + if not response.has_redirect_location: + return response + + request = self._build_redirect_request(request, response) + history = history + [response] + + if follow_redirects: + response.read() + else: + response.next_request = request + return response + + except BaseException as exc: + response.close() + raise exc + + def _send_single_request(self, request: Request) -> Response: + """ + Sends a single request, without handling any redirections. + """ + transport = self._transport_for_url(request.url) + start = time.perf_counter() + + if not isinstance(request.stream, SyncByteStream): + raise RuntimeError( + "Attempted to send an async request with a sync Client instance." + ) + + with request_context(request=request): + response = transport.handle_request(request) + + assert isinstance(response.stream, SyncByteStream) + + response.request = request + response.stream = BoundSyncStream( + response.stream, response=response, start=start + ) + self.cookies.extract_cookies(response) + response.default_encoding = self._default_encoding + + logger.info( + 'HTTP Request: %s %s "%s %d %s"', + request.method, + request.url, + response.http_version, + response.status_code, + response.reason_phrase, + ) + + return response + + def get( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `GET` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def options( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def head( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `HEAD` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def post( + self, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def put( + self, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def patch( + self, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def delete( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `DELETE` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def close(self) -> None: + """ + Close transport and proxies. + """ + if self._state != ClientState.CLOSED: + self._state = ClientState.CLOSED + + self._transport.close() + for transport in self._mounts.values(): + if transport is not None: + transport.close() + + def __enter__(self: T) -> T: + if self._state != ClientState.UNOPENED: + msg = { + ClientState.OPENED: "Cannot open a client instance more than once.", + ClientState.CLOSED: ( + "Cannot reopen a client instance, once it has been closed." + ), + }[self._state] + raise RuntimeError(msg) + + self._state = ClientState.OPENED + + self._transport.__enter__() + for transport in self._mounts.values(): + if transport is not None: + transport.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + self._state = ClientState.CLOSED + + self._transport.__exit__(exc_type, exc_value, traceback) + for transport in self._mounts.values(): + if transport is not None: + transport.__exit__(exc_type, exc_value, traceback) + + +class AsyncClient(BaseClient): + """ + An asynchronous HTTP client, with connection pooling, HTTP/2, redirects, + cookie persistence, etc. + + It can be shared between tasks. + + Usage: + + ```python + >>> async with httpx.AsyncClient() as client: + >>> response = await client.get('https://example.org') + ``` + + **Parameters:** + + * **auth** - *(optional)* An authentication class to use when sending + requests. + * **params** - *(optional)* Query parameters to include in request URLs, as + a string, dictionary, or sequence of two-tuples. + * **headers** - *(optional)* Dictionary of HTTP headers to include when + sending requests. + * **cookies** - *(optional)* Dictionary of Cookie items to include when + sending requests. + * **verify** - *(optional)* Either `True` to use an SSL context with the + default CA bundle, `False` to disable verification, or an instance of + `ssl.SSLContext` to use a custom context. + * **http2** - *(optional)* A boolean indicating if HTTP/2 support should be + enabled. Defaults to `False`. + * **proxy** - *(optional)* A proxy URL where all the traffic should be routed. + * **timeout** - *(optional)* The timeout configuration to use when sending + requests. + * **limits** - *(optional)* The limits configuration to use. + * **max_redirects** - *(optional)* The maximum number of redirect responses + that should be followed. + * **base_url** - *(optional)* A URL to use as the base when building + request URLs. + * **transport** - *(optional)* A transport class to use for sending requests + over the network. + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + * **default_encoding** - *(optional)* The default encoding to use for decoding + response text, if no charset information is included in a response Content-Type + header. Set to a callable for automatic character set detection. Default: "utf-8". + """ + + def __init__( + self, + *, + auth: AuthTypes | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + http1: bool = True, + http2: bool = False, + proxy: ProxyTypes | None = None, + mounts: None | (typing.Mapping[str, AsyncBaseTransport | None]) = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + limits: Limits = DEFAULT_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: None | (typing.Mapping[str, list[EventHook]]) = None, + base_url: URL | str = "", + transport: AsyncBaseTransport | None = None, + trust_env: bool = True, + default_encoding: str | typing.Callable[[bytes], str] = "utf-8", + ) -> None: + super().__init__( + auth=auth, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + event_hooks=event_hooks, + base_url=base_url, + trust_env=trust_env, + default_encoding=default_encoding, + ) + + if http2: + try: + import h2 # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using http2=True, but the 'h2' package is not installed. " + "Make sure to install httpx using `pip install httpx[http2]`." + ) from None + + allow_env_proxies = trust_env and transport is None + proxy_map = self._get_proxy_map(proxy, allow_env_proxies) + + self._transport = self._init_transport( + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + transport=transport, + ) + + self._mounts: dict[URLPattern, AsyncBaseTransport | None] = { + URLPattern(key): None + if proxy is None + else self._init_proxy_transport( + proxy, + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + ) + for key, proxy in proxy_map.items() + } + if mounts is not None: + self._mounts.update( + {URLPattern(key): transport for key, transport in mounts.items()} + ) + self._mounts = dict(sorted(self._mounts.items())) + + def _init_transport( + self, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + transport: AsyncBaseTransport | None = None, + ) -> AsyncBaseTransport: + if transport is not None: + return transport + + return AsyncHTTPTransport( + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + ) + + def _init_proxy_transport( + self, + proxy: Proxy, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + ) -> AsyncBaseTransport: + return AsyncHTTPTransport( + verify=verify, + cert=cert, + trust_env=trust_env, + http1=http1, + http2=http2, + limits=limits, + proxy=proxy, + ) + + def _transport_for_url(self, url: URL) -> AsyncBaseTransport: + """ + Returns the transport instance that should be used for a given URL. + This will either be the standard connection pool, or a proxy. + """ + for pattern, transport in self._mounts.items(): + if pattern.matches(url): + return self._transport if transport is None else transport + + return self._transport + + async def request( + self, + method: str, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Build and send a request. + + Equivalent to: + + ```python + request = client.build_request(...) + response = await client.send(request, ...) + ``` + + See `AsyncClient.build_request()`, `AsyncClient.send()` + and [Merging of configuration][0] for how the various parameters + are merged with client-level configuration. + + [0]: /advanced/clients/#merging-of-configuration + """ + + if cookies is not None: # pragma: no cover + message = ( + "Setting per-request cookies=<...> is being deprecated, because " + "the expected behaviour on cookie persistence is ambiguous. Set " + "cookies directly on the client instance instead." + ) + warnings.warn(message, DeprecationWarning, stacklevel=2) + + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + return await self.send(request, auth=auth, follow_redirects=follow_redirects) + + @asynccontextmanager + async def stream( + self, + method: str, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> typing.AsyncIterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + response = await self.send( + request=request, + auth=auth, + follow_redirects=follow_redirects, + stream=True, + ) + try: + yield response + finally: + await response.aclose() + + async def send( + self, + request: Request, + *, + stream: bool = False, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + ) -> Response: + """ + Send a request. + + The request is sent as-is, unmodified. + + Typically you'll want to build one with `AsyncClient.build_request()` + so that any client-level configuration is merged into the request, + but passing an explicit `httpx.Request()` is supported as well. + + See also: [Request instances][0] + + [0]: /advanced/clients/#request-instances + """ + if self._state == ClientState.CLOSED: + raise RuntimeError("Cannot send a request, as the client has been closed.") + + self._state = ClientState.OPENED + follow_redirects = ( + self.follow_redirects + if isinstance(follow_redirects, UseClientDefault) + else follow_redirects + ) + + self._set_timeout(request) + + auth = self._build_request_auth(request, auth) + + response = await self._send_handling_auth( + request, + auth=auth, + follow_redirects=follow_redirects, + history=[], + ) + try: + if not stream: + await response.aread() + + return response + + except BaseException as exc: + await response.aclose() + raise exc + + async def _send_handling_auth( + self, + request: Request, + auth: Auth, + follow_redirects: bool, + history: list[Response], + ) -> Response: + auth_flow = auth.async_auth_flow(request) + try: + request = await auth_flow.__anext__() + + while True: + response = await self._send_handling_redirects( + request, + follow_redirects=follow_redirects, + history=history, + ) + try: + try: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: + return response + + response.history = list(history) + await response.aread() + request = next_request + history.append(response) + + except BaseException as exc: + await response.aclose() + raise exc + finally: + await auth_flow.aclose() + + async def _send_handling_redirects( + self, + request: Request, + follow_redirects: bool, + history: list[Response], + ) -> Response: + while True: + if len(history) > self.max_redirects: + raise TooManyRedirects( + "Exceeded maximum allowed redirects.", request=request + ) + + for hook in self._event_hooks["request"]: + await hook(request) + + response = await self._send_single_request(request) + try: + for hook in self._event_hooks["response"]: + await hook(response) + + response.history = list(history) + + if not response.has_redirect_location: + return response + + request = self._build_redirect_request(request, response) + history = history + [response] + + if follow_redirects: + await response.aread() + else: + response.next_request = request + return response + + except BaseException as exc: + await response.aclose() + raise exc + + async def _send_single_request(self, request: Request) -> Response: + """ + Sends a single request, without handling any redirections. + """ + transport = self._transport_for_url(request.url) + start = time.perf_counter() + + if not isinstance(request.stream, AsyncByteStream): + raise RuntimeError( + "Attempted to send an sync request with an AsyncClient instance." + ) + + with request_context(request=request): + response = await transport.handle_async_request(request) + + assert isinstance(response.stream, AsyncByteStream) + response.request = request + response.stream = BoundAsyncStream( + response.stream, response=response, start=start + ) + self.cookies.extract_cookies(response) + response.default_encoding = self._default_encoding + + logger.info( + 'HTTP Request: %s %s "%s %d %s"', + request.method, + request.url, + response.http_version, + response.status_code, + response.reason_phrase, + ) + + return response + + async def get( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `GET` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def options( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def head( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `HEAD` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def post( + self, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def put( + self, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def patch( + self, + url: URL | str, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def delete( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Response: + """ + Send a `DELETE` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def aclose(self) -> None: + """ + Close transport and proxies. + """ + if self._state != ClientState.CLOSED: + self._state = ClientState.CLOSED + + await self._transport.aclose() + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.aclose() + + async def __aenter__(self: U) -> U: + if self._state != ClientState.UNOPENED: + msg = { + ClientState.OPENED: "Cannot open a client instance more than once.", + ClientState.CLOSED: ( + "Cannot reopen a client instance, once it has been closed." + ), + }[self._state] + raise RuntimeError(msg) + + self._state = ClientState.OPENED + + await self._transport.__aenter__() + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + self._state = ClientState.CLOSED + + await self._transport.__aexit__(exc_type, exc_value, traceback) + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.__aexit__(exc_type, exc_value, traceback) diff --git a/venv/lib/python3.10/site-packages/httpx/_config.py b/venv/lib/python3.10/site-packages/httpx/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..467a6c90ae269babe3af7963d9d7c78b9f012268 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_config.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import os +import typing + +from ._models import Headers +from ._types import CertTypes, HeaderTypes, TimeoutTypes +from ._urls import URL + +if typing.TYPE_CHECKING: + import ssl # pragma: no cover + +__all__ = ["Limits", "Proxy", "Timeout", "create_ssl_context"] + + +class UnsetType: + pass # pragma: no cover + + +UNSET = UnsetType() + + +def create_ssl_context( + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, +) -> ssl.SSLContext: + import ssl + import warnings + + import certifi + + if verify is True: + if trust_env and os.environ.get("SSL_CERT_FILE"): # pragma: nocover + ctx = ssl.create_default_context(cafile=os.environ["SSL_CERT_FILE"]) + elif trust_env and os.environ.get("SSL_CERT_DIR"): # pragma: nocover + ctx = ssl.create_default_context(capath=os.environ["SSL_CERT_DIR"]) + else: + # Default case... + ctx = ssl.create_default_context(cafile=certifi.where()) + elif verify is False: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + elif isinstance(verify, str): # pragma: nocover + message = ( + "`verify=` is deprecated. " + "Use `verify=ssl.create_default_context(cafile=...)` " + "or `verify=ssl.create_default_context(capath=...)` instead." + ) + warnings.warn(message, DeprecationWarning) + if os.path.isdir(verify): + return ssl.create_default_context(capath=verify) + return ssl.create_default_context(cafile=verify) + else: + ctx = verify + + if cert: # pragma: nocover + message = ( + "`cert=...` is deprecated. Use `verify=` instead," + "with `.load_cert_chain()` to configure the certificate chain." + ) + warnings.warn(message, DeprecationWarning) + if isinstance(cert, str): + ctx.load_cert_chain(cert) + else: + ctx.load_cert_chain(*cert) + + return ctx + + +class Timeout: + """ + Timeout configuration. + + **Usage**: + + Timeout(None) # No timeouts. + Timeout(5.0) # 5s timeout on all operations. + Timeout(None, connect=5.0) # 5s timeout on connect, no other timeouts. + Timeout(5.0, connect=10.0) # 10s timeout on connect. 5s timeout elsewhere. + Timeout(5.0, pool=None) # No timeout on acquiring connection from pool. + # 5s timeout elsewhere. + """ + + def __init__( + self, + timeout: TimeoutTypes | UnsetType = UNSET, + *, + connect: None | float | UnsetType = UNSET, + read: None | float | UnsetType = UNSET, + write: None | float | UnsetType = UNSET, + pool: None | float | UnsetType = UNSET, + ) -> None: + if isinstance(timeout, Timeout): + # Passed as a single explicit Timeout. + assert connect is UNSET + assert read is UNSET + assert write is UNSET + assert pool is UNSET + self.connect = timeout.connect # type: typing.Optional[float] + self.read = timeout.read # type: typing.Optional[float] + self.write = timeout.write # type: typing.Optional[float] + self.pool = timeout.pool # type: typing.Optional[float] + elif isinstance(timeout, tuple): + # Passed as a tuple. + self.connect = timeout[0] + self.read = timeout[1] + self.write = None if len(timeout) < 3 else timeout[2] + self.pool = None if len(timeout) < 4 else timeout[3] + elif not ( + isinstance(connect, UnsetType) + or isinstance(read, UnsetType) + or isinstance(write, UnsetType) + or isinstance(pool, UnsetType) + ): + self.connect = connect + self.read = read + self.write = write + self.pool = pool + else: + if isinstance(timeout, UnsetType): + raise ValueError( + "httpx.Timeout must either include a default, or set all " + "four parameters explicitly." + ) + self.connect = timeout if isinstance(connect, UnsetType) else connect + self.read = timeout if isinstance(read, UnsetType) else read + self.write = timeout if isinstance(write, UnsetType) else write + self.pool = timeout if isinstance(pool, UnsetType) else pool + + def as_dict(self) -> dict[str, float | None]: + return { + "connect": self.connect, + "read": self.read, + "write": self.write, + "pool": self.pool, + } + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.connect == other.connect + and self.read == other.read + and self.write == other.write + and self.pool == other.pool + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + if len({self.connect, self.read, self.write, self.pool}) == 1: + return f"{class_name}(timeout={self.connect})" + return ( + f"{class_name}(connect={self.connect}, " + f"read={self.read}, write={self.write}, pool={self.pool})" + ) + + +class Limits: + """ + Configuration for limits to various client behaviors. + + **Parameters:** + + * **max_connections** - The maximum number of concurrent connections that may be + established. + * **max_keepalive_connections** - Allow the connection pool to maintain + keep-alive connections below this point. Should be less than or equal + to `max_connections`. + * **keepalive_expiry** - Time limit on idle keep-alive connections in seconds. + """ + + def __init__( + self, + *, + max_connections: int | None = None, + max_keepalive_connections: int | None = None, + keepalive_expiry: float | None = 5.0, + ) -> None: + self.max_connections = max_connections + self.max_keepalive_connections = max_keepalive_connections + self.keepalive_expiry = keepalive_expiry + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.max_connections == other.max_connections + and self.max_keepalive_connections == other.max_keepalive_connections + and self.keepalive_expiry == other.keepalive_expiry + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return ( + f"{class_name}(max_connections={self.max_connections}, " + f"max_keepalive_connections={self.max_keepalive_connections}, " + f"keepalive_expiry={self.keepalive_expiry})" + ) + + +class Proxy: + def __init__( + self, + url: URL | str, + *, + ssl_context: ssl.SSLContext | None = None, + auth: tuple[str, str] | None = None, + headers: HeaderTypes | None = None, + ) -> None: + url = URL(url) + headers = Headers(headers) + + if url.scheme not in ("http", "https", "socks5", "socks5h"): + raise ValueError(f"Unknown scheme for proxy URL {url!r}") + + if url.username or url.password: + # Remove any auth credentials from the URL. + auth = (url.username, url.password) + url = url.copy_with(username=None, password=None) + + self.url = url + self.auth = auth + self.headers = headers + self.ssl_context = ssl_context + + @property + def raw_auth(self) -> tuple[bytes, bytes] | None: + # The proxy authentication as raw bytes. + return ( + None + if self.auth is None + else (self.auth[0].encode("utf-8"), self.auth[1].encode("utf-8")) + ) + + def __repr__(self) -> str: + # The authentication is represented with the password component masked. + auth = (self.auth[0], "********") if self.auth else None + + # Build a nice concise representation. + url_str = f"{str(self.url)!r}" + auth_str = f", auth={auth!r}" if auth else "" + headers_str = f", headers={dict(self.headers)!r}" if self.headers else "" + return f"Proxy({url_str}{auth_str}{headers_str})" + + +DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0) +DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20) +DEFAULT_MAX_REDIRECTS = 20 diff --git a/venv/lib/python3.10/site-packages/httpx/_content.py b/venv/lib/python3.10/site-packages/httpx/_content.py new file mode 100644 index 0000000000000000000000000000000000000000..6f479a0885f723b7395843d41164a87041820776 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_content.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import inspect +import warnings +from json import dumps as json_dumps +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + Mapping, +) +from urllib.parse import urlencode + +from ._exceptions import StreamClosed, StreamConsumed +from ._multipart import MultipartStream +from ._types import ( + AsyncByteStream, + RequestContent, + RequestData, + RequestFiles, + ResponseContent, + SyncByteStream, +) +from ._utils import peek_filelike_length, primitive_value_to_str + +__all__ = ["ByteStream"] + + +class ByteStream(AsyncByteStream, SyncByteStream): + def __init__(self, stream: bytes) -> None: + self._stream = stream + + def __iter__(self) -> Iterator[bytes]: + yield self._stream + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._stream + + +class IteratorByteStream(SyncByteStream): + CHUNK_SIZE = 65_536 + + def __init__(self, stream: Iterable[bytes]) -> None: + self._stream = stream + self._is_stream_consumed = False + self._is_generator = inspect.isgenerator(stream) + + def __iter__(self) -> Iterator[bytes]: + if self._is_stream_consumed and self._is_generator: + raise StreamConsumed() + + self._is_stream_consumed = True + if hasattr(self._stream, "read"): + # File-like interfaces should use 'read' directly. + chunk = self._stream.read(self.CHUNK_SIZE) + while chunk: + yield chunk + chunk = self._stream.read(self.CHUNK_SIZE) + else: + # Otherwise iterate. + for part in self._stream: + yield part + + +class AsyncIteratorByteStream(AsyncByteStream): + CHUNK_SIZE = 65_536 + + def __init__(self, stream: AsyncIterable[bytes]) -> None: + self._stream = stream + self._is_stream_consumed = False + self._is_generator = inspect.isasyncgen(stream) + + async def __aiter__(self) -> AsyncIterator[bytes]: + if self._is_stream_consumed and self._is_generator: + raise StreamConsumed() + + self._is_stream_consumed = True + if hasattr(self._stream, "aread"): + # File-like interfaces should use 'aread' directly. + chunk = await self._stream.aread(self.CHUNK_SIZE) + while chunk: + yield chunk + chunk = await self._stream.aread(self.CHUNK_SIZE) + else: + # Otherwise iterate. + async for part in self._stream: + yield part + + +class UnattachedStream(AsyncByteStream, SyncByteStream): + """ + If a request or response is serialized using pickle, then it is no longer + attached to a stream for I/O purposes. Any stream operations should result + in `httpx.StreamClosed`. + """ + + def __iter__(self) -> Iterator[bytes]: + raise StreamClosed() + + async def __aiter__(self) -> AsyncIterator[bytes]: + raise StreamClosed() + yield b"" # pragma: no cover + + +def encode_content( + content: str | bytes | Iterable[bytes] | AsyncIterable[bytes], +) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]: + if isinstance(content, (bytes, str)): + body = content.encode("utf-8") if isinstance(content, str) else content + content_length = len(body) + headers = {"Content-Length": str(content_length)} if body else {} + return headers, ByteStream(body) + + elif isinstance(content, Iterable) and not isinstance(content, dict): + # `not isinstance(content, dict)` is a bit oddly specific, but it + # catches a case that's easy for users to make in error, and would + # otherwise pass through here, like any other bytes-iterable, + # because `dict` happens to be iterable. See issue #2491. + content_length_or_none = peek_filelike_length(content) + + if content_length_or_none is None: + headers = {"Transfer-Encoding": "chunked"} + else: + headers = {"Content-Length": str(content_length_or_none)} + return headers, IteratorByteStream(content) # type: ignore + + elif isinstance(content, AsyncIterable): + headers = {"Transfer-Encoding": "chunked"} + return headers, AsyncIteratorByteStream(content) + + raise TypeError(f"Unexpected type for 'content', {type(content)!r}") + + +def encode_urlencoded_data( + data: RequestData, +) -> tuple[dict[str, str], ByteStream]: + plain_data = [] + for key, value in data.items(): + if isinstance(value, (list, tuple)): + plain_data.extend([(key, primitive_value_to_str(item)) for item in value]) + else: + plain_data.append((key, primitive_value_to_str(value))) + body = urlencode(plain_data, doseq=True).encode("utf-8") + content_length = str(len(body)) + content_type = "application/x-www-form-urlencoded" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_multipart_data( + data: RequestData, files: RequestFiles, boundary: bytes | None +) -> tuple[dict[str, str], MultipartStream]: + multipart = MultipartStream(data=data, files=files, boundary=boundary) + headers = multipart.get_headers() + return headers, multipart + + +def encode_text(text: str) -> tuple[dict[str, str], ByteStream]: + body = text.encode("utf-8") + content_length = str(len(body)) + content_type = "text/plain; charset=utf-8" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_html(html: str) -> tuple[dict[str, str], ByteStream]: + body = html.encode("utf-8") + content_length = str(len(body)) + content_type = "text/html; charset=utf-8" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_json(json: Any) -> tuple[dict[str, str], ByteStream]: + body = json_dumps( + json, ensure_ascii=False, separators=(",", ":"), allow_nan=False + ).encode("utf-8") + content_length = str(len(body)) + content_type = "application/json" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_request( + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + boundary: bytes | None = None, +) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]: + """ + Handles encoding the given `content`, `data`, `files`, and `json`, + returning a two-tuple of (, ). + """ + if data is not None and not isinstance(data, Mapping): + # We prefer to separate `content=` + # for raw request content, and `data=
` for url encoded or + # multipart form content. + # + # However for compat with requests, we *do* still support + # `data=` usages. We deal with that case here, treating it + # as if `content=<...>` had been supplied instead. + message = "Use 'content=<...>' to upload raw bytes/text content." + warnings.warn(message, DeprecationWarning, stacklevel=2) + return encode_content(data) + + if content is not None: + return encode_content(content) + elif files: + return encode_multipart_data(data or {}, files, boundary) + elif data: + return encode_urlencoded_data(data) + elif json is not None: + return encode_json(json) + + return {}, ByteStream(b"") + + +def encode_response( + content: ResponseContent | None = None, + text: str | None = None, + html: str | None = None, + json: Any | None = None, +) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]: + """ + Handles encoding the given `content`, returning a two-tuple of + (, ). + """ + if content is not None: + return encode_content(content) + elif text is not None: + return encode_text(text) + elif html is not None: + return encode_html(html) + elif json is not None: + return encode_json(json) + + return {}, ByteStream(b"") diff --git a/venv/lib/python3.10/site-packages/httpx/_decoders.py b/venv/lib/python3.10/site-packages/httpx/_decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..899dfada878e1181fca6d3c75a79526a076abb9e --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_decoders.py @@ -0,0 +1,393 @@ +""" +Handlers for Content-Encoding. + +See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding +""" + +from __future__ import annotations + +import codecs +import io +import typing +import zlib + +from ._exceptions import DecodingError + +# Brotli support is optional +try: + # The C bindings in `brotli` are recommended for CPython. + import brotli +except ImportError: # pragma: no cover + try: + # The CFFI bindings in `brotlicffi` are recommended for PyPy + # and other environments. + import brotlicffi as brotli + except ImportError: + brotli = None + + +# Zstandard support is optional +try: + import zstandard +except ImportError: # pragma: no cover + zstandard = None # type: ignore + + +class ContentDecoder: + def decode(self, data: bytes) -> bytes: + raise NotImplementedError() # pragma: no cover + + def flush(self) -> bytes: + raise NotImplementedError() # pragma: no cover + + +class IdentityDecoder(ContentDecoder): + """ + Handle unencoded data. + """ + + def decode(self, data: bytes) -> bytes: + return data + + def flush(self) -> bytes: + return b"" + + +class DeflateDecoder(ContentDecoder): + """ + Handle 'deflate' decoding. + + See: https://stackoverflow.com/questions/1838699 + """ + + def __init__(self) -> None: + self.first_attempt = True + self.decompressor = zlib.decompressobj() + + def decode(self, data: bytes) -> bytes: + was_first_attempt = self.first_attempt + self.first_attempt = False + try: + return self.decompressor.decompress(data) + except zlib.error as exc: + if was_first_attempt: + self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) + return self.decode(data) + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + try: + return self.decompressor.flush() + except zlib.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class GZipDecoder(ContentDecoder): + """ + Handle 'gzip' decoding. + + See: https://stackoverflow.com/questions/1838699 + """ + + def __init__(self) -> None: + self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) + + def decode(self, data: bytes) -> bytes: + try: + return self.decompressor.decompress(data) + except zlib.error as exc: + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + try: + return self.decompressor.flush() + except zlib.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class BrotliDecoder(ContentDecoder): + """ + Handle 'brotli' decoding. + + Requires `pip install brotlipy`. See: https://brotlipy.readthedocs.io/ + or `pip install brotli`. See https://github.com/google/brotli + Supports both 'brotlipy' and 'Brotli' packages since they share an import + name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' + """ + + def __init__(self) -> None: + if brotli is None: # pragma: no cover + raise ImportError( + "Using 'BrotliDecoder', but neither of the 'brotlicffi' or 'brotli' " + "packages have been installed. " + "Make sure to install httpx using `pip install httpx[brotli]`." + ) from None + + self.decompressor = brotli.Decompressor() + self.seen_data = False + self._decompress: typing.Callable[[bytes], bytes] + if hasattr(self.decompressor, "decompress"): + # The 'brotlicffi' package. + self._decompress = self.decompressor.decompress # pragma: no cover + else: + # The 'brotli' package. + self._decompress = self.decompressor.process # pragma: no cover + + def decode(self, data: bytes) -> bytes: + if not data: + return b"" + self.seen_data = True + try: + return self._decompress(data) + except brotli.error as exc: + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + if not self.seen_data: + return b"" + try: + if hasattr(self.decompressor, "finish"): + # Only available in the 'brotlicffi' package. + + # As the decompressor decompresses eagerly, this + # will never actually emit any data. However, it will potentially throw + # errors if a truncated or damaged data stream has been used. + self.decompressor.finish() # pragma: no cover + return b"" + except brotli.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class ZStandardDecoder(ContentDecoder): + """ + Handle 'zstd' RFC 8878 decoding. + + Requires `pip install zstandard`. + Can be installed as a dependency of httpx using `pip install httpx[zstd]`. + """ + + # inspired by the ZstdDecoder implementation in urllib3 + def __init__(self) -> None: + if zstandard is None: # pragma: no cover + raise ImportError( + "Using 'ZStandardDecoder', ..." + "Make sure to install httpx using `pip install httpx[zstd]`." + ) from None + + self.decompressor = zstandard.ZstdDecompressor().decompressobj() + self.seen_data = False + + def decode(self, data: bytes) -> bytes: + assert zstandard is not None + self.seen_data = True + output = io.BytesIO() + try: + output.write(self.decompressor.decompress(data)) + while self.decompressor.eof and self.decompressor.unused_data: + unused_data = self.decompressor.unused_data + self.decompressor = zstandard.ZstdDecompressor().decompressobj() + output.write(self.decompressor.decompress(unused_data)) + except zstandard.ZstdError as exc: + raise DecodingError(str(exc)) from exc + return output.getvalue() + + def flush(self) -> bytes: + if not self.seen_data: + return b"" + ret = self.decompressor.flush() # note: this is a no-op + if not self.decompressor.eof: + raise DecodingError("Zstandard data is incomplete") # pragma: no cover + return bytes(ret) + + +class MultiDecoder(ContentDecoder): + """ + Handle the case where multiple encodings have been applied. + """ + + def __init__(self, children: typing.Sequence[ContentDecoder]) -> None: + """ + 'children' should be a sequence of decoders in the order in which + each was applied. + """ + # Note that we reverse the order for decoding. + self.children = list(reversed(children)) + + def decode(self, data: bytes) -> bytes: + for child in self.children: + data = child.decode(data) + return data + + def flush(self) -> bytes: + data = b"" + for child in self.children: + data = child.decode(data) + child.flush() + return data + + +class ByteChunker: + """ + Handles returning byte content in fixed-size chunks. + """ + + def __init__(self, chunk_size: int | None = None) -> None: + self._buffer = io.BytesIO() + self._chunk_size = chunk_size + + def decode(self, content: bytes) -> list[bytes]: + if self._chunk_size is None: + return [content] if content else [] + + self._buffer.write(content) + if self._buffer.tell() >= self._chunk_size: + value = self._buffer.getvalue() + chunks = [ + value[i : i + self._chunk_size] + for i in range(0, len(value), self._chunk_size) + ] + if len(chunks[-1]) == self._chunk_size: + self._buffer.seek(0) + self._buffer.truncate() + return chunks + else: + self._buffer.seek(0) + self._buffer.write(chunks[-1]) + self._buffer.truncate() + return chunks[:-1] + else: + return [] + + def flush(self) -> list[bytes]: + value = self._buffer.getvalue() + self._buffer.seek(0) + self._buffer.truncate() + return [value] if value else [] + + +class TextChunker: + """ + Handles returning text content in fixed-size chunks. + """ + + def __init__(self, chunk_size: int | None = None) -> None: + self._buffer = io.StringIO() + self._chunk_size = chunk_size + + def decode(self, content: str) -> list[str]: + if self._chunk_size is None: + return [content] if content else [] + + self._buffer.write(content) + if self._buffer.tell() >= self._chunk_size: + value = self._buffer.getvalue() + chunks = [ + value[i : i + self._chunk_size] + for i in range(0, len(value), self._chunk_size) + ] + if len(chunks[-1]) == self._chunk_size: + self._buffer.seek(0) + self._buffer.truncate() + return chunks + else: + self._buffer.seek(0) + self._buffer.write(chunks[-1]) + self._buffer.truncate() + return chunks[:-1] + else: + return [] + + def flush(self) -> list[str]: + value = self._buffer.getvalue() + self._buffer.seek(0) + self._buffer.truncate() + return [value] if value else [] + + +class TextDecoder: + """ + Handles incrementally decoding bytes into text + """ + + def __init__(self, encoding: str = "utf-8") -> None: + self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace") + + def decode(self, data: bytes) -> str: + return self.decoder.decode(data) + + def flush(self) -> str: + return self.decoder.decode(b"", True) + + +class LineDecoder: + """ + Handles incrementally reading lines from text. + + Has the same behaviour as the stdllib splitlines, + but handling the input iteratively. + """ + + def __init__(self) -> None: + self.buffer: list[str] = [] + self.trailing_cr: bool = False + + def decode(self, text: str) -> list[str]: + # See https://docs.python.org/3/library/stdtypes.html#str.splitlines + NEWLINE_CHARS = "\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029" + + # We always push a trailing `\r` into the next decode iteration. + if self.trailing_cr: + text = "\r" + text + self.trailing_cr = False + if text.endswith("\r"): + self.trailing_cr = True + text = text[:-1] + + if not text: + # NOTE: the edge case input of empty text doesn't occur in practice, + # because other httpx internals filter out this value + return [] # pragma: no cover + + trailing_newline = text[-1] in NEWLINE_CHARS + lines = text.splitlines() + + if len(lines) == 1 and not trailing_newline: + # No new lines, buffer the input and continue. + self.buffer.append(lines[0]) + return [] + + if self.buffer: + # Include any existing buffer in the first portion of the + # splitlines result. + lines = ["".join(self.buffer) + lines[0]] + lines[1:] + self.buffer = [] + + if not trailing_newline: + # If the last segment of splitlines is not newline terminated, + # then drop it from our output and start a new buffer. + self.buffer = [lines.pop()] + + return lines + + def flush(self) -> list[str]: + if not self.buffer and not self.trailing_cr: + return [] + + lines = ["".join(self.buffer)] + self.buffer = [] + self.trailing_cr = False + return lines + + +SUPPORTED_DECODERS = { + "identity": IdentityDecoder, + "gzip": GZipDecoder, + "deflate": DeflateDecoder, + "br": BrotliDecoder, + "zstd": ZStandardDecoder, +} + + +if brotli is None: + SUPPORTED_DECODERS.pop("br") # pragma: no cover +if zstandard is None: + SUPPORTED_DECODERS.pop("zstd") # pragma: no cover diff --git a/venv/lib/python3.10/site-packages/httpx/_exceptions.py b/venv/lib/python3.10/site-packages/httpx/_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..77f45a6d3986d15626fc8a5fd459d6a3e0fbe466 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_exceptions.py @@ -0,0 +1,379 @@ +""" +Our exception hierarchy: + +* HTTPError + x RequestError + + TransportError + - TimeoutException + · ConnectTimeout + · ReadTimeout + · WriteTimeout + · PoolTimeout + - NetworkError + · ConnectError + · ReadError + · WriteError + · CloseError + - ProtocolError + · LocalProtocolError + · RemoteProtocolError + - ProxyError + - UnsupportedProtocol + + DecodingError + + TooManyRedirects + x HTTPStatusError +* InvalidURL +* CookieConflict +* StreamError + x StreamConsumed + x StreamClosed + x ResponseNotRead + x RequestNotRead +""" + +from __future__ import annotations + +import contextlib +import typing + +if typing.TYPE_CHECKING: + from ._models import Request, Response # pragma: no cover + +__all__ = [ + "CloseError", + "ConnectError", + "ConnectTimeout", + "CookieConflict", + "DecodingError", + "HTTPError", + "HTTPStatusError", + "InvalidURL", + "LocalProtocolError", + "NetworkError", + "PoolTimeout", + "ProtocolError", + "ProxyError", + "ReadError", + "ReadTimeout", + "RemoteProtocolError", + "RequestError", + "RequestNotRead", + "ResponseNotRead", + "StreamClosed", + "StreamConsumed", + "StreamError", + "TimeoutException", + "TooManyRedirects", + "TransportError", + "UnsupportedProtocol", + "WriteError", + "WriteTimeout", +] + + +class HTTPError(Exception): + """ + Base class for `RequestError` and `HTTPStatusError`. + + Useful for `try...except` blocks when issuing a request, + and then calling `.raise_for_status()`. + + For example: + + ``` + try: + response = httpx.get("https://www.example.com") + response.raise_for_status() + except httpx.HTTPError as exc: + print(f"HTTP Exception for {exc.request.url} - {exc}") + ``` + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self._request: Request | None = None + + @property + def request(self) -> Request: + if self._request is None: + raise RuntimeError("The .request property has not been set.") + return self._request + + @request.setter + def request(self, request: Request) -> None: + self._request = request + + +class RequestError(HTTPError): + """ + Base class for all exceptions that may occur when issuing a `.request()`. + """ + + def __init__(self, message: str, *, request: Request | None = None) -> None: + super().__init__(message) + # At the point an exception is raised we won't typically have a request + # instance to associate it with. + # + # The 'request_context' context manager is used within the Client and + # Response methods in order to ensure that any raised exceptions + # have a `.request` property set on them. + self._request = request + + +class TransportError(RequestError): + """ + Base class for all exceptions that occur at the level of the Transport API. + """ + + +# Timeout exceptions... + + +class TimeoutException(TransportError): + """ + The base class for timeout errors. + + An operation has timed out. + """ + + +class ConnectTimeout(TimeoutException): + """ + Timed out while connecting to the host. + """ + + +class ReadTimeout(TimeoutException): + """ + Timed out while receiving data from the host. + """ + + +class WriteTimeout(TimeoutException): + """ + Timed out while sending data to the host. + """ + + +class PoolTimeout(TimeoutException): + """ + Timed out waiting to acquire a connection from the pool. + """ + + +# Core networking exceptions... + + +class NetworkError(TransportError): + """ + The base class for network-related errors. + + An error occurred while interacting with the network. + """ + + +class ReadError(NetworkError): + """ + Failed to receive data from the network. + """ + + +class WriteError(NetworkError): + """ + Failed to send data through the network. + """ + + +class ConnectError(NetworkError): + """ + Failed to establish a connection. + """ + + +class CloseError(NetworkError): + """ + Failed to close a connection. + """ + + +# Other transport exceptions... + + +class ProxyError(TransportError): + """ + An error occurred while establishing a proxy connection. + """ + + +class UnsupportedProtocol(TransportError): + """ + Attempted to make a request to an unsupported protocol. + + For example issuing a request to `ftp://www.example.com`. + """ + + +class ProtocolError(TransportError): + """ + The protocol was violated. + """ + + +class LocalProtocolError(ProtocolError): + """ + A protocol was violated by the client. + + For example if the user instantiated a `Request` instance explicitly, + failed to include the mandatory `Host:` header, and then issued it directly + using `client.send()`. + """ + + +class RemoteProtocolError(ProtocolError): + """ + The protocol was violated by the server. + + For example, returning malformed HTTP. + """ + + +# Other request exceptions... + + +class DecodingError(RequestError): + """ + Decoding of the response failed, due to a malformed encoding. + """ + + +class TooManyRedirects(RequestError): + """ + Too many redirects. + """ + + +# Client errors + + +class HTTPStatusError(HTTPError): + """ + The response had an error HTTP status of 4xx or 5xx. + + May be raised when calling `response.raise_for_status()` + """ + + def __init__(self, message: str, *, request: Request, response: Response) -> None: + super().__init__(message) + self.request = request + self.response = response + + +class InvalidURL(Exception): + """ + URL is improperly formed or cannot be parsed. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class CookieConflict(Exception): + """ + Attempted to lookup a cookie by name, but multiple cookies existed. + + Can occur when calling `response.cookies.get(...)`. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +# Stream exceptions... + +# These may occur as the result of a programming error, by accessing +# the request/response stream in an invalid manner. + + +class StreamError(RuntimeError): + """ + The base class for stream exceptions. + + The developer made an error in accessing the request stream in + an invalid way. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class StreamConsumed(StreamError): + """ + Attempted to read or stream content, but the content has already + been streamed. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream some content, but the content has " + "already been streamed. For requests, this could be due to passing " + "a generator as request content, and then receiving a redirect " + "response or a secondary request as part of an authentication flow." + "For responses, this could be due to attempting to stream the response " + "content more than once." + ) + super().__init__(message) + + +class StreamClosed(StreamError): + """ + Attempted to read or stream response content, but the request has been + closed. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream content, but the stream has " "been closed." + ) + super().__init__(message) + + +class ResponseNotRead(StreamError): + """ + Attempted to access streaming response content, without having called `read()`. + """ + + def __init__(self) -> None: + message = ( + "Attempted to access streaming response content," + " without having called `read()`." + ) + super().__init__(message) + + +class RequestNotRead(StreamError): + """ + Attempted to access streaming request content, without having called `read()`. + """ + + def __init__(self) -> None: + message = ( + "Attempted to access streaming request content," + " without having called `read()`." + ) + super().__init__(message) + + +@contextlib.contextmanager +def request_context( + request: Request | None = None, +) -> typing.Iterator[None]: + """ + A context manager that can be used to attach the given request context + to any `RequestError` exceptions that are raised within the block. + """ + try: + yield + except RequestError as exc: + if request is not None: + exc.request = request + raise exc diff --git a/venv/lib/python3.10/site-packages/httpx/_main.py b/venv/lib/python3.10/site-packages/httpx/_main.py new file mode 100644 index 0000000000000000000000000000000000000000..cffa4bb7db0f930f4db56653a061c4d7400ba4e6 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_main.py @@ -0,0 +1,506 @@ +from __future__ import annotations + +import functools +import json +import sys +import typing + +import click +import pygments.lexers +import pygments.util +import rich.console +import rich.markup +import rich.progress +import rich.syntax +import rich.table + +from ._client import Client +from ._exceptions import RequestError +from ._models import Response +from ._status_codes import codes + +if typing.TYPE_CHECKING: + import httpcore # pragma: no cover + + +def print_help() -> None: + console = rich.console.Console() + + console.print("[bold]HTTPX :butterfly:", justify="center") + console.print() + console.print("A next generation HTTP client.", justify="center") + console.print() + console.print( + "Usage: [bold]httpx[/bold] [cyan] [OPTIONS][/cyan] ", justify="left" + ) + console.print() + + table = rich.table.Table.grid(padding=1, pad_edge=True) + table.add_column("Parameter", no_wrap=True, justify="left", style="bold") + table.add_column("Description") + table.add_row( + "-m, --method [cyan]METHOD", + "Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD.\n" + "[Default: GET, or POST if a request body is included]", + ) + table.add_row( + "-p, --params [cyan] ...", + "Query parameters to include in the request URL.", + ) + table.add_row( + "-c, --content [cyan]TEXT", "Byte content to include in the request body." + ) + table.add_row( + "-d, --data [cyan] ...", "Form data to include in the request body." + ) + table.add_row( + "-f, --files [cyan] ...", + "Form files to include in the request body.", + ) + table.add_row("-j, --json [cyan]TEXT", "JSON data to include in the request body.") + table.add_row( + "-h, --headers [cyan] ...", + "Include additional HTTP headers in the request.", + ) + table.add_row( + "--cookies [cyan] ...", "Cookies to include in the request." + ) + table.add_row( + "--auth [cyan]", + "Username and password to include in the request. Specify '-' for the password" + " to use a password prompt. Note that using --verbose/-v will expose" + " the Authorization header, including the password encoding" + " in a trivially reversible format.", + ) + + table.add_row( + "--proxy [cyan]URL", + "Send the request via a proxy. Should be the URL giving the proxy address.", + ) + + table.add_row( + "--timeout [cyan]FLOAT", + "Timeout value to use for network operations, such as establishing the" + " connection, reading some data, etc... [Default: 5.0]", + ) + + table.add_row("--follow-redirects", "Automatically follow redirects.") + table.add_row("--no-verify", "Disable SSL verification.") + table.add_row( + "--http2", "Send the request using HTTP/2, if the remote server supports it." + ) + + table.add_row( + "--download [cyan]FILE", + "Save the response content as a file, rather than displaying it.", + ) + + table.add_row("-v, --verbose", "Verbose output. Show request as well as response.") + table.add_row("--help", "Show this message and exit.") + console.print(table) + + +def get_lexer_for_response(response: Response) -> str: + content_type = response.headers.get("Content-Type") + if content_type is not None: + mime_type, _, _ = content_type.partition(";") + try: + return typing.cast( + str, pygments.lexers.get_lexer_for_mimetype(mime_type.strip()).name + ) + except pygments.util.ClassNotFound: # pragma: no cover + pass + return "" # pragma: no cover + + +def format_request_headers(request: httpcore.Request, http2: bool = False) -> str: + version = "HTTP/2" if http2 else "HTTP/1.1" + headers = [ + (name.lower() if http2 else name, value) for name, value in request.headers + ] + method = request.method.decode("ascii") + target = request.url.target.decode("ascii") + lines = [f"{method} {target} {version}"] + [ + f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers + ] + return "\n".join(lines) + + +def format_response_headers( + http_version: bytes, + status: int, + reason_phrase: bytes | None, + headers: list[tuple[bytes, bytes]], +) -> str: + version = http_version.decode("ascii") + reason = ( + codes.get_reason_phrase(status) + if reason_phrase is None + else reason_phrase.decode("ascii") + ) + lines = [f"{version} {status} {reason}"] + [ + f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers + ] + return "\n".join(lines) + + +def print_request_headers(request: httpcore.Request, http2: bool = False) -> None: + console = rich.console.Console() + http_text = format_request_headers(request, http2=http2) + syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + + +def print_response_headers( + http_version: bytes, + status: int, + reason_phrase: bytes | None, + headers: list[tuple[bytes, bytes]], +) -> None: + console = rich.console.Console() + http_text = format_response_headers(http_version, status, reason_phrase, headers) + syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + + +def print_response(response: Response) -> None: + console = rich.console.Console() + lexer_name = get_lexer_for_response(response) + if lexer_name: + if lexer_name.lower() == "json": + try: + data = response.json() + text = json.dumps(data, indent=4) + except ValueError: # pragma: no cover + text = response.text + else: + text = response.text + + syntax = rich.syntax.Syntax(text, lexer_name, theme="ansi_dark", word_wrap=True) + console.print(syntax) + else: + console.print(f"<{len(response.content)} bytes of binary data>") + + +_PCTRTT = typing.Tuple[typing.Tuple[str, str], ...] +_PCTRTTT = typing.Tuple[_PCTRTT, ...] +_PeerCertRetDictType = typing.Dict[str, typing.Union[str, _PCTRTTT, _PCTRTT]] + + +def format_certificate(cert: _PeerCertRetDictType) -> str: # pragma: no cover + lines = [] + for key, value in cert.items(): + if isinstance(value, (list, tuple)): + lines.append(f"* {key}:") + for item in value: + if key in ("subject", "issuer"): + for sub_item in item: + lines.append(f"* {sub_item[0]}: {sub_item[1]!r}") + elif isinstance(item, tuple) and len(item) == 2: + lines.append(f"* {item[0]}: {item[1]!r}") + else: + lines.append(f"* {item!r}") + else: + lines.append(f"* {key}: {value!r}") + return "\n".join(lines) + + +def trace( + name: str, info: typing.Mapping[str, typing.Any], verbose: bool = False +) -> None: + console = rich.console.Console() + if name == "connection.connect_tcp.started" and verbose: + host = info["host"] + console.print(f"* Connecting to {host!r}") + elif name == "connection.connect_tcp.complete" and verbose: + stream = info["return_value"] + server_addr = stream.get_extra_info("server_addr") + console.print(f"* Connected to {server_addr[0]!r} on port {server_addr[1]}") + elif name == "connection.start_tls.complete" and verbose: # pragma: no cover + stream = info["return_value"] + ssl_object = stream.get_extra_info("ssl_object") + version = ssl_object.version() + cipher = ssl_object.cipher() + server_cert = ssl_object.getpeercert() + alpn = ssl_object.selected_alpn_protocol() + console.print(f"* SSL established using {version!r} / {cipher[0]!r}") + console.print(f"* Selected ALPN protocol: {alpn!r}") + if server_cert: + console.print("* Server certificate:") + console.print(format_certificate(server_cert)) + elif name == "http11.send_request_headers.started" and verbose: + request = info["request"] + print_request_headers(request, http2=False) + elif name == "http2.send_request_headers.started" and verbose: # pragma: no cover + request = info["request"] + print_request_headers(request, http2=True) + elif name == "http11.receive_response_headers.complete": + http_version, status, reason_phrase, headers = info["return_value"] + print_response_headers(http_version, status, reason_phrase, headers) + elif name == "http2.receive_response_headers.complete": # pragma: no cover + status, headers = info["return_value"] + http_version = b"HTTP/2" + reason_phrase = None + print_response_headers(http_version, status, reason_phrase, headers) + + +def download_response(response: Response, download: typing.BinaryIO) -> None: + console = rich.console.Console() + console.print() + content_length = response.headers.get("Content-Length") + with rich.progress.Progress( + "[progress.description]{task.description}", + "[progress.percentage]{task.percentage:>3.0f}%", + rich.progress.BarColumn(bar_width=None), + rich.progress.DownloadColumn(), + rich.progress.TransferSpeedColumn(), + ) as progress: + description = f"Downloading [bold]{rich.markup.escape(download.name)}" + download_task = progress.add_task( + description, + total=int(content_length or 0), + start=content_length is not None, + ) + for chunk in response.iter_bytes(): + download.write(chunk) + progress.update(download_task, completed=response.num_bytes_downloaded) + + +def validate_json( + ctx: click.Context, + param: click.Option | click.Parameter, + value: typing.Any, +) -> typing.Any: + if value is None: + return None + + try: + return json.loads(value) + except json.JSONDecodeError: # pragma: no cover + raise click.BadParameter("Not valid JSON") + + +def validate_auth( + ctx: click.Context, + param: click.Option | click.Parameter, + value: typing.Any, +) -> typing.Any: + if value == (None, None): + return None + + username, password = value + if password == "-": # pragma: no cover + password = click.prompt("Password", hide_input=True) + return (username, password) + + +def handle_help( + ctx: click.Context, + param: click.Option | click.Parameter, + value: typing.Any, +) -> None: + if not value or ctx.resilient_parsing: + return + + print_help() + ctx.exit() + + +@click.command(add_help_option=False) +@click.argument("url", type=str) +@click.option( + "--method", + "-m", + "method", + type=str, + help=( + "Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD. " + "[Default: GET, or POST if a request body is included]" + ), +) +@click.option( + "--params", + "-p", + "params", + type=(str, str), + multiple=True, + help="Query parameters to include in the request URL.", +) +@click.option( + "--content", + "-c", + "content", + type=str, + help="Byte content to include in the request body.", +) +@click.option( + "--data", + "-d", + "data", + type=(str, str), + multiple=True, + help="Form data to include in the request body.", +) +@click.option( + "--files", + "-f", + "files", + type=(str, click.File(mode="rb")), + multiple=True, + help="Form files to include in the request body.", +) +@click.option( + "--json", + "-j", + "json", + type=str, + callback=validate_json, + help="JSON data to include in the request body.", +) +@click.option( + "--headers", + "-h", + "headers", + type=(str, str), + multiple=True, + help="Include additional HTTP headers in the request.", +) +@click.option( + "--cookies", + "cookies", + type=(str, str), + multiple=True, + help="Cookies to include in the request.", +) +@click.option( + "--auth", + "auth", + type=(str, str), + default=(None, None), + callback=validate_auth, + help=( + "Username and password to include in the request. " + "Specify '-' for the password to use a password prompt. " + "Note that using --verbose/-v will expose the Authorization header, " + "including the password encoding in a trivially reversible format." + ), +) +@click.option( + "--proxy", + "proxy", + type=str, + default=None, + help="Send the request via a proxy. Should be the URL giving the proxy address.", +) +@click.option( + "--timeout", + "timeout", + type=float, + default=5.0, + help=( + "Timeout value to use for network operations, such as establishing the " + "connection, reading some data, etc... [Default: 5.0]" + ), +) +@click.option( + "--follow-redirects", + "follow_redirects", + is_flag=True, + default=False, + help="Automatically follow redirects.", +) +@click.option( + "--no-verify", + "verify", + is_flag=True, + default=True, + help="Disable SSL verification.", +) +@click.option( + "--http2", + "http2", + type=bool, + is_flag=True, + default=False, + help="Send the request using HTTP/2, if the remote server supports it.", +) +@click.option( + "--download", + type=click.File("wb"), + help="Save the response content as a file, rather than displaying it.", +) +@click.option( + "--verbose", + "-v", + type=bool, + is_flag=True, + default=False, + help="Verbose. Show request as well as response.", +) +@click.option( + "--help", + is_flag=True, + is_eager=True, + expose_value=False, + callback=handle_help, + help="Show this message and exit.", +) +def main( + url: str, + method: str, + params: list[tuple[str, str]], + content: str, + data: list[tuple[str, str]], + files: list[tuple[str, click.File]], + json: str, + headers: list[tuple[str, str]], + cookies: list[tuple[str, str]], + auth: tuple[str, str] | None, + proxy: str, + timeout: float, + follow_redirects: bool, + verify: bool, + http2: bool, + download: typing.BinaryIO | None, + verbose: bool, +) -> None: + """ + An HTTP command line client. + Sends a request and displays the response. + """ + if not method: + method = "POST" if content or data or files or json else "GET" + + try: + with Client(proxy=proxy, timeout=timeout, http2=http2, verify=verify) as client: + with client.stream( + method, + url, + params=list(params), + content=content, + data=dict(data), + files=files, # type: ignore + json=json, + headers=headers, + cookies=dict(cookies), + auth=auth, + follow_redirects=follow_redirects, + extensions={"trace": functools.partial(trace, verbose=verbose)}, + ) as response: + if download is not None: + download_response(response, download) + else: + response.read() + if response.content: + print_response(response) + + except RequestError as exc: + console = rich.console.Console() + console.print(f"[red]{type(exc).__name__}[/red]: {exc}") + sys.exit(1) + + sys.exit(0 if response.is_success else 1) diff --git a/venv/lib/python3.10/site-packages/httpx/_models.py b/venv/lib/python3.10/site-packages/httpx/_models.py new file mode 100644 index 0000000000000000000000000000000000000000..67d74bf86bfc80e22d9a4a3153572845accd9039 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_models.py @@ -0,0 +1,1277 @@ +from __future__ import annotations + +import codecs +import datetime +import email.message +import json as jsonlib +import re +import typing +import urllib.request +from collections.abc import Mapping +from http.cookiejar import Cookie, CookieJar + +from ._content import ByteStream, UnattachedStream, encode_request, encode_response +from ._decoders import ( + SUPPORTED_DECODERS, + ByteChunker, + ContentDecoder, + IdentityDecoder, + LineDecoder, + MultiDecoder, + TextChunker, + TextDecoder, +) +from ._exceptions import ( + CookieConflict, + HTTPStatusError, + RequestNotRead, + ResponseNotRead, + StreamClosed, + StreamConsumed, + request_context, +) +from ._multipart import get_multipart_boundary_from_content_type +from ._status_codes import codes +from ._types import ( + AsyncByteStream, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestExtensions, + RequestFiles, + ResponseContent, + ResponseExtensions, + SyncByteStream, +) +from ._urls import URL +from ._utils import to_bytes_or_str, to_str + +__all__ = ["Cookies", "Headers", "Request", "Response"] + +SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} + + +def _is_known_encoding(encoding: str) -> bool: + """ + Return `True` if `encoding` is a known codec. + """ + try: + codecs.lookup(encoding) + except LookupError: + return False + return True + + +def _normalize_header_key(key: str | bytes, encoding: str | None = None) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header key. + """ + return key if isinstance(key, bytes) else key.encode(encoding or "ascii") + + +def _normalize_header_value(value: str | bytes, encoding: str | None = None) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header value. + """ + if isinstance(value, bytes): + return value + if not isinstance(value, str): + raise TypeError(f"Header value must be str or bytes, not {type(value)}") + return value.encode(encoding or "ascii") + + +def _parse_content_type_charset(content_type: str) -> str | None: + # We used to use `cgi.parse_header()` here, but `cgi` became a dead battery. + # See: https://peps.python.org/pep-0594/#cgi + msg = email.message.Message() + msg["content-type"] = content_type + return msg.get_content_charset(failobj=None) + + +def _parse_header_links(value: str) -> list[dict[str, str]]: + """ + Returns a list of parsed link headers, for more info see: + https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link + The generic syntax of those is: + Link: < uri-reference >; param1=value1; param2="value2" + So for instance: + Link; '; type="image/jpeg",;' + would return + [ + {"url": "http:/.../front.jpeg", "type": "image/jpeg"}, + {"url": "http://.../back.jpeg"}, + ] + :param value: HTTP Link entity-header field + :return: list of parsed link headers + """ + links: list[dict[str, str]] = [] + replace_chars = " '\"" + value = value.strip(replace_chars) + if not value: + return links + for val in re.split(", *<", value): + try: + url, params = val.split(";", 1) + except ValueError: + url, params = val, "" + link = {"url": url.strip("<> '\"")} + for param in params.split(";"): + try: + key, value = param.split("=") + except ValueError: + break + link[key.strip(replace_chars)] = value.strip(replace_chars) + links.append(link) + return links + + +def _obfuscate_sensitive_headers( + items: typing.Iterable[tuple[typing.AnyStr, typing.AnyStr]], +) -> typing.Iterator[tuple[typing.AnyStr, typing.AnyStr]]: + for k, v in items: + if to_str(k.lower()) in SENSITIVE_HEADERS: + v = to_bytes_or_str("[secure]", match_type_of=v) + yield k, v + + +class Headers(typing.MutableMapping[str, str]): + """ + HTTP headers, as a case-insensitive multi-dict. + """ + + def __init__( + self, + headers: HeaderTypes | None = None, + encoding: str | None = None, + ) -> None: + self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]] + + if isinstance(headers, Headers): + self._list = list(headers._list) + elif isinstance(headers, Mapping): + for k, v in headers.items(): + bytes_key = _normalize_header_key(k, encoding) + bytes_value = _normalize_header_value(v, encoding) + self._list.append((bytes_key, bytes_key.lower(), bytes_value)) + elif headers is not None: + for k, v in headers: + bytes_key = _normalize_header_key(k, encoding) + bytes_value = _normalize_header_value(v, encoding) + self._list.append((bytes_key, bytes_key.lower(), bytes_value)) + + self._encoding = encoding + + @property + def encoding(self) -> str: + """ + Header encoding is mandated as ascii, but we allow fallbacks to utf-8 + or iso-8859-1. + """ + if self._encoding is None: + for encoding in ["ascii", "utf-8"]: + for key, value in self.raw: + try: + key.decode(encoding) + value.decode(encoding) + except UnicodeDecodeError: + break + else: + # The else block runs if 'break' did not occur, meaning + # all values fitted the encoding. + self._encoding = encoding + break + else: + # The ISO-8859-1 encoding covers all 256 code points in a byte, + # so will never raise decode errors. + self._encoding = "iso-8859-1" + return self._encoding + + @encoding.setter + def encoding(self, value: str) -> None: + self._encoding = value + + @property + def raw(self) -> list[tuple[bytes, bytes]]: + """ + Returns a list of the raw header items, as byte pairs. + """ + return [(raw_key, value) for raw_key, _, value in self._list] + + def keys(self) -> typing.KeysView[str]: + return {key.decode(self.encoding): None for _, key, value in self._list}.keys() + + def values(self) -> typing.ValuesView[str]: + values_dict: dict[str, str] = {} + for _, key, value in self._list: + str_key = key.decode(self.encoding) + str_value = value.decode(self.encoding) + if str_key in values_dict: + values_dict[str_key] += f", {str_value}" + else: + values_dict[str_key] = str_value + return values_dict.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return `(key, value)` items of headers. Concatenate headers + into a single comma separated value when a key occurs multiple times. + """ + values_dict: dict[str, str] = {} + for _, key, value in self._list: + str_key = key.decode(self.encoding) + str_value = value.decode(self.encoding) + if str_key in values_dict: + values_dict[str_key] += f", {str_value}" + else: + values_dict[str_key] = str_value + return values_dict.items() + + def multi_items(self) -> list[tuple[str, str]]: + """ + Return a list of `(key, value)` pairs of headers. Allow multiple + occurrences of the same key without concatenating into a single + comma separated value. + """ + return [ + (key.decode(self.encoding), value.decode(self.encoding)) + for _, key, value in self._list + ] + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Return a header value. If multiple occurrences of the header occur + then concatenate them together with commas. + """ + try: + return self[key] + except KeyError: + return default + + def get_list(self, key: str, split_commas: bool = False) -> list[str]: + """ + Return a list of all header values for a given key. + If `split_commas=True` is passed, then any comma separated header + values are split into multiple return strings. + """ + get_header_key = key.lower().encode(self.encoding) + + values = [ + item_value.decode(self.encoding) + for _, item_key, item_value in self._list + if item_key.lower() == get_header_key + ] + + if not split_commas: + return values + + split_values = [] + for value in values: + split_values.extend([item.strip() for item in value.split(",")]) + return split_values + + def update(self, headers: HeaderTypes | None = None) -> None: # type: ignore + headers = Headers(headers) + for key in headers.keys(): + if key in self: + self.pop(key) + self._list.extend(headers._list) + + def copy(self) -> Headers: + return Headers(self, encoding=self.encoding) + + def __getitem__(self, key: str) -> str: + """ + Return a single header value. + + If there are multiple headers with the same key, then we concatenate + them with commas. See: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + normalized_key = key.lower().encode(self.encoding) + + items = [ + header_value.decode(self.encoding) + for _, header_key, header_value in self._list + if header_key == normalized_key + ] + + if items: + return ", ".join(items) + + raise KeyError(key) + + def __setitem__(self, key: str, value: str) -> None: + """ + Set the header `key` to `value`, removing any duplicate entries. + Retains insertion order. + """ + set_key = key.encode(self._encoding or "utf-8") + set_value = value.encode(self._encoding or "utf-8") + lookup_key = set_key.lower() + + found_indexes = [ + idx + for idx, (_, item_key, _) in enumerate(self._list) + if item_key == lookup_key + ] + + for idx in reversed(found_indexes[1:]): + del self._list[idx] + + if found_indexes: + idx = found_indexes[0] + self._list[idx] = (set_key, lookup_key, set_value) + else: + self._list.append((set_key, lookup_key, set_value)) + + def __delitem__(self, key: str) -> None: + """ + Remove the header `key`. + """ + del_key = key.lower().encode(self.encoding) + + pop_indexes = [ + idx + for idx, (_, item_key, _) in enumerate(self._list) + if item_key.lower() == del_key + ] + + if not pop_indexes: + raise KeyError(key) + + for idx in reversed(pop_indexes): + del self._list[idx] + + def __contains__(self, key: typing.Any) -> bool: + header_key = key.lower().encode(self.encoding) + return header_key in [key for _, key, _ in self._list] + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, other: typing.Any) -> bool: + try: + other_headers = Headers(other) + except ValueError: + return False + + self_list = [(key, value) for _, key, value in self._list] + other_list = [(key, value) for _, key, value in other_headers._list] + return sorted(self_list) == sorted(other_list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + + encoding_str = "" + if self.encoding != "ascii": + encoding_str = f", encoding={self.encoding!r}" + + as_list = list(_obfuscate_sensitive_headers(self.multi_items())) + as_dict = dict(as_list) + + no_duplicate_keys = len(as_dict) == len(as_list) + if no_duplicate_keys: + return f"{class_name}({as_dict!r}{encoding_str})" + return f"{class_name}({as_list!r}{encoding_str})" + + +class Request: + def __init__( + self, + method: str, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: typing.Any | None = None, + stream: SyncByteStream | AsyncByteStream | None = None, + extensions: RequestExtensions | None = None, + ) -> None: + self.method = method.upper() + self.url = URL(url) if params is None else URL(url, params=params) + self.headers = Headers(headers) + self.extensions = {} if extensions is None else dict(extensions) + + if cookies: + Cookies(cookies).set_cookie_header(self) + + if stream is None: + content_type: str | None = self.headers.get("content-type") + headers, stream = encode_request( + content=content, + data=data, + files=files, + json=json, + boundary=get_multipart_boundary_from_content_type( + content_type=content_type.encode(self.headers.encoding) + if content_type + else None + ), + ) + self._prepare(headers) + self.stream = stream + # Load the request body, except for streaming content. + if isinstance(stream, ByteStream): + self.read() + else: + # There's an important distinction between `Request(content=...)`, + # and `Request(stream=...)`. + # + # Using `content=...` implies automatically populated `Host` and content + # headers, of either `Content-Length: ...` or `Transfer-Encoding: chunked`. + # + # Using `stream=...` will not automatically include *any* + # auto-populated headers. + # + # As an end-user you don't really need `stream=...`. It's only + # useful when: + # + # * Preserving the request stream when copying requests, eg for redirects. + # * Creating request instances on the *server-side* of the transport API. + self.stream = stream + + def _prepare(self, default_headers: dict[str, str]) -> None: + for key, value in default_headers.items(): + # Ignore Transfer-Encoding if the Content-Length has been set explicitly. + if key.lower() == "transfer-encoding" and "Content-Length" in self.headers: + continue + self.headers.setdefault(key, value) + + auto_headers: list[tuple[bytes, bytes]] = [] + + has_host = "Host" in self.headers + has_content_length = ( + "Content-Length" in self.headers or "Transfer-Encoding" in self.headers + ) + + if not has_host and self.url.host: + auto_headers.append((b"Host", self.url.netloc)) + if not has_content_length and self.method in ("POST", "PUT", "PATCH"): + auto_headers.append((b"Content-Length", b"0")) + + self.headers = Headers(auto_headers + self.headers.raw) + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + raise RequestNotRead() + return self._content + + def read(self) -> bytes: + """ + Read and return the request content. + """ + if not hasattr(self, "_content"): + assert isinstance(self.stream, typing.Iterable) + self._content = b"".join(self.stream) + if not isinstance(self.stream, ByteStream): + # If a streaming request has been read entirely into memory, then + # we can replace the stream with a raw bytes implementation, + # to ensure that any non-replayable streams can still be used. + self.stream = ByteStream(self._content) + return self._content + + async def aread(self) -> bytes: + """ + Read and return the request content. + """ + if not hasattr(self, "_content"): + assert isinstance(self.stream, typing.AsyncIterable) + self._content = b"".join([part async for part in self.stream]) + if not isinstance(self.stream, ByteStream): + # If a streaming request has been read entirely into memory, then + # we can replace the stream with a raw bytes implementation, + # to ensure that any non-replayable streams can still be used. + self.stream = ByteStream(self._content) + return self._content + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url = str(self.url) + return f"<{class_name}({self.method!r}, {url!r})>" + + def __getstate__(self) -> dict[str, typing.Any]: + return { + name: value + for name, value in self.__dict__.items() + if name not in ["extensions", "stream"] + } + + def __setstate__(self, state: dict[str, typing.Any]) -> None: + for name, value in state.items(): + setattr(self, name, value) + self.extensions = {} + self.stream = UnattachedStream() + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: HeaderTypes | None = None, + content: ResponseContent | None = None, + text: str | None = None, + html: str | None = None, + json: typing.Any = None, + stream: SyncByteStream | AsyncByteStream | None = None, + request: Request | None = None, + extensions: ResponseExtensions | None = None, + history: list[Response] | None = None, + default_encoding: str | typing.Callable[[bytes], str] = "utf-8", + ) -> None: + self.status_code = status_code + self.headers = Headers(headers) + + self._request: Request | None = request + + # When follow_redirects=False and a redirect is received, + # the client will set `response.next_request`. + self.next_request: Request | None = None + + self.extensions = {} if extensions is None else dict(extensions) + self.history = [] if history is None else list(history) + + self.is_closed = False + self.is_stream_consumed = False + + self.default_encoding = default_encoding + + if stream is None: + headers, stream = encode_response(content, text, html, json) + self._prepare(headers) + self.stream = stream + if isinstance(stream, ByteStream): + # Load the response body, except for streaming content. + self.read() + else: + # There's an important distinction between `Response(content=...)`, + # and `Response(stream=...)`. + # + # Using `content=...` implies automatically populated content headers, + # of either `Content-Length: ...` or `Transfer-Encoding: chunked`. + # + # Using `stream=...` will not automatically include any content headers. + # + # As an end-user you don't really need `stream=...`. It's only + # useful when creating response instances having received a stream + # from the transport API. + self.stream = stream + + self._num_bytes_downloaded = 0 + + def _prepare(self, default_headers: dict[str, str]) -> None: + for key, value in default_headers.items(): + # Ignore Transfer-Encoding if the Content-Length has been set explicitly. + if key.lower() == "transfer-encoding" and "content-length" in self.headers: + continue + self.headers.setdefault(key, value) + + @property + def elapsed(self) -> datetime.timedelta: + """ + Returns the time taken for the complete request/response + cycle to complete. + """ + if not hasattr(self, "_elapsed"): + raise RuntimeError( + "'.elapsed' may only be accessed after the response " + "has been read or closed." + ) + return self._elapsed + + @elapsed.setter + def elapsed(self, elapsed: datetime.timedelta) -> None: + self._elapsed = elapsed + + @property + def request(self) -> Request: + """ + Returns the request instance associated to the current response. + """ + if self._request is None: + raise RuntimeError( + "The request instance has not been set on this response." + ) + return self._request + + @request.setter + def request(self, value: Request) -> None: + self._request = value + + @property + def http_version(self) -> str: + try: + http_version: bytes = self.extensions["http_version"] + except KeyError: + return "HTTP/1.1" + else: + return http_version.decode("ascii", errors="ignore") + + @property + def reason_phrase(self) -> str: + try: + reason_phrase: bytes = self.extensions["reason_phrase"] + except KeyError: + return codes.get_reason_phrase(self.status_code) + else: + return reason_phrase.decode("ascii", errors="ignore") + + @property + def url(self) -> URL: + """ + Returns the URL for which the request was made. + """ + return self.request.url + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + raise ResponseNotRead() + return self._content + + @property + def text(self) -> str: + if not hasattr(self, "_text"): + content = self.content + if not content: + self._text = "" + else: + decoder = TextDecoder(encoding=self.encoding or "utf-8") + self._text = "".join([decoder.decode(self.content), decoder.flush()]) + return self._text + + @property + def encoding(self) -> str | None: + """ + Return an encoding to use for decoding the byte content into text. + The priority for determining this is given by... + + * `.encoding = <>` has been set explicitly. + * The encoding as specified by the charset parameter in the Content-Type header. + * The encoding as determined by `default_encoding`, which may either be + a string like "utf-8" indicating the encoding to use, or may be a callable + which enables charset autodetection. + """ + if not hasattr(self, "_encoding"): + encoding = self.charset_encoding + if encoding is None or not _is_known_encoding(encoding): + if isinstance(self.default_encoding, str): + encoding = self.default_encoding + elif hasattr(self, "_content"): + encoding = self.default_encoding(self._content) + self._encoding = encoding or "utf-8" + return self._encoding + + @encoding.setter + def encoding(self, value: str) -> None: + """ + Set the encoding to use for decoding the byte content into text. + + If the `text` attribute has been accessed, attempting to set the + encoding will throw a ValueError. + """ + if hasattr(self, "_text"): + raise ValueError( + "Setting encoding after `text` has been accessed is not allowed." + ) + self._encoding = value + + @property + def charset_encoding(self) -> str | None: + """ + Return the encoding, as specified by the Content-Type header. + """ + content_type = self.headers.get("Content-Type") + if content_type is None: + return None + + return _parse_content_type_charset(content_type) + + def _get_content_decoder(self) -> ContentDecoder: + """ + Returns a decoder instance which can be used to decode the raw byte + content, depending on the Content-Encoding used in the response. + """ + if not hasattr(self, "_decoder"): + decoders: list[ContentDecoder] = [] + values = self.headers.get_list("content-encoding", split_commas=True) + for value in values: + value = value.strip().lower() + try: + decoder_cls = SUPPORTED_DECODERS[value] + decoders.append(decoder_cls()) + except KeyError: + continue + + if len(decoders) == 1: + self._decoder = decoders[0] + elif len(decoders) > 1: + self._decoder = MultiDecoder(children=decoders) + else: + self._decoder = IdentityDecoder() + + return self._decoder + + @property + def is_informational(self) -> bool: + """ + A property which is `True` for 1xx status codes, `False` otherwise. + """ + return codes.is_informational(self.status_code) + + @property + def is_success(self) -> bool: + """ + A property which is `True` for 2xx status codes, `False` otherwise. + """ + return codes.is_success(self.status_code) + + @property + def is_redirect(self) -> bool: + """ + A property which is `True` for 3xx status codes, `False` otherwise. + + Note that not all responses with a 3xx status code indicate a URL redirect. + + Use `response.has_redirect_location` to determine responses with a properly + formed URL redirection. + """ + return codes.is_redirect(self.status_code) + + @property + def is_client_error(self) -> bool: + """ + A property which is `True` for 4xx status codes, `False` otherwise. + """ + return codes.is_client_error(self.status_code) + + @property + def is_server_error(self) -> bool: + """ + A property which is `True` for 5xx status codes, `False` otherwise. + """ + return codes.is_server_error(self.status_code) + + @property + def is_error(self) -> bool: + """ + A property which is `True` for 4xx and 5xx status codes, `False` otherwise. + """ + return codes.is_error(self.status_code) + + @property + def has_redirect_location(self) -> bool: + """ + Returns True for 3xx responses with a properly formed URL redirection, + `False` otherwise. + """ + return ( + self.status_code + in ( + # 301 (Cacheable redirect. Method may change to GET.) + codes.MOVED_PERMANENTLY, + # 302 (Uncacheable redirect. Method may change to GET.) + codes.FOUND, + # 303 (Client should make a GET or HEAD request.) + codes.SEE_OTHER, + # 307 (Equiv. 302, but retain method) + codes.TEMPORARY_REDIRECT, + # 308 (Equiv. 301, but retain method) + codes.PERMANENT_REDIRECT, + ) + and "Location" in self.headers + ) + + def raise_for_status(self) -> Response: + """ + Raise the `HTTPStatusError` if one occurred. + """ + request = self._request + if request is None: + raise RuntimeError( + "Cannot call `raise_for_status` as the request " + "instance has not been set on this response." + ) + + if self.is_success: + return self + + if self.has_redirect_location: + message = ( + "{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n" + "Redirect location: '{0.headers[location]}'\n" + "For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}" + ) + else: + message = ( + "{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n" + "For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}" + ) + + status_class = self.status_code // 100 + error_types = { + 1: "Informational response", + 3: "Redirect response", + 4: "Client error", + 5: "Server error", + } + error_type = error_types.get(status_class, "Invalid status code") + message = message.format(self, error_type=error_type) + raise HTTPStatusError(message, request=request, response=self) + + def json(self, **kwargs: typing.Any) -> typing.Any: + return jsonlib.loads(self.content, **kwargs) + + @property + def cookies(self) -> Cookies: + if not hasattr(self, "_cookies"): + self._cookies = Cookies() + self._cookies.extract_cookies(self) + return self._cookies + + @property + def links(self) -> dict[str | None, dict[str, str]]: + """ + Returns the parsed header links of the response, if any + """ + header = self.headers.get("link") + if header is None: + return {} + + return { + (link.get("rel") or link.get("url")): link + for link in _parse_header_links(header) + } + + @property + def num_bytes_downloaded(self) -> int: + return self._num_bytes_downloaded + + def __repr__(self) -> str: + return f"" + + def __getstate__(self) -> dict[str, typing.Any]: + return { + name: value + for name, value in self.__dict__.items() + if name not in ["extensions", "stream", "is_closed", "_decoder"] + } + + def __setstate__(self, state: dict[str, typing.Any]) -> None: + for name, value in state.items(): + setattr(self, name, value) + self.is_closed = True + self.extensions = {} + self.stream = UnattachedStream() + + def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join(self.iter_bytes()) + return self._content + + def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, brotli, and zstd encoded responses. + """ + if hasattr(self, "_content"): + chunk_size = len(self._content) if chunk_size is None else chunk_size + for i in range(0, len(self._content), max(chunk_size, 1)): + yield self._content[i : i + chunk_size] + else: + decoder = self._get_content_decoder() + chunker = ByteChunker(chunk_size=chunk_size) + with request_context(request=self._request): + for raw_bytes in self.iter_raw(): + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() + for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + def iter_text(self, chunk_size: int | None = None) -> typing.Iterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + decoder = TextDecoder(encoding=self.encoding or "utf-8") + chunker = TextChunker(chunk_size=chunk_size) + with request_context(request=self._request): + for byte_content in self.iter_bytes(): + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() + for chunk in chunker.decode(text_content): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + def iter_lines(self) -> typing.Iterator[str]: + decoder = LineDecoder() + with request_context(request=self._request): + for text in self.iter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): + yield line + + def iter_raw(self, chunk_size: int | None = None) -> typing.Iterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise StreamClosed() + if not isinstance(self.stream, SyncByteStream): + raise RuntimeError("Attempted to call a sync iterator on an async stream.") + + self.is_stream_consumed = True + self._num_bytes_downloaded = 0 + chunker = ByteChunker(chunk_size=chunk_size) + + with request_context(request=self._request): + for raw_stream_bytes in self.stream: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk + + for chunk in chunker.flush(): + yield chunk + + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not isinstance(self.stream, SyncByteStream): + raise RuntimeError("Attempted to call an sync close on an async stream.") + + if not self.is_closed: + self.is_closed = True + with request_context(request=self._request): + self.stream.close() + + async def aread(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join([part async for part in self.aiter_bytes()]) + return self._content + + async def aiter_bytes( + self, chunk_size: int | None = None + ) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, brotli, and zstd encoded responses. + """ + if hasattr(self, "_content"): + chunk_size = len(self._content) if chunk_size is None else chunk_size + for i in range(0, len(self._content), max(chunk_size, 1)): + yield self._content[i : i + chunk_size] + else: + decoder = self._get_content_decoder() + chunker = ByteChunker(chunk_size=chunk_size) + with request_context(request=self._request): + async for raw_bytes in self.aiter_raw(): + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() + for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + async def aiter_text( + self, chunk_size: int | None = None + ) -> typing.AsyncIterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + decoder = TextDecoder(encoding=self.encoding or "utf-8") + chunker = TextChunker(chunk_size=chunk_size) + with request_context(request=self._request): + async for byte_content in self.aiter_bytes(): + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() + for chunk in chunker.decode(text_content): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + async def aiter_lines(self) -> typing.AsyncIterator[str]: + decoder = LineDecoder() + with request_context(request=self._request): + async for text in self.aiter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): + yield line + + async def aiter_raw( + self, chunk_size: int | None = None + ) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise StreamClosed() + if not isinstance(self.stream, AsyncByteStream): + raise RuntimeError("Attempted to call an async iterator on an sync stream.") + + self.is_stream_consumed = True + self._num_bytes_downloaded = 0 + chunker = ByteChunker(chunk_size=chunk_size) + + with request_context(request=self._request): + async for raw_stream_bytes in self.stream: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk + + for chunk in chunker.flush(): + yield chunk + + await self.aclose() + + async def aclose(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not isinstance(self.stream, AsyncByteStream): + raise RuntimeError("Attempted to call an async close on an sync stream.") + + if not self.is_closed: + self.is_closed = True + with request_context(request=self._request): + await self.stream.aclose() + + +class Cookies(typing.MutableMapping[str, str]): + """ + HTTP Cookies, as a mutable mapping. + """ + + def __init__(self, cookies: CookieTypes | None = None) -> None: + if cookies is None or isinstance(cookies, dict): + self.jar = CookieJar() + if isinstance(cookies, dict): + for key, value in cookies.items(): + self.set(key, value) + elif isinstance(cookies, list): + self.jar = CookieJar() + for key, value in cookies: + self.set(key, value) + elif isinstance(cookies, Cookies): + self.jar = CookieJar() + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + else: + self.jar = cookies + + def extract_cookies(self, response: Response) -> None: + """ + Loads any cookies based on the response `Set-Cookie` headers. + """ + urllib_response = self._CookieCompatResponse(response) + urllib_request = self._CookieCompatRequest(response.request) + + self.jar.extract_cookies(urllib_response, urllib_request) # type: ignore + + def set_cookie_header(self, request: Request) -> None: + """ + Sets an appropriate 'Cookie:' HTTP header on the `Request`. + """ + urllib_request = self._CookieCompatRequest(request) + self.jar.add_cookie_header(urllib_request) + + def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None: + """ + Set a cookie value by name. May optionally include domain and path. + """ + kwargs = { + "version": 0, + "name": name, + "value": value, + "port": None, + "port_specified": False, + "domain": domain, + "domain_specified": bool(domain), + "domain_initial_dot": domain.startswith("."), + "path": path, + "path_specified": bool(path), + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": {"HttpOnly": None}, + "rfc2109": False, + } + cookie = Cookie(**kwargs) # type: ignore + self.jar.set_cookie(cookie) + + def get( # type: ignore + self, + name: str, + default: str | None = None, + domain: str | None = None, + path: str | None = None, + ) -> str | None: + """ + Get a cookie by name. May optionally include domain and path + in order to specify exactly which cookie to retrieve. + """ + value = None + for cookie in self.jar: + if cookie.name == name: + if domain is None or cookie.domain == domain: + if path is None or cookie.path == path: + if value is not None: + message = f"Multiple cookies exist with name={name}" + raise CookieConflict(message) + value = cookie.value + + if value is None: + return default + return value + + def delete( + self, + name: str, + domain: str | None = None, + path: str | None = None, + ) -> None: + """ + Delete a cookie by name. May optionally include domain and path + in order to specify exactly which cookie to delete. + """ + if domain is not None and path is not None: + return self.jar.clear(domain, path, name) + + remove = [ + cookie + for cookie in self.jar + if cookie.name == name + and (domain is None or cookie.domain == domain) + and (path is None or cookie.path == path) + ] + + for cookie in remove: + self.jar.clear(cookie.domain, cookie.path, cookie.name) + + def clear(self, domain: str | None = None, path: str | None = None) -> None: + """ + Delete all cookies. Optionally include a domain and path in + order to only delete a subset of all the cookies. + """ + args = [] + if domain is not None: + args.append(domain) + if path is not None: + assert domain is not None + args.append(path) + self.jar.clear(*args) + + def update(self, cookies: CookieTypes | None = None) -> None: # type: ignore + cookies = Cookies(cookies) + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + + def __setitem__(self, name: str, value: str) -> None: + return self.set(name, value) + + def __getitem__(self, name: str) -> str: + value = self.get(name) + if value is None: + raise KeyError(name) + return value + + def __delitem__(self, name: str) -> None: + return self.delete(name) + + def __len__(self) -> int: + return len(self.jar) + + def __iter__(self) -> typing.Iterator[str]: + return (cookie.name for cookie in self.jar) + + def __bool__(self) -> bool: + for _ in self.jar: + return True + return False + + def __repr__(self) -> str: + cookies_repr = ", ".join( + [ + f"" + for cookie in self.jar + ] + ) + + return f"" + + class _CookieCompatRequest(urllib.request.Request): + """ + Wraps a `Request` instance up in a compatibility interface suitable + for use with `CookieJar` operations. + """ + + def __init__(self, request: Request) -> None: + super().__init__( + url=str(request.url), + headers=dict(request.headers), + method=request.method, + ) + self.request = request + + def add_unredirected_header(self, key: str, value: str) -> None: + super().add_unredirected_header(key, value) + self.request.headers[key] = value + + class _CookieCompatResponse: + """ + Wraps a `Request` instance up in a compatibility interface suitable + for use with `CookieJar` operations. + """ + + def __init__(self, response: Response) -> None: + self.response = response + + def info(self) -> email.message.Message: + info = email.message.Message() + for key, value in self.response.headers.multi_items(): + # Note that setting `info[key]` here is an "append" operation, + # not a "replace" operation. + # https://docs.python.org/3/library/email.compat32-message.html#email.message.Message.__setitem__ + info[key] = value + return info diff --git a/venv/lib/python3.10/site-packages/httpx/_multipart.py b/venv/lib/python3.10/site-packages/httpx/_multipart.py new file mode 100644 index 0000000000000000000000000000000000000000..b4761af9b2cf384de5189269927d781a700dbe46 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_multipart.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import io +import mimetypes +import os +import re +import typing +from pathlib import Path + +from ._types import ( + AsyncByteStream, + FileContent, + FileTypes, + RequestData, + RequestFiles, + SyncByteStream, +) +from ._utils import ( + peek_filelike_length, + primitive_value_to_str, + to_bytes, +) + +_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} +_HTML5_FORM_ENCODING_REPLACEMENTS.update( + {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B} +) +_HTML5_FORM_ENCODING_RE = re.compile( + r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()]) +) + + +def _format_form_param(name: str, value: str) -> bytes: + """ + Encode a name/value pair within a multipart form. + """ + + def replacer(match: typing.Match[str]) -> str: + return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)] + + value = _HTML5_FORM_ENCODING_RE.sub(replacer, value) + return f'{name}="{value}"'.encode() + + +def _guess_content_type(filename: str | None) -> str | None: + """ + Guesses the mimetype based on a filename. Defaults to `application/octet-stream`. + + Returns `None` if `filename` is `None` or empty. + """ + if filename: + return mimetypes.guess_type(filename)[0] or "application/octet-stream" + return None + + +def get_multipart_boundary_from_content_type( + content_type: bytes | None, +) -> bytes | None: + if not content_type or not content_type.startswith(b"multipart/form-data"): + return None + # parse boundary according to + # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1 + if b";" in content_type: + for section in content_type.split(b";"): + if section.strip().lower().startswith(b"boundary="): + return section.strip()[len(b"boundary=") :].strip(b'"') + return None + + +class DataField: + """ + A single form field item, within a multipart form field. + """ + + def __init__(self, name: str, value: str | bytes | int | float | None) -> None: + if not isinstance(name, str): + raise TypeError( + f"Invalid type for name. Expected str, got {type(name)}: {name!r}" + ) + if value is not None and not isinstance(value, (str, bytes, int, float)): + raise TypeError( + "Invalid type for value. Expected primitive type," + f" got {type(value)}: {value!r}" + ) + self.name = name + self.value: str | bytes = ( + value if isinstance(value, bytes) else primitive_value_to_str(value) + ) + + def render_headers(self) -> bytes: + if not hasattr(self, "_headers"): + name = _format_form_param("name", self.name) + self._headers = b"".join( + [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"] + ) + + return self._headers + + def render_data(self) -> bytes: + if not hasattr(self, "_data"): + self._data = to_bytes(self.value) + + return self._data + + def get_length(self) -> int: + headers = self.render_headers() + data = self.render_data() + return len(headers) + len(data) + + def render(self) -> typing.Iterator[bytes]: + yield self.render_headers() + yield self.render_data() + + +class FileField: + """ + A single file field item, within a multipart form field. + """ + + CHUNK_SIZE = 64 * 1024 + + def __init__(self, name: str, value: FileTypes) -> None: + self.name = name + + fileobj: FileContent + + headers: dict[str, str] = {} + content_type: str | None = None + + # This large tuple based API largely mirror's requests' API + # It would be good to think of better APIs for this that we could + # include in httpx 2.0 since variable length tuples(especially of 4 elements) + # are quite unwieldly + if isinstance(value, tuple): + if len(value) == 2: + # neither the 3rd parameter (content_type) nor the 4th (headers) + # was included + filename, fileobj = value + elif len(value) == 3: + filename, fileobj, content_type = value + else: + # all 4 parameters included + filename, fileobj, content_type, headers = value # type: ignore + else: + filename = Path(str(getattr(value, "name", "upload"))).name + fileobj = value + + if content_type is None: + content_type = _guess_content_type(filename) + + has_content_type_header = any("content-type" in key.lower() for key in headers) + if content_type is not None and not has_content_type_header: + # note that unlike requests, we ignore the content_type provided in the 3rd + # tuple element if it is also included in the headers requests does + # the opposite (it overwrites the headerwith the 3rd tuple element) + headers["Content-Type"] = content_type + + if isinstance(fileobj, io.StringIO): + raise TypeError( + "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'." + ) + if isinstance(fileobj, io.TextIOBase): + raise TypeError( + "Multipart file uploads must be opened in binary mode, not text mode." + ) + + self.filename = filename + self.file = fileobj + self.headers = headers + + def get_length(self) -> int | None: + headers = self.render_headers() + + if isinstance(self.file, (str, bytes)): + return len(headers) + len(to_bytes(self.file)) + + file_length = peek_filelike_length(self.file) + + # If we can't determine the filesize without reading it into memory, + # then return `None` here, to indicate an unknown file length. + if file_length is None: + return None + + return len(headers) + file_length + + def render_headers(self) -> bytes: + if not hasattr(self, "_headers"): + parts = [ + b"Content-Disposition: form-data; ", + _format_form_param("name", self.name), + ] + if self.filename: + filename = _format_form_param("filename", self.filename) + parts.extend([b"; ", filename]) + for header_name, header_value in self.headers.items(): + key, val = f"\r\n{header_name}: ".encode(), header_value.encode() + parts.extend([key, val]) + parts.append(b"\r\n\r\n") + self._headers = b"".join(parts) + + return self._headers + + def render_data(self) -> typing.Iterator[bytes]: + if isinstance(self.file, (str, bytes)): + yield to_bytes(self.file) + return + + if hasattr(self.file, "seek"): + try: + self.file.seek(0) + except io.UnsupportedOperation: + pass + + chunk = self.file.read(self.CHUNK_SIZE) + while chunk: + yield to_bytes(chunk) + chunk = self.file.read(self.CHUNK_SIZE) + + def render(self) -> typing.Iterator[bytes]: + yield self.render_headers() + yield from self.render_data() + + +class MultipartStream(SyncByteStream, AsyncByteStream): + """ + Request content as streaming multipart encoded form data. + """ + + def __init__( + self, + data: RequestData, + files: RequestFiles, + boundary: bytes | None = None, + ) -> None: + if boundary is None: + boundary = os.urandom(16).hex().encode("ascii") + + self.boundary = boundary + self.content_type = "multipart/form-data; boundary=%s" % boundary.decode( + "ascii" + ) + self.fields = list(self._iter_fields(data, files)) + + def _iter_fields( + self, data: RequestData, files: RequestFiles + ) -> typing.Iterator[FileField | DataField]: + for name, value in data.items(): + if isinstance(value, (tuple, list)): + for item in value: + yield DataField(name=name, value=item) + else: + yield DataField(name=name, value=value) + + file_items = files.items() if isinstance(files, typing.Mapping) else files + for name, value in file_items: + yield FileField(name=name, value=value) + + def iter_chunks(self) -> typing.Iterator[bytes]: + for field in self.fields: + yield b"--%s\r\n" % self.boundary + yield from field.render() + yield b"\r\n" + yield b"--%s--\r\n" % self.boundary + + def get_content_length(self) -> int | None: + """ + Return the length of the multipart encoded content, or `None` if + any of the files have a length that cannot be determined upfront. + """ + boundary_length = len(self.boundary) + length = 0 + + for field in self.fields: + field_length = field.get_length() + if field_length is None: + return None + + length += 2 + boundary_length + 2 # b"--{boundary}\r\n" + length += field_length + length += 2 # b"\r\n" + + length += 2 + boundary_length + 4 # b"--{boundary}--\r\n" + return length + + # Content stream interface. + + def get_headers(self) -> dict[str, str]: + content_length = self.get_content_length() + content_type = self.content_type + if content_length is None: + return {"Transfer-Encoding": "chunked", "Content-Type": content_type} + return {"Content-Length": str(content_length), "Content-Type": content_type} + + def __iter__(self) -> typing.Iterator[bytes]: + for chunk in self.iter_chunks(): + yield chunk + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + for chunk in self.iter_chunks(): + yield chunk diff --git a/venv/lib/python3.10/site-packages/httpx/_status_codes.py b/venv/lib/python3.10/site-packages/httpx/_status_codes.py new file mode 100644 index 0000000000000000000000000000000000000000..133a6231a5b53fd2f073799ca1bd07c50abe40ae --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_status_codes.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from enum import IntEnum + +__all__ = ["codes"] + + +class codes(IntEnum): + """HTTP status codes and reason phrases + + Status codes from the following RFCs are all observed: + + * RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616 + * RFC 6585: Additional HTTP Status Codes + * RFC 3229: Delta encoding in HTTP + * RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518 + * RFC 5842: Binding Extensions to WebDAV + * RFC 7238: Permanent Redirect + * RFC 2295: Transparent Content Negotiation in HTTP + * RFC 2774: An HTTP Extension Framework + * RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2) + * RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0) + * RFC 7725: An HTTP Status Code to Report Legal Obstacles + * RFC 8297: An HTTP Status Code for Indicating Hints + * RFC 8470: Using Early Data in HTTP + """ + + def __new__(cls, value: int, phrase: str = "") -> codes: + obj = int.__new__(cls, value) + obj._value_ = value + + obj.phrase = phrase # type: ignore[attr-defined] + return obj + + def __str__(self) -> str: + return str(self.value) + + @classmethod + def get_reason_phrase(cls, value: int) -> str: + try: + return codes(value).phrase # type: ignore + except ValueError: + return "" + + @classmethod + def is_informational(cls, value: int) -> bool: + """ + Returns `True` for 1xx status codes, `False` otherwise. + """ + return 100 <= value <= 199 + + @classmethod + def is_success(cls, value: int) -> bool: + """ + Returns `True` for 2xx status codes, `False` otherwise. + """ + return 200 <= value <= 299 + + @classmethod + def is_redirect(cls, value: int) -> bool: + """ + Returns `True` for 3xx status codes, `False` otherwise. + """ + return 300 <= value <= 399 + + @classmethod + def is_client_error(cls, value: int) -> bool: + """ + Returns `True` for 4xx status codes, `False` otherwise. + """ + return 400 <= value <= 499 + + @classmethod + def is_server_error(cls, value: int) -> bool: + """ + Returns `True` for 5xx status codes, `False` otherwise. + """ + return 500 <= value <= 599 + + @classmethod + def is_error(cls, value: int) -> bool: + """ + Returns `True` for 4xx or 5xx status codes, `False` otherwise. + """ + return 400 <= value <= 599 + + # informational + CONTINUE = 100, "Continue" + SWITCHING_PROTOCOLS = 101, "Switching Protocols" + PROCESSING = 102, "Processing" + EARLY_HINTS = 103, "Early Hints" + + # success + OK = 200, "OK" + CREATED = 201, "Created" + ACCEPTED = 202, "Accepted" + NON_AUTHORITATIVE_INFORMATION = 203, "Non-Authoritative Information" + NO_CONTENT = 204, "No Content" + RESET_CONTENT = 205, "Reset Content" + PARTIAL_CONTENT = 206, "Partial Content" + MULTI_STATUS = 207, "Multi-Status" + ALREADY_REPORTED = 208, "Already Reported" + IM_USED = 226, "IM Used" + + # redirection + MULTIPLE_CHOICES = 300, "Multiple Choices" + MOVED_PERMANENTLY = 301, "Moved Permanently" + FOUND = 302, "Found" + SEE_OTHER = 303, "See Other" + NOT_MODIFIED = 304, "Not Modified" + USE_PROXY = 305, "Use Proxy" + TEMPORARY_REDIRECT = 307, "Temporary Redirect" + PERMANENT_REDIRECT = 308, "Permanent Redirect" + + # client error + BAD_REQUEST = 400, "Bad Request" + UNAUTHORIZED = 401, "Unauthorized" + PAYMENT_REQUIRED = 402, "Payment Required" + FORBIDDEN = 403, "Forbidden" + NOT_FOUND = 404, "Not Found" + METHOD_NOT_ALLOWED = 405, "Method Not Allowed" + NOT_ACCEPTABLE = 406, "Not Acceptable" + PROXY_AUTHENTICATION_REQUIRED = 407, "Proxy Authentication Required" + REQUEST_TIMEOUT = 408, "Request Timeout" + CONFLICT = 409, "Conflict" + GONE = 410, "Gone" + LENGTH_REQUIRED = 411, "Length Required" + PRECONDITION_FAILED = 412, "Precondition Failed" + REQUEST_ENTITY_TOO_LARGE = 413, "Request Entity Too Large" + REQUEST_URI_TOO_LONG = 414, "Request-URI Too Long" + UNSUPPORTED_MEDIA_TYPE = 415, "Unsupported Media Type" + REQUESTED_RANGE_NOT_SATISFIABLE = 416, "Requested Range Not Satisfiable" + EXPECTATION_FAILED = 417, "Expectation Failed" + IM_A_TEAPOT = 418, "I'm a teapot" + MISDIRECTED_REQUEST = 421, "Misdirected Request" + UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity" + LOCKED = 423, "Locked" + FAILED_DEPENDENCY = 424, "Failed Dependency" + TOO_EARLY = 425, "Too Early" + UPGRADE_REQUIRED = 426, "Upgrade Required" + PRECONDITION_REQUIRED = 428, "Precondition Required" + TOO_MANY_REQUESTS = 429, "Too Many Requests" + REQUEST_HEADER_FIELDS_TOO_LARGE = 431, "Request Header Fields Too Large" + UNAVAILABLE_FOR_LEGAL_REASONS = 451, "Unavailable For Legal Reasons" + + # server errors + INTERNAL_SERVER_ERROR = 500, "Internal Server Error" + NOT_IMPLEMENTED = 501, "Not Implemented" + BAD_GATEWAY = 502, "Bad Gateway" + SERVICE_UNAVAILABLE = 503, "Service Unavailable" + GATEWAY_TIMEOUT = 504, "Gateway Timeout" + HTTP_VERSION_NOT_SUPPORTED = 505, "HTTP Version Not Supported" + VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates" + INSUFFICIENT_STORAGE = 507, "Insufficient Storage" + LOOP_DETECTED = 508, "Loop Detected" + NOT_EXTENDED = 510, "Not Extended" + NETWORK_AUTHENTICATION_REQUIRED = 511, "Network Authentication Required" + + +# Include lower-case styles for `requests` compatibility. +for code in codes: + setattr(codes, code._name_.lower(), int(code)) diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/__init__.py b/venv/lib/python3.10/site-packages/httpx/_transports/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a321053b29bcd48698cf2bd74a1d19c8556aefb --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_transports/__init__.py @@ -0,0 +1,15 @@ +from .asgi import * +from .base import * +from .default import * +from .mock import * +from .wsgi import * + +__all__ = [ + "ASGITransport", + "AsyncBaseTransport", + "BaseTransport", + "AsyncHTTPTransport", + "HTTPTransport", + "MockTransport", + "WSGITransport", +] diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..850ffa33a354fc15bccab2dc6cdf12ecb41d1e92 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/asgi.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/asgi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10dc7e1977b05bd384884020faeecc2a9261a142 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/asgi.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/base.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..108ad110cace417b122b8ada300fd13b29747f85 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/base.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/default.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/default.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a89f3621d122515dc9b2e8754643c9079db2cc45 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/default.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/mock.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/mock.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b8feca610a0a3faac148557b35cdf9645149c29 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/mock.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/wsgi.cpython-310.pyc b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/wsgi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a480f36985d0792552e62ad240fab021f2cf7e6 Binary files /dev/null and b/venv/lib/python3.10/site-packages/httpx/_transports/__pycache__/wsgi.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/asgi.py b/venv/lib/python3.10/site-packages/httpx/_transports/asgi.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc4efae0e1b14620f75f712eb15ecf500d14eef --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_transports/asgi.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import typing + +from .._models import Request, Response +from .._types import AsyncByteStream +from .base import AsyncBaseTransport + +if typing.TYPE_CHECKING: # pragma: no cover + import asyncio + + import trio + + Event = typing.Union[asyncio.Event, trio.Event] + + +_Message = typing.MutableMapping[str, typing.Any] +_Receive = typing.Callable[[], typing.Awaitable[_Message]] +_Send = typing.Callable[ + [typing.MutableMapping[str, typing.Any]], typing.Awaitable[None] +] +_ASGIApp = typing.Callable[ + [typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None] +] + +__all__ = ["ASGITransport"] + + +def is_running_trio() -> bool: + try: + # sniffio is a dependency of trio. + + # See https://github.com/python-trio/trio/issues/2802 + import sniffio + + if sniffio.current_async_library() == "trio": + return True + except ImportError: # pragma: nocover + pass + + return False + + +def create_event() -> Event: + if is_running_trio(): + import trio + + return trio.Event() + + import asyncio + + return asyncio.Event() + + +class ASGIResponseStream(AsyncByteStream): + def __init__(self, body: list[bytes]) -> None: + self._body = body + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + yield b"".join(self._body) + + +class ASGITransport(AsyncBaseTransport): + """ + A custom AsyncTransport that handles sending requests directly to an ASGI app. + + ```python + transport = httpx.ASGITransport( + app=app, + root_path="/submount", + client=("1.2.3.4", 123) + ) + client = httpx.AsyncClient(transport=transport) + ``` + + Arguments: + + * `app` - The ASGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `root_path` - The root path on which the ASGI application should be mounted. + * `client` - A two-tuple indicating the client IP and port of incoming requests. + ``` + """ + + def __init__( + self, + app: _ASGIApp, + raise_app_exceptions: bool = True, + root_path: str = "", + client: tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.root_path = root_path + self.client = client + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + # ASGI scope. + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?")[0], + "query_string": request.url.query, + "server": (request.url.host, request.url.port), + "client": self.client, + "root_path": self.root_path, + } + + # Request. + request_body_chunks = request.stream.__aiter__() + request_complete = False + + # Response. + status_code = None + response_headers = None + body_parts = [] + response_started = False + response_complete = create_event() + + # ASGI callables. + + async def receive() -> dict[str, typing.Any]: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} + + try: + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: typing.MutableMapping[str, typing.Any]) -> None: + nonlocal status_code, response_headers, response_started + + if message["type"] == "http.response.start": + assert not response_started + + status_code = message["status"] + response_headers = message.get("headers", []) + response_started = True + + elif message["type"] == "http.response.body": + assert not response_complete.is_set() + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and request.method != "HEAD": + body_parts.append(body) + + if not more_body: + response_complete.set() + + try: + await self.app(scope, receive, send) + except Exception: # noqa: PIE-786 + if self.raise_app_exceptions: + raise + + response_complete.set() + if status_code is None: + status_code = 500 + if response_headers is None: + response_headers = {} + + assert response_complete.is_set() + assert status_code is not None + assert response_headers is not None + + stream = ASGIResponseStream(body_parts) + + return Response(status_code, headers=response_headers, stream=stream) diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/base.py b/venv/lib/python3.10/site-packages/httpx/_transports/base.py new file mode 100644 index 0000000000000000000000000000000000000000..66fd99d702480b555c06694fe14715ea6df3dfc3 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_transports/base.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import typing +from types import TracebackType + +from .._models import Request, Response + +T = typing.TypeVar("T", bound="BaseTransport") +A = typing.TypeVar("A", bound="AsyncBaseTransport") + +__all__ = ["AsyncBaseTransport", "BaseTransport"] + + +class BaseTransport: + def __enter__(self: T) -> T: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + self.close() + + def handle_request(self, request: Request) -> Response: + """ + Send a single HTTP request and return a response. + + Developers shouldn't typically ever need to call into this API directly, + since the Client class provides all the higher level user-facing API + niceties. + + In order to properly release any network resources, the response + stream should *either* be consumed immediately, with a call to + `response.stream.read()`, or else the `handle_request` call should + be followed with a try/finally block to ensuring the stream is + always closed. + + Example usage: + + with httpx.HTTPTransport() as transport: + req = httpx.Request( + method=b"GET", + url=(b"https", b"www.example.com", 443, b"/"), + headers=[(b"Host", b"www.example.com")], + ) + resp = transport.handle_request(req) + body = resp.stream.read() + print(resp.status_code, resp.headers, body) + + + Takes a `Request` instance as the only argument. + + Returns a `Response` instance. + """ + raise NotImplementedError( + "The 'handle_request' method must be implemented." + ) # pragma: no cover + + def close(self) -> None: + pass + + +class AsyncBaseTransport: + async def __aenter__(self: A) -> A: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + await self.aclose() + + async def handle_async_request( + self, + request: Request, + ) -> Response: + raise NotImplementedError( + "The 'handle_async_request' method must be implemented." + ) # pragma: no cover + + async def aclose(self) -> None: + pass diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/default.py b/venv/lib/python3.10/site-packages/httpx/_transports/default.py new file mode 100644 index 0000000000000000000000000000000000000000..d5aa05ff234fd3fbf4fee88c4a7d3e3c151a538f --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_transports/default.py @@ -0,0 +1,406 @@ +""" +Custom transports, with nicely configured defaults. + +The following additional keyword arguments are currently supported by httpcore... + +* uds: str +* local_address: str +* retries: int + +Example usages... + +# Disable HTTP/2 on a single specific domain. +mounts = { + "all://": httpx.HTTPTransport(http2=True), + "all://*example.org": httpx.HTTPTransport() +} + +# Using advanced httpcore configuration, with connection retries. +transport = httpx.HTTPTransport(retries=1) +client = httpx.Client(transport=transport) + +# Using advanced httpcore configuration, with unix domain sockets. +transport = httpx.HTTPTransport(uds="socket.uds") +client = httpx.Client(transport=transport) +""" + +from __future__ import annotations + +import contextlib +import typing +from types import TracebackType + +if typing.TYPE_CHECKING: + import ssl # pragma: no cover + + import httpx # pragma: no cover + +from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context +from .._exceptions import ( + ConnectError, + ConnectTimeout, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + TimeoutException, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from .._models import Request, Response +from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream +from .._urls import URL +from .base import AsyncBaseTransport, BaseTransport + +T = typing.TypeVar("T", bound="HTTPTransport") +A = typing.TypeVar("A", bound="AsyncHTTPTransport") + +SOCKET_OPTION = typing.Union[ + typing.Tuple[int, int, int], + typing.Tuple[int, int, typing.Union[bytes, bytearray]], + typing.Tuple[int, int, None, int], +] + +__all__ = ["AsyncHTTPTransport", "HTTPTransport"] + +HTTPCORE_EXC_MAP: dict[type[Exception], type[httpx.HTTPError]] = {} + + +def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]: + import httpcore + + return { + httpcore.TimeoutException: TimeoutException, + httpcore.ConnectTimeout: ConnectTimeout, + httpcore.ReadTimeout: ReadTimeout, + httpcore.WriteTimeout: WriteTimeout, + httpcore.PoolTimeout: PoolTimeout, + httpcore.NetworkError: NetworkError, + httpcore.ConnectError: ConnectError, + httpcore.ReadError: ReadError, + httpcore.WriteError: WriteError, + httpcore.ProxyError: ProxyError, + httpcore.UnsupportedProtocol: UnsupportedProtocol, + httpcore.ProtocolError: ProtocolError, + httpcore.LocalProtocolError: LocalProtocolError, + httpcore.RemoteProtocolError: RemoteProtocolError, + } + + +@contextlib.contextmanager +def map_httpcore_exceptions() -> typing.Iterator[None]: + global HTTPCORE_EXC_MAP + if len(HTTPCORE_EXC_MAP) == 0: + HTTPCORE_EXC_MAP = _load_httpcore_exceptions() + try: + yield + except Exception as exc: + mapped_exc = None + + for from_exc, to_exc in HTTPCORE_EXC_MAP.items(): + if not isinstance(exc, from_exc): + continue + # We want to map to the most specific exception we can find. + # Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to + # `httpx.ReadTimeout`, not just `httpx.TimeoutException`. + if mapped_exc is None or issubclass(to_exc, mapped_exc): + mapped_exc = to_exc + + if mapped_exc is None: # pragma: no cover + raise + + message = str(exc) + raise mapped_exc(message) from exc + + +class ResponseStream(SyncByteStream): + def __init__(self, httpcore_stream: typing.Iterable[bytes]) -> None: + self._httpcore_stream = httpcore_stream + + def __iter__(self) -> typing.Iterator[bytes]: + with map_httpcore_exceptions(): + for part in self._httpcore_stream: + yield part + + def close(self) -> None: + if hasattr(self._httpcore_stream, "close"): + self._httpcore_stream.close() + + +class HTTPTransport(BaseTransport): + def __init__( + self, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + proxy: ProxyTypes | None = None, + uds: str | None = None, + local_address: str | None = None, + retries: int = 0, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + import httpcore + + proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy + ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + + if proxy is None: + self._pool = httpcore.ConnectionPool( + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + uds=uds, + local_address=local_address, + retries=retries, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("http", "https"): + self._pool = httpcore.HTTPProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + proxy_headers=proxy.headers.raw, + ssl_context=ssl_context, + proxy_ssl_context=proxy.ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("socks5", "socks5h"): + try: + import socksio # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using SOCKS proxy, but the 'socksio' package is not installed. " + "Make sure to install httpx using `pip install httpx[socks]`." + ) from None + + self._pool = httpcore.SOCKSProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + ) + else: # pragma: no cover + raise ValueError( + "Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h'," + f" but got {proxy.url.scheme!r}." + ) + + def __enter__(self: T) -> T: # Use generics for subclass support. + self._pool.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + with map_httpcore_exceptions(): + self._pool.__exit__(exc_type, exc_value, traceback) + + def handle_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, SyncByteStream) + import httpcore + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = self._pool.handle_request(req) + + assert isinstance(resp.stream, typing.Iterable) + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=ResponseStream(resp.stream), + extensions=resp.extensions, + ) + + def close(self) -> None: + self._pool.close() + + +class AsyncResponseStream(AsyncByteStream): + def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None: + self._httpcore_stream = httpcore_stream + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + with map_httpcore_exceptions(): + async for part in self._httpcore_stream: + yield part + + async def aclose(self) -> None: + if hasattr(self._httpcore_stream, "aclose"): + await self._httpcore_stream.aclose() + + +class AsyncHTTPTransport(AsyncBaseTransport): + def __init__( + self, + verify: ssl.SSLContext | str | bool = True, + cert: CertTypes | None = None, + trust_env: bool = True, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + proxy: ProxyTypes | None = None, + uds: str | None = None, + local_address: str | None = None, + retries: int = 0, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> None: + import httpcore + + proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy + ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + + if proxy is None: + self._pool = httpcore.AsyncConnectionPool( + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + uds=uds, + local_address=local_address, + retries=retries, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("http", "https"): + self._pool = httpcore.AsyncHTTPProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + proxy_headers=proxy.headers.raw, + proxy_ssl_context=proxy.ssl_context, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("socks5", "socks5h"): + try: + import socksio # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using SOCKS proxy, but the 'socksio' package is not installed. " + "Make sure to install httpx using `pip install httpx[socks]`." + ) from None + + self._pool = httpcore.AsyncSOCKSProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + ) + else: # pragma: no cover + raise ValueError( + "Proxy protocol must be either 'http', 'https', 'socks5', or 'socks5h'," + " but got {proxy.url.scheme!r}." + ) + + async def __aenter__(self: A) -> A: # Use generics for subclass support. + await self._pool.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + with map_httpcore_exceptions(): + await self._pool.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + import httpcore + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = await self._pool.handle_async_request(req) + + assert isinstance(resp.stream, typing.AsyncIterable) + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=AsyncResponseStream(resp.stream), + extensions=resp.extensions, + ) + + async def aclose(self) -> None: + await self._pool.aclose() diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/mock.py b/venv/lib/python3.10/site-packages/httpx/_transports/mock.py new file mode 100644 index 0000000000000000000000000000000000000000..8c418f59e06cae43abdbb626ec21cafc7e8c6277 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_transports/mock.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import typing + +from .._models import Request, Response +from .base import AsyncBaseTransport, BaseTransport + +SyncHandler = typing.Callable[[Request], Response] +AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]] + + +__all__ = ["MockTransport"] + + +class MockTransport(AsyncBaseTransport, BaseTransport): + def __init__(self, handler: SyncHandler | AsyncHandler) -> None: + self.handler = handler + + def handle_request( + self, + request: Request, + ) -> Response: + request.read() + response = self.handler(request) + if not isinstance(response, Response): # pragma: no cover + raise TypeError("Cannot use an async handler in a sync Client") + return response + + async def handle_async_request( + self, + request: Request, + ) -> Response: + await request.aread() + response = self.handler(request) + + # Allow handler to *optionally* be an `async` function. + # If it is, then the `response` variable need to be awaited to actually + # return the result. + + if not isinstance(response, Response): + response = await response + + return response diff --git a/venv/lib/python3.10/site-packages/httpx/_transports/wsgi.py b/venv/lib/python3.10/site-packages/httpx/_transports/wsgi.py new file mode 100644 index 0000000000000000000000000000000000000000..8592ffe017a87367cc7578184540096a9682908d --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_transports/wsgi.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import io +import itertools +import sys +import typing + +from .._models import Request, Response +from .._types import SyncByteStream +from .base import BaseTransport + +if typing.TYPE_CHECKING: + from _typeshed import OptExcInfo # pragma: no cover + from _typeshed.wsgi import WSGIApplication # pragma: no cover + +_T = typing.TypeVar("_T") + + +__all__ = ["WSGITransport"] + + +def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]: + body = iter(body) + for chunk in body: + if chunk: + return itertools.chain([chunk], body) + return [] + + +class WSGIByteStream(SyncByteStream): + def __init__(self, result: typing.Iterable[bytes]) -> None: + self._close = getattr(result, "close", None) + self._result = _skip_leading_empty_chunks(result) + + def __iter__(self) -> typing.Iterator[bytes]: + for part in self._result: + yield part + + def close(self) -> None: + if self._close is not None: + self._close() + + +class WSGITransport(BaseTransport): + """ + A custom transport that handles sending requests directly to an WSGI app. + The simplest way to use this functionality is to use the `app` argument. + + ``` + client = httpx.Client(app=app) + ``` + + Alternatively, you can setup the transport instance explicitly. + This allows you to include any additional configuration arguments specific + to the WSGITransport class: + + ``` + transport = httpx.WSGITransport( + app=app, + script_name="/submount", + remote_addr="1.2.3.4" + ) + client = httpx.Client(transport=transport) + ``` + + Arguments: + + * `app` - The WSGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `script_name` - The root path on which the WSGI application should be mounted. + * `remote_addr` - A string indicating the client IP of incoming requests. + ``` + """ + + def __init__( + self, + app: WSGIApplication, + raise_app_exceptions: bool = True, + script_name: str = "", + remote_addr: str = "127.0.0.1", + wsgi_errors: typing.TextIO | None = None, + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.script_name = script_name + self.remote_addr = remote_addr + self.wsgi_errors = wsgi_errors + + def handle_request(self, request: Request) -> Response: + request.read() + wsgi_input = io.BytesIO(request.content) + + port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] + environ = { + "wsgi.version": (1, 0), + "wsgi.url_scheme": request.url.scheme, + "wsgi.input": wsgi_input, + "wsgi.errors": self.wsgi_errors or sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "REQUEST_METHOD": request.method, + "SCRIPT_NAME": self.script_name, + "PATH_INFO": request.url.path, + "QUERY_STRING": request.url.query.decode("ascii"), + "SERVER_NAME": request.url.host, + "SERVER_PORT": str(port), + "SERVER_PROTOCOL": "HTTP/1.1", + "REMOTE_ADDR": self.remote_addr, + } + for header_key, header_value in request.headers.raw: + key = header_key.decode("ascii").upper().replace("-", "_") + if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): + key = "HTTP_" + key + environ[key] = header_value.decode("ascii") + + seen_status = None + seen_response_headers = None + seen_exc_info = None + + def start_response( + status: str, + response_headers: list[tuple[str, str]], + exc_info: OptExcInfo | None = None, + ) -> typing.Callable[[bytes], typing.Any]: + nonlocal seen_status, seen_response_headers, seen_exc_info + seen_status = status + seen_response_headers = response_headers + seen_exc_info = exc_info + return lambda _: None + + result = self.app(environ, start_response) + + stream = WSGIByteStream(result) + + assert seen_status is not None + assert seen_response_headers is not None + if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions: + raise seen_exc_info[1] + + status_code = int(seen_status.split()[0]) + headers = [ + (key.encode("ascii"), value.encode("ascii")) + for key, value in seen_response_headers + ] + + return Response(status_code, headers=headers, stream=stream) diff --git a/venv/lib/python3.10/site-packages/httpx/_types.py b/venv/lib/python3.10/site-packages/httpx/_types.py new file mode 100644 index 0000000000000000000000000000000000000000..704dfdffc8ba61eb913fa918072381e410b23c00 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_types.py @@ -0,0 +1,114 @@ +""" +Type definitions for type checking purposes. +""" + +from http.cookiejar import CookieJar +from typing import ( + IO, + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +if TYPE_CHECKING: # pragma: no cover + from ._auth import Auth # noqa: F401 + from ._config import Proxy, Timeout # noqa: F401 + from ._models import Cookies, Headers, Request # noqa: F401 + from ._urls import URL, QueryParams # noqa: F401 + + +PrimitiveData = Optional[Union[str, int, float, bool]] + +URLTypes = Union["URL", str] + +QueryParamTypes = Union[ + "QueryParams", + Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]], + List[Tuple[str, PrimitiveData]], + Tuple[Tuple[str, PrimitiveData], ...], + str, + bytes, +] + +HeaderTypes = Union[ + "Headers", + Mapping[str, str], + Mapping[bytes, bytes], + Sequence[Tuple[str, str]], + Sequence[Tuple[bytes, bytes]], +] + +CookieTypes = Union["Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]] + +TimeoutTypes = Union[ + Optional[float], + Tuple[Optional[float], Optional[float], Optional[float], Optional[float]], + "Timeout", +] +ProxyTypes = Union["URL", str, "Proxy"] +CertTypes = Union[str, Tuple[str, str], Tuple[str, str, str]] + +AuthTypes = Union[ + Tuple[Union[str, bytes], Union[str, bytes]], + Callable[["Request"], "Request"], + "Auth", +] + +RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +ResponseContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +ResponseExtensions = Mapping[str, Any] + +RequestData = Mapping[str, Any] + +FileContent = Union[IO[bytes], bytes, str] +FileTypes = Union[ + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], +] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] + +RequestExtensions = Mapping[str, Any] + +__all__ = ["AsyncByteStream", "SyncByteStream"] + + +class SyncByteStream: + def __iter__(self) -> Iterator[bytes]: + raise NotImplementedError( + "The '__iter__' method must be implemented." + ) # pragma: no cover + yield b"" # pragma: no cover + + def close(self) -> None: + """ + Subclasses can override this method to release any network resources + after a request/response cycle is complete. + """ + + +class AsyncByteStream: + async def __aiter__(self) -> AsyncIterator[bytes]: + raise NotImplementedError( + "The '__aiter__' method must be implemented." + ) # pragma: no cover + yield b"" # pragma: no cover + + async def aclose(self) -> None: + pass diff --git a/venv/lib/python3.10/site-packages/httpx/_urlparse.py b/venv/lib/python3.10/site-packages/httpx/_urlparse.py new file mode 100644 index 0000000000000000000000000000000000000000..bf190fd560ee4fc8a11af371a15fc5f1dc284d34 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_urlparse.py @@ -0,0 +1,527 @@ +""" +An implementation of `urlparse` that provides URL validation and normalization +as described by RFC3986. + +We rely on this implementation rather than the one in Python's stdlib, because: + +* It provides more complete URL validation. +* It properly differentiates between an empty querystring and an absent querystring, + to distinguish URLs with a trailing '?'. +* It handles scheme, hostname, port, and path normalization. +* It supports IDNA hostnames, normalizing them to their encoded form. +* The API supports passing individual components, as well as the complete URL string. + +Previously we relied on the excellent `rfc3986` package to handle URL parsing and +validation, but this module provides a simpler alternative, with less indirection +required. +""" + +from __future__ import annotations + +import ipaddress +import re +import typing + +import idna + +from ._exceptions import InvalidURL + +MAX_URL_LENGTH = 65536 + +# https://datatracker.ietf.org/doc/html/rfc3986.html#section-2.3 +UNRESERVED_CHARACTERS = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" +) +SUB_DELIMS = "!$&'()*+,;=" + +PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}") + +# https://url.spec.whatwg.org/#percent-encoded-bytes + +# The fragment percent-encode set is the C0 control percent-encode set +# and U+0020 SPACE, U+0022 ("), U+003C (<), U+003E (>), and U+0060 (`). +FRAG_SAFE = "".join( + [chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x3C, 0x3E, 0x60)] +) + +# The query percent-encode set is the C0 control percent-encode set +# and U+0020 SPACE, U+0022 ("), U+0023 (#), U+003C (<), and U+003E (>). +QUERY_SAFE = "".join( + [chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E)] +) + +# The path percent-encode set is the query percent-encode set +# and U+003F (?), U+0060 (`), U+007B ({), and U+007D (}). +PATH_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + (0x3F, 0x60, 0x7B, 0x7D) + ] +) + +# The userinfo percent-encode set is the path percent-encode set +# and U+002F (/), U+003A (:), U+003B (;), U+003D (=), U+0040 (@), +# U+005B ([) to U+005E (^), inclusive, and U+007C (|). +USERNAME_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) +PASSWORD_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) +# Note... The terminology 'userinfo' percent-encode set in the WHATWG document +# is used for the username and password quoting. For the joint userinfo component +# we remove U+003A (:) from the safe set. +USERINFO_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) + + +# {scheme}: (optional) +# //{authority} (optional) +# {path} +# ?{query} (optional) +# #{fragment} (optional) +URL_REGEX = re.compile( + ( + r"(?:(?P{scheme}):)?" + r"(?://(?P{authority}))?" + r"(?P{path})" + r"(?:\?(?P{query}))?" + r"(?:#(?P{fragment}))?" + ).format( + scheme="([a-zA-Z][a-zA-Z0-9+.-]*)?", + authority="[^/?#]*", + path="[^?#]*", + query="[^#]*", + fragment=".*", + ) +) + +# {userinfo}@ (optional) +# {host} +# :{port} (optional) +AUTHORITY_REGEX = re.compile( + ( + r"(?:(?P{userinfo})@)?" r"(?P{host})" r":?(?P{port})?" + ).format( + userinfo=".*", # Any character sequence. + host="(\\[.*\\]|[^:@]*)", # Either any character sequence excluding ':' or '@', + # or an IPv6 address enclosed within square brackets. + port=".*", # Any character sequence. + ) +) + + +# If we call urlparse with an individual component, then we need to regex +# validate that component individually. +# Note that we're duplicating the same strings as above. Shock! Horror!! +COMPONENT_REGEX = { + "scheme": re.compile("([a-zA-Z][a-zA-Z0-9+.-]*)?"), + "authority": re.compile("[^/?#]*"), + "path": re.compile("[^?#]*"), + "query": re.compile("[^#]*"), + "fragment": re.compile(".*"), + "userinfo": re.compile("[^@]*"), + "host": re.compile("(\\[.*\\]|[^:]*)"), + "port": re.compile(".*"), +} + + +# We use these simple regexs as a first pass before handing off to +# the stdlib 'ipaddress' module for IP address validation. +IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") +IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$") + + +class ParseResult(typing.NamedTuple): + scheme: str + userinfo: str + host: str + port: int | None + path: str + query: str | None + fragment: str | None + + @property + def authority(self) -> str: + return "".join( + [ + f"{self.userinfo}@" if self.userinfo else "", + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + @property + def netloc(self) -> str: + return "".join( + [ + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + def copy_with(self, **kwargs: str | None) -> ParseResult: + if not kwargs: + return self + + defaults = { + "scheme": self.scheme, + "authority": self.authority, + "path": self.path, + "query": self.query, + "fragment": self.fragment, + } + defaults.update(kwargs) + return urlparse("", **defaults) + + def __str__(self) -> str: + authority = self.authority + return "".join( + [ + f"{self.scheme}:" if self.scheme else "", + f"//{authority}" if authority else "", + self.path, + f"?{self.query}" if self.query is not None else "", + f"#{self.fragment}" if self.fragment is not None else "", + ] + ) + + +def urlparse(url: str = "", **kwargs: str | None) -> ParseResult: + # Initial basic checks on allowable URLs. + # --------------------------------------- + + # Hard limit the maximum allowable URL length. + if len(url) > MAX_URL_LENGTH: + raise InvalidURL("URL too long") + + # If a URL includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in url): + char = next(char for char in url if char.isascii() and not char.isprintable()) + idx = url.find(char) + error = ( + f"Invalid non-printable ASCII character in URL, {char!r} at position {idx}." + ) + raise InvalidURL(error) + + # Some keyword arguments require special handling. + # ------------------------------------------------ + + # Coerce "port" to a string, if it is provided as an integer. + if "port" in kwargs: + port = kwargs["port"] + kwargs["port"] = str(port) if isinstance(port, int) else port + + # Replace "netloc" with "host and "port". + if "netloc" in kwargs: + netloc = kwargs.pop("netloc") or "" + kwargs["host"], _, kwargs["port"] = netloc.partition(":") + + # Replace "username" and/or "password" with "userinfo". + if "username" in kwargs or "password" in kwargs: + username = quote(kwargs.pop("username", "") or "", safe=USERNAME_SAFE) + password = quote(kwargs.pop("password", "") or "", safe=PASSWORD_SAFE) + kwargs["userinfo"] = f"{username}:{password}" if password else username + + # Replace "raw_path" with "path" and "query". + if "raw_path" in kwargs: + raw_path = kwargs.pop("raw_path") or "" + kwargs["path"], seperator, kwargs["query"] = raw_path.partition("?") + if not seperator: + kwargs["query"] = None + + # Ensure that IPv6 "host" addresses are always escaped with "[...]". + if "host" in kwargs: + host = kwargs.get("host") or "" + if ":" in host and not (host.startswith("[") and host.endswith("]")): + kwargs["host"] = f"[{host}]" + + # If any keyword arguments are provided, ensure they are valid. + # ------------------------------------------------------------- + + for key, value in kwargs.items(): + if value is not None: + if len(value) > MAX_URL_LENGTH: + raise InvalidURL(f"URL component '{key}' too long") + + # If a component includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in value): + char = next( + char for char in value if char.isascii() and not char.isprintable() + ) + idx = value.find(char) + error = ( + f"Invalid non-printable ASCII character in URL {key} component, " + f"{char!r} at position {idx}." + ) + raise InvalidURL(error) + + # Ensure that keyword arguments match as a valid regex. + if not COMPONENT_REGEX[key].fullmatch(value): + raise InvalidURL(f"Invalid URL component '{key}'") + + # The URL_REGEX will always match, but may have empty components. + url_match = URL_REGEX.match(url) + assert url_match is not None + url_dict = url_match.groupdict() + + # * 'scheme', 'authority', and 'path' may be empty strings. + # * 'query' may be 'None', indicating no trailing "?" portion. + # Any string including the empty string, indicates a trailing "?". + # * 'fragment' may be 'None', indicating no trailing "#" portion. + # Any string including the empty string, indicates a trailing "#". + scheme = kwargs.get("scheme", url_dict["scheme"]) or "" + authority = kwargs.get("authority", url_dict["authority"]) or "" + path = kwargs.get("path", url_dict["path"]) or "" + query = kwargs.get("query", url_dict["query"]) + frag = kwargs.get("fragment", url_dict["fragment"]) + + # The AUTHORITY_REGEX will always match, but may have empty components. + authority_match = AUTHORITY_REGEX.match(authority) + assert authority_match is not None + authority_dict = authority_match.groupdict() + + # * 'userinfo' and 'host' may be empty strings. + # * 'port' may be 'None'. + userinfo = kwargs.get("userinfo", authority_dict["userinfo"]) or "" + host = kwargs.get("host", authority_dict["host"]) or "" + port = kwargs.get("port", authority_dict["port"]) + + # Normalize and validate each component. + # We end up with a parsed representation of the URL, + # with components that are plain ASCII bytestrings. + parsed_scheme: str = scheme.lower() + parsed_userinfo: str = quote(userinfo, safe=USERINFO_SAFE) + parsed_host: str = encode_host(host) + parsed_port: int | None = normalize_port(port, scheme) + + has_scheme = parsed_scheme != "" + has_authority = ( + parsed_userinfo != "" or parsed_host != "" or parsed_port is not None + ) + validate_path(path, has_scheme=has_scheme, has_authority=has_authority) + if has_scheme or has_authority: + path = normalize_path(path) + + parsed_path: str = quote(path, safe=PATH_SAFE) + parsed_query: str | None = None if query is None else quote(query, safe=QUERY_SAFE) + parsed_frag: str | None = None if frag is None else quote(frag, safe=FRAG_SAFE) + + # The parsed ASCII bytestrings are our canonical form. + # All properties of the URL are derived from these. + return ParseResult( + parsed_scheme, + parsed_userinfo, + parsed_host, + parsed_port, + parsed_path, + parsed_query, + parsed_frag, + ) + + +def encode_host(host: str) -> str: + if not host: + return "" + + elif IPv4_STYLE_HOSTNAME.match(host): + # Validate IPv4 hostnames like #.#.#.# + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet + try: + ipaddress.IPv4Address(host) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv4 address: {host!r}") + return host + + elif IPv6_STYLE_HOSTNAME.match(host): + # Validate IPv6 hostnames like [...] + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # "A host identified by an Internet Protocol literal address, version 6 + # [RFC3513] or later, is distinguished by enclosing the IP literal + # within square brackets ("[" and "]"). This is the only place where + # square bracket characters are allowed in the URI syntax." + try: + ipaddress.IPv6Address(host[1:-1]) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv6 address: {host!r}") + return host[1:-1] + + elif host.isascii(): + # Regular ASCII hostnames + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # reg-name = *( unreserved / pct-encoded / sub-delims ) + WHATWG_SAFE = '"`{}%|\\' + return quote(host.lower(), safe=SUB_DELIMS + WHATWG_SAFE) + + # IDNA hostnames + try: + return idna.encode(host.lower()).decode("ascii") + except idna.IDNAError: + raise InvalidURL(f"Invalid IDNA hostname: {host!r}") + + +def normalize_port(port: str | int | None, scheme: str) -> int | None: + # From https://tools.ietf.org/html/rfc3986#section-3.2.3 + # + # "A scheme may define a default port. For example, the "http" scheme + # defines a default port of "80", corresponding to its reserved TCP + # port number. The type of port designated by the port number (e.g., + # TCP, UDP, SCTP) is defined by the URI scheme. URI producers and + # normalizers should omit the port component and its ":" delimiter if + # port is empty or if its value would be the same as that of the + # scheme's default." + if port is None or port == "": + return None + + try: + port_as_int = int(port) + except ValueError: + raise InvalidURL(f"Invalid port: {port!r}") + + # See https://url.spec.whatwg.org/#url-miscellaneous + default_port = {"ftp": 21, "http": 80, "https": 443, "ws": 80, "wss": 443}.get( + scheme + ) + if port_as_int == default_port: + return None + return port_as_int + + +def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None: + """ + Path validation rules that depend on if the URL contains + a scheme or authority component. + + See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3 + """ + if has_authority: + # If a URI contains an authority component, then the path component + # must either be empty or begin with a slash ("/") character." + if path and not path.startswith("/"): + raise InvalidURL("For absolute URLs, path must be empty or begin with '/'") + + if not has_scheme and not has_authority: + # If a URI does not contain an authority component, then the path cannot begin + # with two slash characters ("//"). + if path.startswith("//"): + raise InvalidURL("Relative URLs cannot have a path starting with '//'") + + # In addition, a URI reference (Section 4.1) may be a relative-path reference, + # in which case the first path segment cannot contain a colon (":") character. + if path.startswith(":"): + raise InvalidURL("Relative URLs cannot have a path starting with ':'") + + +def normalize_path(path: str) -> str: + """ + Drop "." and ".." segments from a URL path. + + For example: + + normalize_path("/path/./to/somewhere/..") == "/path/to" + """ + # Fast return when no '.' characters in the path. + if "." not in path: + return path + + components = path.split("/") + + # Fast return when no '.' or '..' components in the path. + if "." not in components and ".." not in components: + return path + + # https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4 + output: list[str] = [] + for component in components: + if component == ".": + pass + elif component == "..": + if output and output != [""]: + output.pop() + else: + output.append(component) + return "/".join(output) + + +def PERCENT(string: str) -> str: + return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")]) + + +def percent_encoded(string: str, safe: str) -> str: + """ + Use percent-encoding to quote a string. + """ + NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + + # Fast path for strings that don't need escaping. + if not string.rstrip(NON_ESCAPED_CHARS): + return string + + return "".join( + [char if char in NON_ESCAPED_CHARS else PERCENT(char) for char in string] + ) + + +def quote(string: str, safe: str) -> str: + """ + Use percent-encoding to quote a string, omitting existing '%xx' escape sequences. + + See: https://www.rfc-editor.org/rfc/rfc3986#section-2.1 + + * `string`: The string to be percent-escaped. + * `safe`: A string containing characters that may be treated as safe, and do not + need to be escaped. Unreserved characters are always treated as safe. + See: https://www.rfc-editor.org/rfc/rfc3986#section-2.3 + """ + parts = [] + current_position = 0 + for match in re.finditer(PERCENT_ENCODED_REGEX, string): + start_position, end_position = match.start(), match.end() + matched_text = match.group(0) + # Add any text up to the '%xx' escape sequence. + if start_position != current_position: + leading_text = string[current_position:start_position] + parts.append(percent_encoded(leading_text, safe=safe)) + + # Add the '%xx' escape sequence. + parts.append(matched_text) + current_position = end_position + + # Add any text after the final '%xx' escape sequence. + if current_position != len(string): + trailing_text = string[current_position:] + parts.append(percent_encoded(trailing_text, safe=safe)) + + return "".join(parts) diff --git a/venv/lib/python3.10/site-packages/httpx/_urls.py b/venv/lib/python3.10/site-packages/httpx/_urls.py new file mode 100644 index 0000000000000000000000000000000000000000..147a8fa333acaf31618d37ba2896e3a5bf5e4d02 --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_urls.py @@ -0,0 +1,641 @@ +from __future__ import annotations + +import typing +from urllib.parse import parse_qs, unquote, urlencode + +import idna + +from ._types import QueryParamTypes +from ._urlparse import urlparse +from ._utils import primitive_value_to_str + +__all__ = ["URL", "QueryParams"] + + +class URL: + """ + url = httpx.URL("HTTPS://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink") + + assert url.scheme == "https" + assert url.username == "jo@email.com" + assert url.password == "a secret" + assert url.userinfo == b"jo%40email.com:a%20secret" + assert url.host == "müller.de" + assert url.raw_host == b"xn--mller-kva.de" + assert url.port == 1234 + assert url.netloc == b"xn--mller-kva.de:1234" + assert url.path == "/pa th" + assert url.query == b"?search=ab" + assert url.raw_path == b"/pa%20th?search=ab" + assert url.fragment == "anchorlink" + + The components of a URL are broken down like this: + + https://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink + [scheme] [ username ] [password] [ host ][port][ path ] [ query ] [fragment] + [ userinfo ] [ netloc ][ raw_path ] + + Note that: + + * `url.scheme` is normalized to always be lowercased. + + * `url.host` is normalized to always be lowercased. Internationalized domain + names are represented in unicode, without IDNA encoding applied. For instance: + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "中国.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "中国.icom.museum" + + * `url.raw_host` is normalized to always be lowercased, and is IDNA encoded. + + url = httpx.URL("http://中国.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + * `url.port` is either None or an integer. URLs that include the default port for + "http", "https", "ws", "wss", and "ftp" schemes have their port + normalized to `None`. + + assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80") + assert httpx.URL("http://example.com").port is None + assert httpx.URL("http://example.com:80").port is None + + * `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work + with `url.username` and `url.password` instead, which handle the URL escaping. + + * `url.raw_path` is raw bytes of both the path and query, without URL escaping. + This portion is used as the target when constructing HTTP requests. Usually you'll + want to work with `url.path` instead. + + * `url.query` is raw bytes, without URL escaping. A URL query string portion can + only be properly URL escaped when decoding the parameter names and values + themselves. + """ + + def __init__(self, url: URL | str = "", **kwargs: typing.Any) -> None: + if kwargs: + allowed = { + "scheme": str, + "username": str, + "password": str, + "userinfo": bytes, + "host": str, + "port": int, + "netloc": bytes, + "path": str, + "query": bytes, + "raw_path": bytes, + "fragment": str, + "params": object, + } + + # Perform type checking for all supported keyword arguments. + for key, value in kwargs.items(): + if key not in allowed: + message = f"{key!r} is an invalid keyword argument for URL()" + raise TypeError(message) + if value is not None and not isinstance(value, allowed[key]): + expected = allowed[key].__name__ + seen = type(value).__name__ + message = f"Argument {key!r} must be {expected} but got {seen}" + raise TypeError(message) + if isinstance(value, bytes): + kwargs[key] = value.decode("ascii") + + if "params" in kwargs: + # Replace any "params" keyword with the raw "query" instead. + # + # Ensure that empty params use `kwargs["query"] = None` rather + # than `kwargs["query"] = ""`, so that generated URLs do not + # include an empty trailing "?". + params = kwargs.pop("params") + kwargs["query"] = None if not params else str(QueryParams(params)) + + if isinstance(url, str): + self._uri_reference = urlparse(url, **kwargs) + elif isinstance(url, URL): + self._uri_reference = url._uri_reference.copy_with(**kwargs) + else: + raise TypeError( + "Invalid type for url. Expected str or httpx.URL," + f" got {type(url)}: {url!r}" + ) + + @property + def scheme(self) -> str: + """ + The URL scheme, such as "http", "https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme + + @property + def raw_scheme(self) -> bytes: + """ + The raw bytes representation of the URL scheme, such as b"http", b"https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme.encode("ascii") + + @property + def userinfo(self) -> bytes: + """ + The URL userinfo as a raw bytestring. + For example: b"jo%40email.com:a%20secret". + """ + return self._uri_reference.userinfo.encode("ascii") + + @property + def username(self) -> str: + """ + The URL username as a string, with URL decoding applied. + For example: "jo@email.com" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[0]) + + @property + def password(self) -> str: + """ + The URL password as a string, with URL decoding applied. + For example: "a secret" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[2]) + + @property + def host(self) -> str: + """ + The URL host as a string. + Always normalized to lowercase, with IDNA hosts decoded into unicode. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.host == "www.example.org" + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "中国.icom.museum" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "中国.icom.museum" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.host == "::ffff:192.168.0.1" + """ + host: str = self._uri_reference.host + + if host.startswith("xn--"): + host = idna.decode(host) + + return host + + @property + def raw_host(self) -> bytes: + """ + The raw bytes representation of the URL host. + Always normalized to lowercase, and IDNA encoded. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.raw_host == b"www.example.org" + + url = httpx.URL("http://中国.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.raw_host == b"::ffff:192.168.0.1" + """ + return self._uri_reference.host.encode("ascii") + + @property + def port(self) -> int | None: + """ + The URL port as an integer. + + Note that the URL class performs port normalization as per the WHATWG spec. + Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always + treated as `None`. + + For example: + + assert httpx.URL("http://www.example.com") == httpx.URL("http://www.example.com:80") + assert httpx.URL("http://www.example.com:80").port is None + """ + return self._uri_reference.port + + @property + def netloc(self) -> bytes: + """ + Either `` or `:` as bytes. + Always normalized to lowercase, and IDNA encoded. + + This property may be used for generating the value of a request + "Host" header. + """ + return self._uri_reference.netloc.encode("ascii") + + @property + def path(self) -> str: + """ + The URL path as a string. Excluding the query string, and URL decoded. + + For example: + + url = httpx.URL("https://example.com/pa%20th") + assert url.path == "/pa th" + """ + path = self._uri_reference.path or "/" + return unquote(path) + + @property + def query(self) -> bytes: + """ + The URL query string, as raw bytes, excluding the leading b"?". + + This is necessarily a bytewise interface, because we cannot + perform URL decoding of this representation until we've parsed + the keys and values into a QueryParams instance. + + For example: + + url = httpx.URL("https://example.com/?filter=some%20search%20terms") + assert url.query == b"filter=some%20search%20terms" + """ + query = self._uri_reference.query or "" + return query.encode("ascii") + + @property + def params(self) -> QueryParams: + """ + The URL query parameters, neatly parsed and packaged into an immutable + multidict representation. + """ + return QueryParams(self._uri_reference.query) + + @property + def raw_path(self) -> bytes: + """ + The complete URL path and query string as raw bytes. + Used as the target when constructing HTTP requests. + + For example: + + GET /users?search=some%20text HTTP/1.1 + Host: www.example.org + Connection: close + """ + path = self._uri_reference.path or "/" + if self._uri_reference.query is not None: + path += "?" + self._uri_reference.query + return path.encode("ascii") + + @property + def fragment(self) -> str: + """ + The URL fragments, as used in HTML anchors. + As a string, without the leading '#'. + """ + return unquote(self._uri_reference.fragment or "") + + @property + def is_absolute_url(self) -> bool: + """ + Return `True` for absolute URLs such as 'http://example.com/path', + and `False` for relative URLs such as '/path'. + """ + # We don't use `.is_absolute` from `rfc3986` because it treats + # URLs with a fragment portion as not absolute. + # What we actually care about is if the URL provides + # a scheme and hostname to which connections should be made. + return bool(self._uri_reference.scheme and self._uri_reference.host) + + @property + def is_relative_url(self) -> bool: + """ + Return `False` for absolute URLs such as 'http://example.com/path', + and `True` for relative URLs such as '/path'. + """ + return not self.is_absolute_url + + def copy_with(self, **kwargs: typing.Any) -> URL: + """ + Copy this URL, returning a new URL with some components altered. + Accepts the same set of parameters as the components that are made + available via properties on the `URL` class. + + For example: + + url = httpx.URL("https://www.example.com").copy_with( + username="jo@gmail.com", password="a secret" + ) + assert url == "https://jo%40email.com:a%20secret@www.example.com" + """ + return URL(self, **kwargs) + + def copy_set_param(self, key: str, value: typing.Any = None) -> URL: + return self.copy_with(params=self.params.set(key, value)) + + def copy_add_param(self, key: str, value: typing.Any = None) -> URL: + return self.copy_with(params=self.params.add(key, value)) + + def copy_remove_param(self, key: str) -> URL: + return self.copy_with(params=self.params.remove(key)) + + def copy_merge_params(self, params: QueryParamTypes) -> URL: + return self.copy_with(params=self.params.merge(params)) + + def join(self, url: URL | str) -> URL: + """ + Return an absolute URL, using this URL as the base. + + Eg. + + url = httpx.URL("https://www.example.com/test") + url = url.join("/new/path") + assert url == "https://www.example.com/new/path" + """ + from urllib.parse import urljoin + + return URL(urljoin(str(self), str(URL(url)))) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, (URL, str)) and str(self) == str(URL(other)) + + def __str__(self) -> str: + return str(self._uri_reference) + + def __repr__(self) -> str: + scheme, userinfo, host, port, path, query, fragment = self._uri_reference + + if ":" in userinfo: + # Mask any password component. + userinfo = f'{userinfo.split(":")[0]}:[secure]' + + authority = "".join( + [ + f"{userinfo}@" if userinfo else "", + f"[{host}]" if ":" in host else host, + f":{port}" if port is not None else "", + ] + ) + url = "".join( + [ + f"{self.scheme}:" if scheme else "", + f"//{authority}" if authority else "", + path, + f"?{query}" if query is not None else "", + f"#{fragment}" if fragment is not None else "", + ] + ) + + return f"{self.__class__.__name__}({url!r})" + + @property + def raw(self) -> tuple[bytes, bytes, int, bytes]: # pragma: nocover + import collections + import warnings + + warnings.warn("URL.raw is deprecated.") + RawURL = collections.namedtuple( + "RawURL", ["raw_scheme", "raw_host", "port", "raw_path"] + ) + return RawURL( + raw_scheme=self.raw_scheme, + raw_host=self.raw_host, + port=self.port, + raw_path=self.raw_path, + ) + + +class QueryParams(typing.Mapping[str, str]): + """ + URL query parameters, as a multi-dict. + """ + + def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None: + assert len(args) < 2, "Too many arguments." + assert not (args and kwargs), "Cannot mix named and unnamed arguments." + + value = args[0] if args else kwargs + + if value is None or isinstance(value, (str, bytes)): + value = value.decode("ascii") if isinstance(value, bytes) else value + self._dict = parse_qs(value, keep_blank_values=True) + elif isinstance(value, QueryParams): + self._dict = {k: list(v) for k, v in value._dict.items()} + else: + dict_value: dict[typing.Any, list[typing.Any]] = {} + if isinstance(value, (list, tuple)): + # Convert list inputs like: + # [("a", "123"), ("a", "456"), ("b", "789")] + # To a dict representation, like: + # {"a": ["123", "456"], "b": ["789"]} + for item in value: + dict_value.setdefault(item[0], []).append(item[1]) + else: + # Convert dict inputs like: + # {"a": "123", "b": ["456", "789"]} + # To dict inputs where values are always lists, like: + # {"a": ["123"], "b": ["456", "789"]} + dict_value = { + k: list(v) if isinstance(v, (list, tuple)) else [v] + for k, v in value.items() + } + + # Ensure that keys and values are neatly coerced to strings. + # We coerce values `True` and `False` to JSON-like "true" and "false" + # representations, and coerce `None` values to the empty string. + self._dict = { + str(k): [primitive_value_to_str(item) for item in v] + for k, v in dict_value.items() + } + + def keys(self) -> typing.KeysView[str]: + """ + Return all the keys in the query params. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.keys()) == ["a", "b"] + """ + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + """ + Return all the values in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.values()) == ["123", "789"] + """ + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return all items in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.items()) == [("a", "123"), ("b", "789")] + """ + return {k: v[0] for k, v in self._dict.items()}.items() + + def multi_items(self) -> list[tuple[str, str]]: + """ + Return all items in the query params. Allow duplicate keys to occur. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")] + """ + multi_items: list[tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get("a") == "123" + """ + if key in self._dict: + return self._dict[str(key)][0] + return default + + def get_list(self, key: str) -> list[str]: + """ + Get all values from the query param for a given key. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get_list("a") == ["123", "456"] + """ + return list(self._dict.get(str(key), [])) + + def set(self, key: str, value: typing.Any = None) -> QueryParams: + """ + Return a new QueryParams instance, setting the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.set("a", "456") + assert q == httpx.QueryParams("a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[str(key)] = [primitive_value_to_str(value)] + return q + + def add(self, key: str, value: typing.Any = None) -> QueryParams: + """ + Return a new QueryParams instance, setting or appending the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.add("a", "456") + assert q == httpx.QueryParams("a=123&a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)] + return q + + def remove(self, key: str) -> QueryParams: + """ + Return a new QueryParams instance, removing the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.remove("a") + assert q == httpx.QueryParams("") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict.pop(str(key), None) + return q + + def merge(self, params: QueryParamTypes | None = None) -> QueryParams: + """ + Return a new QueryParams instance, updated with. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.merge({"b": "456"}) + assert q == httpx.QueryParams("a=123&b=456") + + q = httpx.QueryParams("a=123") + q = q.merge({"a": "456", "b": "789"}) + assert q == httpx.QueryParams("a=456&b=789") + """ + q = QueryParams(params) + q._dict = {**self._dict, **q._dict} + return q + + def __getitem__(self, key: typing.Any) -> str: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self.multi_items()) == sorted(other.multi_items()) + + def __str__(self) -> str: + return urlencode(self.multi_items()) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + query_string = str(self) + return f"{class_name}({query_string!r})" + + def update(self, params: QueryParamTypes | None = None) -> None: + raise RuntimeError( + "QueryParams are immutable since 0.18.0. " + "Use `q = q.merge(...)` to create an updated copy." + ) + + def __setitem__(self, key: str, value: str) -> None: + raise RuntimeError( + "QueryParams are immutable since 0.18.0. " + "Use `q = q.set(key, value)` to create an updated copy." + ) diff --git a/venv/lib/python3.10/site-packages/httpx/_utils.py b/venv/lib/python3.10/site-packages/httpx/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe827da4d071b32ea6da44328629699d6fc88ce --- /dev/null +++ b/venv/lib/python3.10/site-packages/httpx/_utils.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import ipaddress +import os +import re +import typing +from urllib.request import getproxies + +from ._types import PrimitiveData + +if typing.TYPE_CHECKING: # pragma: no cover + from ._urls import URL + + +def primitive_value_to_str(value: PrimitiveData) -> str: + """ + Coerce a primitive data type into a string value. + + Note that we prefer JSON-style 'true'/'false' for boolean values here. + """ + if value is True: + return "true" + elif value is False: + return "false" + elif value is None: + return "" + return str(value) + + +def get_environment_proxies() -> dict[str, str | None]: + """Gets proxy information from the environment""" + + # urllib.request.getproxies() falls back on System + # Registry and Config for proxies on Windows and macOS. + # We don't want to propagate non-HTTP proxies into + # our configuration such as 'TRAVIS_APT_PROXY'. + proxy_info = getproxies() + mounts: dict[str, str | None] = {} + + for scheme in ("http", "https", "all"): + if proxy_info.get(scheme): + hostname = proxy_info[scheme] + mounts[f"{scheme}://"] = ( + hostname if "://" in hostname else f"http://{hostname}" + ) + + no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")] + for hostname in no_proxy_hosts: + # See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details + # on how names in `NO_PROXY` are handled. + if hostname == "*": + # If NO_PROXY=* is used or if "*" occurs as any one of the comma + # separated hostnames, then we should just bypass any information + # from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore + # proxies. + return {} + elif hostname: + # NO_PROXY=.google.com is marked as "all://*.google.com, + # which disables "www.google.com" but not "google.com" + # NO_PROXY=google.com is marked as "all://*google.com, + # which disables "www.google.com" and "google.com". + # (But not "wwwgoogle.com") + # NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost" + # NO_PROXY=example.com,::1,localhost,192.168.0.0/16 + if "://" in hostname: + mounts[hostname] = None + elif is_ipv4_hostname(hostname): + mounts[f"all://{hostname}"] = None + elif is_ipv6_hostname(hostname): + mounts[f"all://[{hostname}]"] = None + elif hostname.lower() == "localhost": + mounts[f"all://{hostname}"] = None + else: + mounts[f"all://*{hostname}"] = None + + return mounts + + +def to_bytes(value: str | bytes, encoding: str = "utf-8") -> bytes: + return value.encode(encoding) if isinstance(value, str) else value + + +def to_str(value: str | bytes, encoding: str = "utf-8") -> str: + return value if isinstance(value, str) else value.decode(encoding) + + +def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: + return value if isinstance(match_type_of, str) else value.encode() + + +def unquote(value: str) -> str: + return value[1:-1] if value[0] == value[-1] == '"' else value + + +def peek_filelike_length(stream: typing.Any) -> int | None: + """ + Given a file-like stream object, return its length in number of bytes + without reading it into memory. + """ + try: + # Is it an actual file? + fd = stream.fileno() + # Yup, seems to be an actual file. + length = os.fstat(fd).st_size + except (AttributeError, OSError): + # No... Maybe it's something that supports random access, like `io.BytesIO`? + try: + # Assuming so, go to end of stream to figure out its length, + # then put it back in place. + offset = stream.tell() + length = stream.seek(0, os.SEEK_END) + stream.seek(offset) + except (AttributeError, OSError): + # Not even that? Sorry, we're doomed... + return None + + return length + + +class URLPattern: + """ + A utility class currently used for making lookups against proxy keys... + + # Wildcard matching... + >>> pattern = URLPattern("all://") + >>> pattern.matches(httpx.URL("http://example.com")) + True + + # Witch scheme matching... + >>> pattern = URLPattern("https://") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + False + + # With domain matching... + >>> pattern = URLPattern("https://example.com") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + False + >>> pattern.matches(httpx.URL("https://other.com")) + False + + # Wildcard scheme, with domain matching... + >>> pattern = URLPattern("all://example.com") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + True + >>> pattern.matches(httpx.URL("https://other.com")) + False + + # With port matching... + >>> pattern = URLPattern("https://example.com:1234") + >>> pattern.matches(httpx.URL("https://example.com:1234")) + True + >>> pattern.matches(httpx.URL("https://example.com")) + False + """ + + def __init__(self, pattern: str) -> None: + from ._urls import URL + + if pattern and ":" not in pattern: + raise ValueError( + f"Proxy keys should use proper URL forms rather " + f"than plain scheme strings. " + f'Instead of "{pattern}", use "{pattern}://"' + ) + + url = URL(pattern) + self.pattern = pattern + self.scheme = "" if url.scheme == "all" else url.scheme + self.host = "" if url.host == "*" else url.host + self.port = url.port + if not url.host or url.host == "*": + self.host_regex: typing.Pattern[str] | None = None + elif url.host.startswith("*."): + # *.example.com should match "www.example.com", but not "example.com" + domain = re.escape(url.host[2:]) + self.host_regex = re.compile(f"^.+\\.{domain}$") + elif url.host.startswith("*"): + # *example.com should match "www.example.com" and "example.com" + domain = re.escape(url.host[1:]) + self.host_regex = re.compile(f"^(.+\\.)?{domain}$") + else: + # example.com should match "example.com" but not "www.example.com" + domain = re.escape(url.host) + self.host_regex = re.compile(f"^{domain}$") + + def matches(self, other: URL) -> bool: + if self.scheme and self.scheme != other.scheme: + return False + if ( + self.host + and self.host_regex is not None + and not self.host_regex.match(other.host) + ): + return False + if self.port is not None and self.port != other.port: + return False + return True + + @property + def priority(self) -> tuple[int, int, int]: + """ + The priority allows URLPattern instances to be sortable, so that + we can match from most specific to least specific. + """ + # URLs with a port should take priority over URLs without a port. + port_priority = 0 if self.port is not None else 1 + # Longer hostnames should match first. + host_priority = -len(self.host) + # Longer schemes should match first. + scheme_priority = -len(self.scheme) + return (port_priority, host_priority, scheme_priority) + + def __hash__(self) -> int: + return hash(self.pattern) + + def __lt__(self, other: URLPattern) -> bool: + return self.priority < other.priority + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, URLPattern) and self.pattern == other.pattern + + +def is_ipv4_hostname(hostname: str) -> bool: + try: + ipaddress.IPv4Address(hostname.split("/")[0]) + except Exception: + return False + return True + + +def is_ipv6_hostname(hostname: str) -> bool: + try: + ipaddress.IPv6Address(hostname.split("/")[0]) + except Exception: + return False + return True diff --git a/venv/lib/python3.10/site-packages/httpx/py.typed b/venv/lib/python3.10/site-packages/httpx/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__init__.py b/venv/lib/python3.10/site-packages/huggingface_hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4f04bddd58756c91dde5e0f4b2d331bc8ed6ca --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/__init__.py @@ -0,0 +1,1620 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# *********** +# `huggingface_hub` init has 2 modes: +# - Normal usage: +# If imported to use it, all modules and functions are lazy-loaded. This means +# they exist at top level in module but are imported only the first time they are +# used. This way, `from huggingface_hub import something` will import `something` +# quickly without the hassle of importing all the features from `huggingface_hub`. +# - Static check: +# If statically analyzed, all modules and functions are loaded normally. This way +# static typing check works properly as well as autocomplete in text editors and +# IDEs. +# +# The static model imports are done inside the `if TYPE_CHECKING:` statement at +# the bottom of this file. Since module/functions imports are duplicated, it is +# mandatory to make sure to add them twice when adding one. This is checked in the +# `make quality` command. +# +# To update the static imports, please run the following command and commit the changes. +# ``` +# # Use script +# python utils/check_static_imports.py --update-file +# +# # Or run style on codebase +# make style +# ``` +# +# *********** +# Lazy loader vendored from https://github.com/scientific-python/lazy_loader +import importlib +import os +import sys +from typing import TYPE_CHECKING + + +__version__ = "1.4.1" + +# Alphabetical order of definitions is ensured in tests +# WARNING: any comment added in this dictionary definition will be lost when +# re-generating the file ! +_SUBMOD_ATTRS = { + "_commit_scheduler": [ + "CommitScheduler", + ], + "_eval_results": [ + "EvalResultEntry", + "eval_result_entries_to_yaml", + "parse_eval_result_entries", + ], + "_inference_endpoints": [ + "InferenceEndpoint", + "InferenceEndpointError", + "InferenceEndpointStatus", + "InferenceEndpointTimeoutError", + "InferenceEndpointType", + ], + "_jobs_api": [ + "JobAccelerator", + "JobHardware", + "JobInfo", + "JobOwner", + "JobStage", + "JobStatus", + ], + "_login": [ + "auth_list", + "auth_switch", + "interpreter_login", + "login", + "logout", + "notebook_login", + ], + "_oauth": [ + "OAuthInfo", + "OAuthOrgInfo", + "OAuthUserInfo", + "attach_huggingface_oauth", + "parse_huggingface_oauth", + ], + "_snapshot_download": [ + "snapshot_download", + ], + "_space_api": [ + "SpaceHardware", + "SpaceRuntime", + "SpaceStage", + "SpaceStorage", + "SpaceVariable", + ], + "_tensorboard_logger": [ + "HFSummaryWriter", + ], + "_webhooks_payload": [ + "WebhookPayload", + "WebhookPayloadComment", + "WebhookPayloadDiscussion", + "WebhookPayloadDiscussionChanges", + "WebhookPayloadEvent", + "WebhookPayloadMovedTo", + "WebhookPayloadRepo", + "WebhookPayloadUrl", + "WebhookPayloadWebhook", + ], + "_webhooks_server": [ + "WebhooksServer", + "webhook_endpoint", + ], + "cli._cli_utils": [ + "check_cli_update", + "typer_factory", + ], + "community": [ + "Discussion", + "DiscussionComment", + "DiscussionCommit", + "DiscussionEvent", + "DiscussionStatusChange", + "DiscussionTitleChange", + "DiscussionWithDetails", + ], + "constants": [ + "CONFIG_NAME", + "FLAX_WEIGHTS_NAME", + "HUGGINGFACE_CO_URL_HOME", + "HUGGINGFACE_CO_URL_TEMPLATE", + "PYTORCH_WEIGHTS_NAME", + "REPO_TYPE_DATASET", + "REPO_TYPE_MODEL", + "REPO_TYPE_SPACE", + "TF2_WEIGHTS_NAME", + "TF_WEIGHTS_NAME", + "is_offline_mode", + ], + "fastai_utils": [ + "_save_pretrained_fastai", + "from_pretrained_fastai", + "push_to_hub_fastai", + ], + "file_download": [ + "DryRunFileInfo", + "HfFileMetadata", + "_CACHED_NO_EXIST", + "get_hf_file_metadata", + "hf_hub_download", + "hf_hub_url", + "try_to_load_from_cache", + ], + "hf_api": [ + "Collection", + "CollectionItem", + "CommitInfo", + "CommitOperation", + "CommitOperationAdd", + "CommitOperationCopy", + "CommitOperationDelete", + "DatasetInfo", + "GitCommitInfo", + "GitRefInfo", + "GitRefs", + "HfApi", + "ModelInfo", + "Organization", + "RepoFile", + "RepoFolder", + "RepoUrl", + "SpaceInfo", + "User", + "UserLikes", + "WebhookInfo", + "WebhookWatchedItem", + "accept_access_request", + "add_collection_item", + "add_space_secret", + "add_space_variable", + "auth_check", + "cancel_access_request", + "cancel_job", + "change_discussion_status", + "comment_discussion", + "create_branch", + "create_collection", + "create_commit", + "create_discussion", + "create_inference_endpoint", + "create_inference_endpoint_from_catalog", + "create_pull_request", + "create_repo", + "create_scheduled_job", + "create_scheduled_uv_job", + "create_tag", + "create_webhook", + "dataset_info", + "delete_branch", + "delete_collection", + "delete_collection_item", + "delete_file", + "delete_folder", + "delete_inference_endpoint", + "delete_repo", + "delete_scheduled_job", + "delete_space_secret", + "delete_space_storage", + "delete_space_variable", + "delete_tag", + "delete_webhook", + "disable_webhook", + "duplicate_space", + "edit_discussion_comment", + "enable_webhook", + "fetch_job_logs", + "fetch_job_metrics", + "file_exists", + "get_collection", + "get_dataset_tags", + "get_discussion_details", + "get_full_repo_name", + "get_inference_endpoint", + "get_local_safetensors_metadata", + "get_model_tags", + "get_organization_overview", + "get_paths_info", + "get_repo_discussions", + "get_safetensors_metadata", + "get_space_runtime", + "get_space_variables", + "get_user_overview", + "get_webhook", + "grant_access", + "inspect_job", + "inspect_scheduled_job", + "list_accepted_access_requests", + "list_collections", + "list_daily_papers", + "list_datasets", + "list_inference_catalog", + "list_inference_endpoints", + "list_jobs", + "list_jobs_hardware", + "list_lfs_files", + "list_liked_repos", + "list_models", + "list_organization_followers", + "list_organization_members", + "list_papers", + "list_pending_access_requests", + "list_rejected_access_requests", + "list_repo_commits", + "list_repo_files", + "list_repo_likers", + "list_repo_refs", + "list_repo_tree", + "list_spaces", + "list_user_followers", + "list_user_following", + "list_webhooks", + "merge_pull_request", + "model_info", + "move_repo", + "paper_info", + "parse_local_safetensors_file_metadata", + "parse_safetensors_file_metadata", + "pause_inference_endpoint", + "pause_space", + "permanently_delete_lfs_files", + "preupload_lfs_files", + "reject_access_request", + "rename_discussion", + "repo_exists", + "repo_info", + "repo_type_and_id_from_hf_id", + "request_space_hardware", + "request_space_storage", + "restart_space", + "resume_inference_endpoint", + "resume_scheduled_job", + "revision_exists", + "run_as_future", + "run_job", + "run_uv_job", + "scale_to_zero_inference_endpoint", + "set_space_sleep_time", + "space_info", + "super_squash_history", + "suspend_scheduled_job", + "unlike", + "update_collection_item", + "update_collection_metadata", + "update_inference_endpoint", + "update_repo_settings", + "update_webhook", + "upload_file", + "upload_folder", + "upload_large_folder", + "verify_repo_checksums", + "whoami", + ], + "hf_file_system": [ + "HfFileSystem", + "HfFileSystemFile", + "HfFileSystemResolvedPath", + "HfFileSystemStreamFile", + "hffs", + ], + "hub_mixin": [ + "ModelHubMixin", + "PyTorchModelHubMixin", + ], + "inference._client": [ + "InferenceClient", + "InferenceTimeoutError", + ], + "inference._generated._async_client": [ + "AsyncInferenceClient", + ], + "inference._generated.types": [ + "AudioClassificationInput", + "AudioClassificationOutputElement", + "AudioClassificationOutputTransform", + "AudioClassificationParameters", + "AudioToAudioInput", + "AudioToAudioOutputElement", + "AutomaticSpeechRecognitionEarlyStoppingEnum", + "AutomaticSpeechRecognitionGenerationParameters", + "AutomaticSpeechRecognitionInput", + "AutomaticSpeechRecognitionOutput", + "AutomaticSpeechRecognitionOutputChunk", + "AutomaticSpeechRecognitionParameters", + "ChatCompletionInput", + "ChatCompletionInputFunctionDefinition", + "ChatCompletionInputFunctionName", + "ChatCompletionInputGrammarType", + "ChatCompletionInputJSONSchema", + "ChatCompletionInputMessage", + "ChatCompletionInputMessageChunk", + "ChatCompletionInputMessageChunkType", + "ChatCompletionInputResponseFormatJSONObject", + "ChatCompletionInputResponseFormatJSONSchema", + "ChatCompletionInputResponseFormatText", + "ChatCompletionInputStreamOptions", + "ChatCompletionInputTool", + "ChatCompletionInputToolCall", + "ChatCompletionInputToolChoiceClass", + "ChatCompletionInputToolChoiceEnum", + "ChatCompletionInputURL", + "ChatCompletionOutput", + "ChatCompletionOutputComplete", + "ChatCompletionOutputFunctionDefinition", + "ChatCompletionOutputLogprob", + "ChatCompletionOutputLogprobs", + "ChatCompletionOutputMessage", + "ChatCompletionOutputToolCall", + "ChatCompletionOutputTopLogprob", + "ChatCompletionOutputUsage", + "ChatCompletionStreamOutput", + "ChatCompletionStreamOutputChoice", + "ChatCompletionStreamOutputDelta", + "ChatCompletionStreamOutputDeltaToolCall", + "ChatCompletionStreamOutputFunction", + "ChatCompletionStreamOutputLogprob", + "ChatCompletionStreamOutputLogprobs", + "ChatCompletionStreamOutputTopLogprob", + "ChatCompletionStreamOutputUsage", + "DepthEstimationInput", + "DepthEstimationOutput", + "DocumentQuestionAnsweringInput", + "DocumentQuestionAnsweringInputData", + "DocumentQuestionAnsweringOutputElement", + "DocumentQuestionAnsweringParameters", + "FeatureExtractionInput", + "FeatureExtractionInputTruncationDirection", + "FillMaskInput", + "FillMaskOutputElement", + "FillMaskParameters", + "ImageClassificationInput", + "ImageClassificationOutputElement", + "ImageClassificationOutputTransform", + "ImageClassificationParameters", + "ImageSegmentationInput", + "ImageSegmentationOutputElement", + "ImageSegmentationParameters", + "ImageSegmentationSubtask", + "ImageTextToImageInput", + "ImageTextToImageOutput", + "ImageTextToImageParameters", + "ImageTextToImageTargetSize", + "ImageTextToVideoInput", + "ImageTextToVideoOutput", + "ImageTextToVideoParameters", + "ImageTextToVideoTargetSize", + "ImageToImageInput", + "ImageToImageOutput", + "ImageToImageParameters", + "ImageToImageTargetSize", + "ImageToTextEarlyStoppingEnum", + "ImageToTextGenerationParameters", + "ImageToTextInput", + "ImageToTextOutput", + "ImageToTextParameters", + "ImageToVideoInput", + "ImageToVideoOutput", + "ImageToVideoParameters", + "ImageToVideoTargetSize", + "ObjectDetectionBoundingBox", + "ObjectDetectionInput", + "ObjectDetectionOutputElement", + "ObjectDetectionParameters", + "Padding", + "QuestionAnsweringInput", + "QuestionAnsweringInputData", + "QuestionAnsweringOutputElement", + "QuestionAnsweringParameters", + "SentenceSimilarityInput", + "SentenceSimilarityInputData", + "SummarizationInput", + "SummarizationOutput", + "SummarizationParameters", + "SummarizationTruncationStrategy", + "TableQuestionAnsweringInput", + "TableQuestionAnsweringInputData", + "TableQuestionAnsweringOutputElement", + "TableQuestionAnsweringParameters", + "Text2TextGenerationInput", + "Text2TextGenerationOutput", + "Text2TextGenerationParameters", + "Text2TextGenerationTruncationStrategy", + "TextClassificationInput", + "TextClassificationOutputElement", + "TextClassificationOutputTransform", + "TextClassificationParameters", + "TextGenerationInput", + "TextGenerationInputGenerateParameters", + "TextGenerationInputGrammarType", + "TextGenerationOutput", + "TextGenerationOutputBestOfSequence", + "TextGenerationOutputDetails", + "TextGenerationOutputFinishReason", + "TextGenerationOutputPrefillToken", + "TextGenerationOutputToken", + "TextGenerationStreamOutput", + "TextGenerationStreamOutputStreamDetails", + "TextGenerationStreamOutputToken", + "TextToAudioEarlyStoppingEnum", + "TextToAudioGenerationParameters", + "TextToAudioInput", + "TextToAudioOutput", + "TextToAudioParameters", + "TextToImageInput", + "TextToImageOutput", + "TextToImageParameters", + "TextToSpeechEarlyStoppingEnum", + "TextToSpeechGenerationParameters", + "TextToSpeechInput", + "TextToSpeechOutput", + "TextToSpeechParameters", + "TextToVideoInput", + "TextToVideoOutput", + "TextToVideoParameters", + "TokenClassificationAggregationStrategy", + "TokenClassificationInput", + "TokenClassificationOutputElement", + "TokenClassificationParameters", + "TranslationInput", + "TranslationOutput", + "TranslationParameters", + "TranslationTruncationStrategy", + "TypeEnum", + "VideoClassificationInput", + "VideoClassificationOutputElement", + "VideoClassificationOutputTransform", + "VideoClassificationParameters", + "VisualQuestionAnsweringInput", + "VisualQuestionAnsweringInputData", + "VisualQuestionAnsweringOutputElement", + "VisualQuestionAnsweringParameters", + "ZeroShotClassificationInput", + "ZeroShotClassificationOutputElement", + "ZeroShotClassificationParameters", + "ZeroShotImageClassificationInput", + "ZeroShotImageClassificationOutputElement", + "ZeroShotImageClassificationParameters", + "ZeroShotObjectDetectionBoundingBox", + "ZeroShotObjectDetectionInput", + "ZeroShotObjectDetectionOutputElement", + "ZeroShotObjectDetectionParameters", + ], + "inference._mcp.agent": [ + "Agent", + ], + "inference._mcp.mcp_client": [ + "MCPClient", + ], + "repocard": [ + "DatasetCard", + "ModelCard", + "RepoCard", + "SpaceCard", + "metadata_eval_result", + "metadata_load", + "metadata_save", + "metadata_update", + ], + "repocard_data": [ + "CardData", + "DatasetCardData", + "EvalResult", + "ModelCardData", + "SpaceCardData", + ], + "serialization": [ + "StateDictSplit", + "get_torch_storage_id", + "get_torch_storage_size", + "load_state_dict_from_file", + "load_torch_model", + "save_torch_model", + "save_torch_state_dict", + "split_state_dict_into_shards_factory", + "split_torch_state_dict_into_shards", + ], + "serialization._dduf": [ + "DDUFEntry", + "export_entries_as_dduf", + "export_folder_as_dduf", + "read_dduf_file", + ], + "utils": [ + "ASYNC_CLIENT_FACTORY_T", + "CLIENT_FACTORY_T", + "CacheNotFound", + "CachedFileInfo", + "CachedRepoInfo", + "CachedRevisionInfo", + "CorruptedCacheException", + "DeleteCacheStrategy", + "HFCacheInfo", + "cached_assets_path", + "close_session", + "dump_environment_info", + "get_async_session", + "get_session", + "get_token", + "hf_raise_for_status", + "logging", + "scan_cache_dir", + "set_async_client_factory", + "set_client_factory", + ], +} + +# WARNING: __all__ is generated automatically, Any manual edit will be lost when re-generating this file ! +# +# To update the static imports, please run the following command and commit the changes. +# ``` +# # Use script +# python utils/check_all_variable.py --update +# +# # Or run style on codebase +# make style +# ``` + +__all__ = [ + "ASYNC_CLIENT_FACTORY_T", + "Agent", + "AsyncInferenceClient", + "AudioClassificationInput", + "AudioClassificationOutputElement", + "AudioClassificationOutputTransform", + "AudioClassificationParameters", + "AudioToAudioInput", + "AudioToAudioOutputElement", + "AutomaticSpeechRecognitionEarlyStoppingEnum", + "AutomaticSpeechRecognitionGenerationParameters", + "AutomaticSpeechRecognitionInput", + "AutomaticSpeechRecognitionOutput", + "AutomaticSpeechRecognitionOutputChunk", + "AutomaticSpeechRecognitionParameters", + "CLIENT_FACTORY_T", + "CONFIG_NAME", + "CacheNotFound", + "CachedFileInfo", + "CachedRepoInfo", + "CachedRevisionInfo", + "CardData", + "ChatCompletionInput", + "ChatCompletionInputFunctionDefinition", + "ChatCompletionInputFunctionName", + "ChatCompletionInputGrammarType", + "ChatCompletionInputJSONSchema", + "ChatCompletionInputMessage", + "ChatCompletionInputMessageChunk", + "ChatCompletionInputMessageChunkType", + "ChatCompletionInputResponseFormatJSONObject", + "ChatCompletionInputResponseFormatJSONSchema", + "ChatCompletionInputResponseFormatText", + "ChatCompletionInputStreamOptions", + "ChatCompletionInputTool", + "ChatCompletionInputToolCall", + "ChatCompletionInputToolChoiceClass", + "ChatCompletionInputToolChoiceEnum", + "ChatCompletionInputURL", + "ChatCompletionOutput", + "ChatCompletionOutputComplete", + "ChatCompletionOutputFunctionDefinition", + "ChatCompletionOutputLogprob", + "ChatCompletionOutputLogprobs", + "ChatCompletionOutputMessage", + "ChatCompletionOutputToolCall", + "ChatCompletionOutputTopLogprob", + "ChatCompletionOutputUsage", + "ChatCompletionStreamOutput", + "ChatCompletionStreamOutputChoice", + "ChatCompletionStreamOutputDelta", + "ChatCompletionStreamOutputDeltaToolCall", + "ChatCompletionStreamOutputFunction", + "ChatCompletionStreamOutputLogprob", + "ChatCompletionStreamOutputLogprobs", + "ChatCompletionStreamOutputTopLogprob", + "ChatCompletionStreamOutputUsage", + "Collection", + "CollectionItem", + "CommitInfo", + "CommitOperation", + "CommitOperationAdd", + "CommitOperationCopy", + "CommitOperationDelete", + "CommitScheduler", + "CorruptedCacheException", + "DDUFEntry", + "DatasetCard", + "DatasetCardData", + "DatasetInfo", + "DeleteCacheStrategy", + "DepthEstimationInput", + "DepthEstimationOutput", + "Discussion", + "DiscussionComment", + "DiscussionCommit", + "DiscussionEvent", + "DiscussionStatusChange", + "DiscussionTitleChange", + "DiscussionWithDetails", + "DocumentQuestionAnsweringInput", + "DocumentQuestionAnsweringInputData", + "DocumentQuestionAnsweringOutputElement", + "DocumentQuestionAnsweringParameters", + "DryRunFileInfo", + "EvalResult", + "EvalResultEntry", + "FLAX_WEIGHTS_NAME", + "FeatureExtractionInput", + "FeatureExtractionInputTruncationDirection", + "FillMaskInput", + "FillMaskOutputElement", + "FillMaskParameters", + "GitCommitInfo", + "GitRefInfo", + "GitRefs", + "HFCacheInfo", + "HFSummaryWriter", + "HUGGINGFACE_CO_URL_HOME", + "HUGGINGFACE_CO_URL_TEMPLATE", + "HfApi", + "HfFileMetadata", + "HfFileSystem", + "HfFileSystemFile", + "HfFileSystemResolvedPath", + "HfFileSystemStreamFile", + "ImageClassificationInput", + "ImageClassificationOutputElement", + "ImageClassificationOutputTransform", + "ImageClassificationParameters", + "ImageSegmentationInput", + "ImageSegmentationOutputElement", + "ImageSegmentationParameters", + "ImageSegmentationSubtask", + "ImageTextToImageInput", + "ImageTextToImageOutput", + "ImageTextToImageParameters", + "ImageTextToImageTargetSize", + "ImageTextToVideoInput", + "ImageTextToVideoOutput", + "ImageTextToVideoParameters", + "ImageTextToVideoTargetSize", + "ImageToImageInput", + "ImageToImageOutput", + "ImageToImageParameters", + "ImageToImageTargetSize", + "ImageToTextEarlyStoppingEnum", + "ImageToTextGenerationParameters", + "ImageToTextInput", + "ImageToTextOutput", + "ImageToTextParameters", + "ImageToVideoInput", + "ImageToVideoOutput", + "ImageToVideoParameters", + "ImageToVideoTargetSize", + "InferenceClient", + "InferenceEndpoint", + "InferenceEndpointError", + "InferenceEndpointStatus", + "InferenceEndpointTimeoutError", + "InferenceEndpointType", + "InferenceTimeoutError", + "JobAccelerator", + "JobHardware", + "JobInfo", + "JobOwner", + "JobStage", + "JobStatus", + "MCPClient", + "ModelCard", + "ModelCardData", + "ModelHubMixin", + "ModelInfo", + "OAuthInfo", + "OAuthOrgInfo", + "OAuthUserInfo", + "ObjectDetectionBoundingBox", + "ObjectDetectionInput", + "ObjectDetectionOutputElement", + "ObjectDetectionParameters", + "Organization", + "PYTORCH_WEIGHTS_NAME", + "Padding", + "PyTorchModelHubMixin", + "QuestionAnsweringInput", + "QuestionAnsweringInputData", + "QuestionAnsweringOutputElement", + "QuestionAnsweringParameters", + "REPO_TYPE_DATASET", + "REPO_TYPE_MODEL", + "REPO_TYPE_SPACE", + "RepoCard", + "RepoFile", + "RepoFolder", + "RepoUrl", + "SentenceSimilarityInput", + "SentenceSimilarityInputData", + "SpaceCard", + "SpaceCardData", + "SpaceHardware", + "SpaceInfo", + "SpaceRuntime", + "SpaceStage", + "SpaceStorage", + "SpaceVariable", + "StateDictSplit", + "SummarizationInput", + "SummarizationOutput", + "SummarizationParameters", + "SummarizationTruncationStrategy", + "TF2_WEIGHTS_NAME", + "TF_WEIGHTS_NAME", + "TableQuestionAnsweringInput", + "TableQuestionAnsweringInputData", + "TableQuestionAnsweringOutputElement", + "TableQuestionAnsweringParameters", + "Text2TextGenerationInput", + "Text2TextGenerationOutput", + "Text2TextGenerationParameters", + "Text2TextGenerationTruncationStrategy", + "TextClassificationInput", + "TextClassificationOutputElement", + "TextClassificationOutputTransform", + "TextClassificationParameters", + "TextGenerationInput", + "TextGenerationInputGenerateParameters", + "TextGenerationInputGrammarType", + "TextGenerationOutput", + "TextGenerationOutputBestOfSequence", + "TextGenerationOutputDetails", + "TextGenerationOutputFinishReason", + "TextGenerationOutputPrefillToken", + "TextGenerationOutputToken", + "TextGenerationStreamOutput", + "TextGenerationStreamOutputStreamDetails", + "TextGenerationStreamOutputToken", + "TextToAudioEarlyStoppingEnum", + "TextToAudioGenerationParameters", + "TextToAudioInput", + "TextToAudioOutput", + "TextToAudioParameters", + "TextToImageInput", + "TextToImageOutput", + "TextToImageParameters", + "TextToSpeechEarlyStoppingEnum", + "TextToSpeechGenerationParameters", + "TextToSpeechInput", + "TextToSpeechOutput", + "TextToSpeechParameters", + "TextToVideoInput", + "TextToVideoOutput", + "TextToVideoParameters", + "TokenClassificationAggregationStrategy", + "TokenClassificationInput", + "TokenClassificationOutputElement", + "TokenClassificationParameters", + "TranslationInput", + "TranslationOutput", + "TranslationParameters", + "TranslationTruncationStrategy", + "TypeEnum", + "User", + "UserLikes", + "VideoClassificationInput", + "VideoClassificationOutputElement", + "VideoClassificationOutputTransform", + "VideoClassificationParameters", + "VisualQuestionAnsweringInput", + "VisualQuestionAnsweringInputData", + "VisualQuestionAnsweringOutputElement", + "VisualQuestionAnsweringParameters", + "WebhookInfo", + "WebhookPayload", + "WebhookPayloadComment", + "WebhookPayloadDiscussion", + "WebhookPayloadDiscussionChanges", + "WebhookPayloadEvent", + "WebhookPayloadMovedTo", + "WebhookPayloadRepo", + "WebhookPayloadUrl", + "WebhookPayloadWebhook", + "WebhookWatchedItem", + "WebhooksServer", + "ZeroShotClassificationInput", + "ZeroShotClassificationOutputElement", + "ZeroShotClassificationParameters", + "ZeroShotImageClassificationInput", + "ZeroShotImageClassificationOutputElement", + "ZeroShotImageClassificationParameters", + "ZeroShotObjectDetectionBoundingBox", + "ZeroShotObjectDetectionInput", + "ZeroShotObjectDetectionOutputElement", + "ZeroShotObjectDetectionParameters", + "_CACHED_NO_EXIST", + "_save_pretrained_fastai", + "accept_access_request", + "add_collection_item", + "add_space_secret", + "add_space_variable", + "attach_huggingface_oauth", + "auth_check", + "auth_list", + "auth_switch", + "cached_assets_path", + "cancel_access_request", + "cancel_job", + "change_discussion_status", + "check_cli_update", + "close_session", + "comment_discussion", + "create_branch", + "create_collection", + "create_commit", + "create_discussion", + "create_inference_endpoint", + "create_inference_endpoint_from_catalog", + "create_pull_request", + "create_repo", + "create_scheduled_job", + "create_scheduled_uv_job", + "create_tag", + "create_webhook", + "dataset_info", + "delete_branch", + "delete_collection", + "delete_collection_item", + "delete_file", + "delete_folder", + "delete_inference_endpoint", + "delete_repo", + "delete_scheduled_job", + "delete_space_secret", + "delete_space_storage", + "delete_space_variable", + "delete_tag", + "delete_webhook", + "disable_webhook", + "dump_environment_info", + "duplicate_space", + "edit_discussion_comment", + "enable_webhook", + "eval_result_entries_to_yaml", + "export_entries_as_dduf", + "export_folder_as_dduf", + "fetch_job_logs", + "fetch_job_metrics", + "file_exists", + "from_pretrained_fastai", + "get_async_session", + "get_collection", + "get_dataset_tags", + "get_discussion_details", + "get_full_repo_name", + "get_hf_file_metadata", + "get_inference_endpoint", + "get_local_safetensors_metadata", + "get_model_tags", + "get_organization_overview", + "get_paths_info", + "get_repo_discussions", + "get_safetensors_metadata", + "get_session", + "get_space_runtime", + "get_space_variables", + "get_token", + "get_torch_storage_id", + "get_torch_storage_size", + "get_user_overview", + "get_webhook", + "grant_access", + "hf_hub_download", + "hf_hub_url", + "hf_raise_for_status", + "hffs", + "inspect_job", + "inspect_scheduled_job", + "interpreter_login", + "is_offline_mode", + "list_accepted_access_requests", + "list_collections", + "list_daily_papers", + "list_datasets", + "list_inference_catalog", + "list_inference_endpoints", + "list_jobs", + "list_jobs_hardware", + "list_lfs_files", + "list_liked_repos", + "list_models", + "list_organization_followers", + "list_organization_members", + "list_papers", + "list_pending_access_requests", + "list_rejected_access_requests", + "list_repo_commits", + "list_repo_files", + "list_repo_likers", + "list_repo_refs", + "list_repo_tree", + "list_spaces", + "list_user_followers", + "list_user_following", + "list_webhooks", + "load_state_dict_from_file", + "load_torch_model", + "logging", + "login", + "logout", + "merge_pull_request", + "metadata_eval_result", + "metadata_load", + "metadata_save", + "metadata_update", + "model_info", + "move_repo", + "notebook_login", + "paper_info", + "parse_eval_result_entries", + "parse_huggingface_oauth", + "parse_local_safetensors_file_metadata", + "parse_safetensors_file_metadata", + "pause_inference_endpoint", + "pause_space", + "permanently_delete_lfs_files", + "preupload_lfs_files", + "push_to_hub_fastai", + "read_dduf_file", + "reject_access_request", + "rename_discussion", + "repo_exists", + "repo_info", + "repo_type_and_id_from_hf_id", + "request_space_hardware", + "request_space_storage", + "restart_space", + "resume_inference_endpoint", + "resume_scheduled_job", + "revision_exists", + "run_as_future", + "run_job", + "run_uv_job", + "save_torch_model", + "save_torch_state_dict", + "scale_to_zero_inference_endpoint", + "scan_cache_dir", + "set_async_client_factory", + "set_client_factory", + "set_space_sleep_time", + "snapshot_download", + "space_info", + "split_state_dict_into_shards_factory", + "split_torch_state_dict_into_shards", + "super_squash_history", + "suspend_scheduled_job", + "try_to_load_from_cache", + "typer_factory", + "unlike", + "update_collection_item", + "update_collection_metadata", + "update_inference_endpoint", + "update_repo_settings", + "update_webhook", + "upload_file", + "upload_folder", + "upload_large_folder", + "verify_repo_checksums", + "webhook_endpoint", + "whoami", +] + + +def _attach(package_name, submodules=None, submod_attrs=None): + """Attach lazily loaded submodules, functions, or other attributes. + + Typically, modules import submodules and attributes as follows: + + ```py + import mysubmodule + import anothersubmodule + + from .foo import someattr + ``` + + The idea is to replace a package's `__getattr__`, `__dir__`, such that all imports + work exactly the way they would with normal imports, except that the import occurs + upon first use. + + The typical way to call this function, replacing the above imports, is: + + ```python + __getattr__, __dir__ = lazy.attach( + __name__, + ['mysubmodule', 'anothersubmodule'], + {'foo': ['someattr']} + ) + ``` + This functionality requires Python 3.7 or higher. + + Args: + package_name (`str`): + Typically use `__name__`. + submodules (`set`): + List of submodules to attach. + submod_attrs (`dict`): + Dictionary of submodule -> list of attributes / functions. + These attributes are imported as they are used. + + Returns: + __getattr__, __dir__, __all__ + + """ + if submod_attrs is None: + submod_attrs = {} + + if submodules is None: + submodules = set() + else: + submodules = set(submodules) + + attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs} + + def __getattr__(name): + if name in submodules: + try: + return importlib.import_module(f"{package_name}.{name}") + except Exception as e: + print(f"Error importing {package_name}.{name}: {e}") + raise + elif name in attr_to_modules: + submod_path = f"{package_name}.{attr_to_modules[name]}" + try: + submod = importlib.import_module(submod_path) + except Exception as e: + print(f"Error importing {submod_path}: {e}") + raise + attr = getattr(submod, name) + + # If the attribute lives in a file (module) with the same + # name as the attribute, ensure that the attribute and *not* + # the module is accessible on the package. + if name == attr_to_modules[name]: + pkg = sys.modules[package_name] + pkg.__dict__[name] = attr + + return attr + else: + raise AttributeError(f"No {package_name} attribute {name}") + + def __dir__(): + return __all__ + + return __getattr__, __dir__ + + +__getattr__, __dir__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS) + +if os.environ.get("EAGER_IMPORT", ""): + for attr in __all__: + __getattr__(attr) + +# WARNING: any content below this statement is generated automatically. Any manual edit +# will be lost when re-generating this file ! +# +# To update the static imports, please run the following command and commit the changes. +# ``` +# # Use script +# python utils/check_static_imports.py --update +# +# # Or run style on codebase +# make style +# ``` +if TYPE_CHECKING: # pragma: no cover + from ._commit_scheduler import CommitScheduler # noqa: F401 + from ._eval_results import ( + EvalResultEntry, # noqa: F401 + eval_result_entries_to_yaml, # noqa: F401 + parse_eval_result_entries, # noqa: F401 + ) + from ._inference_endpoints import ( + InferenceEndpoint, # noqa: F401 + InferenceEndpointError, # noqa: F401 + InferenceEndpointStatus, # noqa: F401 + InferenceEndpointTimeoutError, # noqa: F401 + InferenceEndpointType, # noqa: F401 + ) + from ._jobs_api import ( + JobAccelerator, # noqa: F401 + JobHardware, # noqa: F401 + JobInfo, # noqa: F401 + JobOwner, # noqa: F401 + JobStage, # noqa: F401 + JobStatus, # noqa: F401 + ) + from ._login import ( + auth_list, # noqa: F401 + auth_switch, # noqa: F401 + interpreter_login, # noqa: F401 + login, # noqa: F401 + logout, # noqa: F401 + notebook_login, # noqa: F401 + ) + from ._oauth import ( + OAuthInfo, # noqa: F401 + OAuthOrgInfo, # noqa: F401 + OAuthUserInfo, # noqa: F401 + attach_huggingface_oauth, # noqa: F401 + parse_huggingface_oauth, # noqa: F401 + ) + from ._snapshot_download import snapshot_download # noqa: F401 + from ._space_api import ( + SpaceHardware, # noqa: F401 + SpaceRuntime, # noqa: F401 + SpaceStage, # noqa: F401 + SpaceStorage, # noqa: F401 + SpaceVariable, # noqa: F401 + ) + from ._tensorboard_logger import HFSummaryWriter # noqa: F401 + from ._webhooks_payload import ( + WebhookPayload, # noqa: F401 + WebhookPayloadComment, # noqa: F401 + WebhookPayloadDiscussion, # noqa: F401 + WebhookPayloadDiscussionChanges, # noqa: F401 + WebhookPayloadEvent, # noqa: F401 + WebhookPayloadMovedTo, # noqa: F401 + WebhookPayloadRepo, # noqa: F401 + WebhookPayloadUrl, # noqa: F401 + WebhookPayloadWebhook, # noqa: F401 + ) + from ._webhooks_server import ( + WebhooksServer, # noqa: F401 + webhook_endpoint, # noqa: F401 + ) + from .cli._cli_utils import ( + check_cli_update, # noqa: F401 + typer_factory, # noqa: F401 + ) + from .community import ( + Discussion, # noqa: F401 + DiscussionComment, # noqa: F401 + DiscussionCommit, # noqa: F401 + DiscussionEvent, # noqa: F401 + DiscussionStatusChange, # noqa: F401 + DiscussionTitleChange, # noqa: F401 + DiscussionWithDetails, # noqa: F401 + ) + from .constants import ( + CONFIG_NAME, # noqa: F401 + FLAX_WEIGHTS_NAME, # noqa: F401 + HUGGINGFACE_CO_URL_HOME, # noqa: F401 + HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 + PYTORCH_WEIGHTS_NAME, # noqa: F401 + REPO_TYPE_DATASET, # noqa: F401 + REPO_TYPE_MODEL, # noqa: F401 + REPO_TYPE_SPACE, # noqa: F401 + TF2_WEIGHTS_NAME, # noqa: F401 + TF_WEIGHTS_NAME, # noqa: F401 + is_offline_mode, # noqa: F401 + ) + from .fastai_utils import ( + _save_pretrained_fastai, # noqa: F401 + from_pretrained_fastai, # noqa: F401 + push_to_hub_fastai, # noqa: F401 + ) + from .file_download import ( + _CACHED_NO_EXIST, # noqa: F401 + DryRunFileInfo, # noqa: F401 + HfFileMetadata, # noqa: F401 + get_hf_file_metadata, # noqa: F401 + hf_hub_download, # noqa: F401 + hf_hub_url, # noqa: F401 + try_to_load_from_cache, # noqa: F401 + ) + from .hf_api import ( + Collection, # noqa: F401 + CollectionItem, # noqa: F401 + CommitInfo, # noqa: F401 + CommitOperation, # noqa: F401 + CommitOperationAdd, # noqa: F401 + CommitOperationCopy, # noqa: F401 + CommitOperationDelete, # noqa: F401 + DatasetInfo, # noqa: F401 + GitCommitInfo, # noqa: F401 + GitRefInfo, # noqa: F401 + GitRefs, # noqa: F401 + HfApi, # noqa: F401 + ModelInfo, # noqa: F401 + Organization, # noqa: F401 + RepoFile, # noqa: F401 + RepoFolder, # noqa: F401 + RepoUrl, # noqa: F401 + SpaceInfo, # noqa: F401 + User, # noqa: F401 + UserLikes, # noqa: F401 + WebhookInfo, # noqa: F401 + WebhookWatchedItem, # noqa: F401 + accept_access_request, # noqa: F401 + add_collection_item, # noqa: F401 + add_space_secret, # noqa: F401 + add_space_variable, # noqa: F401 + auth_check, # noqa: F401 + cancel_access_request, # noqa: F401 + cancel_job, # noqa: F401 + change_discussion_status, # noqa: F401 + comment_discussion, # noqa: F401 + create_branch, # noqa: F401 + create_collection, # noqa: F401 + create_commit, # noqa: F401 + create_discussion, # noqa: F401 + create_inference_endpoint, # noqa: F401 + create_inference_endpoint_from_catalog, # noqa: F401 + create_pull_request, # noqa: F401 + create_repo, # noqa: F401 + create_scheduled_job, # noqa: F401 + create_scheduled_uv_job, # noqa: F401 + create_tag, # noqa: F401 + create_webhook, # noqa: F401 + dataset_info, # noqa: F401 + delete_branch, # noqa: F401 + delete_collection, # noqa: F401 + delete_collection_item, # noqa: F401 + delete_file, # noqa: F401 + delete_folder, # noqa: F401 + delete_inference_endpoint, # noqa: F401 + delete_repo, # noqa: F401 + delete_scheduled_job, # noqa: F401 + delete_space_secret, # noqa: F401 + delete_space_storage, # noqa: F401 + delete_space_variable, # noqa: F401 + delete_tag, # noqa: F401 + delete_webhook, # noqa: F401 + disable_webhook, # noqa: F401 + duplicate_space, # noqa: F401 + edit_discussion_comment, # noqa: F401 + enable_webhook, # noqa: F401 + fetch_job_logs, # noqa: F401 + fetch_job_metrics, # noqa: F401 + file_exists, # noqa: F401 + get_collection, # noqa: F401 + get_dataset_tags, # noqa: F401 + get_discussion_details, # noqa: F401 + get_full_repo_name, # noqa: F401 + get_inference_endpoint, # noqa: F401 + get_local_safetensors_metadata, # noqa: F401 + get_model_tags, # noqa: F401 + get_organization_overview, # noqa: F401 + get_paths_info, # noqa: F401 + get_repo_discussions, # noqa: F401 + get_safetensors_metadata, # noqa: F401 + get_space_runtime, # noqa: F401 + get_space_variables, # noqa: F401 + get_user_overview, # noqa: F401 + get_webhook, # noqa: F401 + grant_access, # noqa: F401 + inspect_job, # noqa: F401 + inspect_scheduled_job, # noqa: F401 + list_accepted_access_requests, # noqa: F401 + list_collections, # noqa: F401 + list_daily_papers, # noqa: F401 + list_datasets, # noqa: F401 + list_inference_catalog, # noqa: F401 + list_inference_endpoints, # noqa: F401 + list_jobs, # noqa: F401 + list_jobs_hardware, # noqa: F401 + list_lfs_files, # noqa: F401 + list_liked_repos, # noqa: F401 + list_models, # noqa: F401 + list_organization_followers, # noqa: F401 + list_organization_members, # noqa: F401 + list_papers, # noqa: F401 + list_pending_access_requests, # noqa: F401 + list_rejected_access_requests, # noqa: F401 + list_repo_commits, # noqa: F401 + list_repo_files, # noqa: F401 + list_repo_likers, # noqa: F401 + list_repo_refs, # noqa: F401 + list_repo_tree, # noqa: F401 + list_spaces, # noqa: F401 + list_user_followers, # noqa: F401 + list_user_following, # noqa: F401 + list_webhooks, # noqa: F401 + merge_pull_request, # noqa: F401 + model_info, # noqa: F401 + move_repo, # noqa: F401 + paper_info, # noqa: F401 + parse_local_safetensors_file_metadata, # noqa: F401 + parse_safetensors_file_metadata, # noqa: F401 + pause_inference_endpoint, # noqa: F401 + pause_space, # noqa: F401 + permanently_delete_lfs_files, # noqa: F401 + preupload_lfs_files, # noqa: F401 + reject_access_request, # noqa: F401 + rename_discussion, # noqa: F401 + repo_exists, # noqa: F401 + repo_info, # noqa: F401 + repo_type_and_id_from_hf_id, # noqa: F401 + request_space_hardware, # noqa: F401 + request_space_storage, # noqa: F401 + restart_space, # noqa: F401 + resume_inference_endpoint, # noqa: F401 + resume_scheduled_job, # noqa: F401 + revision_exists, # noqa: F401 + run_as_future, # noqa: F401 + run_job, # noqa: F401 + run_uv_job, # noqa: F401 + scale_to_zero_inference_endpoint, # noqa: F401 + set_space_sleep_time, # noqa: F401 + space_info, # noqa: F401 + super_squash_history, # noqa: F401 + suspend_scheduled_job, # noqa: F401 + unlike, # noqa: F401 + update_collection_item, # noqa: F401 + update_collection_metadata, # noqa: F401 + update_inference_endpoint, # noqa: F401 + update_repo_settings, # noqa: F401 + update_webhook, # noqa: F401 + upload_file, # noqa: F401 + upload_folder, # noqa: F401 + upload_large_folder, # noqa: F401 + verify_repo_checksums, # noqa: F401 + whoami, # noqa: F401 + ) + from .hf_file_system import ( + HfFileSystem, # noqa: F401 + HfFileSystemFile, # noqa: F401 + HfFileSystemResolvedPath, # noqa: F401 + HfFileSystemStreamFile, # noqa: F401 + hffs, # noqa: F401 + ) + from .hub_mixin import ( + ModelHubMixin, # noqa: F401 + PyTorchModelHubMixin, # noqa: F401 + ) + from .inference._client import ( + InferenceClient, # noqa: F401 + InferenceTimeoutError, # noqa: F401 + ) + from .inference._generated._async_client import AsyncInferenceClient # noqa: F401 + from .inference._generated.types import ( + AudioClassificationInput, # noqa: F401 + AudioClassificationOutputElement, # noqa: F401 + AudioClassificationOutputTransform, # noqa: F401 + AudioClassificationParameters, # noqa: F401 + AudioToAudioInput, # noqa: F401 + AudioToAudioOutputElement, # noqa: F401 + AutomaticSpeechRecognitionEarlyStoppingEnum, # noqa: F401 + AutomaticSpeechRecognitionGenerationParameters, # noqa: F401 + AutomaticSpeechRecognitionInput, # noqa: F401 + AutomaticSpeechRecognitionOutput, # noqa: F401 + AutomaticSpeechRecognitionOutputChunk, # noqa: F401 + AutomaticSpeechRecognitionParameters, # noqa: F401 + ChatCompletionInput, # noqa: F401 + ChatCompletionInputFunctionDefinition, # noqa: F401 + ChatCompletionInputFunctionName, # noqa: F401 + ChatCompletionInputGrammarType, # noqa: F401 + ChatCompletionInputJSONSchema, # noqa: F401 + ChatCompletionInputMessage, # noqa: F401 + ChatCompletionInputMessageChunk, # noqa: F401 + ChatCompletionInputMessageChunkType, # noqa: F401 + ChatCompletionInputResponseFormatJSONObject, # noqa: F401 + ChatCompletionInputResponseFormatJSONSchema, # noqa: F401 + ChatCompletionInputResponseFormatText, # noqa: F401 + ChatCompletionInputStreamOptions, # noqa: F401 + ChatCompletionInputTool, # noqa: F401 + ChatCompletionInputToolCall, # noqa: F401 + ChatCompletionInputToolChoiceClass, # noqa: F401 + ChatCompletionInputToolChoiceEnum, # noqa: F401 + ChatCompletionInputURL, # noqa: F401 + ChatCompletionOutput, # noqa: F401 + ChatCompletionOutputComplete, # noqa: F401 + ChatCompletionOutputFunctionDefinition, # noqa: F401 + ChatCompletionOutputLogprob, # noqa: F401 + ChatCompletionOutputLogprobs, # noqa: F401 + ChatCompletionOutputMessage, # noqa: F401 + ChatCompletionOutputToolCall, # noqa: F401 + ChatCompletionOutputTopLogprob, # noqa: F401 + ChatCompletionOutputUsage, # noqa: F401 + ChatCompletionStreamOutput, # noqa: F401 + ChatCompletionStreamOutputChoice, # noqa: F401 + ChatCompletionStreamOutputDelta, # noqa: F401 + ChatCompletionStreamOutputDeltaToolCall, # noqa: F401 + ChatCompletionStreamOutputFunction, # noqa: F401 + ChatCompletionStreamOutputLogprob, # noqa: F401 + ChatCompletionStreamOutputLogprobs, # noqa: F401 + ChatCompletionStreamOutputTopLogprob, # noqa: F401 + ChatCompletionStreamOutputUsage, # noqa: F401 + DepthEstimationInput, # noqa: F401 + DepthEstimationOutput, # noqa: F401 + DocumentQuestionAnsweringInput, # noqa: F401 + DocumentQuestionAnsweringInputData, # noqa: F401 + DocumentQuestionAnsweringOutputElement, # noqa: F401 + DocumentQuestionAnsweringParameters, # noqa: F401 + FeatureExtractionInput, # noqa: F401 + FeatureExtractionInputTruncationDirection, # noqa: F401 + FillMaskInput, # noqa: F401 + FillMaskOutputElement, # noqa: F401 + FillMaskParameters, # noqa: F401 + ImageClassificationInput, # noqa: F401 + ImageClassificationOutputElement, # noqa: F401 + ImageClassificationOutputTransform, # noqa: F401 + ImageClassificationParameters, # noqa: F401 + ImageSegmentationInput, # noqa: F401 + ImageSegmentationOutputElement, # noqa: F401 + ImageSegmentationParameters, # noqa: F401 + ImageSegmentationSubtask, # noqa: F401 + ImageTextToImageInput, # noqa: F401 + ImageTextToImageOutput, # noqa: F401 + ImageTextToImageParameters, # noqa: F401 + ImageTextToImageTargetSize, # noqa: F401 + ImageTextToVideoInput, # noqa: F401 + ImageTextToVideoOutput, # noqa: F401 + ImageTextToVideoParameters, # noqa: F401 + ImageTextToVideoTargetSize, # noqa: F401 + ImageToImageInput, # noqa: F401 + ImageToImageOutput, # noqa: F401 + ImageToImageParameters, # noqa: F401 + ImageToImageTargetSize, # noqa: F401 + ImageToTextEarlyStoppingEnum, # noqa: F401 + ImageToTextGenerationParameters, # noqa: F401 + ImageToTextInput, # noqa: F401 + ImageToTextOutput, # noqa: F401 + ImageToTextParameters, # noqa: F401 + ImageToVideoInput, # noqa: F401 + ImageToVideoOutput, # noqa: F401 + ImageToVideoParameters, # noqa: F401 + ImageToVideoTargetSize, # noqa: F401 + ObjectDetectionBoundingBox, # noqa: F401 + ObjectDetectionInput, # noqa: F401 + ObjectDetectionOutputElement, # noqa: F401 + ObjectDetectionParameters, # noqa: F401 + Padding, # noqa: F401 + QuestionAnsweringInput, # noqa: F401 + QuestionAnsweringInputData, # noqa: F401 + QuestionAnsweringOutputElement, # noqa: F401 + QuestionAnsweringParameters, # noqa: F401 + SentenceSimilarityInput, # noqa: F401 + SentenceSimilarityInputData, # noqa: F401 + SummarizationInput, # noqa: F401 + SummarizationOutput, # noqa: F401 + SummarizationParameters, # noqa: F401 + SummarizationTruncationStrategy, # noqa: F401 + TableQuestionAnsweringInput, # noqa: F401 + TableQuestionAnsweringInputData, # noqa: F401 + TableQuestionAnsweringOutputElement, # noqa: F401 + TableQuestionAnsweringParameters, # noqa: F401 + Text2TextGenerationInput, # noqa: F401 + Text2TextGenerationOutput, # noqa: F401 + Text2TextGenerationParameters, # noqa: F401 + Text2TextGenerationTruncationStrategy, # noqa: F401 + TextClassificationInput, # noqa: F401 + TextClassificationOutputElement, # noqa: F401 + TextClassificationOutputTransform, # noqa: F401 + TextClassificationParameters, # noqa: F401 + TextGenerationInput, # noqa: F401 + TextGenerationInputGenerateParameters, # noqa: F401 + TextGenerationInputGrammarType, # noqa: F401 + TextGenerationOutput, # noqa: F401 + TextGenerationOutputBestOfSequence, # noqa: F401 + TextGenerationOutputDetails, # noqa: F401 + TextGenerationOutputFinishReason, # noqa: F401 + TextGenerationOutputPrefillToken, # noqa: F401 + TextGenerationOutputToken, # noqa: F401 + TextGenerationStreamOutput, # noqa: F401 + TextGenerationStreamOutputStreamDetails, # noqa: F401 + TextGenerationStreamOutputToken, # noqa: F401 + TextToAudioEarlyStoppingEnum, # noqa: F401 + TextToAudioGenerationParameters, # noqa: F401 + TextToAudioInput, # noqa: F401 + TextToAudioOutput, # noqa: F401 + TextToAudioParameters, # noqa: F401 + TextToImageInput, # noqa: F401 + TextToImageOutput, # noqa: F401 + TextToImageParameters, # noqa: F401 + TextToSpeechEarlyStoppingEnum, # noqa: F401 + TextToSpeechGenerationParameters, # noqa: F401 + TextToSpeechInput, # noqa: F401 + TextToSpeechOutput, # noqa: F401 + TextToSpeechParameters, # noqa: F401 + TextToVideoInput, # noqa: F401 + TextToVideoOutput, # noqa: F401 + TextToVideoParameters, # noqa: F401 + TokenClassificationAggregationStrategy, # noqa: F401 + TokenClassificationInput, # noqa: F401 + TokenClassificationOutputElement, # noqa: F401 + TokenClassificationParameters, # noqa: F401 + TranslationInput, # noqa: F401 + TranslationOutput, # noqa: F401 + TranslationParameters, # noqa: F401 + TranslationTruncationStrategy, # noqa: F401 + TypeEnum, # noqa: F401 + VideoClassificationInput, # noqa: F401 + VideoClassificationOutputElement, # noqa: F401 + VideoClassificationOutputTransform, # noqa: F401 + VideoClassificationParameters, # noqa: F401 + VisualQuestionAnsweringInput, # noqa: F401 + VisualQuestionAnsweringInputData, # noqa: F401 + VisualQuestionAnsweringOutputElement, # noqa: F401 + VisualQuestionAnsweringParameters, # noqa: F401 + ZeroShotClassificationInput, # noqa: F401 + ZeroShotClassificationOutputElement, # noqa: F401 + ZeroShotClassificationParameters, # noqa: F401 + ZeroShotImageClassificationInput, # noqa: F401 + ZeroShotImageClassificationOutputElement, # noqa: F401 + ZeroShotImageClassificationParameters, # noqa: F401 + ZeroShotObjectDetectionBoundingBox, # noqa: F401 + ZeroShotObjectDetectionInput, # noqa: F401 + ZeroShotObjectDetectionOutputElement, # noqa: F401 + ZeroShotObjectDetectionParameters, # noqa: F401 + ) + from .inference._mcp.agent import Agent # noqa: F401 + from .inference._mcp.mcp_client import MCPClient # noqa: F401 + from .repocard import ( + DatasetCard, # noqa: F401 + ModelCard, # noqa: F401 + RepoCard, # noqa: F401 + SpaceCard, # noqa: F401 + metadata_eval_result, # noqa: F401 + metadata_load, # noqa: F401 + metadata_save, # noqa: F401 + metadata_update, # noqa: F401 + ) + from .repocard_data import ( + CardData, # noqa: F401 + DatasetCardData, # noqa: F401 + EvalResult, # noqa: F401 + ModelCardData, # noqa: F401 + SpaceCardData, # noqa: F401 + ) + from .serialization import ( + StateDictSplit, # noqa: F401 + get_torch_storage_id, # noqa: F401 + get_torch_storage_size, # noqa: F401 + load_state_dict_from_file, # noqa: F401 + load_torch_model, # noqa: F401 + save_torch_model, # noqa: F401 + save_torch_state_dict, # noqa: F401 + split_state_dict_into_shards_factory, # noqa: F401 + split_torch_state_dict_into_shards, # noqa: F401 + ) + from .serialization._dduf import ( + DDUFEntry, # noqa: F401 + export_entries_as_dduf, # noqa: F401 + export_folder_as_dduf, # noqa: F401 + read_dduf_file, # noqa: F401 + ) + from .utils import ( + ASYNC_CLIENT_FACTORY_T, # noqa: F401 + CLIENT_FACTORY_T, # noqa: F401 + CachedFileInfo, # noqa: F401 + CachedRepoInfo, # noqa: F401 + CachedRevisionInfo, # noqa: F401 + CacheNotFound, # noqa: F401 + CorruptedCacheException, # noqa: F401 + DeleteCacheStrategy, # noqa: F401 + HFCacheInfo, # noqa: F401 + cached_assets_path, # noqa: F401 + close_session, # noqa: F401 + dump_environment_info, # noqa: F401 + get_async_session, # noqa: F401 + get_session, # noqa: F401 + get_token, # noqa: F401 + hf_raise_for_status, # noqa: F401 + logging, # noqa: F401 + scan_cache_dir, # noqa: F401 + set_async_client_factory, # noqa: F401 + set_client_factory, # noqa: F401 + ) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/__init__.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff0770e1407f96351b74aefccc58f72ea7415d3e Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/__init__.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_commit_api.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_commit_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7e62d9516069cd1ac78eb58a24048090776ad7f Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_commit_api.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_commit_scheduler.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_commit_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1d8774fe2eeb58e9eda0e3ccc15dd2bc8e5fa1 Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_commit_scheduler.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_eval_results.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_eval_results.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c54c71cd2cf67ff4c660e795f2a78d4ed6b6487 Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_eval_results.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_inference_endpoints.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_inference_endpoints.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b36aaf63d8182ceccc5df13fbc6a8450d96ae4ce Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_inference_endpoints.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_jobs_api.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_jobs_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e35173f25f1e9a8f9d8152609df617b36f9c62f Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_jobs_api.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_local_folder.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_local_folder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddddf41c0f371f763196893d34f8271dd110881d Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_local_folder.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_login.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_login.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..889577ee28e388e74b226b3f3b8db5d60efb5d6a Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_login.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_oauth.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_oauth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24a3eb50d4331a21df585e0df0ce6e9b4cbba49d Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_oauth.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_snapshot_download.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_snapshot_download.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8767fb700bc6544f6596651942ae03e1e222a21 Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_snapshot_download.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_space_api.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_space_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85851c35089ef3b238950611b14f61547ba04993 Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_space_api.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_tensorboard_logger.cpython-310.pyc b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_tensorboard_logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..609392bcda79a7e1bd7cd56cb5b7d5c78bc772da Binary files /dev/null and b/venv/lib/python3.10/site-packages/huggingface_hub/__pycache__/_tensorboard_logger.cpython-310.pyc differ diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_commit_api.py b/venv/lib/python3.10/site-packages/huggingface_hub/_commit_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f93be55a454a056542a1b33845fb52db3bbd691b --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_commit_api.py @@ -0,0 +1,966 @@ +""" +Type definitions and utilities for the `create_commit` API +""" + +import base64 +import io +import os +import warnings +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from itertools import groupby +from pathlib import Path, PurePosixPath +from typing import TYPE_CHECKING, Any, BinaryIO, Iterable, Iterator, Literal, Optional, Union + +from tqdm.contrib.concurrent import thread_map + +from . import constants +from .errors import EntryNotFoundError, HfHubHTTPError, XetAuthorizationError, XetRefreshTokenError +from .file_download import hf_hub_url +from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info +from .utils import ( + FORBIDDEN_FOLDERS, + XetTokenType, + are_progress_bars_disabled, + chunk_iterable, + fetch_xet_connection_info_from_repo_info, + get_session, + hf_raise_for_status, + http_backoff, + logging, + sha, + tqdm_stream_file, + validate_hf_hub_args, +) +from .utils import tqdm as hf_tqdm +from .utils._runtime import is_xet_available + + +if TYPE_CHECKING: + from .hf_api import RepoFile + + +logger = logging.get_logger(__name__) + + +UploadMode = Literal["lfs", "regular"] + +# Max is 1,000 per request on the Hub for HfApi.get_paths_info +# Otherwise we get: +# HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters +# See https://github.com/huggingface/huggingface_hub/issues/1503 +FETCH_LFS_BATCH_SIZE = 500 + +UPLOAD_BATCH_MAX_NUM_FILES = 256 + + +@dataclass +class CommitOperationDelete: + """ + Data structure holding necessary info to delete a file or a folder from a repository + on the Hub. + + Args: + path_in_repo (`str`): + Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` + for a file or `"checkpoints/1fec34a/"` for a folder. + is_folder (`bool` or `Literal["auto"]`, *optional*) + Whether the Delete Operation applies to a folder or not. If "auto", the path + type (file or folder) is guessed automatically by looking if path ends with + a "/" (folder) or not (file). To explicitly set the path type, you can set + `is_folder=True` or `is_folder=False`. + """ + + path_in_repo: str + is_folder: Union[bool, Literal["auto"]] = "auto" + + def __post_init__(self): + self.path_in_repo = _validate_path_in_repo(self.path_in_repo) + + if self.is_folder == "auto": + self.is_folder = self.path_in_repo.endswith("/") + if not isinstance(self.is_folder, bool): + raise ValueError( + f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'." + ) + + +@dataclass +class CommitOperationCopy: + """ + Data structure holding necessary info to copy a file in a repository on the Hub. + + Limitations: + - Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it + - Cross-repository copies are not supported. + + Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub. + + Args: + src_path_in_repo (`str`): + Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`. + path_in_repo (`str`): + Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`. + src_revision (`str`, *optional*): + The git revision of the file to be copied. Can be any valid git revision. + Default to the target commit revision. + """ + + src_path_in_repo: str + path_in_repo: str + src_revision: Optional[str] = None + # set to the OID of the file to be copied if it has already been uploaded + # useful to determine if a commit will be empty or not. + _src_oid: Optional[str] = None + # set to the OID of the file to copy to if it has already been uploaded + # useful to determine if a commit will be empty or not. + _dest_oid: Optional[str] = None + + def __post_init__(self): + self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo) + self.path_in_repo = _validate_path_in_repo(self.path_in_repo) + + +@dataclass +class CommitOperationAdd: + """ + Data structure holding necessary info to upload a file to a repository on the Hub. + + Args: + path_in_repo (`str`): + Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` + path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`): + Either: + - a path to a local file (as `str` or `pathlib.Path`) to upload + - a buffer of bytes (`bytes`) holding the content of the file to upload + - a "file object" (subclass of `io.BufferedIOBase`), typically obtained + with `open(path, "rb")`. It must support `seek()` and `tell()` methods. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both + `seek()` and `tell()`. + """ + + path_in_repo: str + path_or_fileobj: Union[str, Path, bytes, BinaryIO] + upload_info: UploadInfo = field(init=False, repr=False) + + # Internal attributes + + # set to "lfs" or "regular" once known + _upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None) + + # set to True if .gitignore rules prevent the file from being uploaded as LFS + # (server-side check) + _should_ignore: Optional[bool] = field(init=False, repr=False, default=None) + + # set to the remote OID of the file if it has already been uploaded + # useful to determine if a commit will be empty or not + _remote_oid: Optional[str] = field(init=False, repr=False, default=None) + + # set to True once the file has been uploaded as LFS + _is_uploaded: bool = field(init=False, repr=False, default=False) + + # set to True once the file has been committed + _is_committed: bool = field(init=False, repr=False, default=False) + + def __post_init__(self) -> None: + """Validates `path_or_fileobj` and compute `upload_info`.""" + self.path_in_repo = _validate_path_in_repo(self.path_in_repo) + + # Validate `path_or_fileobj` value + if isinstance(self.path_or_fileobj, Path): + self.path_or_fileobj = str(self.path_or_fileobj) + if isinstance(self.path_or_fileobj, str): + path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj)) + if not os.path.isfile(path_or_fileobj): + raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system") + elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)): + # ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode + raise ValueError( + "path_or_fileobj must be either an instance of str, bytes or" + " io.BufferedIOBase. If you passed a file-like object, make sure it is" + " in binary mode." + ) + if isinstance(self.path_or_fileobj, io.BufferedIOBase): + try: + self.path_or_fileobj.tell() + self.path_or_fileobj.seek(0, os.SEEK_CUR) + except (OSError, AttributeError) as exc: + raise ValueError( + "path_or_fileobj is a file-like object but does not implement seek() and tell()" + ) from exc + + # Compute "upload_info" attribute + if isinstance(self.path_or_fileobj, str): + self.upload_info = UploadInfo.from_path(self.path_or_fileobj) + elif isinstance(self.path_or_fileobj, bytes): + self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj) + else: + self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj) + + @contextmanager + def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]: + """ + A context manager that yields a file-like object allowing to read the underlying + data behind `path_or_fileobj`. + + Args: + with_tqdm (`bool`, *optional*, defaults to `False`): + If True, iterating over the file object will display a progress bar. Only + works if the file-like object is a path to a file. Pure bytes and buffers + are not supported. + + Example: + + ```python + >>> operation = CommitOperationAdd( + ... path_in_repo="remote/dir/weights.h5", + ... path_or_fileobj="./local/weights.h5", + ... ) + CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5') + + >>> with operation.as_file() as file: + ... content = file.read() + + >>> with operation.as_file(with_tqdm=True) as file: + ... while True: + ... data = file.read(1024) + ... if not data: + ... break + config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] + + >>> with operation.as_file(with_tqdm=True) as file: + ... httpx.put(..., data=file) + config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] + ``` + """ + if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path): + if with_tqdm: + with tqdm_stream_file(self.path_or_fileobj) as file: + yield file + else: + with open(self.path_or_fileobj, "rb") as file: + yield file + elif isinstance(self.path_or_fileobj, bytes): + yield io.BytesIO(self.path_or_fileobj) + elif isinstance(self.path_or_fileobj, io.BufferedIOBase): + prev_pos = self.path_or_fileobj.tell() + yield self.path_or_fileobj + self.path_or_fileobj.seek(prev_pos, io.SEEK_SET) + + def b64content(self) -> bytes: + """ + The base64-encoded content of `path_or_fileobj` + + Returns: `bytes` + """ + with self.as_file() as file: + return base64.b64encode(file.read()) + + @property + def _local_oid(self) -> Optional[str]: + """Return the OID of the local file. + + This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one. + If the file did not change, we won't upload it again to prevent empty commits. + + For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref). + For regular files, the OID corresponds to the SHA1 of the file content. + Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the + pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes + and more convenient client-side. + """ + if self._upload_mode is None: + return None + elif self._upload_mode == "lfs": + return self.upload_info.sha256.hex() + else: + # Regular file => compute sha1 + # => no need to read by chunk since the file is guaranteed to be <=5MB. + with self.as_file() as file: + return sha.git_hash(file.read()) + + +def _validate_path_in_repo(path_in_repo: str) -> str: + # Validate `path_in_repo` value to prevent a server-side issue + if path_in_repo.startswith("/"): + path_in_repo = path_in_repo[1:] + if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"): + raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'") + if path_in_repo.startswith("./"): + path_in_repo = path_in_repo[2:] + for forbidden in FORBIDDEN_FOLDERS: + if any(part == forbidden for part in path_in_repo.split("/")): + raise ValueError( + f"Invalid `path_in_repo` in CommitOperation: cannot update files under a '{forbidden}/' folder (path:" + f" '{path_in_repo}')." + ) + return path_in_repo + + +CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete] + + +def _warn_on_overwriting_operations(operations: list[CommitOperation]) -> None: + """ + Warn user when a list of operations is expected to overwrite itself in a single + commit. + + Rules: + - If a filepath is updated by multiple `CommitOperationAdd` operations, a warning + message is triggered. + - If a filepath is updated at least once by a `CommitOperationAdd` and then deleted + by a `CommitOperationDelete`, a warning is triggered. + - If a `CommitOperationDelete` deletes a filepath that is then updated by a + `CommitOperationAdd`, no warning is triggered. This is usually useless (no need to + delete before upload) but can happen if a user deletes an entire folder and then + add new files to it. + """ + nb_additions_per_path: dict[str, int] = defaultdict(int) + for operation in operations: + path_in_repo = operation.path_in_repo + if isinstance(operation, CommitOperationAdd): + if nb_additions_per_path[path_in_repo] > 0: + warnings.warn( + "About to update multiple times the same file in the same commit:" + f" '{path_in_repo}'. This can cause undesired inconsistencies in" + " your repo." + ) + nb_additions_per_path[path_in_repo] += 1 + for parent in PurePosixPath(path_in_repo).parents: + # Also keep track of number of updated files per folder + # => warns if deleting a folder overwrite some contained files + nb_additions_per_path[str(parent)] += 1 + if isinstance(operation, CommitOperationDelete): + if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0: + if operation.is_folder: + warnings.warn( + "About to delete a folder containing files that have just been" + f" updated within the same commit: '{path_in_repo}'. This can" + " cause undesired inconsistencies in your repo." + ) + else: + warnings.warn( + "About to delete a file that have just been updated within the" + f" same commit: '{path_in_repo}'. This can cause undesired" + " inconsistencies in your repo." + ) + + +@validate_hf_hub_args +def _upload_files( + *, + additions: list[CommitOperationAdd], + repo_type: str, + repo_id: str, + headers: dict[str, str], + endpoint: Optional[str] = None, + num_threads: int = 5, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, +): + """ + Negotiates per-file transfer (LFS vs Xet) and uploads in batches. + """ + xet_additions: list[CommitOperationAdd] = [] + lfs_actions: list[dict[str, Any]] = [] + lfs_oid2addop: dict[str, CommitOperationAdd] = {} + + for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES): + chunk_list = [op for op in chunk] + + transfers: list[str] = ["basic", "multipart"] + has_buffered_io_data = any(isinstance(op.path_or_fileobj, io.BufferedIOBase) for op in chunk_list) + if is_xet_available(): + if not has_buffered_io_data: + transfers.append("xet") + else: + logger.warning( + "Uploading files as a binary IO buffer is not supported by Xet Storage. " + "Falling back to HTTP upload." + ) + + actions_chunk, errors_chunk, chosen_transfer = post_lfs_batch_info( + upload_infos=[op.upload_info for op in chunk_list], + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + headers=headers, + token=None, # already passed in 'headers' + transfers=transfers, + ) + if errors_chunk: + message = "\n".join( + [ + f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}" + for err in errors_chunk + ] + ) + raise ValueError(f"LFS batch API returned errors:\n{message}") + + # If server returns a transfer we didn't offer (e.g "xet" while uploading from BytesIO), + # fall back to LFS for this chunk. + if chosen_transfer == "xet" and ("xet" in transfers): + xet_additions.extend(chunk_list) + else: + lfs_actions.extend(actions_chunk) + for op in chunk_list: + lfs_oid2addop[op.upload_info.sha256.hex()] = op + + if len(lfs_actions) > 0: + _upload_lfs_files( + actions=lfs_actions, + oid2addop=lfs_oid2addop, + headers=headers, + endpoint=endpoint, + num_threads=num_threads, + ) + + if len(xet_additions) > 0: + _upload_xet_files( + additions=xet_additions, + repo_type=repo_type, + repo_id=repo_id, + headers=headers, + endpoint=endpoint, + revision=revision, + create_pr=create_pr, + ) + + +@validate_hf_hub_args +def _upload_lfs_files( + *, + actions: list[dict[str, Any]], + oid2addop: dict[str, CommitOperationAdd], + headers: dict[str, str], + endpoint: Optional[str] = None, + num_threads: int = 5, +): + """ + Uploads the content of `additions` to the Hub using the large file storage protocol. + + Relevant external documentation: + - LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md + + Args: + actions (`list[dict[str, Any]]`): + LFS batch actions returned by the server. + oid2addop (`dict[str, CommitOperationAdd]`): + A dictionary mapping the OID of the file to the corresponding `CommitOperationAdd` object. + headers (`dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + endpoint (`str`, *optional*): + The endpoint to use for the request. Defaults to `constants.ENDPOINT`. + num_threads (`int`, *optional*): + The number of concurrent threads to use when uploading. Defaults to 5. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If an upload failed for any reason + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + num_threads (`int`, *optional*): + The number of concurrent threads to use when uploading. Defaults to 5. + revision (`str`, *optional*): + The git revision to upload to. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If an upload failed for any reason + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the server returns malformed responses + [`HfHubHTTPError`] + If the LFS batch endpoint returned an HTTP error. + """ + # Filter out files already present upstream + filtered_actions = [] + for action in actions: + if action.get("actions") is None: + logger.debug( + f"Content of file {oid2addop[action['oid']].path_in_repo} is already present upstream - skipping upload." + ) + else: + filtered_actions.append(action) + + # Upload according to server-provided actions + def _wrapped_lfs_upload(batch_action) -> None: + try: + operation = oid2addop[batch_action["oid"]] + lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers, endpoint=endpoint) + except Exception as exc: + raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc + + if len(filtered_actions) == 1: + logger.debug("Uploading 1 LFS file to the Hub") + _wrapped_lfs_upload(filtered_actions[0]) + else: + logger.debug( + f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently" + ) + thread_map( + _wrapped_lfs_upload, + filtered_actions, + desc=f"Upload {len(filtered_actions)} LFS files", + max_workers=num_threads, + tqdm_class=hf_tqdm, + ) + + +@validate_hf_hub_args +def _upload_xet_files( + *, + additions: list[CommitOperationAdd], + repo_type: str, + repo_id: str, + headers: dict[str, str], + endpoint: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, +): + """ + Uploads the content of `additions` to the Hub using the xet storage protocol. + This chunks the files and deduplicates the chunks before uploading them to xetcas storage. + + Args: + additions (`` of `CommitOperationAdd`): + The files to be uploaded. + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + endpoint: (`str`, *optional*): + The endpoint to use for the xetcas service. Defaults to `constants.ENDPOINT`. + revision (`str`, *optional*): + The git revision to upload to. + create_pr (`bool`, *optional*): + Whether or not to create a Pull Request with that commit. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If an upload failed for any reason. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the server returns malformed responses or if the user is unauthorized to upload to xet storage. + [`HfHubHTTPError`] + If the LFS batch endpoint returned an HTTP error. + + **How it works:** + The file download system uses Xet storage, which is a content-addressable storage system that breaks files into chunks + for efficient storage and transfer. + + `hf_xet.upload_files` manages uploading files by: + - Taking a list of file paths to upload + - Breaking files into smaller chunks for efficient storage + - Avoiding duplicate storage by recognizing identical chunks across files + - Connecting to a storage server (CAS server) that manages these chunks + + The upload process works like this: + 1. Create a local folder at ~/.cache/huggingface/xet/chunk-cache to store file chunks for reuse. + 2. Process files in parallel (up to 8 files at once): + 2.1. Read the file content. + 2.2. Split the file content into smaller chunks based on content patterns: each chunk gets a unique ID based on what's in it. + 2.3. For each chunk: + - Check if it already exists in storage. + - Skip uploading chunks that already exist. + 2.4. Group chunks into larger blocks for efficient transfer. + 2.5. Upload these blocks to the storage server. + 2.6. Create and upload information about how the file is structured. + 3. Return reference files that contain information about the uploaded files, which can be used later to download them. + """ + if len(additions) == 0: + return + + # at this point, we know that hf_xet is installed + from hf_xet import upload_bytes, upload_files + + from .utils._xet_progress_reporting import XetProgressReporter + + try: + xet_connection_info = fetch_xet_connection_info_from_repo_info( + token_type=XetTokenType.WRITE, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + headers=headers, + endpoint=endpoint, + params={"create_pr": "1"} if create_pr else None, + ) + except HfHubHTTPError as e: + if e.response.status_code == 401: + raise XetAuthorizationError( + f"You are unauthorized to upload to xet storage for {repo_type}/{repo_id}. " + f"Please check that you have configured your access token with write access to the repo." + ) from e + raise + + xet_endpoint = xet_connection_info.endpoint + access_token_info = (xet_connection_info.access_token, xet_connection_info.expiration_unix_epoch) + + def token_refresher() -> tuple[str, int]: + new_xet_connection = fetch_xet_connection_info_from_repo_info( + token_type=XetTokenType.WRITE, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + headers=headers, + endpoint=endpoint, + params={"create_pr": "1"} if create_pr else None, + ) + if new_xet_connection is None: + raise XetRefreshTokenError("Failed to refresh xet token") + return new_xet_connection.access_token, new_xet_connection.expiration_unix_epoch + + if not are_progress_bars_disabled(): + progress = XetProgressReporter() + progress_callback = progress.update_progress + else: + progress, progress_callback = None, None + + try: + all_bytes_ops = [op for op in additions if isinstance(op.path_or_fileobj, bytes)] + all_paths_ops = [op for op in additions if isinstance(op.path_or_fileobj, (str, Path))] + + if len(all_paths_ops) > 0: + all_paths = [str(op.path_or_fileobj) for op in all_paths_ops] + upload_files( + all_paths, + xet_endpoint, + access_token_info, + token_refresher, + progress_callback, + repo_type, + ) + + if len(all_bytes_ops) > 0: + all_bytes = [op.path_or_fileobj for op in all_bytes_ops] + upload_bytes( + all_bytes, + xet_endpoint, + access_token_info, + token_refresher, + progress_callback, + repo_type, + ) + + finally: + if progress is not None: + progress.close(False) + + return + + +def _validate_preupload_info(preupload_info: dict): + files = preupload_info.get("files") + if not isinstance(files, list): + raise ValueError("preupload_info is improperly formatted") + for file_info in files: + if not ( + isinstance(file_info, dict) + and isinstance(file_info.get("path"), str) + and isinstance(file_info.get("uploadMode"), str) + and (file_info["uploadMode"] in ("lfs", "regular")) + ): + raise ValueError("preupload_info is improperly formatted:") + return preupload_info + + +@validate_hf_hub_args +def _fetch_upload_modes( + additions: Iterable[CommitOperationAdd], + repo_type: str, + repo_id: str, + headers: dict[str, str], + revision: str, + endpoint: Optional[str] = None, + create_pr: bool = False, + gitignore_content: Optional[str] = None, +) -> None: + """ + Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob, + as a git LFS blob, or as a XET file. Input `additions` are mutated in-place with the upload mode. + + Args: + additions (`Iterable` of :class:`CommitOperationAdd`): + Iterable of :class:`CommitOperationAdd` describing the files to + upload to the Hub. + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + revision (`str`): + The git revision to upload the files to. Can be any valid git revision. + gitignore_content (`str`, *optional*): + The content of the `.gitignore` file to know which files should be ignored. The order of priority + is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present + in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub + (if any). + Raises: + [`~utils.HfHubHTTPError`] + If the Hub API returned an error. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the Hub API response is improperly formatted. + """ + endpoint = endpoint if endpoint is not None else constants.ENDPOINT + + # Fetch upload mode (LFS or regular) chunk by chunk. + upload_modes: dict[str, UploadMode] = {} + should_ignore_info: dict[str, bool] = {} + oid_info: dict[str, Optional[str]] = {} + + for chunk in chunk_iterable(additions, 256): + payload: dict = { + "files": [ + { + "path": op.path_in_repo, + "sample": base64.b64encode(op.upload_info.sample).decode("ascii"), + "size": op.upload_info.size, + } + for op in chunk + ] + } + if gitignore_content is not None: + payload["gitIgnore"] = gitignore_content + + resp = http_backoff( + "POST", + f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}", + json=payload, + headers=headers, + params={"create_pr": "1"} if create_pr else None, + ) + hf_raise_for_status(resp) + preupload_info = _validate_preupload_info(resp.json()) + upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]}) + should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]}) + oid_info.update(**{file["path"]: file.get("oid") for file in preupload_info["files"]}) + + # Set upload mode for each addition operation + for addition in additions: + addition._upload_mode = upload_modes[addition.path_in_repo] + addition._should_ignore = should_ignore_info[addition.path_in_repo] + addition._remote_oid = oid_info[addition.path_in_repo] + + # Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented) + # => empty files are uploaded as "regular" to still allow users to commit them. + for addition in additions: + if addition.upload_info.size == 0: + addition._upload_mode = "regular" + + +@validate_hf_hub_args +def _fetch_files_to_copy( + copies: Iterable[CommitOperationCopy], + repo_type: str, + repo_id: str, + headers: dict[str, str], + revision: str, + endpoint: Optional[str] = None, +) -> dict[tuple[str, Optional[str]], Union["RepoFile", bytes]]: + """ + Fetch information about the files to copy. + + For LFS files, we only need their metadata (file size and sha256) while for regular files + we need to download the raw content from the Hub. + + Args: + copies (`Iterable` of :class:`CommitOperationCopy`): + Iterable of :class:`CommitOperationCopy` describing the files to + copy on the Hub. + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + headers (`dict[str, str]`): + Headers to use for the request, including authorization headers and user agent. + revision (`str`): + The git revision to upload the files to. Can be any valid git revision. + + Returns: `dict[tuple[str, Optional[str]], Union[RepoFile, bytes]]]` + Key is the file path and revision of the file to copy. + Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files). + + Raises: + [`~utils.HfHubHTTPError`] + If the Hub API returned an error. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the Hub API response is improperly formatted. + """ + from .hf_api import HfApi, RepoFolder + + hf_api = HfApi(endpoint=endpoint, headers=headers) + files_to_copy: dict[tuple[str, Optional[str]], Union["RepoFile", bytes]] = {} + # Store (path, revision) -> oid mapping + oid_info: dict[tuple[str, Optional[str]], Optional[str]] = {} + # 1. Fetch OIDs for destination paths in batches. + dest_paths = [op.path_in_repo for op in copies] + for offset in range(0, len(dest_paths), FETCH_LFS_BATCH_SIZE): + dest_repo_files = hf_api.get_paths_info( + repo_id=repo_id, + paths=dest_paths[offset : offset + FETCH_LFS_BATCH_SIZE], + revision=revision, + repo_type=repo_type, + ) + for file in dest_repo_files: + if not isinstance(file, RepoFolder): + oid_info[(file.path, revision)] = file.blob_id + + # 2. Group by source revision and fetch source file info in batches. + for src_revision, operations in groupby(copies, key=lambda op: op.src_revision): + operations = list(operations) # type: ignore + src_paths = [op.src_path_in_repo for op in operations] + for offset in range(0, len(src_paths), FETCH_LFS_BATCH_SIZE): + src_repo_files = hf_api.get_paths_info( + repo_id=repo_id, + paths=src_paths[offset : offset + FETCH_LFS_BATCH_SIZE], + revision=src_revision or revision, + repo_type=repo_type, + ) + + for src_repo_file in src_repo_files: + if isinstance(src_repo_file, RepoFolder): + raise NotImplementedError("Copying a folder is not implemented.") + oid_info[(src_repo_file.path, src_revision)] = src_repo_file.blob_id + # If it's an LFS file, store the RepoFile object. Otherwise, download raw bytes. + if src_repo_file.lfs: + files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file + else: + # TODO: (optimization) download regular files to copy concurrently + url = hf_hub_url( + endpoint=endpoint, + repo_type=repo_type, + repo_id=repo_id, + revision=src_revision or revision, + filename=src_repo_file.path, + ) + response = get_session().get(url, headers=headers) + hf_raise_for_status(response) + files_to_copy[(src_repo_file.path, src_revision)] = response.content + # 3. Ensure all operations found a corresponding file in the Hub + # and track src/dest OIDs for each operation. + for operation in operations: + if (operation.src_path_in_repo, src_revision) not in files_to_copy: + raise EntryNotFoundError( + f"Cannot copy {operation.src_path_in_repo} at revision " + f"{src_revision or revision}: file is missing on repo." + ) + operation._src_oid = oid_info.get((operation.src_path_in_repo, operation.src_revision)) + operation._dest_oid = oid_info.get((operation.path_in_repo, revision)) + return files_to_copy + + +def _prepare_commit_payload( + operations: Iterable[CommitOperation], + files_to_copy: dict[tuple[str, Optional[str]], Union["RepoFile", bytes]], + commit_message: str, + commit_description: Optional[str] = None, + parent_commit: Optional[str] = None, +) -> Iterable[dict[str, Any]]: + """ + Builds the payload to POST to the `/commit` API of the Hub. + + Payload is returned as an iterator so that it can be streamed as a ndjson in the + POST request. + + For more information, see: + - https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073 + - http://ndjson.org/ + """ + commit_description = commit_description if commit_description is not None else "" + + # 1. Send a header item with the commit metadata + header_value = {"summary": commit_message, "description": commit_description} + if parent_commit is not None: + header_value["parentCommit"] = parent_commit + yield {"key": "header", "value": header_value} + + nb_ignored_files = 0 + + # 2. Send operations, one per line + for operation in operations: + # Skip ignored files + if isinstance(operation, CommitOperationAdd) and operation._should_ignore: + logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).") + nb_ignored_files += 1 + continue + + # 2.a. Case adding a regular file + if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular": + yield { + "key": "file", + "value": { + "content": operation.b64content().decode(), + "path": operation.path_in_repo, + "encoding": "base64", + }, + } + # 2.b. Case adding an LFS file + elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs": + yield { + "key": "lfsFile", + "value": { + "path": operation.path_in_repo, + "algo": "sha256", + "oid": operation.upload_info.sha256.hex(), + "size": operation.upload_info.size, + }, + } + # 2.c. Case deleting a file or folder + elif isinstance(operation, CommitOperationDelete): + yield { + "key": "deletedFolder" if operation.is_folder else "deletedFile", + "value": {"path": operation.path_in_repo}, + } + # 2.d. Case copying a file or folder + elif isinstance(operation, CommitOperationCopy): + file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)] + if isinstance(file_to_copy, bytes): + yield { + "key": "file", + "value": { + "content": base64.b64encode(file_to_copy).decode(), + "path": operation.path_in_repo, + "encoding": "base64", + }, + } + elif file_to_copy.lfs: + yield { + "key": "lfsFile", + "value": { + "path": operation.path_in_repo, + "algo": "sha256", + "oid": file_to_copy.lfs.sha256, + }, + } + else: + raise ValueError( + "Malformed files_to_copy (should be raw file content as bytes or RepoFile objects with LFS info." + ) + # 2.e. Never expected to happen + else: + raise ValueError( + f"Unknown operation to commit. Operation: {operation}. Upload mode:" + f" {getattr(operation, '_upload_mode', None)}" + ) + + if nb_ignored_files > 0: + logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).") diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_commit_scheduler.py b/venv/lib/python3.10/site-packages/huggingface_hub/_commit_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..497c9a0be52d23d1d8bf4fe36f3dcfb54e7d665f --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_commit_scheduler.py @@ -0,0 +1,353 @@ +import atexit +import logging +import os +import time +from concurrent.futures import Future +from dataclasses import dataclass +from io import SEEK_END, SEEK_SET, BytesIO +from pathlib import Path +from threading import Lock, Thread +from typing import Optional, Union + +from .hf_api import DEFAULT_IGNORE_PATTERNS, CommitInfo, CommitOperationAdd, HfApi +from .utils import filter_repo_objects + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class _FileToUpload: + """Temporary dataclass to store info about files to upload. Not meant to be used directly.""" + + local_path: Path + path_in_repo: str + size_limit: int + last_modified: float + + +class CommitScheduler: + """ + Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes). + + The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is + properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually + with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads) + to learn more about how to use it. + + Args: + repo_id (`str`): + The id of the repo to commit to. + folder_path (`str` or `Path`): + Path to the local folder to upload regularly. + every (`int` or `float`, *optional*): + The number of minutes between each commit. Defaults to 5 minutes. + path_in_repo (`str`, *optional*): + Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder + of the repository. + repo_type (`str`, *optional*): + The type of the repo to commit to. Defaults to `model`. + revision (`str`, *optional*): + The revision of the repo to commit to. Defaults to `main`. + private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + token (`str`, *optional*): + The token to use to commit to the repo. Defaults to the token saved on the machine. + allow_patterns (`list[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`list[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + squash_history (`bool`, *optional*): + Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is + useful to avoid degraded performances on the repo when it grows too large. + hf_api (`HfApi`, *optional*): + The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...). + + Example: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import CommitScheduler + + # Scheduler uploads every 10 minutes + >>> csv_path = Path("watched_folder/data.csv") + >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10) + + >>> with csv_path.open("a") as f: + ... f.write("first line") + + # Some time later (...) + >>> with csv_path.open("a") as f: + ... f.write("second line") + ``` + + Example using a context manager: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import CommitScheduler + + >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler: + ... csv_path = Path("watched_folder/data.csv") + ... with csv_path.open("a") as f: + ... f.write("first line") + ... (...) + ... with csv_path.open("a") as f: + ... f.write("second line") + + # Scheduler is now stopped and last commit have been triggered + ``` + """ + + def __init__( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + every: Union[int, float] = 5, + path_in_repo: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[str] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + squash_history: bool = False, + hf_api: Optional["HfApi"] = None, + ) -> None: + self.api = hf_api or HfApi(token=token) + + # Folder + self.folder_path = Path(folder_path).expanduser().resolve() + self.path_in_repo = path_in_repo or "" + self.allow_patterns = allow_patterns + + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS + + if self.folder_path.is_file(): + raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.") + self.folder_path.mkdir(parents=True, exist_ok=True) + + # Repository + repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True) + self.repo_id = repo_url.repo_id + self.repo_type = repo_type + self.revision = revision + self.token = token + + # Keep track of already uploaded files + self.last_uploaded: dict[Path, float] = {} # key is local path, value is timestamp + + # Scheduler + if not every > 0: + raise ValueError(f"'every' must be a positive integer, not '{every}'.") + self.lock = Lock() + self.every = every + self.squash_history = squash_history + + logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.") + self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True) + self._scheduler_thread.start() + atexit.register(self._push_to_hub) + + self.__stopped = False + + def stop(self) -> None: + """Stop the scheduler. + + A stopped scheduler cannot be restarted. Mostly for tests purposes. + """ + self.__stopped = True + + def __enter__(self) -> "CommitScheduler": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + # Upload last changes before exiting + self.trigger().result() + self.stop() + return + + def _run_scheduler(self) -> None: + """Dumb thread waiting between each scheduled push to Hub.""" + while True: + self.last_future = self.trigger() + time.sleep(self.every * 60) + if self.__stopped: + break + + def trigger(self) -> Future: + """Trigger a `push_to_hub` and return a future. + + This method is automatically called every `every` minutes. You can also call it manually to trigger a commit + immediately, without waiting for the next scheduled commit. + """ + return self.api.run_as_future(self._push_to_hub) + + def _push_to_hub(self) -> Optional[CommitInfo]: + if self.__stopped: # If stopped, already scheduled commits are ignored + return None + + logger.info("(Background) scheduled commit triggered.") + try: + value = self.push_to_hub() + if self.squash_history: + logger.info("(Background) squashing repo history.") + self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision) + return value + except Exception as e: + logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced + raise + + def push_to_hub(self) -> Optional[CommitInfo]: + """ + Push folder to the Hub and return the commit info. + + > [!WARNING] + > This method is not meant to be called directly. It is run in the background by the scheduler, respecting a + > queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency + > issues. + + The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and + uploads only changed files. If no changes are found, the method returns without committing anything. If you want + to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful + for example to compress data together in a single file before committing. For more details and examples, check + out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads). + """ + # Check files to upload (with lock) + with self.lock: + logger.debug("Listing files to upload for scheduled commit.") + + # List files from folder (taken from `_prepare_upload_folder_additions`) + relpath_to_abspath = { + path.relative_to(self.folder_path).as_posix(): path + for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic + if path.is_file() + } + prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else "" + + # Filter with pattern + filter out unchanged files + retrieve current file size + files_to_upload: list[_FileToUpload] = [] + for relpath in filter_repo_objects( + relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns + ): + local_path = relpath_to_abspath[relpath] + stat = local_path.stat() + if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime: + files_to_upload.append( + _FileToUpload( + local_path=local_path, + path_in_repo=prefix + relpath, + size_limit=stat.st_size, + last_modified=stat.st_mtime, + ) + ) + + # Return if nothing to upload + if len(files_to_upload) == 0: + logger.debug("Dropping schedule commit: no changed file to upload.") + return None + + # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size) + logger.debug("Removing unchanged files since previous scheduled commit.") + add_operations = [ + CommitOperationAdd( + # Cap the file to its current size, even if the user append data to it while a scheduled commit is happening + path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit), + path_in_repo=file_to_upload.path_in_repo, + ) + for file_to_upload in files_to_upload + ] + + # Upload files (append mode expected - no need for lock) + logger.debug("Uploading files for scheduled commit.") + commit_info = self.api.create_commit( + repo_id=self.repo_id, + repo_type=self.repo_type, + operations=add_operations, + commit_message="Scheduled Commit", + revision=self.revision, + ) + + # Successful commit: keep track of the latest "last_modified" for each file + for file in files_to_upload: + self.last_uploaded[file.local_path] = file.last_modified + return commit_info + + +class PartialFileIO(BytesIO): + """A file-like object that reads only the first part of a file. + + Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the + file is uploaded (i.e. the part that was available when the filesystem was first scanned). + + In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal + disturbance for the user. The object is passed to `CommitOperationAdd`. + + Only supports `read`, `tell` and `seek` methods. + + Args: + file_path (`str` or `Path`): + Path to the file to read. + size_limit (`int`): + The maximum number of bytes to read from the file. If the file is larger than this, only the first part + will be read (and uploaded). + """ + + def __init__(self, file_path: Union[str, Path], size_limit: int) -> None: + self._file_path = Path(file_path) + self._file = self._file_path.open("rb") + self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size) + + def __del__(self) -> None: + self._file.close() + return super().__del__() + + def __repr__(self) -> str: + return f"" + + def __len__(self) -> int: + return self._size_limit + + def __getattribute__(self, name: str): + if name.startswith("_") or name in ("read", "tell", "seek", "fileno"): # only 4 public methods supported + return super().__getattribute__(name) + raise NotImplementedError(f"PartialFileIO does not support '{name}'.") + + def fileno(self): + raise AttributeError("PartialFileIO does not have a fileno.") + + def tell(self) -> int: + """Return the current file position.""" + return self._file.tell() + + def seek(self, __offset: int, __whence: int = SEEK_SET) -> int: + """Change the stream position to the given offset. + + Behavior is the same as a regular file, except that the position is capped to the size limit. + """ + if __whence == SEEK_END: + # SEEK_END => set from the truncated end + __offset = len(self) + __offset + __whence = SEEK_SET + + pos = self._file.seek(__offset, __whence) + if pos > self._size_limit: + return self._file.seek(self._size_limit) + return pos + + def read(self, __size: Optional[int] = -1) -> bytes: + """Read at most `__size` bytes from the file. + + Behavior is the same as a regular file, except that it is capped to the size limit. + """ + current = self._file.tell() + if __size is None or __size < 0: + # Read until file limit + truncated_size = self._size_limit - current + else: + # Read until file limit or __size + truncated_size = min(__size, self._size_limit - current) + return self._file.read(truncated_size) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_eval_results.py b/venv/lib/python3.10/site-packages/huggingface_hub/_eval_results.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b753693cc8953d5c4e2128c7b56880598ff2f5 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_eval_results.py @@ -0,0 +1,211 @@ +"""Evaluation results utilities for the `.eval_results/*.yaml` format. + +See https://huggingface.co/docs/hub/eval-results for more details. +Specifications are available at https://github.com/huggingface/hub-docs/blob/main/eval_results.yaml. +""" + +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class EvalResultEntry: + """ + Evaluation result entry for the `.eval_results/*.yaml` format. + + Represents evaluation scores stored in model repos that automatically appear on + the model page and the benchmark dataset's leaderboard. + + For the legacy `model-index` format in `README.md`, use [`EvalResult`] instead. + + See https://huggingface.co/docs/hub/eval-results for more details. + + Args: + dataset_id (`str`): + Benchmark dataset ID from the Hub. Example: "cais/hle", "Idavidrein/gpqa". + task_id (`str`): + Task identifier within the benchmark. Example: "gpqa_diamond". + value (`Any`): + The metric value. Example: 20.90. + dataset_revision (`str`, *optional*): + Git SHA of the benchmark dataset. + verify_token (`str`, *optional*): + A signature that can be used to prove that evaluation is provably auditable and reproducible. + date (`str`, *optional*): + When the evaluation was run (ISO-8601 datetime). Defaults to git commit time. + source_url (`str`, *optional*): + Link to the evaluation source (e.g., https://huggingface.co/spaces/SaylorTwift/smollm3-mmlu-pro). Required if `source_name`, `source_user`, or `source_org` is provided. + source_name (`str`, *optional*): + Display name for the source. Example: "Eval Logs". + source_user (`str`, *optional*): + HF user name for attribution. Example: "celinah". + source_org (`str`, *optional*): + HF org name for attribution. Example: "cais". + notes (`str`, *optional*): + Details about the evaluation setup. Example: "tools", "no-tools", "chain-of-thought". + + Example: + ```python + >>> from huggingface_hub import EvalResultEntry + >>> # Minimal example with required fields only + >>> result = EvalResultEntry( + ... dataset_id="Idavidrein/gpqa", + ... task_id="gpqa_diamond", + ... value=0.412, + ... ) + >>> # Full example with all fields + >>> result = EvalResultEntry( + ... dataset_id="cais/hle", + ... task_id="default", + ... value=20.90, + ... dataset_revision="5503434ddd753f426f4b38109466949a1217c2bb", + ... verify_token="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + ... date="2025-01-15T10:30:00Z", + ... source_url="https://huggingface.co/datasets/cais/hle", + ... source_name="CAIS HLE", + ... source_org="cais", + ... notes="no-tools", + ... ) + + ``` + """ + + dataset_id: str + task_id: str + value: Any + dataset_revision: Optional[str] = None + verify_token: Optional[str] = None + date: Optional[str] = None + source_url: Optional[str] = None + source_name: Optional[str] = None + source_user: Optional[str] = None + source_org: Optional[str] = None + notes: Optional[str] = None + + def __post_init__(self) -> None: + if ( + self.source_name is not None or self.source_user is not None or self.source_org is not None + ) and self.source_url is None: + raise ValueError( + "If `source_name`, `source_user`, or `source_org` is provided, `source_url` must also be provided." + ) + + +def eval_result_entries_to_yaml(entries: list[EvalResultEntry]) -> list[dict[str, Any]]: + """Convert a list of [`EvalResultEntry`] objects to a YAML-serializable list of dicts. + + This produces the format expected in `.eval_results/*.yaml` files. + + Args: + entries (`list[EvalResultEntry]`): + List of evaluation result entries to serialize. + + Returns: + `list[dict[str, Any]]`: A list of dictionaries ready to be dumped to YAML. + + Example: + ```python + >>> from huggingface_hub import EvalResultEntry, eval_result_entries_to_yaml + >>> entries = [ + ... EvalResultEntry(dataset_id="cais/hle", task_id="default", value=20.90), + ... EvalResultEntry(dataset_id="Idavidrein/gpqa", task_id="gpqa_diamond", value=0.412), + ... ] + >>> yaml_data = eval_result_entries_to_yaml(entries) + >>> yaml_data[0] + {'dataset': {'id': 'cais/hle', 'task_id': 'default'}, 'value': 20.9} + + ``` + + To upload eval results to the Hub: + ```python + >>> import yaml + >>> from huggingface_hub import upload_file, EvalResultEntry, eval_result_entries_to_yaml + >>> entries = [ + ... EvalResultEntry(dataset_id="cais/hle", task_id="default", value=20.90), + ... ] + >>> yaml_content = yaml.dump(eval_result_entries_to_yaml(entries)) + >>> upload_file( + ... path_or_fileobj=yaml_content.encode(), + ... path_in_repo=".eval_results/hle.yaml", + ... repo_id="your-username/your-model", + ... ) + + ``` + """ + result = [] + for entry in entries: + # build the dataset object + dataset: dict[str, Any] = {"id": entry.dataset_id, "task_id": entry.task_id} + if entry.dataset_revision is not None: + dataset["revision"] = entry.dataset_revision + + data: dict[str, Any] = {"dataset": dataset, "value": entry.value} + if entry.verify_token is not None: + data["verifyToken"] = entry.verify_token + if entry.date is not None: + data["date"] = entry.date + # build the source object + if entry.source_url is not None: + source: dict[str, Any] = {"url": entry.source_url} + if entry.source_name is not None: + source["name"] = entry.source_name + if entry.source_user is not None: + source["user"] = entry.source_user + if entry.source_org is not None: + source["org"] = entry.source_org + data["source"] = source + if entry.notes is not None: + data["notes"] = entry.notes + + result.append(data) + return result + + +def parse_eval_result_entries(data: list[dict[str, Any]]) -> list[EvalResultEntry]: + """Parse a list of dicts into [`EvalResultEntry`] objects. + + This parses the `.eval_results/*.yaml` format. For the legacy `model-index` format, + use [`model_index_to_eval_results`] instead. + + Args: + data (`list[dict[str, Any]]`): + A list of dictionaries (e.g., parsed from YAML or API response). + + Returns: + `list[EvalResultEntry]`: A list of evaluation result entry objects. + + Example: + ```python + >>> from huggingface_hub import parse_eval_result_entries + >>> data = [ + ... {"dataset": {"id": "cais/hle", "task_id": "default"}, "value": 20.90}, + ... {"dataset": {"id": "Idavidrein/gpqa", "task_id": "gpqa_diamond"}, "value": 0.412}, + ... ] + >>> entries = parse_eval_result_entries(data) + >>> entries[0].dataset_id + 'cais/hle' + >>> entries[0].value + 20.9 + + ``` + """ + entries = [] + for item in data: + entry_data = item.get("data", item) + dataset = entry_data.get("dataset", {}) + source = entry_data.get("source", {}) + entry = EvalResultEntry( + dataset_id=dataset["id"], + value=entry_data["value"], + task_id=dataset["task_id"], + dataset_revision=dataset.get("revision"), + verify_token=entry_data.get("verifyToken"), + date=entry_data.get("date"), + source_url=source.get("url") if source else None, + source_name=source.get("name") if source else None, + source_user=source.get("user") if source else None, + source_org=source.get("org") if source else None, + notes=entry_data.get("notes"), + ) + entries.append(entry) + return entries diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_inference_endpoints.py b/venv/lib/python3.10/site-packages/huggingface_hub/_inference_endpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..1a680850c4aa8856e43e91174ee47f75e3dc5341 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_inference_endpoints.py @@ -0,0 +1,418 @@ +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError + +from .utils import get_session, logging, parse_datetime + + +if TYPE_CHECKING: + from .hf_api import HfApi + from .inference._client import InferenceClient + from .inference._generated._async_client import AsyncInferenceClient + +logger = logging.get_logger(__name__) + + +class InferenceEndpointStatus(str, Enum): + PENDING = "pending" + INITIALIZING = "initializing" + UPDATING = "updating" + UPDATE_FAILED = "updateFailed" + RUNNING = "running" + PAUSED = "paused" + FAILED = "failed" + SCALED_TO_ZERO = "scaledToZero" + + +class InferenceEndpointType(str, Enum): + PUBlIC = "public" + PROTECTED = "protected" + PRIVATE = "private" + + +class InferenceEndpointScalingMetric(str, Enum): + PENDING_REQUESTS = "pendingRequests" + HARDWARE_USAGE = "hardwareUsage" + + +@dataclass +class InferenceEndpoint: + """ + Contains information about a deployed Inference Endpoint. + + Args: + name (`str`): + The unique name of the Inference Endpoint. + namespace (`str`): + The namespace where the Inference Endpoint is located. + repository (`str`): + The name of the model repository deployed on this Inference Endpoint. + status ([`InferenceEndpointStatus`]): + The current status of the Inference Endpoint. + url (`str`, *optional*): + The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL. + framework (`str`): + The machine learning framework used for the model. + revision (`str`): + The specific model revision deployed on the Inference Endpoint. + task (`str`): + The task associated with the deployed model. + created_at (`datetime.datetime`): + The timestamp when the Inference Endpoint was created. + updated_at (`datetime.datetime`): + The timestamp of the last update of the Inference Endpoint. + type ([`InferenceEndpointType`]): + The type of the Inference Endpoint (public, protected, private). + raw (`dict`): + The raw dictionary data returned from the API. + token (`str` or `bool`, *optional*): + Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the + locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server. + + Example: + ```python + >>> from huggingface_hub import get_inference_endpoint + >>> endpoint = get_inference_endpoint("my-text-to-image") + >>> endpoint + InferenceEndpoint(name='my-text-to-image', ...) + + # Get status + >>> endpoint.status + 'running' + >>> endpoint.url + 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' + + # Run inference + >>> endpoint.client.text_to_image(...) + + # Pause endpoint to save $$$ + >>> endpoint.pause() + + # ... + # Resume and wait for deployment + >>> endpoint.resume() + >>> endpoint.wait() + >>> endpoint.client.text_to_image(...) + ``` + """ + + # Field in __repr__ + name: str = field(init=False) + namespace: str + repository: str = field(init=False) + status: InferenceEndpointStatus = field(init=False) + health_route: str = field(init=False) + url: Optional[str] = field(init=False) + + # Other fields + framework: str = field(repr=False, init=False) + revision: str = field(repr=False, init=False) + task: str = field(repr=False, init=False) + created_at: datetime = field(repr=False, init=False) + updated_at: datetime = field(repr=False, init=False) + type: InferenceEndpointType = field(repr=False, init=False) + + # Raw dict from the API + raw: dict = field(repr=False) + + # Internal fields + _token: Union[str, bool, None] = field(repr=False, compare=False) + _api: "HfApi" = field(repr=False, compare=False) + + @classmethod + def from_raw( + cls, raw: dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None + ) -> "InferenceEndpoint": + """Initialize object from raw dictionary.""" + if api is None: + from .hf_api import HfApi + + api = HfApi() + if token is None: + token = api.token + + # All other fields are populated in __post_init__ + return cls(raw=raw, namespace=namespace, _token=token, _api=api) + + def __post_init__(self) -> None: + """Populate fields from raw dictionary.""" + self._populate_from_raw() + + @property + def client(self) -> "InferenceClient": + """Returns a client to make predictions on this Inference Endpoint. + + Returns: + [`InferenceClient`]: an inference client pointing to the deployed endpoint. + + Raises: + [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. + """ + if self.url is None: + raise InferenceEndpointError( + "Cannot create a client for this Inference Endpoint as it is not yet deployed. " + "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." + ) + from .inference._client import InferenceClient + + return InferenceClient( + model=self.url, + token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. + ) + + @property + def async_client(self) -> "AsyncInferenceClient": + """Returns a client to make predictions on this Inference Endpoint. + + Returns: + [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint. + + Raises: + [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. + """ + if self.url is None: + raise InferenceEndpointError( + "Cannot create a client for this Inference Endpoint as it is not yet deployed. " + "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." + ) + from .inference._generated._async_client import AsyncInferenceClient + + return AsyncInferenceClient( + model=self.url, + token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. + ) + + def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint": + """Wait for the Inference Endpoint to be deployed. + + Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout` + seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest + data. + + Args: + timeout (`int`, *optional*): + The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait + indefinitely. + refresh_every (`int`, *optional*): + The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + + Raises: + [`InferenceEndpointError`] + If the Inference Endpoint ended up in a failed state. + [`InferenceEndpointTimeoutError`] + If the Inference Endpoint is not deployed after `timeout` seconds. + """ + if timeout is not None and timeout < 0: + raise ValueError("`timeout` cannot be negative.") + if refresh_every <= 0: + raise ValueError("`refresh_every` must be positive.") + + start = time.time() + while True: + if self.status == InferenceEndpointStatus.FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information." + ) + if self.status == InferenceEndpointStatus.UPDATE_FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to update. Please check the logs for more information." + ) + if self.status == InferenceEndpointStatus.RUNNING and self.url is not None: + # Verify the endpoint is actually reachable + _health_url = f"{self.url.rstrip('/')}/{self.health_route.lstrip('/')}" + response = get_session().get(_health_url, headers=self._api._build_hf_headers(token=self._token)) + if response.status_code == 200: + logger.info("Inference Endpoint is ready to be used.") + return self + + if timeout is not None: + if time.time() - start > timeout: + raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") + logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...") + time.sleep(refresh_every) + self.fetch() + + def fetch(self) -> "InferenceEndpoint": + """Fetch latest information about the Inference Endpoint. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def update( + self, + *, + # Compute update + accelerator: Optional[str] = None, + instance_size: Optional[str] = None, + instance_type: Optional[str] = None, + min_replica: Optional[int] = None, + max_replica: Optional[int] = None, + scale_to_zero_timeout: Optional[int] = None, + # Model update + repository: Optional[str] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[dict] = None, + secrets: Optional[dict[str, str]] = None, + ) -> "InferenceEndpoint": + """Update the Inference Endpoint. + + This method allows the update of either the compute configuration, the deployed model, or both. All arguments are + optional but at least one must be provided. + + This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Args: + accelerator (`str`, *optional*): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`, *optional*): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`, *optional*): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero. + + repository (`str`, *optional*): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`, *optional*): + The machine learning framework used for the model (e.g. `"custom"`). + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + secrets (`dict[str, str]`, *optional*): + Secret values to inject in the container environment. + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + # Make API call + obj = self._api.update_inference_endpoint( + name=self.name, + namespace=self.namespace, + accelerator=accelerator, + instance_size=instance_size, + instance_type=instance_type, + min_replica=min_replica, + max_replica=max_replica, + scale_to_zero_timeout=scale_to_zero_timeout, + repository=repository, + framework=framework, + revision=revision, + task=task, + custom_image=custom_image, + secrets=secrets, + token=self._token, # type: ignore [arg-type] + ) + + # Mutate current object + self.raw = obj.raw + self._populate_from_raw() + return self + + def pause(self) -> "InferenceEndpoint": + """Pause the Inference Endpoint. + + A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`]. + This is different from scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which + would be automatically restarted when a request is made to it. + + This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def resume(self, running_ok: bool = True) -> "InferenceEndpoint": + """Resume the Inference Endpoint. + + This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Args: + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.resume_inference_endpoint( + name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token + ) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def scale_to_zero(self) -> "InferenceEndpoint": + """Scale Inference Endpoint to zero. + + An Inference Endpoint scaled to zero will not be charged. It will be resumed on the next request to it, with a + cold start delay. This is different from pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which + would require a manual resume with [`InferenceEndpoint.resume`]. + + This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def delete(self) -> None: + """Delete the Inference Endpoint. + + This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable + to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`]. + + This is an alias for [`HfApi.delete_inference_endpoint`]. + """ + self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + + def _populate_from_raw(self) -> None: + """Populate fields from raw dictionary. + + Called in __post_init__ + each time the Inference Endpoint is updated. + """ + # Repr fields + self.name = self.raw["name"] + self.repository = self.raw["model"]["repository"] + self.status = self.raw["status"]["state"] + self.url = self.raw["status"].get("url") + self.health_route = self.raw["healthRoute"] + + # Other fields + self.framework = self.raw["model"]["framework"] + self.revision = self.raw["model"]["revision"] + self.task = self.raw["model"]["task"] + self.created_at = parse_datetime(self.raw["status"]["createdAt"]) + self.updated_at = parse_datetime(self.raw["status"]["updatedAt"]) + self.type = self.raw["type"] diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_jobs_api.py b/venv/lib/python3.10/site-packages/huggingface_hub/_jobs_api.py new file mode 100644 index 0000000000000000000000000000000000000000..914ec6d4c30c8593c0f6868389b8aa47de3ac33f --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_jobs_api.py @@ -0,0 +1,399 @@ +# coding=utf-8 +# Copyright 2025-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Optional, Union + +from huggingface_hub import constants +from huggingface_hub._space_api import SpaceHardware +from huggingface_hub.utils._datetime import parse_datetime + + +class JobStage(str, Enum): + """ + Enumeration of possible stage of a Job on the Hub. + + Value can be compared to a string: + ```py + assert JobStage.COMPLETED == "COMPLETED" + ``` + Possible values are: `COMPLETED`, `CANCELED`, `ERROR`, `DELETED`, `RUNNING`. + Taken from https://github.com/huggingface/moon-landing/blob/main/server/job_types/JobInfo.ts#L61 (private url). + """ + + # Copied from moon-landing > server > lib > Job.ts + COMPLETED = "COMPLETED" + CANCELED = "CANCELED" + ERROR = "ERROR" + DELETED = "DELETED" + RUNNING = "RUNNING" + + +@dataclass +class JobStatus: + stage: JobStage + message: Optional[str] + + +@dataclass +class JobOwner: + id: str + name: str + type: str + + +@dataclass +class JobInfo: + """ + Contains information about a Job. + + Args: + id (`str`): + Job ID. + created_at (`datetime` or `None`): + When the Job was created. + docker_image (`str` or `None`): + The Docker image from Docker Hub used for the Job. + Can be None if space_id is present instead. + space_id (`str` or `None`): + The Docker image from Hugging Face Spaces used for the Job. + Can be None if docker_image is present instead. + command (`list[str]` or `None`): + Command of the Job, e.g. `["python", "-c", "print('hello world')"]` + arguments (`list[str]` or `None`): + Arguments passed to the command + environment (`dict[str]` or `None`): + Environment variables of the Job as a dictionary. + secrets (`dict[str]` or `None`): + Secret environment variables of the Job (encrypted). + flavor (`str` or `None`): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + E.g. `"cpu-basic"`. + labels (`dict[str, str]` or `None`): + Labels to attach to the job (key-value pairs). + status: (`JobStatus` or `None`): + Status of the Job, e.g. `JobStatus(stage="RUNNING", message=None)` + See [`JobStage`] for possible stage values. + owner: (`JobOwner` or `None`): + Owner of the Job, e.g. `JobOwner(id="5e9ecfc04957053f60648a3e", name="lhoestq", type="user")` + + Example: + + ```python + >>> from huggingface_hub import run_job + >>> job = run_job( + ... image="python:3.12", + ... command=["python", "-c", "print('Hello from the cloud!')"] + ... ) + >>> job + JobInfo(id='687fb701029421ae5549d998', created_at=datetime.datetime(2025, 7, 22, 16, 6, 25, 79000, tzinfo=datetime.timezone.utc), docker_image='python:3.12', space_id=None, command=['python', '-c', "print('Hello from the cloud!')"], arguments=[], environment={}, secrets={}, flavor='cpu-basic', labels=None, status=JobStatus(stage='RUNNING', message=None), owner=JobOwner(id='5e9ecfc04957053f60648a3e', name='lhoestq', type='user'), endpoint='https://huggingface.co', url='https://huggingface.co/jobs/lhoestq/687fb701029421ae5549d998') + >>> job.id + '687fb701029421ae5549d998' + >>> job.url + 'https://huggingface.co/jobs/lhoestq/687fb701029421ae5549d998' + >>> job.status.stage + 'RUNNING' + ``` + """ + + id: str + created_at: Optional[datetime] + docker_image: Optional[str] + space_id: Optional[str] + command: Optional[list[str]] + arguments: Optional[list[str]] + environment: Optional[dict[str, Any]] + secrets: Optional[dict[str, Any]] + flavor: Optional[SpaceHardware] + labels: Optional[dict[str, str]] + status: JobStatus + owner: JobOwner + + # Inferred fields + endpoint: str + url: str + + def __init__(self, **kwargs) -> None: + self.id = kwargs["id"] + created_at = kwargs.get("createdAt") or kwargs.get("created_at") + self.created_at = parse_datetime(created_at) if created_at else None + self.docker_image = kwargs.get("dockerImage") or kwargs.get("docker_image") + self.space_id = kwargs.get("spaceId") or kwargs.get("space_id") + owner = kwargs.get("owner", {}) + self.owner = JobOwner(id=owner["id"], name=owner["name"], type=owner["type"]) + self.command = kwargs.get("command") + self.arguments = kwargs.get("arguments") + self.environment = kwargs.get("environment") + self.secrets = kwargs.get("secrets") + self.flavor = kwargs.get("flavor") + self.labels = kwargs.get("labels") + status = kwargs.get("status", {}) + self.status = JobStatus(stage=status["stage"], message=status.get("message")) + + # Inferred fields + self.endpoint = kwargs.get("endpoint", constants.ENDPOINT) + self.url = f"{self.endpoint}/jobs/{self.owner.name}/{self.id}" + + +@dataclass +class JobSpec: + docker_image: Optional[str] + space_id: Optional[str] + command: Optional[list[str]] + arguments: Optional[list[str]] + environment: Optional[dict[str, Any]] + secrets: Optional[dict[str, Any]] + flavor: Optional[SpaceHardware] + timeout: Optional[int] + tags: Optional[list[str]] + arch: Optional[str] + labels: Optional[dict[str, str]] + + def __init__(self, **kwargs) -> None: + self.docker_image = kwargs.get("dockerImage") or kwargs.get("docker_image") + self.space_id = kwargs.get("spaceId") or kwargs.get("space_id") + self.command = kwargs.get("command") + self.arguments = kwargs.get("arguments") + self.environment = kwargs.get("environment") + self.secrets = kwargs.get("secrets") + self.flavor = kwargs.get("flavor") + self.timeout = kwargs.get("timeout") + self.tags = kwargs.get("tags") + self.arch = kwargs.get("arch") + self.labels = kwargs.get("labels") + + +@dataclass +class LastJobInfo: + id: str + at: datetime + + def __init__(self, **kwargs) -> None: + self.id = kwargs["id"] + self.at = parse_datetime(kwargs["at"]) + + +@dataclass +class ScheduledJobStatus: + last_job: Optional[LastJobInfo] + next_job_run_at: Optional[datetime] + + def __init__(self, **kwargs) -> None: + last_job = kwargs.get("lastJob") or kwargs.get("last_job") + self.last_job = LastJobInfo(**last_job) if last_job else None + next_job_run_at = kwargs.get("nextJobRunAt") or kwargs.get("next_job_run_at") + self.next_job_run_at = parse_datetime(str(next_job_run_at)) if next_job_run_at else None + + +@dataclass +class ScheduledJobInfo: + """ + Contains information about a Job. + + Args: + id (`str`): + Scheduled Job ID. + created_at (`datetime` or `None`): + When the scheduled Job was created. + tags (`list[str]` or `None`): + The tags of the scheduled Job. + schedule (`str` or `None`): + One of "@annually", "@yearly", "@monthly", "@weekly", "@daily", "@hourly", or a + CRON schedule expression (e.g., '0 9 * * 1' for 9 AM every Monday). + suspend (`bool` or `None`): + Whether the scheduled job is suspended (paused). + concurrency (`bool` or `None`): + Whether multiple instances of this Job can run concurrently. + status (`ScheduledJobStatus` or `None`): + Status of the scheduled Job. + owner: (`JobOwner` or `None`): + Owner of the scheduled Job, e.g. `JobOwner(id="5e9ecfc04957053f60648a3e", name="lhoestq", type="user")` + job_spec: (`JobSpec` or `None`): + Specifications of the Job. + + Example: + + ```python + >>> from huggingface_hub import run_job + >>> scheduled_job = create_scheduled_job( + ... image="python:3.12", + ... command=["python", "-c", "print('Hello from the cloud!')"], + ... schedule="@hourly", + ... ) + >>> scheduled_job.id + '687fb701029421ae5549d999' + >>> scheduled_job.status.next_job_run_at + datetime.datetime(2025, 7, 22, 17, 6, 25, 79000, tzinfo=datetime.timezone.utc) + ``` + """ + + id: str + created_at: Optional[datetime] + job_spec: JobSpec + schedule: Optional[str] + suspend: Optional[bool] + concurrency: Optional[bool] + status: ScheduledJobStatus + owner: JobOwner + + def __init__(self, **kwargs) -> None: + self.id = kwargs["id"] + created_at = kwargs.get("createdAt") or kwargs.get("created_at") + self.created_at = parse_datetime(created_at) if created_at else None + self.job_spec = JobSpec(**(kwargs.get("job_spec") or kwargs.get("jobSpec", {}))) + self.schedule = kwargs.get("schedule") + self.suspend = kwargs.get("suspend") + self.concurrency = kwargs.get("concurrency") + status = kwargs.get("status", {}) + self.status = ScheduledJobStatus( + last_job=status.get("last_job") or status.get("lastJob"), + next_job_run_at=status.get("next_job_run_at") or status.get("nextJobRunAt"), + ) + owner = kwargs.get("owner", {}) + self.owner = JobOwner(id=owner["id"], name=owner["name"], type=owner["type"]) + + +@dataclass +class JobAccelerator: + """ + Contains information about a Job accelerator (GPU). + + Args: + type (`str`): + Type of accelerator, e.g. `"gpu"`. + model (`str`): + Model of accelerator, e.g. `"T4"`, `"A10G"`, `"A100"`, `"L4"`, `"L40S"`. + quantity (`str`): + Number of accelerators, e.g. `"1"`, `"2"`, `"4"`, `"8"`. + vram (`str`): + Total VRAM, e.g. `"16 GB"`, `"24 GB"`. + manufacturer (`str`): + Manufacturer of the accelerator, e.g. `"Nvidia"`. + """ + + type: str + model: str + quantity: str + vram: str + manufacturer: str + + def __init__(self, **kwargs) -> None: + self.type = kwargs["type"] + self.model = kwargs["model"] + self.quantity = kwargs["quantity"] + self.vram = kwargs["vram"] + self.manufacturer = kwargs["manufacturer"] + + +@dataclass +class JobHardware: + """ + Contains information about available Job hardware. + + Args: + name (`str`): + Machine identifier, e.g. `"cpu-basic"`, `"a10g-large"`. + pretty_name (`str`): + Human-readable name, e.g. `"CPU Basic"`, `"Nvidia A10G - large"`. + cpu (`str`): + CPU specification, e.g. `"2 vCPU"`, `"12 vCPU"`. + ram (`str`): + RAM specification, e.g. `"16 GB"`, `"46 GB"`. + accelerator (`JobAccelerator` or `None`): + GPU/accelerator details if available. + unit_cost_micro_usd (`int`): + Cost in micro-dollars per unit, e.g. `167` (= $0.000167). + unit_cost_usd (`float`): + Cost in USD per unit, e.g. `0.000167`. + unit_label (`str`): + Cost unit period, e.g. `"minute"`. + + Example: + + ```python + >>> from huggingface_hub import list_jobs_hardware + >>> hardware_list = list_jobs_hardware() + >>> hardware_list[0] + JobHardware(name='cpu-basic', pretty_name='CPU Basic', cpu='2 vCPU', ram='16 GB', accelerator=None, unit_cost_micro_usd=167, unit_cost_usd=0.000167, unit_label='minute') + >>> hardware_list[0].name + 'cpu-basic' + ``` + """ + + name: str + pretty_name: str + cpu: str + ram: str + accelerator: Optional[JobAccelerator] + unit_cost_micro_usd: int + unit_cost_usd: float + unit_label: str + + def __init__(self, **kwargs) -> None: + self.name = kwargs["name"] + self.pretty_name = kwargs["prettyName"] + self.cpu = kwargs["cpu"] + self.ram = kwargs["ram"] + accelerator = kwargs.get("accelerator") + self.accelerator = JobAccelerator(**accelerator) if accelerator else None + self.unit_cost_micro_usd = kwargs["unitCostMicroUSD"] + self.unit_cost_usd = kwargs["unitCostUSD"] + self.unit_label = kwargs["unitLabel"] + + +def _create_job_spec( + *, + image: str, + command: list[str], + env: Optional[dict[str, Any]], + secrets: Optional[dict[str, Any]], + flavor: Optional[SpaceHardware], + timeout: Optional[Union[int, float, str]], + labels: Optional[dict[str, str]] = None, +) -> dict[str, Any]: + # prepare job spec to send to HF Jobs API + job_spec: dict[str, Any] = { + "command": command, + "arguments": [], + "environment": env or {}, + "flavor": flavor or SpaceHardware.CPU_BASIC, + } + # secrets are optional + if secrets: + job_spec["secrets"] = secrets + # timeout is optional + if timeout: + time_units_factors = {"s": 1, "m": 60, "h": 3600, "d": 3600 * 24} + if isinstance(timeout, str) and timeout[-1] in time_units_factors: + job_spec["timeoutSeconds"] = int(float(timeout[:-1]) * time_units_factors[timeout[-1]]) + else: + job_spec["timeoutSeconds"] = int(timeout) + # labels are optional + if labels: + job_spec["labels"] = labels + # input is either from docker hub or from HF spaces + for prefix in ( + "https://huggingface.co/spaces/", + "https://hf.co/spaces/", + "huggingface.co/spaces/", + "hf.co/spaces/", + ): + if image.startswith(prefix): + job_spec["spaceId"] = image[len(prefix) :] + break + else: + job_spec["dockerImage"] = image + return job_spec diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_local_folder.py b/venv/lib/python3.10/site-packages/huggingface_hub/_local_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..f213fe9e2724c0d607407397a96e59639a6f105d --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_local_folder.py @@ -0,0 +1,451 @@ +# coding=utf-8 +# Copyright 2024-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle the `../.cache/huggingface` folder in local directories. + +First discussed in https://github.com/huggingface/huggingface_hub/issues/1738 to store +download metadata when downloading files from the hub to a local directory (without +using the cache). + +./.cache/huggingface folder structure: +[4.0K] data +├── [4.0K] .cache +│ └── [4.0K] huggingface +│ └── [4.0K] download +│ ├── [ 16] file.parquet.metadata +│ ├── [ 16] file.txt.metadata +│ └── [4.0K] folder +│ └── [ 16] file.parquet.metadata +│ +├── [6.5G] file.parquet +├── [1.5K] file.txt +└── [4.0K] folder + └── [ 16] file.parquet + + +Download metadata file structure: +``` +# file.txt.metadata +11c5a3d5811f50298f278a704980280950aedb10 +a16a55fda99d2f2e7b69cce5cf93ff4ad3049930 +1712656091.123 + +# file.parquet.metadata +11c5a3d5811f50298f278a704980280950aedb10 +7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421 +1712656091.123 +} +``` +""" + +import base64 +import hashlib +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from .utils import WeakFileLock + + +logger = logging.getLogger(__name__) + + +@dataclass +class LocalDownloadFilePaths: + """ + Paths to the files related to a download process in a local dir. + + Returned by [`get_local_download_paths`]. + + Attributes: + file_path (`Path`): + Path where the file will be saved. + lock_path (`Path`): + Path to the lock file used to ensure atomicity when reading/writing metadata. + metadata_path (`Path`): + Path to the metadata file. + """ + + file_path: Path + lock_path: Path + metadata_path: Path + + def incomplete_path(self, etag: str) -> Path: + """Return the path where a file will be temporarily downloaded before being moved to `file_path`.""" + path = self.metadata_path.parent / f"{_short_hash(self.metadata_path.name)}.{etag}.incomplete" + resolved_path = str(path.resolve()) + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix. + if os.name == "nt" and len(resolved_path) > 255 and not resolved_path.startswith("\\\\?\\"): + path = Path("\\\\?\\" + resolved_path) + return path + + +@dataclass(frozen=True) +class LocalUploadFilePaths: + """ + Paths to the files related to an upload process in a local dir. + + Returned by [`get_local_upload_paths`]. + + Attributes: + path_in_repo (`str`): + Path of the file in the repo. + file_path (`Path`): + Path where the file will be saved. + lock_path (`Path`): + Path to the lock file used to ensure atomicity when reading/writing metadata. + metadata_path (`Path`): + Path to the metadata file. + """ + + path_in_repo: str + file_path: Path + lock_path: Path + metadata_path: Path + + +@dataclass +class LocalDownloadFileMetadata: + """ + Metadata about a file in the local directory related to a download process. + + Attributes: + filename (`str`): + Path of the file in the repo. + commit_hash (`str`): + Commit hash of the file in the repo. + etag (`str`): + ETag of the file in the repo. Used to check if the file has changed. + For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. + timestamp (`int`): + Unix timestamp of when the metadata was saved i.e. when the metadata was accurate. + """ + + filename: str + commit_hash: str + etag: str + timestamp: float + + +@dataclass +class LocalUploadFileMetadata: + """ + Metadata about a file in the local directory related to an upload process. + """ + + size: int + + # Default values correspond to "we don't know yet" + timestamp: Optional[float] = None + should_ignore: Optional[bool] = None + sha256: Optional[str] = None + upload_mode: Optional[str] = None + remote_oid: Optional[str] = None + is_uploaded: bool = False + is_committed: bool = False + + def save(self, paths: LocalUploadFilePaths) -> None: + """Save the metadata to disk.""" + with WeakFileLock(paths.lock_path): + with paths.metadata_path.open("w") as f: + new_timestamp = time.time() + f.write(str(new_timestamp) + "\n") + + f.write(str(self.size)) # never None + f.write("\n") + + if self.should_ignore is not None: + f.write(str(int(self.should_ignore))) + f.write("\n") + + if self.sha256 is not None: + f.write(self.sha256) + f.write("\n") + + if self.upload_mode is not None: + f.write(self.upload_mode) + f.write("\n") + + if self.remote_oid is not None: + f.write(self.remote_oid) + f.write("\n") + + f.write(str(int(self.is_uploaded)) + "\n") + f.write(str(int(self.is_committed)) + "\n") + + self.timestamp = new_timestamp + + +def get_local_download_paths(local_dir: Path, filename: str) -> LocalDownloadFilePaths: + """Compute paths to the files related to a download process. + + Folders containing the paths are all guaranteed to exist. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + filename (`str`): + Path of the file in the repo. + + Return: + [`LocalDownloadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path, incomplete_path). + """ + # filename is the path in the Hub repository (separated by '/') + # make sure to have a cross-platform transcription + sanitized_filename = os.path.join(*filename.split("/")) + if os.name == "nt": + if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename: + raise ValueError( + f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository" + " owner to rename this file." + ) + file_path = local_dir / sanitized_filename + metadata_path = _huggingface_dir(local_dir) / "download" / f"{sanitized_filename}.metadata" + lock_path = metadata_path.with_suffix(".lock") + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix + if os.name == "nt": + if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255: + file_path = Path("\\\\?\\" + os.path.abspath(file_path)) + lock_path = Path("\\\\?\\" + os.path.abspath(lock_path)) + metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path)) + + file_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.parent.mkdir(parents=True, exist_ok=True) + return LocalDownloadFilePaths(file_path=file_path, lock_path=lock_path, metadata_path=metadata_path) + + +def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePaths: + """Compute paths to the files related to an upload process. + + Folders containing the paths are all guaranteed to exist. + + Args: + local_dir (`Path`): + Path to the local directory that is uploaded. + filename (`str`): + Path of the file in the repo. + + Return: + [`LocalUploadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path). + """ + # filename is the path in the Hub repository (separated by '/') + # make sure to have a cross-platform transcription + sanitized_filename = os.path.join(*filename.split("/")) + if os.name == "nt": + if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename: + raise ValueError( + f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository" + " owner to rename this file." + ) + file_path = local_dir / sanitized_filename + metadata_path = _huggingface_dir(local_dir) / "upload" / f"{sanitized_filename}.metadata" + lock_path = metadata_path.with_suffix(".lock") + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix + if os.name == "nt": + if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255: + file_path = Path("\\\\?\\" + os.path.abspath(file_path)) + lock_path = Path("\\\\?\\" + os.path.abspath(lock_path)) + metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path)) + + file_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.parent.mkdir(parents=True, exist_ok=True) + return LocalUploadFilePaths( + path_in_repo=filename, file_path=file_path, lock_path=lock_path, metadata_path=metadata_path + ) + + +def read_download_metadata(local_dir: Path, filename: str) -> Optional[LocalDownloadFileMetadata]: + """Read metadata about a file in the local directory related to a download process. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + filename (`str`): + Path of the file in the repo. + + Return: + `[LocalDownloadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise. + """ + paths = get_local_download_paths(local_dir, filename) + with WeakFileLock(paths.lock_path): + if paths.metadata_path.exists(): + try: + with paths.metadata_path.open() as f: + commit_hash = f.readline().strip() + etag = f.readline().strip() + timestamp = float(f.readline().strip()) + metadata = LocalDownloadFileMetadata( + filename=filename, + commit_hash=commit_hash, + etag=etag, + timestamp=timestamp, + ) + except Exception as e: + # remove the metadata file if it is corrupted / not the right format + logger.warning( + f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue." + ) + try: + paths.metadata_path.unlink() + except Exception as e: + logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}") + return None + + try: + # check if the file exists and hasn't been modified since the metadata was saved + stat = paths.file_path.stat() + if ( + stat.st_mtime - 1 <= metadata.timestamp + ): # allow 1s difference as stat.st_mtime might not be precise + return metadata + logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.") + except FileNotFoundError: + # file does not exist => metadata is outdated + return None + return None + + +def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetadata: + """Read metadata about a file in the local directory related to an upload process. + + TODO: factorize logic with `read_download_metadata`. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + filename (`str`): + Path of the file in the repo. + + Return: + `[LocalUploadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise. + """ + paths = get_local_upload_paths(local_dir, filename) + with WeakFileLock(paths.lock_path): + if paths.metadata_path.exists(): + try: + with paths.metadata_path.open() as f: + timestamp = float(f.readline().strip()) + + size = int(f.readline().strip()) # never None + + _should_ignore = f.readline().strip() + should_ignore = None if _should_ignore == "" else bool(int(_should_ignore)) + + _sha256 = f.readline().strip() + sha256 = None if _sha256 == "" else _sha256 + + _upload_mode = f.readline().strip() + upload_mode = None if _upload_mode == "" else _upload_mode + if upload_mode not in (None, "regular", "lfs"): + raise ValueError(f"Invalid upload mode in metadata {paths.path_in_repo}: {upload_mode}") + + _remote_oid = f.readline().strip() + remote_oid = None if _remote_oid == "" else _remote_oid + + is_uploaded = bool(int(f.readline().strip())) + is_committed = bool(int(f.readline().strip())) + + metadata = LocalUploadFileMetadata( + timestamp=timestamp, + size=size, + should_ignore=should_ignore, + sha256=sha256, + upload_mode=upload_mode, + remote_oid=remote_oid, + is_uploaded=is_uploaded, + is_committed=is_committed, + ) + except Exception as e: + # remove the metadata file if it is corrupted / not the right format + logger.warning( + f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue." + ) + try: + paths.metadata_path.unlink() + except Exception as e: + logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}") + + # corrupted metadata => we don't know anything expect its size + return LocalUploadFileMetadata(size=paths.file_path.stat().st_size) + + # TODO: can we do better? + if ( + metadata.timestamp is not None + and metadata.is_uploaded # file was uploaded + and not metadata.is_committed # but not committed + and time.time() - metadata.timestamp > 20 * 3600 # and it's been more than 20 hours + ): # => we consider it as garbage-collected by S3 + metadata.is_uploaded = False + + # check if the file exists and hasn't been modified since the metadata was saved + try: + if metadata.timestamp is not None and paths.file_path.stat().st_mtime <= metadata.timestamp: + return metadata + logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.") + except FileNotFoundError: + # file does not exist => metadata is outdated + pass + + # empty metadata => we don't know anything expect its size + return LocalUploadFileMetadata(size=paths.file_path.stat().st_size) + + +def write_download_metadata(local_dir: Path, filename: str, commit_hash: str, etag: str) -> None: + """Write metadata about a file in the local directory related to a download process. + + Args: + local_dir (`Path`): + Path to the local directory in which files are downloaded. + """ + paths = get_local_download_paths(local_dir, filename) + with WeakFileLock(paths.lock_path): + with paths.metadata_path.open("w") as f: + f.write(f"{commit_hash}\n{etag}\n{time.time()}\n") + + +def _huggingface_dir(local_dir: Path) -> Path: + """Return the path to the `.cache/huggingface` directory in a local directory.""" + # Wrap in lru_cache to avoid overwriting the .gitignore file if called multiple times + path = local_dir / ".cache" / "huggingface" + path.mkdir(exist_ok=True, parents=True) + + # Create a .gitignore file in the .cache/huggingface directory if it doesn't exist + # Should be thread-safe enough like this. + gitignore = path / ".gitignore" + gitignore_lock = path / ".gitignore.lock" + if not gitignore.exists(): + try: + with WeakFileLock(gitignore_lock, timeout=0.1): + gitignore.write_text("*") + except IndexError: + pass + except OSError: # TimeoutError, FileNotFoundError, PermissionError, etc. + pass + try: + gitignore_lock.unlink() + except OSError: + pass + return path + + +def _short_hash(filename: str) -> str: + return base64.urlsafe_b64encode(hashlib.sha1(filename.encode()).digest()).decode() diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_login.py b/venv/lib/python3.10/site-packages/huggingface_hub/_login.py new file mode 100644 index 0000000000000000000000000000000000000000..fe266d2deac58d94109aaedb4e9a6ee162d8f0a5 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_login.py @@ -0,0 +1,492 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains methods to log in to the Hub.""" + +import os +import subprocess +from getpass import getpass +from pathlib import Path +from typing import Optional + +import typer + +from . import constants +from .utils import ( + ANSI, + capture_output, + get_token, + is_google_colab, + is_notebook, + list_credential_helpers, + logging, + run_subprocess, + set_git_credential, + unset_git_credential, +) +from .utils._auth import ( + _get_token_by_name, + _get_token_from_environment, + _get_token_from_file, + _get_token_from_google_colab, + _save_stored_tokens, + _save_token, + get_stored_tokens, +) + + +logger = logging.get_logger(__name__) + +_HF_LOGO_ASCII = """ + _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| + _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| + _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| +""" + + +def login( + token: Optional[str] = None, + *, + add_to_git_credential: bool = False, + skip_if_logged_in: bool = False, +) -> None: + """Login the machine to access the Hub. + + The `token` is persisted in cache and set as a git credential. Once done, the machine + is logged in and the access token will be available across all `huggingface_hub` + components. If `token` is not provided, it will be prompted to the user either with + a widget (in a notebook) or via the terminal. + + To log in from outside of a script, one can also use `hf auth login` which is + a cli command that wraps [`login`]. + + > [!TIP] + > [`login`] is a drop-in replacement method for [`notebook_login`] as it wraps and + > extends its capabilities. + + > [!TIP] + > When the token is not passed, [`login`] will automatically detect if the script runs + > in a notebook or not. However, this detection might not be accurate due to the + > variety of notebooks that exists nowadays. If that is the case, you can always force + > the UI by using [`notebook_login`] or [`interpreter_login`]. + + Args: + token (`str`, *optional*): + User access token to generate from https://huggingface.co/settings/token. + add_to_git_credential (`bool`, defaults to `False`): + If `True`, token will be set as git credential. If no git credential helper + is configured, a warning will be displayed to the user. If `token` is `None`, + the value of `add_to_git_credential` is ignored and will be prompted again + to the end user. + skip_if_logged_in (`bool`, defaults to `False`): + If `True`, do not prompt for token if user is already logged in. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If an organization token is passed. Only personal account tokens are valid + to log in. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If token is invalid. + [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + If running in a notebook but `ipywidgets` is not installed. + """ + if token is not None: + if not add_to_git_credential: + logger.info( + "The token has not been saved to the git credentials helper. Pass " + "`add_to_git_credential=True` in this function directly or " + "`--add-to-git-credential` if using via `hf`CLI if " + "you want to set the git credential as well." + ) + _login(token, add_to_git_credential=add_to_git_credential) + elif is_notebook(): + notebook_login(skip_if_logged_in=skip_if_logged_in) + else: + interpreter_login(skip_if_logged_in=skip_if_logged_in) + + +def logout(token_name: Optional[str] = None) -> None: + """Logout the machine from the Hub. + + Token is deleted from the machine and removed from git credential. + + Args: + token_name (`str`, *optional*): + Name of the access token to logout from. If `None`, will log out from all saved access tokens. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): + If the access token name is not found. + """ + if get_token() is None and not get_stored_tokens(): # No active token and no saved access tokens + logger.warning("Not logged in!") + return + if not token_name: + # Delete all saved access tokens and token + for file_path in (constants.HF_TOKEN_PATH, constants.HF_STORED_TOKENS_PATH): + try: + Path(file_path).unlink() + except FileNotFoundError: + pass + logger.info("Successfully logged out from all access tokens.") + else: + _logout_from_token(token_name) + logger.info(f"Successfully logged out from access token: {token_name}.") + + unset_git_credential() + + # Check if still logged in + if _get_token_from_google_colab() is not None: + raise EnvironmentError( + "You are automatically logged in using a Google Colab secret.\n" + "To log out, you must unset the `HF_TOKEN` secret in your Colab settings." + ) + if _get_token_from_environment() is not None: + raise EnvironmentError( + "Token has been deleted from your machine but you are still logged in.\n" + "To log out, you must clear out both `HF_TOKEN` and `HUGGING_FACE_HUB_TOKEN` environment variables." + ) + + +def auth_switch(token_name: str, add_to_git_credential: bool = False) -> None: + """Switch to a different access token. + + Args: + token_name (`str`): + Name of the access token to switch to. + add_to_git_credential (`bool`, defaults to `False`): + If `True`, token will be set as git credential. If no git credential helper + is configured, a warning will be displayed to the user. If `token` is `None`, + the value of `add_to_git_credential` is ignored and will be prompted again + to the end user. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): + If the access token name is not found. + """ + token = _get_token_by_name(token_name) + if not token: + raise ValueError(f"Access token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}") + # Write token to HF_TOKEN_PATH + _set_active_token(token_name, add_to_git_credential) + logger.info(f"The current active token is: {token_name}") + token_from_environment = _get_token_from_environment() + if token_from_environment is not None and token_from_environment != token: + logger.warning( + "The environment variable `HF_TOKEN` is set and will override the access token you've just switched to." + ) + + +def auth_list() -> None: + """List all stored access tokens.""" + tokens = get_stored_tokens() + + if not tokens: + if _get_token_from_environment(): + logger.info("No stored access tokens found.") + logger.warning("Note: Environment variable `HF_TOKEN` is set and is the current active token.") + else: + logger.info("No access tokens found.") + return + # Find current token + current_token = get_token() + current_token_name = None + for token_name in tokens: + if tokens.get(token_name) == current_token: + current_token_name = token_name + # Print header + max_offset = max(len("token"), max(len(token) for token in tokens)) + 2 + print(f" {{:<{max_offset}}}| {{:<15}}".format("name", "token")) + print("-" * (max_offset + 2) + "|" + "-" * 15) + + # Print saved access tokens + for token_name in tokens: + token = tokens.get(token_name, "") + masked_token = f"{token[:3]}****{token[-4:]}" if token != "" else token + is_current = "*" if token == current_token else " " + + print(f"{is_current} {{:<{max_offset}}}| {{:<15}}".format(token_name, masked_token)) + + if _get_token_from_environment(): + logger.warning( + "\nNote: Environment variable `HF_TOKEN` is set and is the current active token independently from the stored tokens listed above." + ) + elif current_token_name is None: + logger.warning( + "\nNote: No active token is set and no environment variable `HF_TOKEN` is found. Use `hf auth login` to log in." + ) + + +### +# Interpreter-based login (text) +### + + +def interpreter_login(*, skip_if_logged_in: bool = False) -> None: + """ + Displays a prompt to log in to the HF website and store the token. + + This is equivalent to [`login`] without passing a token when not run in a notebook. + [`interpreter_login`] is useful if you want to force the use of the terminal prompt + instead of a notebook widget. + + For more details, see [`login`]. + + Args: + skip_if_logged_in (`bool`, defaults to `False`): + If `True`, do not prompt for token if user is already logged in. + """ + if not skip_if_logged_in and get_token() is not None: + logger.info("User is already logged in.") + return + + print(_HF_LOGO_ASCII) + if get_token() is not None: + logger.info( + " A token is already saved on your machine. Run `hf auth whoami`" + " to get more information or `hf auth logout` if you want" + " to log out." + ) + logger.info(" Setting a new token will erase the existing one.") + + logger.info( + " To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens ." + ) + if os.name == "nt": + logger.info("Token can be pasted using 'Right-Click'.") + token = getpass("Enter your token (input will not be visible): ") + add_to_git_credential = typer.confirm("Add token as git credential?") + + _login(token=token, add_to_git_credential=add_to_git_credential) + + +### +# Notebook-based login (widget) +### + +NOTEBOOK_LOGIN_PASSWORD_HTML = """

Immediately click login after typing your password or +it might be stored in plain text in this notebook file.
""" + + +NOTEBOOK_LOGIN_TOKEN_HTML_START = """

Copy a token from your Hugging Face +tokens page and paste it below.
Immediately click login after copying +your token or it might be stored in plain text in this notebook file.
""" + + +NOTEBOOK_LOGIN_TOKEN_HTML_END = """ +Pro Tip: If you don't already have one, you can create a dedicated +'notebooks' token with 'write' access, that you can then easily reuse for all +notebooks. """ + + +def notebook_login(*, skip_if_logged_in: bool = False) -> None: + """ + Displays a widget to log in to the HF website and store the token. + + This is equivalent to [`login`] without passing a token when run in a notebook. + [`notebook_login`] is useful if you want to force the use of the notebook widget + instead of a prompt in the terminal. + + For more details, see [`login`]. + + Args: + skip_if_logged_in (`bool`, defaults to `False`): + If `True`, do not prompt for token if user is already logged in. + """ + try: + import ipywidgets.widgets as widgets # type: ignore + from IPython.display import display # type: ignore + except ImportError: + raise ImportError( + "The `notebook_login` function can only be used in a notebook (Jupyter or" + " Colab) and you need the `ipywidgets` module: `pip install ipywidgets`." + ) + if not skip_if_logged_in and get_token() is not None: + logger.info("User is already logged in.") + return + + box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%") + + token_widget = widgets.Password(description="Token:") + git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?") + token_finish_button = widgets.Button(description="Login") + + login_token_widget = widgets.VBox( + [ + widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_START), + token_widget, + git_checkbox_widget, + token_finish_button, + widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_END), + ], + layout=box_layout, + ) + display(login_token_widget) + + # On click events + def login_token_event(t): + """Event handler for the login button.""" + token = token_widget.value + add_to_git_credential = git_checkbox_widget.value + # Erase token and clear value to make sure it's not saved in the notebook. + token_widget.value = "" + # Hide inputs + login_token_widget.children = [widgets.Label("Connecting...")] + try: + with capture_output() as captured: + _login(token, add_to_git_credential=add_to_git_credential) + message = captured.getvalue() + except Exception as error: + message = str(error) + # Print result (success message or error) + login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()] + + token_finish_button.on_click(login_token_event) + + +### +# Login private helpers +### + + +def _login( + token: str, + add_to_git_credential: bool, +) -> None: + from .hf_api import whoami # avoid circular import + + if token.startswith("api_org"): + raise ValueError("You must use your personal account token, not an organization token.") + + token_info = whoami(token) + permission = token_info["auth"]["accessToken"]["role"] + logger.info(f"Token is valid (permission: {permission}).") + + token_name = token_info["auth"]["accessToken"]["displayName"] + # Store token locally + _save_token(token=token, token_name=token_name) + # Set active token + _set_active_token(token_name=token_name, add_to_git_credential=add_to_git_credential) + logger.info("Login successful.") + if _get_token_from_environment(): + logger.warning( + "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured." + ) + else: + logger.info(f"The current active token is: `{token_name}`") + + +def _logout_from_token(token_name: str) -> None: + """Logout from a specific access token. + + Args: + token_name (`str`): + The name of the access token to logout from. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): + If the access token name is not found. + """ + stored_tokens = get_stored_tokens() + # If there is no access tokens saved or the access token name is not found, do nothing + if not stored_tokens or token_name not in stored_tokens: + return + + token = stored_tokens.pop(token_name) + _save_stored_tokens(stored_tokens) + + if token == _get_token_from_file(): + logger.warning(f"Active token '{token_name}' has been deleted.") + Path(constants.HF_TOKEN_PATH).unlink(missing_ok=True) + + +def _set_active_token( + token_name: str, + add_to_git_credential: bool, +) -> None: + """Set the active access token. + + Args: + token_name (`str`): + The name of the token to set as active. + """ + token = _get_token_by_name(token_name) + if not token: + raise ValueError(f"Token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}") + if add_to_git_credential: + if _is_git_credential_helper_configured(): + set_git_credential(token) + logger.info( + "Your token has been saved in your configured git credential helpers" + + f" ({','.join(list_credential_helpers())})." + ) + else: + logger.warning("Token has not been saved to git credential helper.") + # Write token to HF_TOKEN_PATH + path = Path(constants.HF_TOKEN_PATH) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(token) + logger.info(f"Your token has been saved to {constants.HF_TOKEN_PATH}") + + +def _is_git_credential_helper_configured() -> bool: + """Check if a git credential helper is configured. + + Warns user if not the case (except for Google Colab where "store" is set by default + by `huggingface_hub`). + """ + helpers = list_credential_helpers() + if len(helpers) > 0: + return True # Do not warn: at least 1 helper is set + + # Only in Google Colab to avoid the warning message + # See https://github.com/huggingface/huggingface_hub/issues/1043#issuecomment-1247010710 + if is_google_colab(): + _set_store_as_git_credential_helper_globally() + return True # Do not warn: "store" is used by default in Google Colab + + # Otherwise, warn user + print( + ANSI.red( + "Cannot authenticate through git-credential as no helper is defined on your" + " machine.\nYou might have to re-authenticate when pushing to the Hugging" + " Face Hub.\nRun the following command in your terminal in case you want to" + " set the 'store' credential helper as default.\n\ngit config --global" + " credential.helper store\n\nRead" + " https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more" + " details." + ) + ) + return False + + +def _set_store_as_git_credential_helper_globally() -> None: + """Set globally the credential.helper to `store`. + + To be used only in Google Colab as we assume the user doesn't care about the git + credential config. It is the only particular case where we don't want to display the + warning message in [`notebook_login()`]. + + Related: + - https://github.com/huggingface/huggingface_hub/issues/1043 + - https://github.com/huggingface/huggingface_hub/issues/1051 + - https://git-scm.com/docs/git-credential-store + """ + try: + run_subprocess("git config --global credential.helper store") + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_oauth.py b/venv/lib/python3.10/site-packages/huggingface_hub/_oauth.py new file mode 100644 index 0000000000000000000000000000000000000000..0594c7b56b808d12a37bbf419c59d335a31c03ff --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_oauth.py @@ -0,0 +1,460 @@ +import datetime +import hashlib +import logging +import os +import time +import urllib.parse +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, Optional, Union + +from . import constants +from .hf_api import whoami +from .utils import experimental, get_token + + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import fastapi + + +@dataclass +class OAuthOrgInfo: + """ + Information about an organization linked to a user logged in with OAuth. + + Attributes: + sub (`str`): + Unique identifier for the org. OpenID Connect field. + name (`str`): + The org's full name. OpenID Connect field. + preferred_username (`str`): + The org's username. OpenID Connect field. + picture (`str`): + The org's profile picture URL. OpenID Connect field. + plan (`str`, *optional*): + The org's plan (e.g., "enterprise", "team"). Hugging Face field. + can_pay (`Optional[bool]`, *optional*): + Whether the org has a payment method set up. Hugging Face field. + role_in_org (`Optional[str]`, *optional*): + The user's role in the org. Hugging Face field. + security_restrictions (`Optional[list[Literal["ip", "token-policy", "mfa", "sso"]]]`, *optional*): + Array of security restrictions that the user hasn't completed for this org. Possible values: "ip", "token-policy", "mfa", "sso". Hugging Face field. + """ + + sub: str + name: str + preferred_username: str + picture: str + plan: Optional[str] = None + can_pay: Optional[bool] = None + role_in_org: Optional[str] = None + security_restrictions: Optional[list[Literal["ip", "token-policy", "mfa", "sso"]]] = None + + +@dataclass +class OAuthUserInfo: + """ + Information about a user logged in with OAuth. + + Attributes: + sub (`str`): + Unique identifier for the user, even in case of rename. OpenID Connect field. + name (`str`): + The user's full name. OpenID Connect field. + preferred_username (`str`): + The user's username. OpenID Connect field. + email_verified (`Optional[bool]`, *optional*): + Indicates if the user's email is verified. OpenID Connect field. + email (`Optional[str]`, *optional*): + The user's email address. OpenID Connect field. + picture (`str`): + The user's profile picture URL. OpenID Connect field. + profile (`str`): + The user's profile URL. OpenID Connect field. + website (`Optional[str]`, *optional*): + The user's website URL. OpenID Connect field. + is_pro (`bool`): + Whether the user is a pro user. Hugging Face field. + can_pay (`Optional[bool]`, *optional*): + Whether the user has a payment method set up. Hugging Face field. + orgs (`Optional[list[OrgInfo]]`, *optional*): + List of organizations the user is part of. Hugging Face field. + """ + + sub: str + name: str + preferred_username: str + email_verified: Optional[bool] + email: Optional[str] + picture: str + profile: str + website: Optional[str] + is_pro: bool + can_pay: Optional[bool] + orgs: Optional[list[OAuthOrgInfo]] + + +@dataclass +class OAuthInfo: + """ + Information about the OAuth login. + + Attributes: + access_token (`str`): + The access token. + access_token_expires_at (`datetime.datetime`): + The expiration date of the access token. + user_info ([`OAuthUserInfo`]): + The user information. + state (`str`, *optional*): + State passed to the OAuth provider in the original request to the OAuth provider. + scope (`str`): + Granted scope. + """ + + access_token: str + access_token_expires_at: datetime.datetime + user_info: OAuthUserInfo + state: Optional[str] + scope: str + + +@experimental +def attach_huggingface_oauth(app: "fastapi.FastAPI", route_prefix: str = "/"): + """ + Add OAuth endpoints to a FastAPI app to enable OAuth login with Hugging Face. + + How to use: + - Call this method on your FastAPI app to add the OAuth endpoints. + - Inside your route handlers, call `parse_huggingface_oauth(request)` to retrieve the OAuth info. + - If user is logged in, an [`OAuthInfo`] object is returned with the user's info. If not, `None` is returned. + - In your app, make sure to add links to `/oauth/huggingface/login` and `/oauth/huggingface/logout` for the user to log in and out. + + Example: + ```py + from huggingface_hub import attach_huggingface_oauth, parse_huggingface_oauth + + # Create a FastAPI app + app = FastAPI() + + # Add OAuth endpoints to the FastAPI app + attach_huggingface_oauth(app) + + # Add a route that greets the user if they are logged in + @app.get("/") + def greet_json(request: Request): + # Retrieve the OAuth info from the request + oauth_info = parse_huggingface_oauth(request) # e.g. OAuthInfo dataclass + if oauth_info is None: + return {"msg": "Not logged in!"} + return {"msg": f"Hello, {oauth_info.user_info.preferred_username}!"} + ``` + """ + # TODO: handle generic case (handling OAuth in a non-Space environment with custom dev values) (low priority) + + # Add SessionMiddleware to the FastAPI app to store the OAuth info in the session. + # Session Middleware requires a secret key to sign the cookies. Let's use a hash + # of the OAuth secret key to make it unique to the Space + updated in case OAuth + # config gets updated. When ran locally, we use an empty string as a secret key. + try: + from starlette.middleware.sessions import SessionMiddleware + except ImportError as e: + raise ImportError( + "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " + "`huggingface_hub[oauth]` to your requirements.txt file in order to install the required dependencies." + ) from e + session_secret = (constants.OAUTH_CLIENT_SECRET or "") + "-v1" + app.add_middleware( + SessionMiddleware, # type: ignore[arg-type] + secret_key=hashlib.sha256(session_secret.encode()).hexdigest(), + same_site="none", + https_only=True, + ) # type: ignore + + # Add OAuth endpoints to the FastAPI app: + # - {route_prefix}/oauth/huggingface/login + # - {route_prefix}/oauth/huggingface/callback + # - {route_prefix}/oauth/huggingface/logout + # If the app is running in a Space, OAuth is enabled normally. + # Otherwise, we mock the endpoints to make the user log in with a fake user profile - without any calls to hf.co. + route_prefix = route_prefix.strip("/") + if os.getenv("SPACE_ID") is not None: + logger.info("OAuth is enabled in the Space. Adding OAuth routes.") + _add_oauth_routes(app, route_prefix=route_prefix) + else: + logger.info("App is not running in a Space. Adding mocked OAuth routes.") + _add_mocked_oauth_routes(app, route_prefix=route_prefix) + + +def parse_huggingface_oauth(request: "fastapi.Request") -> Optional[OAuthInfo]: + """ + Returns the information from a logged-in user as a [`OAuthInfo`] object. + + For flexibility and future-proofing, this method is very lax in its parsing and does not raise errors. + Missing fields are set to `None` without a warning. + + Return `None`, if the user is not logged in (no info in session cookie). + + See [`attach_huggingface_oauth`] for an example on how to use this method. + """ + if "oauth_info" not in request.session: + logger.debug("No OAuth info in session.") + return None + + logger.debug("Parsing OAuth info from session.") + oauth_data = request.session["oauth_info"] + user_data = oauth_data.get("userinfo", {}) + orgs_data = user_data.get("orgs", []) + + orgs = ( + [ + OAuthOrgInfo( + sub=org.get("sub"), + name=org.get("name"), + preferred_username=org.get("preferred_username"), + picture=org.get("picture"), + plan=org.get("plan"), + can_pay=org.get("canPay"), + role_in_org=org.get("roleInOrg"), + security_restrictions=org.get("securityRestrictions"), + ) + for org in orgs_data + ] + if orgs_data + else None + ) + + user_info = OAuthUserInfo( + sub=user_data.get("sub"), + name=user_data.get("name"), + preferred_username=user_data.get("preferred_username"), + email_verified=user_data.get("email_verified"), + email=user_data.get("email"), + picture=user_data.get("picture"), + profile=user_data.get("profile"), + website=user_data.get("website"), + is_pro=user_data.get("isPro"), + can_pay=user_data.get("canPay"), + orgs=orgs, + ) + + return OAuthInfo( + access_token=oauth_data.get("access_token"), + access_token_expires_at=datetime.datetime.fromtimestamp(oauth_data.get("expires_at")), + user_info=user_info, + state=oauth_data.get("state"), + scope=oauth_data.get("scope"), + ) + + +def _add_oauth_routes(app: "fastapi.FastAPI", route_prefix: str) -> None: + """Add OAuth routes to the FastAPI app (login, callback handler and logout).""" + try: + import fastapi + from authlib.integrations.base_client.errors import MismatchingStateError + from authlib.integrations.starlette_client import OAuth + from fastapi.responses import RedirectResponse + except ImportError as e: + raise ImportError( + "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " + "`huggingface_hub[oauth]` to your requirements.txt file." + ) from e + + # Check environment variables + msg = ( + "OAuth is required but '{}' environment variable is not set. Make sure you've enabled OAuth in your Space by" + " setting `hf_oauth: true` in the Space metadata." + ) + if constants.OAUTH_CLIENT_ID is None: + raise ValueError(msg.format("OAUTH_CLIENT_ID")) + if constants.OAUTH_CLIENT_SECRET is None: + raise ValueError(msg.format("OAUTH_CLIENT_SECRET")) + if constants.OAUTH_SCOPES is None: + raise ValueError(msg.format("OAUTH_SCOPES")) + if constants.OPENID_PROVIDER_URL is None: + raise ValueError(msg.format("OPENID_PROVIDER_URL")) + + # Register OAuth server + oauth = OAuth() + oauth.register( + name="huggingface", + client_id=constants.OAUTH_CLIENT_ID, + client_secret=constants.OAUTH_CLIENT_SECRET, + client_kwargs={"scope": constants.OAUTH_SCOPES}, + server_metadata_url=constants.OPENID_PROVIDER_URL + "/.well-known/openid-configuration", + ) + + login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) + + # Register OAuth endpoints + @app.get(login_uri) + async def oauth_login(request: fastapi.Request) -> RedirectResponse: + """Endpoint that redirects to HF OAuth page.""" + redirect_uri = _generate_redirect_uri(request) + return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore + + @app.get(callback_uri) + async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: + """Endpoint that handles the OAuth callback.""" + try: + oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore + except MismatchingStateError: + # Parse query params + nb_redirects = int(request.query_params.get("_nb_redirects", 0)) + target_url = request.query_params.get("_target_url") + + # Build redirect URI with the same query params as before and bump nb_redirects count + query_params: dict[str, Union[int, str]] = {"_nb_redirects": nb_redirects + 1} + if target_url: + query_params["_target_url"] = target_url + + redirect_uri = f"{login_uri}?{urllib.parse.urlencode(query_params)}" + + # If the user is redirected more than 3 times, it is very likely that the cookie is not working properly. + # (e.g. browser is blocking third-party cookies in iframe). In this case, redirect the user in the + # non-iframe view. + if nb_redirects > constants.OAUTH_MAX_REDIRECTS: + host = os.environ.get("SPACE_HOST") + if host is None: # cannot happen in a Space + raise RuntimeError( + "App is not running in a Space (SPACE_HOST environment variable is not set). Cannot redirect to non-iframe view." + ) from None + host_url = "https://" + host.rstrip("/") + return RedirectResponse(host_url + redirect_uri) + + # Redirect the user to the login page again + return RedirectResponse(redirect_uri) + + # OAuth login worked => store the user info in the session and redirect + logger.debug("Successfully logged in with OAuth. Storing user info in session.") + request.session["oauth_info"] = oauth_info + return RedirectResponse(_get_redirect_target(request)) + + @app.get(logout_uri) + async def oauth_logout(request: fastapi.Request) -> RedirectResponse: + """Endpoint that logs out the user (e.g. delete info from cookie session).""" + logger.debug("Logged out with OAuth. Removing user info from session.") + request.session.pop("oauth_info", None) + return RedirectResponse(_get_redirect_target(request)) + + +def _add_mocked_oauth_routes(app: "fastapi.FastAPI", route_prefix: str = "/") -> None: + """Add fake oauth routes if app is run locally and OAuth is enabled. + + Using OAuth will have the same behavior as in a Space but instead of authenticating with HF, a mocked user profile + is added to the session. + """ + try: + import fastapi + from fastapi.responses import RedirectResponse + from starlette.datastructures import URL + except ImportError as e: + raise ImportError( + "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " + "`huggingface_hub[oauth]` to your requirements.txt file." + ) from e + + warnings.warn( + "OAuth is not supported outside of a Space environment. To help you debug your app locally, the oauth endpoints" + " are mocked to return your profile and token. To make it work, your machine must be logged in to Huggingface." + ) + mocked_oauth_info = _get_mocked_oauth_info() + + login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) + + # Define OAuth routes + @app.get(login_uri) + async def oauth_login(request: fastapi.Request) -> RedirectResponse: + """Fake endpoint that redirects to HF OAuth page.""" + # Define target (where to redirect after login) + redirect_uri = _generate_redirect_uri(request) + return RedirectResponse(callback_uri + "?" + urllib.parse.urlencode({"_target_url": redirect_uri})) + + @app.get(callback_uri) + async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: + """Endpoint that handles the OAuth callback.""" + request.session["oauth_info"] = mocked_oauth_info + return RedirectResponse(_get_redirect_target(request)) + + @app.get(logout_uri) + async def oauth_logout(request: fastapi.Request) -> RedirectResponse: + """Endpoint that logs out the user (e.g. delete cookie session).""" + request.session.pop("oauth_info", None) + logout_url = URL("/").include_query_params(**request.query_params) + return RedirectResponse(url=logout_url, status_code=302) # see https://github.com/gradio-app/gradio/pull/9659 + + +def _generate_redirect_uri(request: "fastapi.Request") -> str: + if "_target_url" in request.query_params: + # if `_target_url` already in query params => respect it + target = request.query_params["_target_url"] + else: + # otherwise => keep query params + target = "/?" + urllib.parse.urlencode(request.query_params) + + redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target) + redirect_uri_as_str = str(redirect_uri) + if redirect_uri.netloc.endswith(".hf.space"): + # In Space, FastAPI redirect as http but we want https + redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://") + return redirect_uri_as_str + + +def _get_redirect_target(request: "fastapi.Request", default_target: str = "/") -> str: + return request.query_params.get("_target_url", default_target) + + +def _get_mocked_oauth_info() -> dict: + token = get_token() + if token is None: + raise ValueError( + "Your machine must be logged in to HF to debug an OAuth app locally. Please" + " run `hf auth login` or set `HF_TOKEN` as environment variable " + "with one of your access token. You can generate a new token in your " + "settings page (https://huggingface.co/settings/tokens)." + ) + + user = whoami() + if user["type"] != "user": + raise ValueError( + "Your machine is not logged in with a personal account. Please use a " + "personal access token. You can generate a new token in your settings page" + " (https://huggingface.co/settings/tokens)." + ) + + return { + "access_token": token, + "token_type": "bearer", + "expires_in": 8 * 60 * 60, # 8 hours + "id_token": "FOOBAR", + "scope": "openid profile", + "refresh_token": "hf_oauth__refresh_token", + "expires_at": int(time.time()) + 8 * 60 * 60, # 8 hours + "userinfo": { + "sub": "0123456789", + "name": user["fullname"], + "preferred_username": user["name"], + "profile": f"https://huggingface.co/{user['name']}", + "picture": user["avatarUrl"], + "website": "", + "aud": "00000000-0000-0000-0000-000000000000", + "auth_time": 1691672844, + "nonce": "aaaaaaaaaaaaaaaaaaa", + "iat": 1691672844, + "exp": 1691676444, + "iss": "https://huggingface.co", + }, + } + + +def _get_oauth_uris(route_prefix: str = "/") -> tuple[str, str, str]: + route_prefix = route_prefix.strip("/") + if route_prefix: + route_prefix = f"/{route_prefix}" + return ( + f"{route_prefix}/oauth/huggingface/login", + f"{route_prefix}/oauth/huggingface/callback", + f"{route_prefix}/oauth/huggingface/logout", + ) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_snapshot_download.py b/venv/lib/python3.10/site-packages/huggingface_hub/_snapshot_download.py new file mode 100644 index 0000000000000000000000000000000000000000..59d5fa9239966b1271eac100e005bce3e8c564e7 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_snapshot_download.py @@ -0,0 +1,465 @@ +import os +from pathlib import Path +from typing import Iterable, List, Literal, Optional, Union, overload + +import httpx +from tqdm.auto import tqdm as base_tqdm +from tqdm.contrib.concurrent import thread_map + +from . import constants +from .errors import ( + DryRunError, + GatedRepoError, + HfHubHTTPError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name +from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo +from .utils import OfflineModeIsEnabled, filter_repo_objects, is_tqdm_disabled, logging, validate_hf_hub_args +from .utils import tqdm as hf_tqdm + + +logger = logging.get_logger(__name__) + +LARGE_REPO_THRESHOLD = 1000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough + + +@overload +def snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Optional[Union[dict, str]] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[type[base_tqdm]] = None, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + dry_run: Literal[False] = False, +) -> str: ... + + +@overload +def snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Optional[Union[dict, str]] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[type[base_tqdm]] = None, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + dry_run: Literal[True] = True, +) -> list[DryRunFileInfo]: ... + + +@overload +def snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Optional[Union[dict, str]] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[type[base_tqdm]] = None, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + dry_run: bool = False, +) -> Union[str, list[DryRunFileInfo]]: ... + + +@validate_hf_hub_args +def snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Optional[Union[dict, str]] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[type[base_tqdm]] = None, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + dry_run: bool = False, +) -> Union[str, list[DryRunFileInfo]]: + """Download repo files. + + Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from + a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. You can also filter which files to download using + `allow_patterns` and `ignore_patterns`. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files. While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly + configured. It is also not possible to filter which files to download when cloning a repository using git. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded files will be placed under this directory. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`str`, `dict`, *optional*): + The user-agent info in the form of a dictionary or a string. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `httpx.request`. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in the local cache. + token (`str`, `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If a string, it's used as the authentication token. + headers (`dict`, *optional*): + Additional headers to include in the request. Those headers take precedence over the others. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_patterns (`list[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are downloaded. + ignore_patterns (`list[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not downloaded. + max_workers (`int`, *optional*): + Number of concurrent threads to download files (1 thread = 1 file download). + Defaults to 8. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Note that the `tqdm_class` is not passed to each individual download. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + dry_run (`bool`, *optional*, defaults to `False`): + If `True`, perform a dry run without actually downloading the files. Returns a list of + [`DryRunFileInfo`] objects containing information about what would be downloaded. + + Returns: + `str` or list of [`DryRunFileInfo`]: + - If `dry_run=False`: Local snapshot path. + - If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` and the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid. + """ + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + if revision is None: + revision = constants.DEFAULT_REVISION + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if repo_type is None: + repo_type = "model" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") + + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) + + api = HfApi( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + endpoint=endpoint, + headers=headers, + token=token, + ) + + repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None + api_call_error: Optional[Exception] = None + if not local_files_only: + # try/except logic to handle different errors => taken from `hf_hub_download` + try: + # if we have internet connection we want to list files to download + repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision) + except httpx.ProxyError: + # Actually raise on proxy error + raise + except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled) as error: + # Internet connection is down + # => will try to use local files only + api_call_error = error + pass + except RevisionNotFoundError: + # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) + raise + except HfHubHTTPError as error: + # Multiple reasons for an http error: + # - Repository is private and invalid/missing token sent + # - Repository is gated and invalid/missing token sent + # - Hub is down (error 500 or 504) + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + api_call_error = error + pass + + # At this stage, if `repo_info` is None it means either: + # - internet connection is down + # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True) + # - repo is private/gated and invalid/missing token sent + # - Hub is down + # => let's look if we can find the appropriate folder in the cache: + # - if the specified revision is a commit hash, look inside "snapshots". + # - f the specified revision is a branch or tag, look inside "refs". + # => if local_dir is not None, we will return the path to the local folder if it exists. + if repo_info is None: + if dry_run: + raise DryRunError( + "Dry run cannot be performed as the repository cannot be accessed. Please check your internet connection or authentication token." + ) from api_call_error + + # Try to get which commit hash corresponds to the specified revision + commit_hash = None + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + ref_path = os.path.join(storage_folder, "refs", revision) + if os.path.exists(ref_path): + # retrieve commit_hash from refs file + with open(ref_path) as f: + commit_hash = f.read() + + # Try to locate snapshot folder for this commit hash + if commit_hash is not None and local_dir is None: + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(snapshot_folder): + # Snapshot folder exists => let's return it + # (but we can't check if all the files are actually there) + return snapshot_folder + + # If local_dir is not None, return it if it exists and is not empty + if local_dir is not None: + local_dir = Path(local_dir) + if local_dir.is_dir() and any(local_dir.iterdir()): + logger.warning( + f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})." + ) + return str(local_dir.resolve()) + # If we couldn't find the appropriate folder on disk, raise an error. + if local_files_only: + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass " + "'local_files_only=False' as input." + ) + elif isinstance(api_call_error, OfflineModeIsEnabled): + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set " + "'HF_HUB_OFFLINE=0' as environment variable." + ) from api_call_error + elif isinstance(api_call_error, (RepositoryNotFoundError, GatedRepoError)) or ( + isinstance(api_call_error, HfHubHTTPError) and api_call_error.response.status_code == 401 + ): + # Repo not found, gated, or specific authentication error => let's raise the actual error + raise api_call_error + else: + # Otherwise: most likely a connection issue or Hub downtime => let's warn the user + raise LocalEntryNotFoundError( + "An error happened while trying to locate the files on the Hub and we cannot find the appropriate" + " snapshot folder for the specified revision on the local disk. Please check your internet connection" + " and try again." + ) from api_call_error + + # At this stage, internet connection is up and running + # => let's download the files! + assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." + + # Corner case: on very large repos, the siblings list in `repo_info` might not contain all files. + # In that case, we need to use the `list_repo_tree` method to prevent caching issues. + repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] if repo_info.siblings is not None else [] + unreliable_nb_files = ( + repo_info.siblings is None or len(repo_info.siblings) == 0 or len(repo_info.siblings) > LARGE_REPO_THRESHOLD + ) + if unreliable_nb_files: + logger.info( + "Number of files in the repo is unreliable. Using `list_repo_tree` to ensure all files are listed." + ) + repo_files = ( + f.rfilename + for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type) + if isinstance(f, RepoFile) + ) + + filtered_repo_files: Iterable[str] = filter_repo_objects( + items=repo_files, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + if not unreliable_nb_files: + filtered_repo_files = list(filtered_repo_files) + tqdm_desc = f"Fetching {len(filtered_repo_files)} files" + else: + tqdm_desc = "Fetching ... files" + if dry_run: + tqdm_desc = "[dry-run] " + tqdm_desc + + commit_hash = repo_info.sha + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + if revision != commit_hash: + ref_path = os.path.join(storage_folder, "refs", revision) + try: + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with open(ref_path, "w") as f: + f.write(commit_hash) + except OSError as e: + logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.") + + results: List[Union[str, DryRunFileInfo]] = [] + + # User can use its own tqdm class or the default one from `huggingface_hub.utils` + tqdm_class = tqdm_class or hf_tqdm + + # Create a progress bar for the bytes downloaded + # This progress bar is shared across threads/files and gets updated each time we fetch + # metadata for a file. + bytes_progress = tqdm_class( + desc="Downloading (incomplete total...)", + disable=is_tqdm_disabled(log_level=logger.getEffectiveLevel()), + total=0, + initial=0, + unit="B", + unit_scale=True, + name="huggingface_hub.snapshot_download", + ) + + class _AggregatedTqdm: + """Fake tqdm object to aggregate progress into the parent `bytes_progress` bar. + + In practice the `_AggregatedTqdm` object won't be displayed, it's just used to update + the `bytes_progress` bar from each thread/file download. + """ + + def __init__(self, *args, **kwargs): + # Adjust the total of the parent progress bar + total = kwargs.pop("total", None) + if total is not None: + bytes_progress.total += total + bytes_progress.refresh() + + # Adjust initial of the parent progress bar + initial = kwargs.pop("initial", 0) + if initial: + bytes_progress.update(initial) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + def update(self, n: Optional[Union[int, float]] = 1) -> None: + bytes_progress.update(n) + + # we pass the commit_hash to hf_hub_download + # so no network call happens if we already + # have the file locally. + def _inner_hf_hub_download(repo_file: str) -> None: + results.append( + hf_hub_download( # type: ignore + repo_id, + filename=repo_file, + repo_type=repo_type, + revision=commit_hash, + endpoint=endpoint, + cache_dir=cache_dir, + local_dir=local_dir, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + etag_timeout=etag_timeout, + force_download=force_download, + token=token, + headers=headers, + tqdm_class=_AggregatedTqdm, # type: ignore + dry_run=dry_run, + ) + ) + + thread_map( + _inner_hf_hub_download, + filtered_repo_files, + desc=tqdm_desc, + max_workers=max_workers, + tqdm_class=tqdm_class, + ) + + bytes_progress.set_description("Download complete") + + if dry_run: + assert all(isinstance(r, DryRunFileInfo) for r in results) + return results # type: ignore + + if local_dir is not None: + return str(os.path.realpath(local_dir)) + return snapshot_folder diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_space_api.py b/venv/lib/python3.10/site-packages/huggingface_hub/_space_api.py new file mode 100644 index 0000000000000000000000000000000000000000..4a15e870e4a9e31bb5e53465be759d3367b94600 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_space_api.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Optional + +from huggingface_hub.utils import parse_datetime + + +class SpaceStage(str, Enum): + """ + Enumeration of possible stage of a Space on the Hub. + + Value can be compared to a string: + ```py + assert SpaceStage.BUILDING == "BUILDING" + ``` + + Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L61 (private url). + """ + + # Copied from moon-landing > server > repo_types > SpaceInfo.ts (private repo) + NO_APP_FILE = "NO_APP_FILE" + CONFIG_ERROR = "CONFIG_ERROR" + BUILDING = "BUILDING" + BUILD_ERROR = "BUILD_ERROR" + RUNNING = "RUNNING" + RUNNING_BUILDING = "RUNNING_BUILDING" + RUNTIME_ERROR = "RUNTIME_ERROR" + DELETING = "DELETING" + STOPPED = "STOPPED" + PAUSED = "PAUSED" + + +class SpaceHardware(str, Enum): + """ + Enumeration of hardwares available to run your Space on the Hub. + + Value can be compared to a string: + ```py + assert SpaceHardware.CPU_BASIC == "cpu-basic" + ``` + + Taken from https://github.com/huggingface-internal/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts (private url). + """ + + # CPU + CPU_BASIC = "cpu-basic" + CPU_UPGRADE = "cpu-upgrade" + CPU_XL = "cpu-xl" + + # ZeroGPU + ZERO_A10G = "zero-a10g" + + # GPU + T4_SMALL = "t4-small" + T4_MEDIUM = "t4-medium" + L4X1 = "l4x1" + L4X4 = "l4x4" + L40SX1 = "l40sx1" + L40SX4 = "l40sx4" + L40SX8 = "l40sx8" + A10G_SMALL = "a10g-small" + A10G_LARGE = "a10g-large" + A10G_LARGEX2 = "a10g-largex2" + A10G_LARGEX4 = "a10g-largex4" + A100_LARGE = "a100-large" + A100x4 = "a100x4" + A100x8 = "a100x8" + + +class SpaceStorage(str, Enum): + """ + Enumeration of persistent storage available for your Space on the Hub. + + Value can be compared to a string: + ```py + assert SpaceStorage.SMALL == "small" + ``` + + Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts#L24 (private url). + """ + + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + + +@dataclass +class SpaceRuntime: + """ + Contains information about the current runtime of a Space. + + Args: + stage (`str`): + Current stage of the space. Example: RUNNING. + hardware (`str` or `None`): + Current hardware of the space. Example: "cpu-basic". Can be `None` if Space + is `BUILDING` for the first time. + requested_hardware (`str` or `None`): + Requested hardware. Can be different from `hardware` especially if the request + has just been made. Example: "t4-medium". Can be `None` if no hardware has + been requested yet. + sleep_time (`int` or `None`): + Number of seconds the Space will be kept alive after the last request. By default (if value is `None`), the + Space will never go to sleep if it's running on an upgraded hardware, while it will go to sleep after 48 + hours on a free 'cpu-basic' hardware. For more details, see https://huggingface.co/docs/hub/spaces-gpus#sleep-time. + raw (`dict`): + Raw response from the server. Contains more information about the Space + runtime like number of replicas, number of cpu, memory size,... + """ + + stage: SpaceStage + hardware: Optional[SpaceHardware] + requested_hardware: Optional[SpaceHardware] + sleep_time: Optional[int] + storage: Optional[SpaceStorage] + raw: dict + + def __init__(self, data: dict) -> None: + self.stage = data["stage"] + self.hardware = data.get("hardware", {}).get("current") + self.requested_hardware = data.get("hardware", {}).get("requested") + self.sleep_time = data.get("gcTimeout") + self.storage = data.get("storage") + self.raw = data + + +@dataclass +class SpaceVariable: + """ + Contains information about the current variables of a Space. + + Args: + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + value (`str`): + Variable value. Example: `"the_model_repo_id"`. + description (`str` or None): + Description of the variable. Example: `"Model Repo ID of the implemented model"`. + updatedAt (`datetime` or None): + datetime of the last update of the variable (if the variable has been updated at least once). + """ + + key: str + value: str + description: Optional[str] + updated_at: Optional[datetime] + + def __init__(self, key: str, values: dict) -> None: + self.key = key + self.value = values["value"] + self.description = values.get("description") + updated_at = values.get("updatedAt") + self.updated_at = parse_datetime(updated_at) if updated_at is not None else None diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_tensorboard_logger.py b/venv/lib/python3.10/site-packages/huggingface_hub/_tensorboard_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..2783a250015afa99fc83e4fcd6484306b54af683 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_tensorboard_logger.py @@ -0,0 +1,190 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a logger to push training logs to the Hub, using Tensorboard.""" + +from pathlib import Path +from typing import Optional, Union + +from ._commit_scheduler import CommitScheduler +from .errors import EntryNotFoundError +from .repocard import ModelCard +from .utils import experimental + + +# Depending on user's setup, SummaryWriter can come either from 'tensorboardX' +# or from 'torch.utils.tensorboard'. Both are compatible so let's try to load +# from either of them. +try: + from tensorboardX import SummaryWriter as _RuntimeSummaryWriter + + is_summary_writer_available = True +except ImportError: + try: + from torch.utils.tensorboard import SummaryWriter as _RuntimeSummaryWriter + + is_summary_writer_available = True + except ImportError: + # Dummy class to avoid failing at import. Will raise on instance creation. + class _DummySummaryWriter: + pass + + _RuntimeSummaryWriter = _DummySummaryWriter # type: ignore[assignment] + is_summary_writer_available = False + + +class HFSummaryWriter(_RuntimeSummaryWriter): + """ + Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub. + + Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate + thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection + issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every` + minutes (default to every 5 minutes). + + > [!WARNING] + > `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice. + + Args: + repo_id (`str`): + The id of the repo to which the logs will be pushed. + logdir (`str`, *optional*): + The directory where the logs will be written. If not specified, a local directory will be created by the + underlying `SummaryWriter` object. + commit_every (`int` or `float`, *optional*): + The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes. + squash_history (`bool`, *optional*): + Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is + useful to avoid degraded performances on the repo when it grows too large. + repo_type (`str`, *optional*): + The type of the repo to which the logs will be pushed. Defaults to "model". + repo_revision (`str`, *optional*): + The revision of the repo to which the logs will be pushed. Defaults to "main". + repo_private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + path_in_repo (`str`, *optional*): + The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/". + repo_allow_patterns (`list[str]` or `str`, *optional*): + A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the + [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. + repo_ignore_patterns (`list[str]` or `str`, *optional*): + A list of patterns to exclude in the upload. Check out the + [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. + token (`str`, *optional*): + Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more + details + kwargs: + Additional keyword arguments passed to `SummaryWriter`. + + Examples: + ```diff + # Taken from https://pytorch.org/docs/stable/tensorboard.html + - from torch.utils.tensorboard import SummaryWriter + + from huggingface_hub import HFSummaryWriter + + import numpy as np + + - writer = SummaryWriter() + + writer = HFSummaryWriter(repo_id="username/my-trained-model") + + for n_iter in range(100): + writer.add_scalar('Loss/train', np.random.random(), n_iter) + writer.add_scalar('Loss/test', np.random.random(), n_iter) + writer.add_scalar('Accuracy/train', np.random.random(), n_iter) + writer.add_scalar('Accuracy/test', np.random.random(), n_iter) + ``` + + ```py + >>> from huggingface_hub import HFSummaryWriter + + # Logs are automatically pushed every 15 minutes (5 by default) + when exiting the context manager + >>> with HFSummaryWriter(repo_id="test_hf_logger", commit_every=15) as logger: + ... logger.add_scalar("a", 1) + ... logger.add_scalar("b", 2) + ``` + """ + + @experimental + def __new__(cls, *args, **kwargs) -> "HFSummaryWriter": + if not is_summary_writer_available: + raise ImportError( + "You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade" + " tensorboardX` first." + ) + return super().__new__(cls) + + def __init__( + self, + repo_id: str, + *, + logdir: Optional[str] = None, + commit_every: Union[int, float] = 5, + squash_history: bool = False, + repo_type: Optional[str] = None, + repo_revision: Optional[str] = None, + repo_private: Optional[bool] = None, + path_in_repo: Optional[str] = "tensorboard", + repo_allow_patterns: Optional[Union[list[str], str]] = "*.tfevents.*", + repo_ignore_patterns: Optional[Union[list[str], str]] = None, + token: Optional[str] = None, + **kwargs, + ): + # Initialize SummaryWriter + super().__init__(logdir=logdir, **kwargs) + + # Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it. + if not isinstance(self.logdir, str): + raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.") + + # Append logdir name to `path_in_repo` + if path_in_repo is None or path_in_repo == "": + path_in_repo = Path(self.logdir).name + else: + path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name + + # Initialize scheduler + self.scheduler = CommitScheduler( + folder_path=self.logdir, + path_in_repo=path_in_repo, + repo_id=repo_id, + repo_type=repo_type, + revision=repo_revision, + private=repo_private, + token=token, + allow_patterns=repo_allow_patterns, + ignore_patterns=repo_ignore_patterns, + every=commit_every, + squash_history=squash_history, + ) + + # Exposing some high-level info at root level + self.repo_id = self.scheduler.repo_id + self.repo_type = self.scheduler.repo_type + self.repo_revision = self.scheduler.revision + + # Add `hf-summary-writer` tag to the model card metadata + try: + card = ModelCard.load(repo_id_or_path=self.repo_id, repo_type=self.repo_type) + except EntryNotFoundError: + card = ModelCard("") + tags = card.data.get("tags", []) + if "hf-summary-writer" not in tags: + tags.append("hf-summary-writer") + card.data["tags"] = tags + card.push_to_hub(repo_id=self.repo_id, repo_type=self.repo_type) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Push to hub in a non-blocking way when exiting the logger's context manager.""" + super().__exit__(exc_type, exc_val, exc_tb) + future = self.scheduler.trigger() + future.result() diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_upload_large_folder.py b/venv/lib/python3.10/site-packages/huggingface_hub/_upload_large_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e269717892c0d6bbdd4fc30bd6705d844dedfd --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_upload_large_folder.py @@ -0,0 +1,765 @@ +# coding=utf-8 +# Copyright 2024-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +import logging +import os +import queue +import shutil +import sys +import threading +import time +import traceback +from datetime import datetime +from pathlib import Path +from threading import Lock +from typing import TYPE_CHECKING, Any, Optional, Union +from urllib.parse import quote + +from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes +from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata +from .constants import DEFAULT_REVISION, REPO_TYPES +from .utils import DEFAULT_IGNORE_PATTERNS, _format_size, filter_repo_objects, tqdm +from .utils._runtime import is_xet_available +from .utils.sha import sha_fileobj + + +if TYPE_CHECKING: + from .hf_api import HfApi + +logger = logging.getLogger(__name__) + +WAITING_TIME_IF_NO_TASKS = 10 # seconds +MAX_NB_FILES_FETCH_UPLOAD_MODE = 100 +COMMIT_SIZE_SCALE: list[int] = [20, 50, 75, 100, 125, 200, 250, 400, 600, 1000] + +UPLOAD_BATCH_SIZE_XET = 256 # Max 256 files per upload batch for XET-enabled repos +UPLOAD_BATCH_SIZE_LFS = 1 # Otherwise, batches of 1 for regular LFS upload + +# Repository limits (from https://huggingface.co/docs/hub/repositories-recommendations) +MAX_FILES_PER_REPO = 100_000 # Recommended maximum number of files per repository +MAX_FILES_PER_FOLDER = 10_000 # Recommended maximum number of files per folder +MAX_FILE_SIZE_GB = 200 # Recommended maximum for individual file size (split larger files) +RECOMMENDED_FILE_SIZE_GB = 20 # Recommended maximum for individual file size + + +def _validate_upload_limits(paths_list: list[LocalUploadFilePaths]) -> None: + """ + Validate upload against repository limits and warn about potential issues. + + Args: + paths_list: List of file paths to be uploaded + + Warns about: + - Too many files in the repository (>100k) + - Too many entries (files or subdirectories) in a single folder (>10k) + - Files exceeding size limits (>20GB recommended, >200GB maximum) + """ + logger.info("Running validation checks on files to upload...") + + # Check 1: Total file count + if len(paths_list) > MAX_FILES_PER_REPO: + logger.warning( + f"You are about to upload {len(paths_list):,} files. " + f"This exceeds the recommended limit of {MAX_FILES_PER_REPO:,} files per repository.\n" + f"Consider:\n" + f" - Splitting your data into multiple repositories\n" + f" - Using fewer, larger files (e.g., parquet files)\n" + f" - See: https://huggingface.co/docs/hub/repositories-recommendations" + ) + + # Check 2: Files and subdirectories per folder + # Track immediate children (files and subdirs) for each folder + from collections import defaultdict + + entries_per_folder: dict[str, Any] = defaultdict(lambda: {"files": 0, "subdirs": set()}) + + for paths in paths_list: + path = Path(paths.path_in_repo) + parts = path.parts + + # Count this file in its immediate parent directory + parent = str(path.parent) if str(path.parent) != "." else "." + entries_per_folder[parent]["files"] += 1 + + # Track immediate subdirectories for each parent folder + # Walk through the path components to track parent-child relationships + for i, child in enumerate(parts[:-1]): + parent = "." if i == 0 else "/".join(parts[:i]) + entries_per_folder[parent]["subdirs"].add(child) + + # Check limits for each folder + for folder, data in entries_per_folder.items(): + file_count = data["files"] + subdir_count = len(data["subdirs"]) + total_entries = file_count + subdir_count + + if total_entries > MAX_FILES_PER_FOLDER: + folder_display = "root" if folder == "." else folder + logger.warning( + f"Folder '{folder_display}' contains {total_entries:,} entries " + f"({file_count:,} files and {subdir_count:,} subdirectories). " + f"This exceeds the recommended {MAX_FILES_PER_FOLDER:,} entries per folder.\n" + "Consider reorganising into sub-folders." + ) + + # Check 3: File sizes + large_files = [] + very_large_files = [] + + for paths in paths_list: + size = paths.file_path.stat().st_size + size_gb = size / 1_000_000_000 # Use decimal GB as per Hub limits + + if size_gb > MAX_FILE_SIZE_GB: + very_large_files.append((paths.path_in_repo, size_gb)) + elif size_gb > RECOMMENDED_FILE_SIZE_GB: + large_files.append((paths.path_in_repo, size_gb)) + + # Warn about very large files (>200GB) + if very_large_files: + files_str = "\n - ".join(f"{path}: {size:.1f}GB" for path, size in very_large_files[:5]) + more_str = f"\n ... and {len(very_large_files) - 5} more files" if len(very_large_files) > 5 else "" + logger.warning( + f"Found {len(very_large_files)} files exceeding the {MAX_FILE_SIZE_GB}GB recommended maximum:\n" + f" - {files_str}{more_str}\n" + f"Consider splitting these files into smaller chunks." + ) + + # Warn about large files (>20GB) + if large_files: + files_str = "\n - ".join(f"{path}: {size:.1f}GB" for path, size in large_files[:5]) + more_str = f"\n ... and {len(large_files) - 5} more files" if len(large_files) > 5 else "" + logger.warning( + f"Found {len(large_files)} files larger than {RECOMMENDED_FILE_SIZE_GB}GB (recommended limit):\n" + f" - {files_str}{more_str}\n" + f"Large files may slow down loading and processing." + ) + + logger.info("Validation checks complete.") + + +def upload_large_folder_internal( + api: "HfApi", + repo_id: str, + folder_path: Union[str, Path], + *, + repo_type: str, # Repo type is required! + revision: Optional[str] = None, + private: Optional[bool] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + num_workers: Optional[int] = None, + print_report: bool = True, + print_report_every: int = 60, +): + """Upload a large folder to the Hub in the most resilient way possible. + + See [`HfApi.upload_large_folder`] for the full documentation. + """ + # 1. Check args and setup + if repo_type is None: + raise ValueError( + "For large uploads, `repo_type` is explicitly required. Please set it to `model`, `dataset` or `space`." + " If you are using the CLI, pass it as `--repo-type=model`." + ) + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if revision is None: + revision = DEFAULT_REVISION + + folder_path = Path(folder_path).expanduser().resolve() + if not folder_path.is_dir(): + raise ValueError(f"Provided path: '{folder_path}' is not a directory") + + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + ignore_patterns += DEFAULT_IGNORE_PATTERNS + + if num_workers is None: + nb_cores = os.cpu_count() or 1 + num_workers = max(nb_cores // 2, 1) # Use at most half of cpu cores + + # 2. Create repo if missing + repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True) + logger.info(f"Repo created: {repo_url}") + repo_id = repo_url.repo_id + + # Warn on too many commits + try: + commits = api.list_repo_commits(repo_id=repo_id, repo_type=repo_type, revision=revision) + commit_count = len(commits) + if commit_count > 500: + logger.warning( + f"\n{'=' * 80}\n" + f"WARNING: This repository has {commit_count} commits.\n" + f"Repositories with a large number of commits can experience performance issues.\n" + f"\n" + f"Consider squashing your commit history using `super_squash_history()`.\n" + "To do so, you need to stop this process, run the snippet below and restart the upload command." + f" from huggingface_hub import super_squash_history\n" + f" super_squash_history(repo_id='{repo_id}', repo_type='{repo_type}')\n" + f"\n" + f"Note: This is a non-revertible operation. See the documentation for more details:\n" + f"https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.HfApi.super_squash_history\n" + f"{'=' * 80}\n" + ) + except Exception as e: + # Don't fail the upload if we can't check commit count + logger.debug(f"Could not check commit count: {e}") + + # 2.1 Check if xet is enabled to set batch file upload size + upload_batch_size = UPLOAD_BATCH_SIZE_XET if is_xet_available() else UPLOAD_BATCH_SIZE_LFS + + # 3. List files to upload + filtered_paths_list = filter_repo_objects( + (path.relative_to(folder_path).as_posix() for path in folder_path.glob("**/*") if path.is_file()), + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + paths_list = [get_local_upload_paths(folder_path, relpath) for relpath in filtered_paths_list] + logger.info(f"Found {len(paths_list)} candidate files to upload") + + # Validate upload against repository limits + _validate_upload_limits(paths_list) + + logger.info("Starting upload...") + + # Read metadata for each file + items = [ + (paths, read_upload_metadata(folder_path, paths.path_in_repo)) + for paths in tqdm(paths_list, desc="Recovering from metadata files") + ] + + # 4. Start workers + status = LargeUploadStatus(items, upload_batch_size) + threads = [ + threading.Thread( + target=_worker_job, + kwargs={ + "status": status, + "api": api, + "repo_id": repo_id, + "repo_type": repo_type, + "revision": revision, + }, + ) + for _ in range(num_workers) + ] + + for thread in threads: + thread.start() + + # 5. Print regular reports + if print_report: + print("\n\n" + status.current_report()) + last_report_ts = time.time() + while True: + time.sleep(1) + if time.time() - last_report_ts >= print_report_every: + if print_report: + _print_overwrite(status.current_report()) + last_report_ts = time.time() + if status.is_done(): + logging.info("Is done: exiting main loop") + break + + for thread in threads: + thread.join() + + logger.info(status.current_report()) + logging.info("Upload is complete!") + + +#################### +# Logic to manage workers and synchronize tasks +#################### + + +class WorkerJob(enum.Enum): + SHA256 = enum.auto() + GET_UPLOAD_MODE = enum.auto() + PREUPLOAD_LFS = enum.auto() + COMMIT = enum.auto() + WAIT = enum.auto() # if no tasks are available but we don't want to exit + + +JOB_ITEM_T = tuple[LocalUploadFilePaths, LocalUploadFileMetadata] + + +class LargeUploadStatus: + """Contains information, queues and tasks for a large upload process.""" + + def __init__(self, items: list[JOB_ITEM_T], upload_batch_size: int = 1): + self.items = items + self.queue_sha256: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.queue_get_upload_mode: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.queue_preupload_lfs: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.queue_commit: "queue.Queue[JOB_ITEM_T]" = queue.Queue() + self.lock = Lock() + + self.nb_workers_sha256: int = 0 + self.nb_workers_get_upload_mode: int = 0 + self.nb_workers_preupload_lfs: int = 0 + self.upload_batch_size: int = upload_batch_size + self.nb_workers_commit: int = 0 + self.nb_workers_waiting: int = 0 + self.last_commit_attempt: Optional[float] = None + + self._started_at = datetime.now() + self._chunk_idx: int = 1 + self._chunk_lock: Lock = Lock() + + # Setup queues + for item in self.items: + paths, metadata = item + if metadata.sha256 is None: + self.queue_sha256.put(item) + elif metadata.upload_mode is None: + self.queue_get_upload_mode.put(item) + elif metadata.upload_mode == "lfs" and not metadata.is_uploaded: + self.queue_preupload_lfs.put(item) + elif not metadata.is_committed: + self.queue_commit.put(item) + else: + logger.debug(f"Skipping file {paths.path_in_repo} (already uploaded and committed)") + + def target_chunk(self) -> int: + with self._chunk_lock: + return COMMIT_SIZE_SCALE[self._chunk_idx] + + def update_chunk(self, success: bool, nb_items: int, duration: float) -> None: + with self._chunk_lock: + if not success: + logger.warning(f"Failed to commit {nb_items} files at once. Will retry with less files in next batch.") + self._chunk_idx -= 1 + elif nb_items >= COMMIT_SIZE_SCALE[self._chunk_idx] and duration < 40: + logger.info(f"Successfully committed {nb_items} at once. Increasing the limit for next batch.") + self._chunk_idx += 1 + + self._chunk_idx = max(0, min(self._chunk_idx, len(COMMIT_SIZE_SCALE) - 1)) + + def current_report(self) -> str: + """Generate a report of the current status of the large upload.""" + nb_hashed = 0 + size_hashed = 0 + nb_preuploaded = 0 + nb_lfs = 0 + nb_lfs_unsure = 0 + size_preuploaded = 0 + nb_committed = 0 + size_committed = 0 + total_size = 0 + ignored_files = 0 + total_files = 0 + + with self.lock: + for _, metadata in self.items: + if metadata.should_ignore: + ignored_files += 1 + continue + total_size += metadata.size + total_files += 1 + if metadata.sha256 is not None: + nb_hashed += 1 + size_hashed += metadata.size + if metadata.upload_mode == "lfs": + nb_lfs += 1 + if metadata.upload_mode is None: + nb_lfs_unsure += 1 + if metadata.is_uploaded: + nb_preuploaded += 1 + size_preuploaded += metadata.size + if metadata.is_committed: + nb_committed += 1 + size_committed += metadata.size + total_size_str = _format_size(total_size) + + now = datetime.now() + now_str = now.strftime("%Y-%m-%d %H:%M:%S") + elapsed = now - self._started_at + elapsed_str = str(elapsed).split(".")[0] # remove milliseconds + + message = "\n" + "-" * 10 + message += f" {now_str} ({elapsed_str}) " + message += "-" * 10 + "\n" + + message += "Files: " + message += f"hashed {nb_hashed}/{total_files} ({_format_size(size_hashed)}/{total_size_str}) | " + message += f"pre-uploaded: {nb_preuploaded}/{nb_lfs} ({_format_size(size_preuploaded)}/{total_size_str})" + if nb_lfs_unsure > 0: + message += f" (+{nb_lfs_unsure} unsure)" + message += f" | committed: {nb_committed}/{total_files} ({_format_size(size_committed)}/{total_size_str})" + message += f" | ignored: {ignored_files}\n" + + message += "Workers: " + message += f"hashing: {self.nb_workers_sha256} | " + message += f"get upload mode: {self.nb_workers_get_upload_mode} | " + message += f"pre-uploading: {self.nb_workers_preupload_lfs} | " + message += f"committing: {self.nb_workers_commit} | " + message += f"waiting: {self.nb_workers_waiting}\n" + message += "-" * 51 + + return message + + def is_done(self) -> bool: + with self.lock: + return all(metadata.is_committed or metadata.should_ignore for _, metadata in self.items) + + +def _worker_job( + status: LargeUploadStatus, + api: "HfApi", + repo_id: str, + repo_type: str, + revision: str, +): + """ + Main process for a worker. The worker will perform tasks based on the priority list until all files are uploaded + and committed. If no tasks are available, the worker will wait for 10 seconds before checking again. + + If a task fails for any reason, the item(s) are put back in the queue for another worker to pick up. + + Read `upload_large_folder` docstring for more information on how tasks are prioritized. + """ + while True: + next_job: Optional[tuple[WorkerJob, list[JOB_ITEM_T]]] = None + + # Determine next task + next_job = _determine_next_job(status) + if next_job is None: + return + job, items = next_job + + # Perform task + if job == WorkerJob.SHA256: + item = items[0] # single item + try: + _compute_sha256(item) + status.queue_get_upload_mode.put(item) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to compute sha256: {e}") + traceback.format_exc() + status.queue_sha256.put(item) + + with status.lock: + status.nb_workers_sha256 -= 1 + + elif job == WorkerJob.GET_UPLOAD_MODE: + try: + _get_upload_mode(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to get upload mode: {e}") + traceback.format_exc() + + # Items are either: + # - dropped (if should_ignore) + # - put in LFS queue (if LFS) + # - put in commit queue (if regular) + # - or put back (if error occurred). + for item in items: + _, metadata = item + if metadata.should_ignore: + continue + if metadata.upload_mode == "lfs": + status.queue_preupload_lfs.put(item) + elif metadata.upload_mode == "regular": + status.queue_commit.put(item) + else: + status.queue_get_upload_mode.put(item) + + with status.lock: + status.nb_workers_get_upload_mode -= 1 + + elif job == WorkerJob.PREUPLOAD_LFS: + try: + _preupload_lfs(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) + for item in items: + status.queue_commit.put(item) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to preupload LFS: {e}") + traceback.format_exc() + for item in items: + status.queue_preupload_lfs.put(item) + + with status.lock: + status.nb_workers_preupload_lfs -= 1 + + elif job == WorkerJob.COMMIT: + start_ts = time.time() + success = True + try: + _commit(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Failed to commit: {e}") + traceback.format_exc() + for item in items: + status.queue_commit.put(item) + success = False + duration = time.time() - start_ts + status.update_chunk(success, len(items), duration) + with status.lock: + status.last_commit_attempt = time.time() + status.nb_workers_commit -= 1 + + elif job == WorkerJob.WAIT: + time.sleep(WAITING_TIME_IF_NO_TASKS) + with status.lock: + status.nb_workers_waiting -= 1 + + +def _determine_next_job(status: LargeUploadStatus) -> Optional[tuple[WorkerJob, list[JOB_ITEM_T]]]: + with status.lock: + # 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file) + if ( + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.last_commit_attempt is not None + and time.time() - status.last_commit_attempt > 5 * 60 + ): + status.nb_workers_commit += 1 + logger.debug("Job: commit (more than 5 minutes since last commit attempt)") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 2. Commit if at least 100 files are ready to commit + elif status.nb_workers_commit == 0 and status.queue_commit.qsize() >= 150: + status.nb_workers_commit += 1 + logger.debug("Job: commit (>100 files ready)") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 3. Get upload mode if at least 100 files + elif status.queue_get_upload_mode.qsize() >= MAX_NB_FILES_FETCH_UPLOAD_MODE: + status.nb_workers_get_upload_mode += 1 + logger.debug(f"Job: get upload mode (>{MAX_NB_FILES_FETCH_UPLOAD_MODE} files ready)") + return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE)) + + # 4. Preupload LFS file if at least `status.upload_batch_size` files and no worker is preuploading LFS + elif status.queue_preupload_lfs.qsize() >= status.upload_batch_size and status.nb_workers_preupload_lfs == 0: + status.nb_workers_preupload_lfs += 1 + logger.debug("Job: preupload LFS (no other worker preuploading LFS)") + return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size)) + + # 5. Compute sha256 if at least 1 file and no worker is computing sha256 + elif status.queue_sha256.qsize() > 0 and status.nb_workers_sha256 == 0: + status.nb_workers_sha256 += 1 + logger.debug("Job: sha256 (no other worker computing sha256)") + return (WorkerJob.SHA256, _get_one(status.queue_sha256)) + + # 6. Get upload mode if at least 1 file and no worker is getting upload mode + elif status.queue_get_upload_mode.qsize() > 0 and status.nb_workers_get_upload_mode == 0: + status.nb_workers_get_upload_mode += 1 + logger.debug("Job: get upload mode (no other worker getting upload mode)") + return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE)) + + # 7. Preupload LFS file if at least `status.upload_batch_size` files + elif status.queue_preupload_lfs.qsize() >= status.upload_batch_size: + status.nb_workers_preupload_lfs += 1 + logger.debug("Job: preupload LFS") + return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size)) + + # 8. Compute sha256 if at least 1 file + elif status.queue_sha256.qsize() > 0: + status.nb_workers_sha256 += 1 + logger.debug("Job: sha256") + return (WorkerJob.SHA256, _get_one(status.queue_sha256)) + + # 9. Get upload mode if at least 1 file + elif status.queue_get_upload_mode.qsize() > 0: + status.nb_workers_get_upload_mode += 1 + logger.debug("Job: get upload mode") + return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE)) + + # 10. Preupload LFS file if at least 1 file + elif status.queue_preupload_lfs.qsize() > 0: + status.nb_workers_preupload_lfs += 1 + logger.debug("Job: preupload LFS") + return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size)) + + # 11. Commit if at least 1 file and 1 min since last commit attempt + elif ( + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.last_commit_attempt is not None + and time.time() - status.last_commit_attempt > 1 * 60 + ): + status.nb_workers_commit += 1 + logger.debug("Job: commit (1 min since last commit attempt)") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 12. Commit if at least 1 file all other queues are empty and all workers are waiting + # e.g. when it's the last commit + elif ( + status.nb_workers_commit == 0 + and status.queue_commit.qsize() > 0 + and status.queue_sha256.qsize() == 0 + and status.queue_get_upload_mode.qsize() == 0 + and status.queue_preupload_lfs.qsize() == 0 + and status.nb_workers_sha256 == 0 + and status.nb_workers_get_upload_mode == 0 + and status.nb_workers_preupload_lfs == 0 + ): + status.nb_workers_commit += 1 + logger.debug("Job: commit") + return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk())) + + # 13. If all queues are empty, exit + elif all(metadata.is_committed or metadata.should_ignore for _, metadata in status.items): + logger.info("All files have been processed! Exiting worker.") + return None + + # 14. If no task is available, wait + else: + status.nb_workers_waiting += 1 + logger.debug(f"No task available, waiting... ({WAITING_TIME_IF_NO_TASKS}s)") + return (WorkerJob.WAIT, []) + + +#################### +# Atomic jobs (sha256, get_upload_mode, preupload_lfs, commit) +#################### + + +def _compute_sha256(item: JOB_ITEM_T) -> None: + """Compute sha256 of a file and save it in metadata.""" + paths, metadata = item + if metadata.sha256 is None: + with paths.file_path.open("rb") as f: + metadata.sha256 = sha_fileobj(f).hex() + metadata.save(paths) + + +def _get_upload_mode(items: list[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: + """Get upload mode for each file and update metadata. + + Also receive info if the file should be ignored. + """ + additions = [_build_hacky_operation(item) for item in items] + _fetch_upload_modes( + additions=additions, + repo_type=repo_type, + repo_id=repo_id, + headers=api._build_hf_headers(), + revision=quote(revision, safe=""), + endpoint=api.endpoint, + ) + for item, addition in zip(items, additions): + paths, metadata = item + metadata.upload_mode = addition._upload_mode + metadata.should_ignore = addition._should_ignore + metadata.remote_oid = addition._remote_oid + metadata.save(paths) + + +def _preupload_lfs(items: list[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: + """Preupload LFS files and update metadata.""" + additions = [_build_hacky_operation(item) for item in items] + api.preupload_lfs_files( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + additions=additions, + ) + + for paths, metadata in items: + metadata.is_uploaded = True + metadata.save(paths) + + +def _commit(items: list[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: + """Commit files to the repo.""" + additions = [_build_hacky_operation(item) for item in items] + api.create_commit( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + operations=additions, + commit_message="Add files using upload-large-folder tool", + ) + for paths, metadata in items: + metadata.is_committed = True + metadata.save(paths) + + +#################### +# Hacks with CommitOperationAdd to bypass checks/sha256 calculation +#################### + + +class HackyCommitOperationAdd(CommitOperationAdd): + def __post_init__(self) -> None: + if isinstance(self.path_or_fileobj, Path): + self.path_or_fileobj = str(self.path_or_fileobj) + + +def _build_hacky_operation(item: JOB_ITEM_T) -> HackyCommitOperationAdd: + paths, metadata = item + operation = HackyCommitOperationAdd(path_in_repo=paths.path_in_repo, path_or_fileobj=paths.file_path) + with paths.file_path.open("rb") as file: + sample = file.peek(512)[:512] + if metadata.sha256 is None: + raise ValueError("sha256 must have been computed by now!") + operation.upload_info = UploadInfo(sha256=bytes.fromhex(metadata.sha256), size=metadata.size, sample=sample) + operation._upload_mode = metadata.upload_mode # type: ignore[assignment] + operation._should_ignore = metadata.should_ignore + operation._remote_oid = metadata.remote_oid + return operation + + +#################### +# Misc helpers +#################### + + +def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> list[JOB_ITEM_T]: + return [queue.get()] + + +def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> list[JOB_ITEM_T]: + return [queue.get() for _ in range(min(queue.qsize(), n))] + + +def _print_overwrite(report: str) -> None: + """Print a report, overwriting the previous lines. + + Since tqdm in using `sys.stderr` to (re-)write progress bars, we need to use `sys.stdout` + to print the report. + + Note: works well only if no other process is writing to `sys.stdout`! + """ + report += "\n" + # Get terminal width + terminal_width = shutil.get_terminal_size().columns + + # Count number of lines that should be cleared + nb_lines = sum(len(line) // terminal_width + 1 for line in report.splitlines()) + + # Clear previous lines based on the number of lines in the report + for _ in range(nb_lines): + sys.stdout.write("\r\033[K") # Clear line + sys.stdout.write("\033[F") # Move cursor up one line + + # Print the new report, filling remaining space with whitespace + sys.stdout.write(report) + sys.stdout.write(" " * (terminal_width - len(report.splitlines()[-1]))) + sys.stdout.flush() diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_webhooks_payload.py b/venv/lib/python3.10/site-packages/huggingface_hub/_webhooks_payload.py new file mode 100644 index 0000000000000000000000000000000000000000..90f12425cbbf4a8fb279ee6b3fe7f594be88a8b1 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_webhooks_payload.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains data structures to parse the webhooks payload.""" + +from typing import Literal, Optional + +from .utils import is_pydantic_available + + +if is_pydantic_available(): + from pydantic import BaseModel +else: + # Define a dummy BaseModel to avoid import errors when pydantic is not installed + # Import error will be raised when trying to use the class + + class BaseModel: # type: ignore [no-redef] + def __init__(self, *args, **kwargs) -> None: + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + + +# This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they +# are not in used anymore. To keep in sync when format is updated in +# https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link). + + +WebhookEvent_T = Literal[ + "create", + "delete", + "move", + "update", +] +RepoChangeEvent_T = Literal[ + "add", + "move", + "remove", + "update", +] +RepoType_T = Literal[ + "dataset", + "model", + "space", +] +DiscussionStatus_T = Literal[ + "closed", + "draft", + "open", + "merged", +] +SupportedWebhookVersion = Literal[3] + + +class ObjectId(BaseModel): + id: str + + +class WebhookPayloadUrl(BaseModel): + web: str + api: Optional[str] = None + + +class WebhookPayloadMovedTo(BaseModel): + name: str + owner: ObjectId + + +class WebhookPayloadWebhook(ObjectId): + version: SupportedWebhookVersion + + +class WebhookPayloadEvent(BaseModel): + action: WebhookEvent_T + scope: str + + +class WebhookPayloadDiscussionChanges(BaseModel): + base: str + mergeCommitId: Optional[str] = None + + +class WebhookPayloadComment(ObjectId): + author: ObjectId + hidden: bool + content: Optional[str] = None + url: WebhookPayloadUrl + + +class WebhookPayloadDiscussion(ObjectId): + num: int + author: ObjectId + url: WebhookPayloadUrl + title: str + isPullRequest: bool + status: DiscussionStatus_T + changes: Optional[WebhookPayloadDiscussionChanges] = None + pinned: Optional[bool] = None + + +class WebhookPayloadRepo(ObjectId): + owner: ObjectId + head_sha: Optional[str] = None + name: str + private: bool + subdomain: Optional[str] = None + tags: Optional[list[str]] = None + type: Literal["dataset", "model", "space"] + url: WebhookPayloadUrl + + +class WebhookPayloadUpdatedRef(BaseModel): + ref: str + oldSha: Optional[str] = None + newSha: Optional[str] = None + + +class WebhookPayload(BaseModel): + event: WebhookPayloadEvent + repo: WebhookPayloadRepo + discussion: Optional[WebhookPayloadDiscussion] = None + comment: Optional[WebhookPayloadComment] = None + webhook: WebhookPayloadWebhook + movedTo: Optional[WebhookPayloadMovedTo] = None + updatedRefs: Optional[list[WebhookPayloadUpdatedRef]] = None diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/_webhooks_server.py b/venv/lib/python3.10/site-packages/huggingface_hub/_webhooks_server.py new file mode 100644 index 0000000000000000000000000000000000000000..601a55c3d2801964d4fb2172e18e7149d33cc0b2 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/_webhooks_server.py @@ -0,0 +1,376 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily.""" + +import atexit +import inspect +import os +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Optional + +from .utils import experimental, is_fastapi_available, is_gradio_available + + +if TYPE_CHECKING: + import gradio as gr + from fastapi import Request + +if is_fastapi_available(): + from fastapi import FastAPI, Request + from fastapi.responses import JSONResponse +else: + # Will fail at runtime if FastAPI is not available + FastAPI = Request = JSONResponse = None # type: ignore + + +_global_app: Optional["WebhooksServer"] = None +_is_local = os.environ.get("SPACE_ID") is None + + +@experimental +class WebhooksServer: + """ + The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks. + These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to + the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `launch` method has to be + called to start the app. + + It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic + model that contains all the information about the webhook event. The data will be parsed automatically for you. + + Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to set up your + WebhooksServer and deploy it on a Space. + + > [!WARNING] + > `WebhooksServer` is experimental. Its API is subject to change in the future. + + > [!WARNING] + > You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`). + + Args: + ui (`gradio.Blocks`, optional): + A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions + about the configured webhooks is created. + webhook_secret (`str`, optional): + A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as + you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You + can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the + webhook endpoints are opened without any security. + + Example: + + ```python + import gradio as gr + from huggingface_hub import WebhooksServer, WebhookPayload + + with gr.Blocks() as ui: + ... + + app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") + + @app.add_webhook("/say_hello") + async def hello(payload: WebhookPayload): + return {"message": "hello"} + + app.launch() + ``` + """ + + def __new__(cls, *args, **kwargs) -> "WebhooksServer": + if not is_gradio_available(): + raise ImportError( + "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`" + " first." + ) + if not is_fastapi_available(): + raise ImportError( + "You must have `fastapi` installed to use `WebhooksServer`. Please run `pip install --upgrade fastapi`" + " first." + ) + return super().__new__(cls) + + def __init__( + self, + ui: Optional["gr.Blocks"] = None, + webhook_secret: Optional[str] = None, + ) -> None: + self._ui = ui + + self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET") + self.registered_webhooks: dict[str, Callable] = {} + _warn_on_empty_secret(self.webhook_secret) + + def add_webhook(self, path: Optional[str] = None) -> Callable: + """ + Decorator to add a webhook to the [`WebhooksServer`] server. + + Args: + path (`str`, optional): + The URL path to register the webhook function. If not provided, the function name will be used as the + path. In any case, all webhooks are registered under `/webhooks`. + + Raises: + ValueError: If the provided path is already registered as a webhook. + + Example: + ```python + from huggingface_hub import WebhooksServer, WebhookPayload + + app = WebhooksServer() + + @app.add_webhook + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + app.launch() + ``` + """ + # Usage: directly as decorator. Example: `@app.add_webhook` + if callable(path): + # If path is a function, it means it was used as a decorator without arguments + return self.add_webhook()(path) + + # Usage: provide a path. Example: `@app.add_webhook(...)` + @wraps(FastAPI.post) + def _inner_post(*args, **kwargs): + func = args[0] + abs_path = f"/webhooks/{(path or func.__name__).strip('/')}" + if abs_path in self.registered_webhooks: + raise ValueError(f"Webhook {abs_path} already exists.") + self.registered_webhooks[abs_path] = func + + return _inner_post + + def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None: + """Launch the Gradio app and register webhooks to the underlying FastAPI server. + + Input parameters are forwarded to Gradio when launching the app. + """ + ui = self._ui or self._get_default_ui() + + # Start Gradio App + # - as non-blocking so that webhooks can be added afterwards + # - as shared if launch locally (to debug webhooks) + launch_kwargs.setdefault("share", _is_local) + self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs) + + # Register webhooks to FastAPI app + for path, func in self.registered_webhooks.items(): + # Add secret check if required + if self.webhook_secret is not None: + func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret) + + # Add route to FastAPI app + self.fastapi_app.post(path)(func) + + # Print instructions and block main thread + space_host = os.environ.get("SPACE_HOST") + url = "https://" + space_host if space_host is not None else (ui.share_url or ui.local_url) + if url is None: + raise ValueError("Cannot find the URL of the app. Please provide a valid `ui` or update `gradio` version.") + url = url.strip("/") + message = "\nWebhooks are correctly setup and ready to use:" + message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks) + message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks." + print(message) + + if not prevent_thread_lock: + ui.block_thread() + + def _get_default_ui(self) -> "gr.Blocks": + """Default UI if not provided (lists webhooks and provides basic instructions).""" + import gradio as gr + + with gr.Blocks() as ui: + gr.Markdown("# This is an app to process 🤗 Webhooks") + gr.Markdown( + "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on" + " specific repos or to all repos belonging to particular set of users/organizations (not just your" + " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to" + " know more about webhooks on the Huggingface Hub." + ) + gr.Markdown( + f"{len(self.registered_webhooks)} webhook(s) are registered:" + + "\n\n" + + "\n ".join( + f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})" + for webhook_path, webhook in self.registered_webhooks.items() + ) + ) + gr.Markdown( + "Go to https://huggingface.co/settings/webhooks to setup your webhooks." + + "\nYou app is running locally. Please look at the logs to check the full URL you need to set." + if _is_local + else ( + "\nThis app is running on a Space. You can find the corresponding URL in the options menu" + " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'." + ) + ) + return ui + + +@experimental +def webhook_endpoint(path: Optional[str] = None) -> Callable: + """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint. + + This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret), + you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using + this decorator multiple times. + + Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to set up your + server and deploy it on a Space. + + > [!WARNING] + > `webhook_endpoint` is experimental. Its API is subject to change in the future. + + > [!WARNING] + > You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`). + + Args: + path (`str`, optional): + The URL path to register the webhook function. If not provided, the function name will be used as the path. + In any case, all webhooks are registered under `/webhooks`. + + Examples: + The default usage is to register a function as a webhook endpoint. The function name will be used as the path. + The server will be started automatically at exit (i.e. at the end of the script). + + ```python + from huggingface_hub import webhook_endpoint, WebhookPayload + + @webhook_endpoint + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + # Server is automatically started at the end of the script. + ``` + + Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you + are running it in a notebook. + + ```python + from huggingface_hub import webhook_endpoint, WebhookPayload + + @webhook_endpoint + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + # Start the server manually + trigger_training.launch() + ``` + """ + if callable(path): + # If path is a function, it means it was used as a decorator without arguments + return webhook_endpoint()(path) + + @wraps(WebhooksServer.add_webhook) + def _inner(func: Callable) -> Callable: + app = _get_global_app() + app.add_webhook(path)(func) + if len(app.registered_webhooks) == 1: + # Register `app.launch` to run at exit (only once) + atexit.register(app.launch) + + @wraps(app.launch) + def _launch_now(): + # Run the app directly (without waiting atexit) + atexit.unregister(app.launch) + app.launch() + + func.launch = _launch_now # type: ignore + return func + + return _inner + + +def _get_global_app() -> WebhooksServer: + global _global_app + if _global_app is None: + _global_app = WebhooksServer() + return _global_app + + +def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None: + if webhook_secret is None: + print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.") + print( + "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: " + "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`" + ) + print( + "For more details about webhook secrets, please refer to" + " https://huggingface.co/docs/hub/webhooks#webhook-secret." + ) + else: + print("Webhook secret is correctly defined.") + + +def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str: + """Returns the anchor to a given webhook in the docs (experimental)""" + return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post" + + +def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable: + """Wraps a webhook function to check the webhook secret before calling the function. + + This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route + parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request` + object (and hence the headers). A far cleaner solution would be to use a middleware. However, since + `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by + Gradio internals (and not by us), we cannot add a middleware. + + This method is called only when a secret has been defined by the user. If a request is sent without the + "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect, + the function will return a 403 error (forbidden). + + Inspired by https://stackoverflow.com/a/33112180. + """ + initial_sig = inspect.signature(func) + + @wraps(func) + async def _protected_func(request: Request, **kwargs): + request_secret = request.headers.get("x-webhook-secret") + if request_secret is None: + return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401) + if request_secret != webhook_secret: + return JSONResponse({"error": "Invalid webhook secret."}, status_code=403) + + # Inject `request` in kwargs if required + if "request" in initial_sig.parameters: + kwargs["request"] = request + + # Handle both sync and async routes + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) + + # Update signature to include request + if "request" not in initial_sig.parameters: + _protected_func.__signature__ = initial_sig.replace( # type: ignore + parameters=( + inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request), + ) + + tuple(initial_sig.parameters.values()) + ) + + # Return protected route + return _protected_func diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/community.py b/venv/lib/python3.10/site-packages/huggingface_hub/community.py new file mode 100644 index 0000000000000000000000000000000000000000..68ad181a81e96f4587b56e4757bae1ef20c59e89 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/community.py @@ -0,0 +1,363 @@ +""" +Data structures to interact with Discussions and Pull Requests on the Hub. + +See [the Discussions and Pull Requests guide](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) +for more information on Pull Requests, Discussions, and the community tab. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Literal, Optional, TypedDict, Union + +from . import constants +from .utils import parse_datetime + + +DiscussionStatus = Literal["open", "closed", "merged", "draft"] + + +@dataclass +class Discussion: + """ + A Discussion or Pull Request on the Hub. + + This dataclass is not intended to be instantiated directly. + + Attributes: + title (`str`): + The title of the Discussion / Pull Request + status (`str`): + The status of the Discussion / Pull Request. + It must be one of: + * `"open"` + * `"closed"` + * `"merged"` (only for Pull Requests ) + * `"draft"` (only for Pull Requests ) + num (`int`): + The number of the Discussion / Pull Request. + repo_id (`str`): + The id (`"{namespace}/{repo_name}"`) of the repo on which + the Discussion / Pull Request was open. + repo_type (`str`): + The type of the repo on which the Discussion / Pull Request was open. + Possible values are: `"model"`, `"dataset"`, `"space"`. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + is_pull_request (`bool`): + Whether or not this is a Pull Request. + created_at (`datetime`): + The `datetime` of creation of the Discussion / Pull Request. + endpoint (`str`): + Endpoint of the Hub. Default is https://huggingface.co. + git_reference (`str`, *optional*): + (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise. + url (`str`): + (property) URL of the discussion on the Hub. + """ + + title: str + status: DiscussionStatus + num: int + repo_id: str + repo_type: str + author: str + is_pull_request: bool + created_at: datetime + endpoint: str + + @property + def git_reference(self) -> Optional[str]: + """ + If this is a Pull Request , returns the git reference to which changes can be pushed. + Returns `None` otherwise. + """ + if self.is_pull_request: + return f"refs/pr/{self.num}" + return None + + @property + def url(self) -> str: + """Returns the URL of the discussion on the Hub.""" + if self.repo_type is None or self.repo_type == constants.REPO_TYPE_MODEL: + return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}" + return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}" + + +@dataclass +class DiscussionWithDetails(Discussion): + """ + Subclass of [`Discussion`]. + + Attributes: + title (`str`): + The title of the Discussion / Pull Request + status (`str`): + The status of the Discussion / Pull Request. + It can be one of: + * `"open"` + * `"closed"` + * `"merged"` (only for Pull Requests ) + * `"draft"` (only for Pull Requests ) + num (`int`): + The number of the Discussion / Pull Request. + repo_id (`str`): + The id (`"{namespace}/{repo_name}"`) of the repo on which + the Discussion / Pull Request was open. + repo_type (`str`): + The type of the repo on which the Discussion / Pull Request was open. + Possible values are: `"model"`, `"dataset"`, `"space"`. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + is_pull_request (`bool`): + Whether or not this is a Pull Request. + created_at (`datetime`): + The `datetime` of creation of the Discussion / Pull Request. + events (`list` of [`DiscussionEvent`]) + The list of [`DiscussionEvents`] in this Discussion or Pull Request. + conflicting_files (`Union[list[str], bool, None]`, *optional*): + A list of conflicting files if this is a Pull Request. + `None` if `self.is_pull_request` is `False`. + `True` if there are conflicting files but the list can't be retrieved. + target_branch (`str`, *optional*): + The branch into which changes are to be merged if this is a + Pull Request . `None` if `self.is_pull_request` is `False`. + merge_commit_oid (`str`, *optional*): + If this is a merged Pull Request , this is set to the OID / SHA of + the merge commit, `None` otherwise. + diff (`str`, *optional*): + The git diff if this is a Pull Request , `None` otherwise. + endpoint (`str`): + Endpoint of the Hub. Default is https://huggingface.co. + git_reference (`str`, *optional*): + (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise. + url (`str`): + (property) URL of the discussion on the Hub. + """ + + events: list["DiscussionEvent"] + conflicting_files: Union[list[str], bool, None] + target_branch: Optional[str] + merge_commit_oid: Optional[str] + diff: Optional[str] + + +class DiscussionEventArgs(TypedDict): + id: str + type: str + created_at: datetime + author: str + _event: dict + + +@dataclass +class DiscussionEvent: + """ + An event in a Discussion or Pull Request. + + Use concrete classes: + * [`DiscussionComment`] + * [`DiscussionStatusChange`] + * [`DiscussionCommit`] + * [`DiscussionTitleChange`] + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + """ + + id: str + type: str + created_at: datetime + author: str + + _event: dict + """Stores the original event data, in case we need to access it later.""" + + +@dataclass +class DiscussionComment(DiscussionEvent): + """A comment in a Discussion / Pull Request. + + Subclass of [`DiscussionEvent`]. + + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + content (`str`): + The raw markdown content of the comment. Mentions, links and images are not rendered. + edited (`bool`): + Whether or not this comment has been edited. + hidden (`bool`): + Whether or not this comment has been hidden. + """ + + content: str + edited: bool + hidden: bool + + @property + def rendered(self) -> str: + """The rendered comment, as a HTML string""" + return self._event["data"]["latest"]["html"] + + @property + def last_edited_at(self) -> datetime: + """The last edit time, as a `datetime` object.""" + return parse_datetime(self._event["data"]["latest"]["updatedAt"]) + + @property + def last_edited_by(self) -> str: + """The last edit time, as a `datetime` object.""" + return self._event["data"]["latest"].get("author", {}).get("name", "deleted") + + @property + def edit_history(self) -> list[dict]: + """The edit history of the comment""" + return self._event["data"]["history"] + + @property + def number_of_edits(self) -> int: + return len(self.edit_history) + + +@dataclass +class DiscussionStatusChange(DiscussionEvent): + """A change of status in a Discussion / Pull Request. + + Subclass of [`DiscussionEvent`]. + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + new_status (`str`): + The status of the Discussion / Pull Request after the change. + It can be one of: + * `"open"` + * `"closed"` + * `"merged"` (only for Pull Requests ) + """ + + new_status: str + + +@dataclass +class DiscussionCommit(DiscussionEvent): + """A commit in a Pull Request. + + Subclass of [`DiscussionEvent`]. + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + summary (`str`): + The summary of the commit. + oid (`str`): + The OID / SHA of the commit, as a hexadecimal string. + """ + + summary: str + oid: str + + +@dataclass +class DiscussionTitleChange(DiscussionEvent): + """A rename event in a Discussion / Pull Request. + + Subclass of [`DiscussionEvent`]. + + Attributes: + id (`str`): + The ID of the event. An hexadecimal string. + type (`str`): + The type of the event. + created_at (`datetime`): + A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) + object holding the creation timestamp for the event. + author (`str`): + The username of the Discussion / Pull Request author. + Can be `"deleted"` if the user has been deleted since. + old_title (`str`): + The previous title for the Discussion / Pull Request. + new_title (`str`): + The new title. + """ + + old_title: str + new_title: str + + +def deserialize_event(event: dict) -> DiscussionEvent: + """Instantiates a [`DiscussionEvent`] from a dict""" + event_id: str = event["id"] + event_type: str = event["type"] + created_at = parse_datetime(event["createdAt"]) + + common_args: DiscussionEventArgs = { + "id": event_id, + "type": event_type, + "created_at": created_at, + "author": event.get("author", {}).get("name", "deleted"), + "_event": event, + } + + if event_type == "comment": + return DiscussionComment( + **common_args, + edited=event["data"]["edited"], + hidden=event["data"]["hidden"], + content=event["data"]["latest"]["raw"], + ) + if event_type == "status-change": + return DiscussionStatusChange( + **common_args, + new_status=event["data"]["status"], + ) + if event_type == "commit": + return DiscussionCommit( + **common_args, + summary=event["data"]["subject"], + oid=event["data"]["oid"], + ) + if event_type == "title-change": + return DiscussionTitleChange( + **common_args, + old_title=event["data"]["from"], + new_title=event["data"]["to"], + ) + + return DiscussionEvent(**common_args) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/constants.py b/venv/lib/python3.10/site-packages/huggingface_hub/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..56ab0bfca70587f1c2f02f8d24fb5f991c3c1ded --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/constants.py @@ -0,0 +1,281 @@ +import os +import re +import typing +from typing import Literal, Optional + + +# Possible values for env variables + + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + + +def _is_true(value: Optional[str]) -> bool: + if value is None: + return False + return value.upper() in ENV_VARS_TRUE_VALUES + + +def _as_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + return int(value) + + +# Constants for file downloads + +PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" +TF2_WEIGHTS_NAME = "tf_model.h5" +TF_WEIGHTS_NAME = "model.ckpt" +FLAX_WEIGHTS_NAME = "flax_model.msgpack" +CONFIG_NAME = "config.json" +REPOCARD_NAME = "README.md" +EVAL_RESULTS_FOLDER = ".eval_results" +DEFAULT_ETAG_TIMEOUT = 10 +DEFAULT_DOWNLOAD_TIMEOUT = 10 +DEFAULT_REQUEST_TIMEOUT = 10 +DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024 +MAX_HTTP_DOWNLOAD_SIZE = 50 * 1000 * 1000 * 1000 # 50 GB + +# Constants for serialization + +PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead +SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors" +TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5" + +# Constants for safetensors repos + +SAFETENSORS_SINGLE_FILE = "model.safetensors" +SAFETENSORS_INDEX_FILE = "model.safetensors.index.json" +SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000 + +# Timeout of aquiring file lock and logging the attempt +FILELOCK_LOG_EVERY_SECONDS = 10 + +# Git-related constants + +DEFAULT_REVISION = "main" +REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}") + +HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/" + +_staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING")) + +_HF_DEFAULT_ENDPOINT = "https://huggingface.co" +_HF_DEFAULT_STAGING_ENDPOINT = "https://hub-ci.huggingface.co" +ENDPOINT = os.getenv("HF_ENDPOINT", _HF_DEFAULT_ENDPOINT).rstrip("/") +HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" + +if _staging_mode: + ENDPOINT = _HF_DEFAULT_STAGING_ENDPOINT + HUGGINGFACE_CO_URL_TEMPLATE = _HF_DEFAULT_STAGING_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" + +HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit" +HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag" +HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size" +HUGGINGFACE_HEADER_X_BILL_TO = "X-HF-Bill-To" + +INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co") + +# See https://huggingface.co/docs/inference-endpoints/index +INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2" +INFERENCE_CATALOG_ENDPOINT = "https://endpoints.huggingface.co/api/catalog" + +# See https://api.endpoints.huggingface.cloud/#post-/v2/endpoint/-namespace- +INFERENCE_ENDPOINT_IMAGE_KEYS = [ + "custom", + "huggingface", + "huggingfaceNeuron", + "llamacpp", + "tei", + "tgi", + "tgiNeuron", +] + +# Proxy for third-party providers +INFERENCE_PROXY_TEMPLATE = "https://router.huggingface.co/{provider}" + +REPO_ID_SEPARATOR = "--" +# ^ this substring is not allowed in repo_ids on hf.co +# and is the canonical one we use for serialization of repo ids elsewhere. + + +REPO_TYPE_DATASET = "dataset" +REPO_TYPE_SPACE = "space" +REPO_TYPE_MODEL = "model" +REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE] +SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"] + +REPO_TYPES_URL_PREFIXES = { + REPO_TYPE_DATASET: "datasets/", + REPO_TYPE_SPACE: "spaces/", +} +REPO_TYPES_MAPPING = { + "datasets": REPO_TYPE_DATASET, + "spaces": REPO_TYPE_SPACE, + "models": REPO_TYPE_MODEL, +} + +DiscussionTypeFilter = Literal["all", "discussion", "pull_request"] +DISCUSSION_TYPES: tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter) +DiscussionStatusFilter = Literal["all", "open", "closed"] +DISCUSSION_STATUS: tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter) + +# Webhook subscription types +WEBHOOK_DOMAIN_T = Literal["repo", "discussions"] + +# default cache +default_home = os.path.join(os.path.expanduser("~"), ".cache") +HF_HOME = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_HOME", + os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"), + ) + ) +) + +default_cache_path = os.path.join(HF_HOME, "hub") +default_assets_cache_path = os.path.join(HF_HOME, "assets") + +# Legacy env variables +HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path) +HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path) + +# New env variables +HF_HUB_CACHE = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_HUB_CACHE", + HUGGINGFACE_HUB_CACHE, + ) + ) +) +HF_ASSETS_CACHE = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_ASSETS_CACHE", + HUGGINGFACE_ASSETS_CACHE, + ) + ) +) + +HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE")) + + +def is_offline_mode() -> bool: + """Returns whether we are in offline mode for the Hub. + + When offline mode is enabled, all HTTP requests made with `get_session` will raise an `OfflineModeIsEnabled` exception. + + Example: + ```py + from huggingface_hub import is_offline_mode + + def list_files(repo_id: str): + if is_offline_mode(): + ... # list files from local cache (degraded experience but still functional) + else: + ... # list files from Hub (complete experience) + ``` + """ + return HF_HUB_OFFLINE + + +# File created to mark that the version check has been done. +# Check is performed once per 24 hours at most. +CHECK_FOR_UPDATE_DONE_PATH = os.path.join(HF_HOME, ".check_for_update_done") + +# If set, log level will be set to DEBUG and all requests made to the Hub will be logged +# as curl commands for reproducibility. +HF_DEBUG = _is_true(os.environ.get("HF_DEBUG")) + +# Opt-out from telemetry requests +HF_HUB_DISABLE_TELEMETRY = ( + _is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable + or _is_true(os.environ.get("DISABLE_TELEMETRY")) + or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/ +) + +HF_TOKEN_PATH = os.path.expandvars( + os.path.expanduser( + os.getenv( + "HF_TOKEN_PATH", + os.path.join(HF_HOME, "token"), + ) + ) +) +HF_STORED_TOKENS_PATH = os.path.join(os.path.dirname(HF_TOKEN_PATH), "stored_tokens") + +if _staging_mode: + # In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens + # In practice in `huggingface_hub` tests, we monkeypatch these values with temporary directories. The following + # lines are only used in third-party libraries tests (e.g. `transformers`, `diffusers`, etc.). + _staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging") + HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub") + HF_TOKEN_PATH = os.path.join(_staging_home, "token") + +# Here, `True` will disable progress bars globally without possibility of enabling it +# programmatically. `False` will enable them without possibility of disabling them. +# If environment variable is not set (None), then the user is free to enable/disable +# them programmatically. +# TL;DR: env variable has priority over code +__HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS") +HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = ( + _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None +) + +# Disable warning on machines that do not support symlinks (e.g. Windows non-developer) +HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING")) + +# Disable warning when using experimental features +HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING")) + +# Disable sending the cached token by default is all HTTP requests to the Hub +HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN")) + +HF_XET_HIGH_PERFORMANCE: bool = _is_true(os.environ.get("HF_XET_HIGH_PERFORMANCE")) + +# hf_transfer is not used anymore. Let's warn user is case they set the env variable +if _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")) and not HF_XET_HIGH_PERFORMANCE: + import warnings + + warnings.warn( + "The `HF_HUB_ENABLE_HF_TRANSFER` environment variable is deprecated as 'hf_transfer' is not used anymore. " + "Please use `HF_XET_HIGH_PERFORMANCE` instead to enable high performance transfer with Xet. " + "Visit https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfxethighperformance for more details.", + DeprecationWarning, + ) + +# Used to override the etag timeout on a system level +HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT + +# Used to override the get request timeout on a system level +# Also used as a default timeout for other requests if not specified (kept the naming for legacy reasons) +HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT + +# Allows to add information about the requester in the user-agent (e.g. partner name) +HF_HUB_USER_AGENT_ORIGIN: Optional[str] = os.environ.get("HF_HUB_USER_AGENT_ORIGIN") + +# If OAuth didn't work after 2 redirects, there's likely a third-party cookie issue in the Space iframe view. +# In this case, we redirect the user to the non-iframe view. +OAUTH_MAX_REDIRECTS = 2 + +# OAuth-related environment variables injected by the Space +OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID") +OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET") +OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES") +OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL") + +# Xet constants +HUGGINGFACE_HEADER_X_XET_ENDPOINT = "X-Xet-Cas-Url" +HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN = "X-Xet-Access-Token" +HUGGINGFACE_HEADER_X_XET_EXPIRATION = "X-Xet-Token-Expiration" +HUGGINGFACE_HEADER_X_XET_HASH = "X-Xet-Hash" +HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE = "X-Xet-Refresh-Route" +HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY = "xet-auth" + +default_xet_cache_path = os.path.join(HF_HOME, "xet") +HF_XET_CACHE = os.getenv("HF_XET_CACHE", default_xet_cache_path) +HF_HUB_DISABLE_XET: bool = _is_true(os.environ.get("HF_HUB_DISABLE_XET")) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/dataclasses.py b/venv/lib/python3.10/site-packages/huggingface_hub/dataclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..c59f951acc868ce27ee57e14730b6c84ec9d7952 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/dataclasses.py @@ -0,0 +1,629 @@ +import inspect +import sys +import types +from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass +from functools import lru_cache, wraps +from typing import ( + Annotated, + Any, + Callable, + ForwardRef, + Literal, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, + overload, +) + + +try: + # Python 3.11+ + from typing import NotRequired, Required # type: ignore +except ImportError: + try: + # In case typing_extensions is installed + from typing_extensions import NotRequired, Required # type: ignore + except ImportError: + # Fallback: create dummy types that will never match + Required = type("Required", (), {}) # type: ignore + NotRequired = type("NotRequired", (), {}) # type: ignore + +from .errors import ( + StrictDataclassClassValidationError, + StrictDataclassDefinitionError, + StrictDataclassFieldValidationError, +) + + +Validator_T = Callable[[Any], None] +T = TypeVar("T") +TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any]) + +_TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None) + + +# The overload decorator helps type checkers understand the different return types +@overload +def strict(cls: Type[T]) -> Type[T]: ... + + +@overload +def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ... + + +def strict( + cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False +) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + """ + Decorator to add strict validation to a dataclass. + + This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools + recognize the class as a dataclass. + + Can be used with or without arguments: + - `@strict` + - `@strict(accept_kwargs=True)` + + Args: + cls: + The class to convert to a strict dataclass. + accept_kwargs (`bool`, *optional*): + If True, allows arbitrary keyword arguments in `__init__`. Defaults to False. + + Returns: + The enhanced dataclass with strict validation on field assignment. + + Example: + ```py + >>> from dataclasses import dataclass + >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field + + >>> @as_validated_field + >>> def positive_int(value: int): + ... if not value >= 0: + ... raise ValueError(f"Value must be positive, got {value}") + + >>> @strict(accept_kwargs=True) + ... @dataclass + ... class User: + ... name: str + ... age: int = positive_int(default=10) + + # Initialize + >>> User(name="John") + User(name='John', age=10) + + # Extra kwargs are accepted + >>> User(name="John", age=30, lastname="Doe") + User(name='John', age=30, *lastname='Doe') + + # Invalid type => raises + >>> User(name="John", age="30") + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + TypeError: Field 'age' expected int, got str (value: '30') + + # Invalid value => raises + >>> User(name="John", age=-1) + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + ValueError: Value must be positive, got -1 + ``` + """ + + def wrap(cls: Type[T]) -> Type[T]: + if not hasattr(cls, "__dataclass_fields__"): + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' must be a dataclass before applying @strict." + ) + + # List and store validators + field_validators: dict[str, list[Validator_T]] = {} + for f in fields(cls): # type: ignore [arg-type] + validators = [] + validators.append(_create_type_validator(f)) + custom_validator = f.metadata.get("validator") + if custom_validator is not None: + if not isinstance(custom_validator, list): + custom_validator = [custom_validator] + for validator in custom_validator: + if not _is_validator(validator): + raise StrictDataclassDefinitionError( + f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument." + ) + validators.extend(custom_validator) + field_validators[f.name] = validators + cls.__validators__ = field_validators # type: ignore + + # Override __setattr__ to validate fields on assignment + original_setattr = cls.__setattr__ + + def __strict_setattr__(self: Any, name: str, value: Any) -> None: + """Custom __setattr__ method for strict dataclasses.""" + # Run all validators + for validator in self.__validators__.get(name, []): + try: + validator(value) + except (ValueError, TypeError) as e: + raise StrictDataclassFieldValidationError(field=name, cause=e) from e + + # If validation passed, set the attribute + original_setattr(self, name, value) + + cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign] + + if accept_kwargs: + # (optional) Override __init__ to accept arbitrary keyword arguments + original_init = cls.__init__ + + @wraps(original_init) + def __init__(self, **kwargs: Any) -> None: + # Extract only the fields that are part of the dataclass + dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type] + standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields} + + # Call the original __init__ with standard fields + original_init(self, **standard_kwargs) + + # Pass any additional kwargs to `__post_init__` and let the object + # decide whether to set the attr or use for different purposes (e.g. BC checks) + additional_kwargs = {} + for name, value in kwargs.items(): + if name not in dataclass_fields: + additional_kwargs[name] = value + + self.__post_init__(**additional_kwargs) + + cls.__init__ = __init__ # type: ignore[method-assign] + + # Define a default __post_init__ if not defined + if not hasattr(cls, "__post_init__"): + + def __post_init__(self, **kwargs: Any) -> None: + """Default __post_init__ to accept additional kwargs.""" + for name, value in kwargs.items(): + setattr(self, name, value) + + cls.__post_init__ = __post_init__ # type: ignore + + # (optional) Override __repr__ to include additional kwargs + original_repr = cls.__repr__ + + @wraps(original_repr) + def __repr__(self) -> str: + # Call the original __repr__ to get the standard fields + standard_repr = original_repr(self) + + # Get additional kwargs + additional_kwargs = [ + # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass + f"*{k}={v!r}" + for k, v in self.__dict__.items() + if k not in cls.__dataclass_fields__ # type: ignore [attr-defined] + ] + additional_repr = ", ".join(additional_kwargs) + + # Combine both representations + return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr + + cls.__repr__ = __repr__ # type: ignore [method-assign] + + # List all public methods starting with `validate_` => class validators. + class_validators = [] + + for name in dir(cls): + if not name.startswith("validate_"): + continue + method = getattr(cls, name) + if not callable(method): + continue + if len(inspect.signature(method).parameters) != 1: + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument." + " Class validators must take only 'self' as an argument. Methods starting with 'validate_'" + " are considered to be class validators." + ) + class_validators.append(method) + + cls.__class_validators__ = class_validators # type: ignore [attr-defined] + + # Add `validate` method to the class, but first check if it already exists + def validate(self: T) -> None: + """Run class validators on the instance.""" + for validator in cls.__class_validators__: # type: ignore [attr-defined] + try: + validator(self) + except (ValueError, TypeError) as e: + raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e + + # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class + # (in which case we just override it) + validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined] + + if hasattr(cls, "validate"): + if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined] + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' already implements a method called 'validate'." + " This method name is reserved when using the @strict decorator on a dataclass." + " If you want to keep your own method, please rename it." + ) + + cls.validate = validate # type: ignore + + # Run class validators after initialization + initial_init = cls.__init__ + + @wraps(initial_init) + def init_with_validate(self, *args, **kwargs) -> None: + """Run class validators after initialization.""" + initial_init(self, *args, **kwargs) # type: ignore [call-arg] + cls.validate(self) # type: ignore [attr-defined] + + setattr(cls, "__init__", init_with_validate) + + return cls + + # Return wrapped class or the decorator itself + return wrap(cls) if cls is not None else wrap + + +def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None: + """ + Validate that a dictionary conforms to the types defined in a TypedDict class. + + Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator. + + Args: + schema (`type[TypedDictType]`): + The TypedDict class defining the expected structure and types. + data (`dict`): + The dictionary to validate. + + Raises: + `StrictDataclassFieldValidationError`: + If any field in the dictionary does not conform to the expected type. + + Example: + ```py + >>> from typing import Annotated, TypedDict + >>> from huggingface_hub.dataclasses import validate_typed_dict + + >>> def positive_int(value: int): + ... if not value >= 0: + ... raise ValueError(f"Value must be positive, got {value}") + + >>> class User(TypedDict): + ... name: str + ... age: Annotated[int, positive_int] + + >>> # Valid data + >>> validate_typed_dict(User, {"name": "John", "age": 30}) + + >>> # Invalid type for age + >>> validate_typed_dict(User, {"name": "John", "age": "30"}) + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + TypeError: Field 'age' expected int, got str (value: '30') + + >>> # Invalid value for age + >>> validate_typed_dict(User, {"name": "John", "age": -1}) + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + ValueError: Value must be positive, got -1 + ``` + """ + # Convert typed dict to dataclass + strict_cls = _build_strict_cls_from_typed_dict(schema) + + # Validate the data by instantiating the strict dataclass + strict_cls(**data) # will raise if validation fails + + +@lru_cache +def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type: + # Extract type hints from the TypedDict class + type_hints = _get_typed_dict_annotations(schema) + + # If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired) + if not getattr(schema, "__total__", True): + for key, value in type_hints.items(): + origin = get_origin(value) + + if origin is Annotated: + base, *meta = get_args(value) + if not _is_required_or_notrequired(base): + base = NotRequired[base] + type_hints[key] = Annotated[tuple([base] + list(meta))] # type: ignore + elif not _is_required_or_notrequired(value): + type_hints[key] = NotRequired[value] + + # Convert type hints to dataclass fields + fields = [] + for key, value in type_hints.items(): + if get_origin(value) is Annotated: + base, *meta = get_args(value) + fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]}))) + else: + fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE))) + + # Create a strict dataclass from the TypedDict fields + return strict(make_dataclass(schema.__name__, fields)) + + +def _get_typed_dict_annotations(schema: type[TypedDictType]) -> dict[str, Any]: + """Extract type annotations from a TypedDict class.""" + try: + # Available in Python 3.14+ + import annotationlib + + return annotationlib.get_annotations(schema) + except ImportError: + return { + # We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail). + # ForwardRefs are not validated by @strict anyway. + name: value if value is not None else type(None) + for name, value in schema.__dict__.get("__annotations__", {}).items() + } + + +def validated_field( + validator: Union[list[Validator_T], Validator_T], + default: Union[Any, _MISSING_TYPE] = MISSING, + default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[dict] = None, + **kwargs: Any, +) -> Any: + """ + Create a dataclass field with a custom validator. + + Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator. + + Args: + validator (`Callable` or `list[Callable]`): + A method that takes a value as input and raises ValueError/TypeError if the value is invalid. + Can be a list of validators to apply multiple checks. + **kwargs: + Additional arguments to pass to `dataclasses.field()`. + + Returns: + A field with the validator attached in metadata + """ + if not isinstance(validator, list): + validator = [validator] + if metadata is None: + metadata = {} + metadata["validator"] = validator + return field( # type: ignore + default=default, # type: ignore [arg-type] + default_factory=default_factory, # type: ignore [arg-type] + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + **kwargs, + ) + + +def as_validated_field(validator: Validator_T): + """ + Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator). + + Args: + validator (`Callable`): + A method that takes a value as input and raises ValueError/TypeError if the value is invalid. + """ + + def _inner( + default: Union[Any, _MISSING_TYPE] = MISSING, + default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[dict] = None, + **kwargs: Any, + ): + return validated_field( + validator, + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + **kwargs, + ) + + return _inner + + +def type_validator(name: str, value: Any, expected_type: Any) -> None: + """Validate that 'value' matches 'expected_type'.""" + origin = get_origin(expected_type) + args = get_args(expected_type) + + if expected_type is Any: + return + elif validator := _BASIC_TYPE_VALIDATORS.get(origin): + validator(name, value, args) + elif isinstance(expected_type, type): # simple types + _validate_simple_type(name, value, expected_type) + elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str): + return + elif origin is Required: + if value is _TYPED_DICT_DEFAULT_VALUE: + raise TypeError(f"Field '{name}' is required but missing.") + type_validator(name, value, args[0]) + elif origin is NotRequired: + if value is _TYPED_DICT_DEFAULT_VALUE: + return + type_validator(name, value, args[0]) + else: + raise TypeError(f"Unsupported type for field '{name}': {expected_type}") + + +def _validate_union(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate that value matches one of the types in a Union.""" + errors = [] + for t in args: + try: + type_validator(name, value, t) + return # Valid if any type matches + except TypeError as e: + errors.append(str(e)) + + raise TypeError( + f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}" + ) + + +def _validate_literal(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate Literal type.""" + if value not in args: + raise TypeError(f"Field '{name}' expected one of {args}, got {value}") + + +def _validate_list(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate list[T] type.""" + if not isinstance(value, list): + raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}") + + # Validate each item in the list + item_type = args[0] + for i, item in enumerate(value): + try: + type_validator(f"{name}[{i}]", item, item_type) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in list '{name}'") from e + + +def _validate_dict(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate dict[K, V] type.""" + if not isinstance(value, dict): + raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}") + + # Validate keys and values + key_type, value_type = args + for k, v in value.items(): + try: + type_validator(f"{name}.key", k, key_type) + type_validator(f"{name}[{k!r}]", v, value_type) + except TypeError as e: + raise TypeError(f"Invalid key or value in dict '{name}'") from e + + +def _validate_tuple(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate Tuple type.""" + if not isinstance(value, tuple): + raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}") + + # Handle variable-length tuples: tuple[T, ...] + if len(args) == 2 and args[1] is Ellipsis: + for i, item in enumerate(value): + try: + type_validator(f"{name}[{i}]", item, args[0]) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e + # Handle fixed-length tuples: tuple[T1, T2, ...] + elif len(args) != len(value): + raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}") + else: + for i, (item, expected) in enumerate(zip(value, args)): + try: + type_validator(f"{name}[{i}]", item, expected) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e + + +def _validate_set(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate set[T] type.""" + if not isinstance(value, set): + raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}") + + # Validate each item in the set + item_type = args[0] + for i, item in enumerate(value): + try: + type_validator(f"{name} item", item, item_type) + except TypeError as e: + raise TypeError(f"Invalid item in set '{name}'") from e + + +def _validate_simple_type(name: str, value: Any, expected_type: type) -> None: + """Validate simple type (int, str, etc.).""" + if not isinstance(value, expected_type): + raise TypeError( + f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})" + ) + + +def _create_type_validator(field: Field) -> Validator_T: + """Create a type validator function for a field.""" + # Hacky: we cannot use a lambda here because of reference issues + + def validator(value: Any) -> None: + type_validator(field.name, value, field.type) + + return validator + + +def _is_validator(validator: Any) -> bool: + """Check if a function is a validator. + + A validator is a Callable that can be called with a single positional argument. + The validator can have more arguments with default values. + + Basically, returns True if `validator(value)` is possible. + """ + if not callable(validator): + return False + + signature = inspect.signature(validator) + parameters = list(signature.parameters.values()) + if len(parameters) == 0: + return False + if parameters[0].kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.VAR_POSITIONAL, + ): + return False + for parameter in parameters[1:]: + if parameter.default == inspect.Parameter.empty: + return False + return True + + +def _is_required_or_notrequired(type_hint: Any) -> bool: + """Helper to check if a type is Required/NotRequired.""" + return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired)) + + +_BASIC_TYPE_VALIDATORS = { + Union: _validate_union, + Literal: _validate_literal, + list: _validate_list, + dict: _validate_dict, + tuple: _validate_tuple, + set: _validate_set, +} + +if sys.version_info >= (3, 10): + # TODO: make it first class citizen when bumping to Python 3.10+ + _BASIC_TYPE_VALIDATORS[types.UnionType] = _validate_union # x | y syntax, available only Python 3.10+ + + +__all__ = [ + "strict", + "validate_typed_dict", + "validated_field", + "Validator_T", + "StrictDataclassClassValidationError", + "StrictDataclassDefinitionError", + "StrictDataclassFieldValidationError", +] diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/errors.py b/venv/lib/python3.10/site-packages/huggingface_hub/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..6d917e98486bb4d573e31c1fedf56f96cf1b7234 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/errors.py @@ -0,0 +1,415 @@ +"""Contains all custom errors.""" + +from pathlib import Path +from typing import Optional, Union + +from httpx import HTTPError, Response + + +# CACHE ERRORS + + +class CacheNotFound(Exception): + """Exception thrown when the Huggingface cache is not found.""" + + cache_dir: Union[str, Path] + + def __init__(self, msg: str, cache_dir: Union[str, Path], *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.cache_dir = cache_dir + + +class CorruptedCacheException(Exception): + """Exception for any unexpected structure in the Huggingface cache-system.""" + + +# HEADERS ERRORS + + +class LocalTokenNotFoundError(EnvironmentError): + """Raised if local token is required but not found.""" + + +# HTTP ERRORS + + +class OfflineModeIsEnabled(ConnectionError): + """Raised when a request is made but `HF_HUB_OFFLINE=1` is set as environment variable.""" + + +class HfHubHTTPError(HTTPError, OSError): + """ + HTTPError to inherit from for any custom HTTP Error raised in HF Hub. + + Any HTTPError is converted at least into a `HfHubHTTPError`. If some information is + sent back by the server, it will be added to the error message. + + Added details: + - Request ID sourced from headers in order of precedence: "X-Request-Id", "X-Amzn-Trace-Id", "X-Amz-Cf-Id". + - Server error message from the header "X-Error-Message". + - Server error message if we can found one in the response body. + + Example: + ```py + import httpx + from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError + + response = get_session().post(...) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + print(str(e)) # formatted message + e.request_id, e.server_message # details returned by server + + # Complete the error message with additional information once it's raised + e.append_to_message("\n`create_commit` expects the repository to exist.") + raise + ``` + """ + + def __init__( + self, + message: str, + *, + response: Response, + server_message: Optional[str] = None, + ): + self.request_id = ( + response.headers.get("x-request-id") + or response.headers.get("X-Amzn-Trace-Id") + or response.headers.get("x-amz-cf-id") + ) + self.server_message = server_message + self.response = response + self.request = response.request + super().__init__(message) + + def append_to_message(self, additional_message: str) -> None: + """Append additional information to the `HfHubHTTPError` initial message.""" + self.args = (self.args[0] + additional_message,) + self.args[1:] + + @classmethod + def _reconstruct_hf_hub_http_error( + cls, message: str, response: Response, server_message: Optional[str] + ) -> "HfHubHTTPError": + return cls(message, response=response, server_message=server_message) + + def __reduce_ex__(self, protocol): + """Fix pickling of Exception subclass with kwargs. We need to override __reduce_ex__ of the parent class""" + return (self.__class__._reconstruct_hf_hub_http_error, (str(self), self.response, self.server_message)) + + +# INFERENCE CLIENT ERRORS + + +class InferenceTimeoutError(HTTPError, TimeoutError): + """Error raised when a model is unavailable or the request times out.""" + + +# INFERENCE ENDPOINT ERRORS + + +class InferenceEndpointError(Exception): + """Generic exception when dealing with Inference Endpoints.""" + + +class InferenceEndpointTimeoutError(InferenceEndpointError, TimeoutError): + """Exception for timeouts while waiting for Inference Endpoint.""" + + +# SAFETENSORS ERRORS + + +class SafetensorsParsingError(Exception): + """Raised when failing to parse a safetensors file metadata. + + This can be the case if the file is not a safetensors file or does not respect the specification. + """ + + +class NotASafetensorsRepoError(Exception): + """Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a + `model.safetensors.index.json` file. + """ + + +# TEXT GENERATION ERRORS + + +class TextGenerationError(HTTPError): + """Generic error raised if text-generation went wrong.""" + + +# Text Generation Inference Errors +class ValidationError(TextGenerationError): + """Server-side validation error.""" + + +class GenerationError(TextGenerationError): + pass + + +class OverloadedError(TextGenerationError): + pass + + +class IncompleteGenerationError(TextGenerationError): + pass + + +class UnknownError(TextGenerationError): + pass + + +# VALIDATION ERRORS + + +class HFValidationError(ValueError): + """Generic exception thrown by `huggingface_hub` validators. + + Inherits from [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError). + """ + + +# FILE METADATA ERRORS + + +class DryRunError(OSError): + """Error triggered when a dry run is requested but cannot be performed (e.g. invalid repo).""" + + +class FileMetadataError(OSError): + """Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash). + + Inherits from `OSError` for backward compatibility. + """ + + +# REPOSITORY ERRORS + + +class RepositoryNotFoundError(HfHubHTTPError): + """ + Raised when trying to access a hf.co URL with an invalid repository name, or + with a private repo name the user does not have access to. + + Example: + + ```py + >>> from huggingface_hub import model_info + >>> model_info("") + (...) + huggingface_hub.errors.RepositoryNotFoundError: 401 Client Error. (Request ID: PvMw_VjBMjVdMz53WKIzP) + + Repository Not Found for url: https://huggingface.co/api/models/%3Cnon_existent_repository%3E. + Please make sure you specified the correct `repo_id` and `repo_type`. + If the repo is private, make sure you are authenticated. + Invalid username or password. + ``` + """ + + +class GatedRepoError(RepositoryNotFoundError): + """ + Raised when trying to access a gated repository for which the user is not on the + authorized list. + + Note: derives from `RepositoryNotFoundError` to ensure backward compatibility. + + Example: + + ```py + >>> from huggingface_hub import model_info + >>> model_info("") + (...) + huggingface_hub.errors.GatedRepoError: 403 Client Error. (Request ID: ViT1Bf7O_026LGSQuVqfa) + + Cannot access gated repo for url https://huggingface.co/api/models/ardent-figment/gated-model. + Access to model ardent-figment/gated-model is restricted and you are not in the authorized list. + Visit https://huggingface.co/ardent-figment/gated-model to ask for access. + ``` + """ + + +class DisabledRepoError(HfHubHTTPError): + """ + Raised when trying to access a repository that has been disabled by its author. + + Example: + + ```py + >>> from huggingface_hub import dataset_info + >>> dataset_info("laion/laion-art") + (...) + huggingface_hub.errors.DisabledRepoError: 403 Client Error. (Request ID: Root=1-659fc3fa-3031673e0f92c71a2260dbe2;bc6f4dfb-b30a-4862-af0a-5cfe827610d8) + + Cannot access repository for url https://huggingface.co/api/datasets/laion/laion-art. + Access to this resource is disabled. + ``` + """ + + +# REVISION ERROR + + +class RevisionNotFoundError(HfHubHTTPError): + """ + Raised when trying to access a hf.co URL with a valid repository but an invalid + revision. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', 'config.json', revision='') + (...) + huggingface_hub.errors.RevisionNotFoundError: 404 Client Error. (Request ID: Mwhe_c3Kt650GcdKEFomX) + + Revision Not Found for url: https://huggingface.co/bert-base-cased/resolve/%3Cnon-existent-revision%3E/config.json. + ``` + """ + + +# ENTRY ERRORS +class EntryNotFoundError(Exception): + """ + Raised when entry not found, either locally or remotely. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', '') + (...) + huggingface_hub.errors.RemoteEntryNotFoundError (...) + >>> hf_hub_download('bert-base-cased', '', local_files_only=True) + (...) + huggingface_hub.utils.errors.LocalEntryNotFoundError (...) + ``` + """ + + +class RemoteEntryNotFoundError(HfHubHTTPError, EntryNotFoundError): + """ + Raised when trying to access a hf.co URL with a valid repository and revision + but an invalid filename. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', '') + (...) + huggingface_hub.errors.EntryNotFoundError: 404 Client Error. (Request ID: 53pNl6M0MxsnG5Sw8JA6x) + + Entry Not Found for url: https://huggingface.co/bert-base-cased/resolve/main/%3Cnon-existent-file%3E. + ``` + """ + + +class LocalEntryNotFoundError(FileNotFoundError, EntryNotFoundError): + """ + Raised when trying to access a file or snapshot that is not on the disk when network is + disabled or unavailable (connection issue). The entry may exist on the Hub. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', '', local_files_only=True) + (...) + huggingface_hub.errors.LocalEntryNotFoundError: Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co look-ups and downloads online, set 'local_files_only' to False. + ``` + """ + + def __init__(self, message: str): + super().__init__(message) + + +# REQUEST ERROR +class BadRequestError(HfHubHTTPError, ValueError): + """ + Raised by `hf_raise_for_status` when the server returns a HTTP 400 error. + + Example: + + ```py + >>> resp = httpx.post("hf.co/api/check", ...) + >>> hf_raise_for_status(resp, endpoint_name="check") + huggingface_hub.errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX) + ``` + """ + + +# DDUF file format ERROR + + +class DDUFError(Exception): + """Base exception for errors related to the DDUF format.""" + + +class DDUFCorruptedFileError(DDUFError): + """Exception thrown when the DDUF file is corrupted.""" + + +class DDUFExportError(DDUFError): + """Base exception for errors during DDUF export.""" + + +class DDUFInvalidEntryNameError(DDUFExportError): + """Exception thrown when the entry name is invalid.""" + + +# STRICT DATACLASSES ERRORS + + +class StrictDataclassError(Exception): + """Base exception for strict dataclasses.""" + + +class StrictDataclassDefinitionError(StrictDataclassError): + """Exception thrown when a strict dataclass is defined incorrectly.""" + + +class StrictDataclassFieldValidationError(StrictDataclassError): + """Exception thrown when a strict dataclass fails validation for a given field.""" + + def __init__(self, field: str, cause: Exception): + error_message = f"Validation error for field '{field}':" + error_message += f"\n {cause.__class__.__name__}: {cause}" + super().__init__(error_message) + + +class StrictDataclassClassValidationError(StrictDataclassError): + """Exception thrown when a strict dataclass fails validation on a class validator.""" + + def __init__(self, validator: str, cause: Exception): + error_message = f"Class validation error for validator '{validator}':" + error_message += f"\n {cause.__class__.__name__}: {cause}" + super().__init__(error_message) + + +# XET ERRORS + + +class XetError(Exception): + """Base exception for errors related to Xet Storage.""" + + +class XetAuthorizationError(XetError): + """Exception thrown when the user does not have the right authorization to use Xet Storage.""" + + +class XetRefreshTokenError(XetError): + """Exception thrown when the refresh token is invalid.""" + + +class XetDownloadError(Exception): + """Exception thrown when the download from Xet Storage fails.""" + + +# CLI ERRORS + + +class CLIError(Exception): + """CLI error with clean message (no traceback by default).""" diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/fastai_utils.py b/venv/lib/python3.10/site-packages/huggingface_hub/fastai_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..768cbf6450227075b4296f63140129cc6391894e --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/fastai_utils.py @@ -0,0 +1,414 @@ +import json +import os +from pathlib import Path +from pickle import DEFAULT_PROTOCOL, PicklingError +from typing import Any, Optional, Union + +from packaging import version + +from huggingface_hub import constants, snapshot_download +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import ( + SoftTemporaryDirectory, + get_fastai_version, + get_fastcore_version, + get_python_version, +) + +from .utils import logging, validate_hf_hub_args + + +logger = logging.get_logger(__name__) + + +def _check_fastai_fastcore_versions( + fastai_min_version: str = "2.4", + fastcore_min_version: str = "1.3.27", +): + """ + Checks that the installed fastai and fastcore versions are compatible for pickle serialization. + + Args: + fastai_min_version (`str`, *optional*): + The minimum fastai version supported. + fastcore_min_version (`str`, *optional*): + The minimum fastcore version supported. + + > [!TIP] + > Raises the following error: + > + > - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + > if the fastai or fastcore libraries are not available or are of an invalid version. + """ + + if (get_fastcore_version() or get_fastai_version()) == "N/A": + raise ImportError( + f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are" + f" required. Currently using fastai=={get_fastai_version()} and" + f" fastcore=={get_fastcore_version()}." + ) + + current_fastai_version = version.Version(get_fastai_version()) + current_fastcore_version = version.Version(get_fastcore_version()) + + if current_fastai_version < version.Version(fastai_min_version): + raise ImportError( + "`push_to_hub_fastai` and `from_pretrained_fastai` require a" + f" fastai>={fastai_min_version} version, but you are using fastai version" + f" {get_fastai_version()} which is incompatible. Upgrade with `pip install" + " fastai==2.5.6`." + ) + + if current_fastcore_version < version.Version(fastcore_min_version): + raise ImportError( + "`push_to_hub_fastai` and `from_pretrained_fastai` require a" + f" fastcore>={fastcore_min_version} version, but you are using fastcore" + f" version {get_fastcore_version()} which is incompatible. Upgrade with" + " `pip install fastcore==1.3.27`." + ) + + +def _check_fastai_fastcore_pyproject_versions( + storage_folder: str, + fastai_min_version: str = "2.4", + fastcore_min_version: str = "1.3.27", +): + """ + Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions + that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist + or does not contain versions for fastai and fastcore, then it logs a warning. + + Args: + storage_folder (`str`): + Folder to look for the `pyproject.toml` file. + fastai_min_version (`str`, *optional*): + The minimum fastai version supported. + fastcore_min_version (`str`, *optional*): + The minimum fastcore version supported. + + > [!TIP] + > Raises the following errors: + > + > - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + > if the `toml` module is not installed. + > - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + > if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore. + """ + + try: + import toml + except ModuleNotFoundError: + raise ImportError( + "`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module." + " Install it with `pip install toml`." + ) + + # Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages. + if not os.path.isfile(f"{storage_folder}/pyproject.toml"): + logger.warning( + "There is no `pyproject.toml` in the repository that contains the fastai" + " `Learner`. The `pyproject.toml` would allow us to verify that your fastai" + " and fastcore versions are compatible with those of the model you want to" + " load." + ) + return + pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml") + + if "build-system" not in pyproject_toml.keys(): + logger.warning( + "There is no `build-system` section in the pyproject.toml of the repository" + " that contains the fastai `Learner`. The `build-system` would allow us to" + " verify that your fastai and fastcore versions are compatible with those" + " of the model you want to load." + ) + return + build_system_toml = pyproject_toml["build-system"] + + if "requires" not in build_system_toml.keys(): + logger.warning( + "There is no `requires` section in the pyproject.toml of the repository" + " that contains the fastai `Learner`. The `requires` would allow us to" + " verify that your fastai and fastcore versions are compatible with those" + " of the model you want to load." + ) + return + package_versions = build_system_toml["requires"] + + # Extracts contains fastai and fastcore versions from `pyproject.toml` if available. + # If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest. + fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")] + if len(fastai_packages) == 0: + logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.") + # fastai_version is an empty string if not specified + else: + fastai_version = str(fastai_packages[0]).partition("=")[2] + if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version): + raise ImportError( + "`from_pretrained_fastai` requires" + f" fastai>={fastai_min_version} version but the model to load uses" + f" {fastai_version} which is incompatible." + ) + + fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")] + if len(fastcore_packages) == 0: + logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.") + # fastcore_version is an empty string if not specified + else: + fastcore_version = str(fastcore_packages[0]).partition("=")[2] + if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version): + raise ImportError( + "`from_pretrained_fastai` requires" + f" fastcore>={fastcore_min_version} version, but you are using fastcore" + f" version {fastcore_version} which is incompatible." + ) + + +README_TEMPLATE = """--- +tags: +- fastai +--- + +# Amazing! + +🥳 Congratulations on hosting your fastai model on the Hugging Face Hub! + +# Some next steps +1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))! + +2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)). + +3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)! + +Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card. + + +--- + + +# Model card + +## Model description +More information needed + +## Intended uses & limitations +More information needed + +## Training and evaluation data +More information needed +""" + +PYPROJECT_TEMPLATE = f"""[build-system] +requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"] +build-backend = "setuptools.build_meta:__legacy__" +""" + + +def _create_model_card(repo_dir: Path): + """ + Creates a model card for the repository. + + Args: + repo_dir (`Path`): + Directory where model card is created. + """ + readme_path = repo_dir / "README.md" + + if not readme_path.exists(): + with readme_path.open("w", encoding="utf-8") as f: + f.write(README_TEMPLATE) + + +def _create_model_pyproject(repo_dir: Path): + """ + Creates a `pyproject.toml` for the repository. + + Args: + repo_dir (`Path`): + Directory where `pyproject.toml` is created. + """ + pyproject_path = repo_dir / "pyproject.toml" + + if not pyproject_path.exists(): + with pyproject_path.open("w", encoding="utf-8") as f: + f.write(PYPROJECT_TEMPLATE) + + +def _save_pretrained_fastai( + learner, + save_directory: Union[str, Path], + config: Optional[dict[str, Any]] = None, +): + """ + Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used. + + Args: + learner (`Learner`): + The `fastai.Learner` you'd like to save. + save_directory (`str` or `Path`): + Specific directory in which you want to save the fastai learner. + config (`dict`, *optional*): + Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'. + + > [!TIP] + > Raises the following error: + > + > - [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError) + > if the config file provided is not a dictionary. + """ + _check_fastai_fastcore_versions() + + os.makedirs(save_directory, exist_ok=True) + + # if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE. + if config is not None: + if not isinstance(config, dict): + raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'") + path = os.path.join(save_directory, constants.CONFIG_NAME) + with open(path, "w") as f: + json.dump(config, f) + + _create_model_card(Path(save_directory)) + _create_model_pyproject(Path(save_directory)) + + # learner.export saves the model in `self.path`. + learner.path = Path(save_directory) + os.makedirs(save_directory, exist_ok=True) + try: + learner.export( + fname="model.pkl", + pickle_protocol=DEFAULT_PROTOCOL, + ) + except PicklingError: + raise PicklingError( + "You are using a lambda function, i.e., an anonymous function. `pickle`" + " cannot pickle function objects and requires that all functions have" + " names. One possible solution is to name the function." + ) + + +@validate_hf_hub_args +def from_pretrained_fastai( + repo_id: str, + revision: Optional[str] = None, +): + """ + Load pretrained fastai model from the Hub or from a local directory. + + Args: + repo_id (`str`): + The location where the pickled fastai.Learner is. It can be either of the two: + - Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'. + You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`. + Revision is the specific model version to use. Since we use a git-based system for storing models and other + artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id. + - Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml + indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`. + revision (`str`, *optional*): + Revision at which the repo's files are downloaded. See documentation of `snapshot_download`. + + Returns: + The `fastai.Learner` model in the `repo_id` repo. + """ + _check_fastai_fastcore_versions() + + # Load the `repo_id` repo. + # `snapshot_download` returns the folder where the model was stored. + # `cache_dir` will be the default '/root/.cache/huggingface/hub' + if not os.path.isdir(repo_id): + storage_folder = snapshot_download( + repo_id=repo_id, + revision=revision, + library_name="fastai", + library_version=get_fastai_version(), + ) + else: + storage_folder = repo_id + + _check_fastai_fastcore_pyproject_versions(storage_folder) + + from fastai.learner import load_learner # type: ignore + + return load_learner(os.path.join(storage_folder, "model.pkl")) + + +@validate_hf_hub_args +def push_to_hub_fastai( + learner, + *, + repo_id: str, + commit_message: str = "Push FastAI model using huggingface_hub.", + private: Optional[bool] = None, + token: Optional[str] = None, + config: Optional[dict] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, + api_endpoint: Optional[str] = None, +): + """ + Upload learner checkpoint files to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + learner (`Learner`): + The `fastai.Learner' you'd like to push to the Hub. + repo_id (`str`): + The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de'). + commit_message (`str`, *optional*): + Message to commit while pushing. Will default to :obj:`"add model"`. + private (`bool`, *optional*): + Whether or not the repository created should be private. + If `None` (default), will default to been public except if the organization's default is private. + token (`str`, *optional*): + The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to + the default branch as specified in your repository, which + defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. + Defaults to `False`. + api_endpoint (`str`, *optional*): + The API endpoint to use when pushing the model to the hub. + allow_patterns (`list[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`list[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`list[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + + Returns: + The url of the commit of your model in the given repository. + + > [!TIP] + > Raises the following error: + > + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if the user is not log on to the Hugging Face Hub. + """ + _check_fastai_fastcore_versions() + api = HfApi(endpoint=api_endpoint) + repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + _save_pretrained_fastai(learner, saved_path, config=config) + return api.upload_folder( + repo_id=repo_id, + token=token, + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py b/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py new file mode 100644 index 0000000000000000000000000000000000000000..df38649ef1f08d91a8c9a77d021a8ea9b0701202 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py @@ -0,0 +1,1966 @@ +import copy +import errno +import os +import re +import shutil +import stat +import time +import uuid +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, BinaryIO, Literal, NoReturn, Optional, Union, overload +from urllib.parse import quote, urlparse + +import httpx +from tqdm.auto import tqdm as base_tqdm + +from . import constants +from ._local_folder import get_local_download_paths, read_download_metadata, write_download_metadata +from .errors import ( + FileMetadataError, + GatedRepoError, + HfHubHTTPError, + LocalEntryNotFoundError, + RemoteEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from .utils import ( + OfflineModeIsEnabled, + SoftTemporaryDirectory, + WeakFileLock, + XetFileData, + build_hf_headers, + hf_raise_for_status, + logging, + parse_xet_file_data_from_response, + refresh_xet_connection_info, + tqdm, + validate_hf_hub_args, +) +from .utils._http import ( + _DEFAULT_RETRY_ON_EXCEPTIONS, + _DEFAULT_RETRY_ON_STATUS_CODES, + _adjust_range_header, + http_backoff, + http_stream_backoff, +) +from .utils._runtime import is_xet_available +from .utils._typing import HTTP_METHOD_T +from .utils.sha import sha_fileobj +from .utils.tqdm import _get_progress_bar_context + + +logger = logging.get_logger(__name__) + +# Return value when trying to load a file from cache but the file does not exist in the distant repo. +_CACHED_NO_EXIST = object() +_CACHED_NO_EXIST_T = Any + +# Regex to get filename from a "Content-Disposition" header for CDN-served files +HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P.*?)";') + +# Regex to check if the revision IS directly a commit_hash +REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$") + +# Regex to check if the file etag IS a valid sha256 +REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$") + +_are_symlinks_supported_in_dir: dict[str, bool] = {} + +# Internal retry timeout for metadata fetch when no local file exists +_ETAG_RETRY_TIMEOUT = 60 + + +def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool: + """Return whether the symlinks are supported on the machine. + + Since symlinks support can change depending on the mounted disk, we need to check + on the precise cache folder. By default, the default HF cache directory is checked. + + Args: + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + + Returns: [bool] Whether symlinks are supported in the directory. + """ + # Defaults to HF cache + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique + + # Check symlink compatibility only once (per cache directory) at first time use + if cache_dir not in _are_symlinks_supported_in_dir: + _are_symlinks_supported_in_dir[cache_dir] = True + + os.makedirs(cache_dir, exist_ok=True) + with SoftTemporaryDirectory(dir=cache_dir) as tmpdir: + src_path = Path(tmpdir) / "dummy_file_src" + src_path.touch() + dst_path = Path(tmpdir) / "dummy_file_dst" + + # Relative source path as in `_create_symlink`` + relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path)) + try: + os.symlink(relative_src, dst_path) + except OSError: + # Likely running on Windows + _are_symlinks_supported_in_dir[cache_dir] = False + + if not constants.HF_HUB_DISABLE_SYMLINKS_WARNING: + message = ( + "`huggingface_hub` cache-system uses symlinks by default to" + " efficiently store duplicated files but your machine does not" + f" support them in {cache_dir}. Caching files will still work" + " but in a degraded version that might require more space on" + " your disk. This warning can be disabled by setting the" + " `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For" + " more details, see" + " https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations." + ) + if os.name == "nt": + message += ( + "\nTo support symlinks on Windows, you either need to" + " activate Developer Mode or to run Python as an" + " administrator. In order to activate developer mode," + " see this article:" + " https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development" + ) + warnings.warn(message) + + return _are_symlinks_supported_in_dir[cache_dir] + + +@dataclass(frozen=True) +class HfFileMetadata: + """Data structure containing information about a file versioned on the Hub. + + Returned by [`get_hf_file_metadata`] based on a URL. + + Args: + commit_hash (`str`, *optional*): + The commit_hash related to the file. + etag (`str`, *optional*): + Etag of the file on the server. + location (`str`): + Location where to download the file. Can be a Hub url or not (CDN). + size (`size`): + Size of the file. In case of an LFS file, contains the size of the actual + LFS file, not the pointer. + xet_file_data (`XetFileData`, *optional*): + Xet information for the file. This is only set if the file is stored using Xet storage. + """ + + commit_hash: Optional[str] + etag: Optional[str] + location: str + size: Optional[int] + xet_file_data: Optional[XetFileData] + + +@dataclass +class DryRunFileInfo: + """Information returned when performing a dry run of a file download. + + Returned by [`hf_hub_download`] when `dry_run=True`. + + Args: + commit_hash (`str`): + The commit_hash related to the file. + file_size (`int`): + Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer. + filename (`str`): + Name of the file in the repo. + is_cached (`bool`): + Whether the file is already cached locally. + will_download (`bool`): + Whether the file will be downloaded if `hf_hub_download` is called with `dry_run=False`. + In practice, will_download is `True` if the file is not cached or if `force_download=True`. + """ + + commit_hash: str + file_size: int + filename: str + local_path: str + is_cached: bool + will_download: bool + + +@validate_hf_hub_args +def hf_hub_url( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + endpoint: Optional[str] = None, +) -> str: + """Construct the URL of a file from the given information. + + The resolved address can either be a huggingface.co-hosted url, or a link to + Cloudfront (a Content Delivery Network, or CDN) for large files which are + more than a few MBs. + + Args: + repo_id (`str`): + A namespace (user or an organization) name and a repo name separated + by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + + Example: + + ```python + >>> from huggingface_hub import hf_hub_url + + >>> hf_hub_url( + ... repo_id="julien-c/EsperBERTo-small", filename="pytorch_model.bin" + ... ) + 'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin' + ``` + + > [!TIP] + > Notes: + > + > Cloudfront is replicated over the globe so downloads are way faster for + > the end user (and it also lowers our bandwidth costs). + > + > Cloudfront aggressively caches files by default (default TTL is 24 + > hours), however this is not an issue here because we implement a + > git-based versioning system on huggingface.co, which means that we store + > the files on S3/Cloudfront in a content-addressable way (i.e., the file + > name is its hash). Using content-addressable filenames means cache can't + > ever be stale. + > + > In terms of client-side caching from this library, we base our caching + > on the objects' entity tag (`ETag`), which is an identifier of a + > specific version of a resource [1]_. An object's ETag is: its git-sha1 + > if stored in git, or its sha256 if stored in git-lfs. + + References: + + - [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag + """ + if subfolder == "": + subfolder = None + if subfolder is not None: + filename = f"{subfolder}/{filename}" + + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + + if revision is None: + revision = constants.DEFAULT_REVISION + url = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename) + ) + # Update endpoint if provided + if endpoint is not None and url.startswith(constants.ENDPOINT): + url = endpoint + url[len(constants.ENDPOINT) :] + return url + + +def _httpx_follow_relative_redirects( + method: HTTP_METHOD_T, url: str, *, retry_on_errors: bool = False, **httpx_kwargs +) -> httpx.Response: + """Perform an HTTP request with backoff and follow relative redirects only. + + This is useful to follow a redirection to a renamed repository without following redirection to a CDN. + + A backoff mechanism retries the HTTP call on errors (429, 5xx, timeout, network errors). + + Args: + method (`str`): + HTTP method, such as 'GET' or 'HEAD'. + url (`str`): + The URL of the resource to fetch. + retry_on_errors (`bool`, *optional*, defaults to `False`): + Whether to retry on errors. If False, no retry is performed (fast fallback to local cache). + If True, uses default retry behavior (429, 5xx, timeout, network errors). + **httpx_kwargs (`dict`, *optional*): + Params to pass to `httpx.request`. + """ + # if `retry_on_errors=False`, disable all retries for fast fallback to cache + no_retry_kwargs: dict[str, Any] = ( + {} if retry_on_errors else {"retry_on_exceptions": (), "retry_on_status_codes": ()} + ) + + while True: + response = http_backoff( + method=method, + url=url, + **httpx_kwargs, + follow_redirects=False, + **no_retry_kwargs, + ) + hf_raise_for_status(response) + + # Check if response is a relative redirect + if 300 <= response.status_code <= 399: + parsed_target = urlparse(response.headers["Location"]) + if parsed_target.netloc == "": + # Relative redirect -> update URL and retry + url = urlparse(url)._replace(path=parsed_target.path).geturl() + continue + + # Break if no relative redirect + break + + return response + + +def _get_file_length_from_http_response(response: httpx.Response) -> Optional[int]: + """ + Get the length of the file from the HTTP response headers. + + This function extracts the file size from the HTTP response headers, either from the + `Content-Range` or `Content-Length` header, if available (in that order). + + Args: + response (`httpx.Response`): + The HTTP response object. + + Returns: + `int` or `None`: The length of the file in bytes, or None if not available. + """ + + # If HTTP response contains compressed body (e.g. gzip), the `Content-Length` header will + # contain the length of the compressed body, not the uncompressed file size. + # And at the start of transmission there's no way to know the uncompressed file size for gzip, + # thus we return None in that case. + content_encoding = response.headers.get("Content-Encoding", "identity").lower() + if content_encoding != "identity": + # gzip/br/deflate/zstd etc + return None + + content_range = response.headers.get("Content-Range") + if content_range is not None: + return int(content_range.rsplit("/")[-1]) + + content_length = response.headers.get("Content-Length") + if content_length is not None: + return int(content_length) + + return None + + +@validate_hf_hub_args +def http_get( + url: str, + temp_file: BinaryIO, + *, + resume_size: int = 0, + headers: Optional[dict[str, Any]] = None, + expected_size: Optional[int] = None, + displayed_filename: Optional[str] = None, + tqdm_class: Optional[type[base_tqdm]] = None, + _nb_retries: int = 5, + _tqdm_bar: Optional[tqdm] = None, +) -> None: + """ + Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub. + + If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely a + transient error (network outage?). We log a warning message and try to resume the download a few times before + giving up. The method gives up after 5 attempts if no new data has being received from the server. + + Args: + url (`str`): + The URL of the file to download. + temp_file (`BinaryIO`): + The file-like object where to save the file. + resume_size (`int`, *optional*): + The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a + positive number, the download will resume at the given position. + headers (`dict`, *optional*): + Dictionary of HTTP Headers to send with the request. + expected_size (`int`, *optional*): + The expected size of the file to download. If set, the download will raise an error if the size of the + received content is different from the expected one. + displayed_filename (`str`, *optional*): + The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If + not set, the filename is guessed from the URL or the `Content-Disposition` header. + """ + if expected_size is not None and resume_size == expected_size: + # If the file is already fully downloaded, we don't need to download it again. + return + + initial_headers = headers + headers = copy.deepcopy(headers) or {} + if resume_size > 0: + headers["Range"] = _adjust_range_header(headers.get("Range"), resume_size) + elif expected_size and expected_size > constants.MAX_HTTP_DOWNLOAD_SIZE: + # Any files over 50GB will not be available through basic http requests. + raise ValueError( + "The file is too large to be downloaded using the regular download method. " + " Install `hf_xet` with `pip install hf_xet` for xet-powered downloads." + ) + + with http_stream_backoff( + method="GET", + url=url, + headers=headers, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + retry_on_exceptions=(), + retry_on_status_codes=(429,), + ) as response: + hf_raise_for_status(response) + + # If we requested a Range but got 200 back, the server ignored our Range header + # (e.g. CloudFront with Accept-Encoding: gzip). Reset file to avoid corruption. + if resume_size > 0 and response.status_code == 200: + temp_file.seek(0) + temp_file.truncate() + resume_size = 0 + + total: Optional[int] = _get_file_length_from_http_response(response) + + if displayed_filename is None: + displayed_filename = url + content_disposition = response.headers.get("Content-Disposition") + if content_disposition is not None: + match = HEADER_FILENAME_PATTERN.search(content_disposition) + if match is not None: + # Means file is on CDN + displayed_filename = match.groupdict()["filename"] + + # Truncate filename if too long to display + if len(displayed_filename) > 40: + displayed_filename = f"(…){displayed_filename[-40:]}" + + consistency_error_message = ( + f"Consistency check failed: file should be of size {expected_size} but has size" + f" {{actual_size}} ({displayed_filename}).\nThis is usually due to network issues while downloading the file." + " Please retry with `force_download=True`." + ) + progress_cm = _get_progress_bar_context( + desc=displayed_filename, + log_level=logger.getEffectiveLevel(), + total=total, + initial=resume_size, + name="huggingface_hub.http_get", + tqdm_class=tqdm_class, + _tqdm_bar=_tqdm_bar, + ) + + with progress_cm as progress: + new_resume_size = resume_size + try: + for chunk in response.iter_bytes(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + new_resume_size += len(chunk) + # Some data has been downloaded from the server so we reset the number of retries. + _nb_retries = 5 + except (httpx.ConnectError, httpx.TimeoutException) as e: + # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely + # a transient error (network outage?). We log a warning message and try to resume the download a few times + # before giving up. Tre retry mechanism is basic but should be enough in most cases. + if _nb_retries <= 0: + logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e)) + raise + logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e)) + time.sleep(1) + return http_get( + url=url, + temp_file=temp_file, + resume_size=new_resume_size, + headers=initial_headers, + expected_size=expected_size, + tqdm_class=tqdm_class, + _nb_retries=_nb_retries - 1, + _tqdm_bar=_tqdm_bar, + ) + + if expected_size is not None and expected_size != temp_file.tell(): + raise EnvironmentError( + consistency_error_message.format( + actual_size=temp_file.tell(), + ) + ) + + +def xet_get( + *, + incomplete_path: Path, + xet_file_data: XetFileData, + headers: dict[str, str], + expected_size: Optional[int] = None, + displayed_filename: Optional[str] = None, + tqdm_class: Optional[type[base_tqdm]] = None, + _tqdm_bar: Optional[tqdm] = None, +) -> None: + """ + Download a file using Xet storage service. + + Args: + incomplete_path (`Path`): + The path to the file to download. + xet_file_data (`XetFileData`): + The file metadata needed to make the request to the xet storage service. + headers (`dict[str, str]`): + The headers to send to the xet storage service. + expected_size (`int`, *optional*): + The expected size of the file to download. If set, the download will raise an error if the size of the + received content is different from the expected one. + displayed_filename (`str`, *optional*): + The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If + not set, the filename is guessed from the URL or the `Content-Disposition` header. + + **How it works:** + The file download system uses Xet storage, which is a content-addressable storage system that breaks files into chunks + for efficient storage and transfer. + + `hf_xet.download_files` manages downloading files by: + - Taking a list of files to download (each with its unique content hash) + - Connecting to a storage server (CAS server) that knows how files are chunked + - Using authentication to ensure secure access + - Providing progress updates during download + + Authentication works by regularly refreshing access tokens through `refresh_xet_connection_info` to maintain a valid + connection to the storage server. + + The download process works like this: + 1. Create a local cache folder at `~/.cache/huggingface/xet/chunk-cache` to store reusable file chunks + 2. Download files in parallel: + 2.1. Prepare to write the file to disk + 2.2. Ask the server "how is this file split into chunks?" using the file's unique hash + The server responds with: + - Which chunks make up the complete file + - Where each chunk can be downloaded from + 2.3. For each needed chunk: + - Checks if we already have it in our local cache + - If not, download it from cloud storage (S3) + - Save it to cache for future use + - Assemble the chunks in order to recreate the original file + + """ + try: + from hf_xet import PyXetDownloadInfo, download_files # type: ignore[no-redef] + except ImportError: + raise ValueError( + "To use optimized download using Xet storage, you need to install the hf_xet package. " + 'Try `pip install "huggingface_hub[hf_xet]"` or `pip install hf_xet`.' + ) + + connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) + + def token_refresher() -> tuple[str, int]: + connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) + if connection_info is None: + raise ValueError("Failed to refresh token using xet metadata.") + return connection_info.access_token, connection_info.expiration_unix_epoch + + xet_download_info = [ + PyXetDownloadInfo( + destination_path=str(incomplete_path.absolute()), hash=xet_file_data.file_hash, file_size=expected_size + ) + ] + + if not displayed_filename: + displayed_filename = incomplete_path.name + + # Truncate filename if too long to display + if len(displayed_filename) > 40: + displayed_filename = f"{displayed_filename[:40]}(…)" + + progress_cm = _get_progress_bar_context( + desc=displayed_filename, + log_level=logger.getEffectiveLevel(), + total=expected_size, + initial=0, + name="huggingface_hub.xet_get", + tqdm_class=tqdm_class, + _tqdm_bar=_tqdm_bar, + ) + + with progress_cm as progress: + + def progress_updater(progress_bytes: float): + progress.update(progress_bytes) + + download_files( + xet_download_info, + endpoint=connection_info.endpoint, + token_info=(connection_info.access_token, connection_info.expiration_unix_epoch), + token_refresher=token_refresher, + progress_updater=[progress_updater], + ) + + +def _normalize_etag(etag: Optional[str]) -> Optional[str]: + """Normalize ETag HTTP header, so it can be used to create nice filepaths. + + The HTTP spec allows two forms of ETag: + ETag: W/"" + ETag: "" + + For now, we only expect the second form from the server, but we want to be future-proof so we support both. For + more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428. + + Args: + etag (`str`, *optional*): HTTP header + + Returns: + `str` or `None`: string that can be used as a nice directory name. + Returns `None` if input is None. + """ + if etag is None: + return None + return etag.lstrip("W/").strip('"') + + +def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None: + """Alias method used in `transformers` conversion script.""" + return _create_symlink(src=src, dst=dst, new_blob=new_blob) + + +def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None: + """Create a symbolic link named dst pointing to src. + + By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages: + - If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will + not break. + - Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when + changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398, + https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228. + NOTE: The issue with absolute paths doesn't happen on admin mode. + When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created. + This happens when paths are not on the same volume. In that case, we use absolute paths. + + + The result layout looks something like + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + + If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by + having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file + (`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing + cache, the file is duplicated on the disk. + + In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`. + The warning message can be disabled with the `DISABLE_SYMLINKS_WARNING` environment variable. + """ + try: + os.remove(dst) + except OSError: + pass + + abs_src = os.path.abspath(os.path.expanduser(src)) + abs_dst = os.path.abspath(os.path.expanduser(dst)) + abs_dst_folder = os.path.dirname(abs_dst) + + # Use relative_dst in priority + try: + relative_src = os.path.relpath(abs_src, abs_dst_folder) + except ValueError: + # Raised on Windows if src and dst are not on the same volume. This is the case when creating a symlink to a + # local_dir instead of within the cache directory. + # See https://docs.python.org/3/library/os.path.html#os.path.relpath + relative_src = None + + try: + commonpath = os.path.commonpath([abs_src, abs_dst]) + _support_symlinks = are_symlinks_supported(commonpath) + except ValueError: + # Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos. + # See https://docs.python.org/3/library/os.path.html#os.path.commonpath + _support_symlinks = os.name != "nt" + except PermissionError: + # Permission error means src and dst are not in the same volume (e.g. destination path has been provided + # by the user via `local_dir`. Let's test symlink support there) + _support_symlinks = are_symlinks_supported(abs_dst_folder) + except OSError as e: + # OS error (errno=30) means that the commonpath is readonly on Linux/MacOS. + if e.errno == errno.EROFS: + _support_symlinks = are_symlinks_supported(abs_dst_folder) + else: + raise + + # Symlinks are supported => let's create a symlink. + if _support_symlinks: + src_rel_or_abs = relative_src or abs_src + logger.debug(f"Creating pointer from {src_rel_or_abs} to {abs_dst}") + try: + os.symlink(src_rel_or_abs, abs_dst) + return + except FileExistsError: + if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src): + # `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has + # been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing. + return + else: + # Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and + # `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception. + raise + except PermissionError: + # Permission error means src and dst are not in the same volume (e.g. download to local dir) and symlink + # is supported on both volumes but not between them. Let's just make a hard copy in that case. + pass + + # Symlinks are not supported => let's move or copy the file. + if new_blob: + logger.debug(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}") + shutil.move(abs_src, abs_dst, copy_function=_copy_no_matter_what) + else: + logger.debug(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}") + shutil.copyfile(abs_src, abs_dst) + + +def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None: + """Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash. + + Does nothing if `revision` is already a proper `commit_hash` or reference is already cached. + """ + if revision != commit_hash: + ref_path = Path(storage_folder) / "refs" / revision + ref_path.parent.mkdir(parents=True, exist_ok=True) + if not ref_path.exists() or commit_hash != ref_path.read_text(): + # Update ref only if has been updated. Could cause useless error in case + # repo is already cached and user doesn't have write access to cache folder. + # See https://github.com/huggingface/huggingface_hub/issues/1216. + ref_path.write_text(commit_hash) + + +@validate_hf_hub_args +def repo_folder_name(*, repo_id: str, repo_type: str) -> str: + """Return a serialized version of a hf.co repo name and type, safe for disk storage + as a single non-nested folder. + + Example: models--julien-c--EsperBERTo-small + """ + # remove all `/` occurrences to correctly convert repo to directory name + parts = [f"{repo_type}s", *repo_id.split("/")] + return constants.REPO_ID_SEPARATOR.join(parts) + + +def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None: + """Check disk usage and log a warning if there is not enough disk space to download the file. + + Args: + expected_size (`int`): + The expected size of the file in bytes. + target_dir (`str`): + The directory where the file will be stored after downloading. + """ + + target_dir = Path(target_dir) # format as `Path` + for path in [target_dir] + list(target_dir.parents): # first check target_dir, then each parents one by one + try: + target_dir_free = shutil.disk_usage(path).free + if target_dir_free < expected_size: + warnings.warn( + "Not enough free disk space to download the file. " + f"The expected file size is: {expected_size / 1e6:.2f} MB. " + f"The target location {target_dir} only has {target_dir_free / 1e6:.2f} MB free disk space." + ) + return + except OSError: # raise on anything: file does not exist or space disk cannot be checked + pass + + +@overload +def hf_hub_download( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + user_agent: Union[dict, str, None] = None, + force_download: bool = False, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: Literal[False] = False, +) -> str: ... + + +@overload +def hf_hub_download( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + user_agent: Union[dict, str, None] = None, + force_download: bool = False, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: Literal[True] = True, +) -> DryRunFileInfo: ... + + +@overload +def hf_hub_download( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + user_agent: Union[dict, str, None] = None, + force_download: bool = False, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: bool = False, +) -> Union[str, DryRunFileInfo]: ... + + +@validate_hf_hub_args +def hf_hub_download( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + user_agent: Union[dict, str, None] = None, + force_download: bool = False, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: bool = False, +) -> Union[str, DryRunFileInfo]: + """Download a given file if it's not already present in the local cache. + + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + ``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + ``` + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files. While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the model repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded file will be placed under this directory. + user_agent (`dict`, `str`, *optional*): + The user-agent info in the form of a dictionary or a string. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in + the local cache. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + token (`str`, `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If a string, it's used as the authentication token. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + headers (`dict`, *optional*): + Additional headers to be sent with the request. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + dry_run (`bool`, *optional*, defaults to `False`): + If `True`, perform a dry run without actually downloading the file. Returns a + [`DryRunFileInfo`] object containing information about what would be downloaded. + + Returns: + `str` or [`DryRunFileInfo`]: + - If `dry_run=False`: Local path of file or if networking is off, last version of file cached on disk. + - If `dry_run=True`: A [`DryRunFileInfo`] object containing download information. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`~utils.RemoteEntryNotFoundError`] + If the file to download cannot be found. + [`~utils.LocalEntryNotFoundError`] + If network is disabled or unavailable and file is not found in cache. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` but the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + If ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If some parameter value is invalid. + + """ + if constants.HF_HUB_ETAG_TIMEOUT != constants.DEFAULT_ETAG_TIMEOUT: + # Respect environment variable above user value + etag_timeout = constants.HF_HUB_ETAG_TIMEOUT + + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + if revision is None: + revision = constants.DEFAULT_REVISION + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + if isinstance(local_dir, Path): + local_dir = str(local_dir) + + if subfolder == "": + subfolder = None + if subfolder is not None: + # This is used to create a URL, and not a local path, hence the forward slash. + filename = f"{subfolder}/{filename}" + + if repo_type is None: + repo_type = "model" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") + + hf_headers = build_hf_headers( + token=token, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + headers=headers, + ) + + if local_dir is not None: + return _hf_hub_download_to_local_dir( + # Destination + local_dir=local_dir, + # File info + repo_id=repo_id, + repo_type=repo_type, + filename=filename, + revision=revision, + # HTTP info + endpoint=endpoint, + etag_timeout=etag_timeout, + headers=hf_headers, + token=token, + # Additional options + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + tqdm_class=tqdm_class, + dry_run=dry_run, + ) + else: + return _hf_hub_download_to_cache_dir( + # Destination + cache_dir=cache_dir, + # File info + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + # HTTP info + endpoint=endpoint, + etag_timeout=etag_timeout, + headers=hf_headers, + token=token, + # Additional options + local_files_only=local_files_only, + force_download=force_download, + tqdm_class=tqdm_class, + dry_run=dry_run, + ) + + +def _hf_hub_download_to_cache_dir( + *, + # Destination + cache_dir: str, + # File info + repo_id: str, + filename: str, + repo_type: str, + revision: str, + # HTTP info + endpoint: Optional[str], + etag_timeout: float, + headers: dict[str, str], + token: Optional[Union[bool, str]], + # Additional options + local_files_only: bool, + force_download: bool, + tqdm_class: Optional[type[base_tqdm]], + dry_run: bool, +) -> Union[str, DryRunFileInfo]: + """Download a given file to a cache folder, if not already present. + + Method should not be called directly. Please use `hf_hub_download` instead. + """ + locks_dir = os.path.join(cache_dir, ".locks") + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) + + # cross-platform transcription of filename, to be used as a local file path. + relative_filename = os.path.join(*filename.split("/")) + if os.name == "nt": + if relative_filename.startswith("..\\") or "\\..\\" in relative_filename: + raise ValueError( + f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository" + " owner to rename this file." + ) + + # if user provides a commit_hash and they already have the file on disk, shortcut everything. + if REGEX_COMMIT_HASH.match(revision): + pointer_path = _get_pointer_path(storage_folder, revision, relative_filename) + if os.path.exists(pointer_path): + if dry_run: + return DryRunFileInfo( + commit_hash=revision, + file_size=os.path.getsize(pointer_path), + filename=filename, + is_cached=True, + local_path=pointer_path, + will_download=force_download, + ) + if not force_download: + return pointer_path + + # Try to get metadata (etag, commit_hash, url, size) from the server. + # If we can't, a HEAD request error is returned. + (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = _get_metadata_or_catch_error( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + etag_timeout=etag_timeout, + headers=headers, + token=token, + local_files_only=local_files_only, + storage_folder=storage_folder, + relative_filename=relative_filename, + ) + + # etag can be None for several reasons: + # 1. we passed local_files_only. + # 2. we don't have a connection + # 3. Hub is down (HTTP 500, 503, 504) + # 4. repo is not found -for example private or gated- and invalid/missing token sent + # 5. Hub is blocked by a firewall or proxy is not set correctly. + # => Try to get the last downloaded one from the specified revision. + # + # If the specified revision is a commit hash, look inside "snapshots". + # If the specified revision is a branch or tag, look inside "refs". + if head_call_error is not None: + # Couldn't make a HEAD call => let's try to find a local file + if not force_download: + commit_hash = None + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + ref_path = os.path.join(storage_folder, "refs", revision) + if os.path.isfile(ref_path): + with open(ref_path) as f: + commit_hash = f.read() + + # Return pointer file if exists + if commit_hash is not None: + pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) + if os.path.exists(pointer_path): + if dry_run: + return DryRunFileInfo( + commit_hash=commit_hash, + file_size=os.path.getsize(pointer_path), + filename=filename, + is_cached=True, + local_path=pointer_path, + will_download=force_download, + ) + if not force_download: + return pointer_path + + if isinstance(head_call_error, _DEFAULT_RETRY_ON_EXCEPTIONS) or ( + isinstance(head_call_error, HfHubHTTPError) + and head_call_error.response.status_code in _DEFAULT_RETRY_ON_STATUS_CODES + ): + logger.info("No local file found. Retrying..") + (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = ( + _get_metadata_or_catch_error( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + etag_timeout=_ETAG_RETRY_TIMEOUT, + headers=headers, + token=token, + local_files_only=local_files_only, + storage_folder=storage_folder, + relative_filename=relative_filename, + retry_on_errors=True, + ) + ) + + # If still error, raise + if head_call_error is not None: + _raise_on_head_call_error(head_call_error, force_download, local_files_only) + + # From now on, etag, commit_hash, url and size are not None. + assert etag is not None, "etag must have been retrieved from server" + assert commit_hash is not None, "commit_hash must have been retrieved from server" + assert url_to_download is not None, "file location must have been retrieved from server" + assert expected_size is not None, "expected_size must have been retrieved from server" + blob_path = os.path.join(storage_folder, "blobs", etag) + pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) + + if dry_run: + is_cached = os.path.exists(pointer_path) or os.path.exists(blob_path) + return DryRunFileInfo( + commit_hash=commit_hash, + file_size=expected_size, + filename=filename, + is_cached=is_cached, + local_path=pointer_path, + will_download=force_download or not is_cached, + ) + + os.makedirs(os.path.dirname(blob_path), exist_ok=True) + os.makedirs(os.path.dirname(pointer_path), exist_ok=True) + + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash) + + # Prevent parallel downloads of the same file with a lock. + # etag could be duplicated across repos, + lock_path = os.path.join(locks_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type), f"{etag}.lock") + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix. + if ( + os.name == "nt" + and len(os.path.abspath(lock_path)) > 255 + and not os.path.abspath(lock_path).startswith("\\\\?\\") + ): + lock_path = "\\\\?\\" + os.path.abspath(lock_path) + + if ( + os.name == "nt" + and len(os.path.abspath(blob_path)) > 255 + and not os.path.abspath(blob_path).startswith("\\\\?\\") + ): + blob_path = "\\\\?\\" + os.path.abspath(blob_path) + + Path(lock_path).parent.mkdir(parents=True, exist_ok=True) + + # pointer already exists -> immediate return + if not force_download and os.path.exists(pointer_path): + return pointer_path + + # Blob exists but pointer must be (safely) created -> take the lock + if not force_download and os.path.exists(blob_path): + with WeakFileLock(lock_path): + if not os.path.exists(pointer_path): + _create_symlink(blob_path, pointer_path, new_blob=False) + return pointer_path + + # Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache) + + with WeakFileLock(lock_path): + _download_to_tmp_and_move( + incomplete_path=Path(blob_path + ".incomplete"), + destination_path=Path(blob_path), + url_to_download=url_to_download, + headers=headers, + expected_size=expected_size, + filename=filename, + force_download=force_download, + etag=etag, + xet_file_data=xet_file_data, + tqdm_class=tqdm_class, + ) + if not os.path.exists(pointer_path): + _create_symlink(blob_path, pointer_path, new_blob=True) + + return pointer_path + + +def _hf_hub_download_to_local_dir( + *, + # Destination + local_dir: Union[str, Path], + # File info + repo_id: str, + repo_type: str, + filename: str, + revision: str, + # HTTP info + endpoint: Optional[str], + etag_timeout: float, + headers: dict[str, str], + token: Union[bool, str, None], + # Additional options + cache_dir: str, + force_download: bool, + local_files_only: bool, + tqdm_class: Optional[type[base_tqdm]], + dry_run: bool, +) -> Union[str, DryRunFileInfo]: + """Download a given file to a local folder, if not already present. + + Method should not be called directly. Please use `hf_hub_download` instead. + """ + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it as an extended path by using the "\\?\" prefix. + if os.name == "nt" and len(os.path.abspath(local_dir)) > 255: + local_dir = "\\\\?\\" + os.path.abspath(local_dir) + local_dir = Path(local_dir) + paths = get_local_download_paths(local_dir=local_dir, filename=filename) + local_metadata = read_download_metadata(local_dir=local_dir, filename=filename) + + # Local file exists + metadata exists + commit_hash matches => return file + if ( + REGEX_COMMIT_HASH.match(revision) + and paths.file_path.is_file() + and local_metadata is not None + and local_metadata.commit_hash == revision + ): + local_file = str(paths.file_path) + if dry_run: + return DryRunFileInfo( + commit_hash=revision, + file_size=os.path.getsize(local_file), + filename=filename, + is_cached=True, + local_path=local_file, + will_download=force_download, + ) + if not force_download: + return local_file + + # Local file doesn't exist or commit_hash doesn't match => we need the etag + (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = _get_metadata_or_catch_error( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + etag_timeout=etag_timeout, + headers=headers, + token=token, + local_files_only=local_files_only, + ) + + if head_call_error is not None: + # No HEAD call but local file exists => default to local file + if paths.file_path.is_file(): + if dry_run or not force_download: + logger.warning( + f"Couldn't access the Hub to check for update but local file already exists. Defaulting to existing file. (error: {head_call_error})" + ) + local_path = str(paths.file_path) + if dry_run and local_metadata is not None: + return DryRunFileInfo( + commit_hash=local_metadata.commit_hash, + file_size=os.path.getsize(local_path), + filename=filename, + is_cached=True, + local_path=local_path, + will_download=force_download, + ) + if not force_download: + return local_path + elif not force_download: + if isinstance(head_call_error, _DEFAULT_RETRY_ON_EXCEPTIONS) or ( + isinstance(head_call_error, HfHubHTTPError) + and head_call_error.response.status_code in _DEFAULT_RETRY_ON_STATUS_CODES + ): + logger.info("No local file found. Retrying..") + (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = ( + _get_metadata_or_catch_error( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + endpoint=endpoint, + etag_timeout=_ETAG_RETRY_TIMEOUT, + headers=headers, + token=token, + local_files_only=local_files_only, + retry_on_errors=True, + ) + ) + + # If still error, raise + if head_call_error is not None: + _raise_on_head_call_error(head_call_error, force_download, local_files_only) + + # From now on, etag, commit_hash, url and size are not None. + assert etag is not None, "etag must have been retrieved from server" + assert commit_hash is not None, "commit_hash must have been retrieved from server" + assert url_to_download is not None, "file location must have been retrieved from server" + assert expected_size is not None, "expected_size must have been retrieved from server" + + # Local file exists => check if it's up-to-date + if not force_download and paths.file_path.is_file(): + # etag matches => update metadata and return file + if local_metadata is not None and local_metadata.etag == etag: + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + if dry_run: + return DryRunFileInfo( + commit_hash=commit_hash, + file_size=expected_size, + filename=filename, + is_cached=True, + local_path=str(paths.file_path), + will_download=False, + ) + return str(paths.file_path) + + # metadata is outdated + etag is a sha256 + # => means it's an LFS file (large) + # => let's compute local hash and compare + # => if match, update metadata and return file + if local_metadata is None and REGEX_SHA256.match(etag) is not None: + with open(paths.file_path, "rb") as f: + file_hash = sha_fileobj(f).hex() + if file_hash == etag: + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + if dry_run: + return DryRunFileInfo( + commit_hash=commit_hash, + file_size=expected_size, + filename=filename, + is_cached=True, + local_path=str(paths.file_path), + will_download=False, + ) + return str(paths.file_path) + + # Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache) + + # If we are lucky enough, the file is already in the cache => copy it + if not force_download: + cached_path = try_to_load_from_cache( + repo_id=repo_id, + filename=filename, + cache_dir=cache_dir, + revision=commit_hash, + repo_type=repo_type, + ) + if isinstance(cached_path, str): + with WeakFileLock(paths.lock_path): + paths.file_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(cached_path, paths.file_path) + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + if dry_run: + return DryRunFileInfo( + commit_hash=commit_hash, + file_size=expected_size, + filename=filename, + is_cached=True, + local_path=str(paths.file_path), + will_download=False, + ) + return str(paths.file_path) + + if dry_run: + is_cached = paths.file_path.is_file() + return DryRunFileInfo( + commit_hash=commit_hash, + file_size=expected_size, + filename=filename, + is_cached=is_cached, + local_path=str(paths.file_path), + will_download=force_download or not is_cached, + ) + + # Otherwise, let's download the file! + with WeakFileLock(paths.lock_path): + paths.file_path.unlink(missing_ok=True) # delete outdated file first + _download_to_tmp_and_move( + incomplete_path=paths.incomplete_path(etag), + destination_path=paths.file_path, + url_to_download=url_to_download, + headers=headers, + expected_size=expected_size, + filename=filename, + force_download=force_download, + etag=etag, + xet_file_data=xet_file_data, + tqdm_class=tqdm_class, + ) + + write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) + return str(paths.file_path) + + +@validate_hf_hub_args +def try_to_load_from_cache( + repo_id: str, + filename: str, + cache_dir: Union[str, Path, None] = None, + revision: Optional[str] = None, + repo_type: Optional[str] = None, +) -> Union[str, _CACHED_NO_EXIST_T, None]: + """ + Explores the cache to return the latest cached file for a given revision if found. + + This function will not raise any exception if the file in not cached. + + Args: + cache_dir (`str` or `os.PathLike`): + The folder where the cached files lie. + repo_id (`str`): + The ID of the repo on huggingface.co. + filename (`str`): + The filename to look for inside `repo_id`. + revision (`str`, *optional*): + The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is + provided either. + repo_type (`str`, *optional*): + The type of the repository. Will default to `"model"`. + + Returns: + `Optional[str]` or `_CACHED_NO_EXIST`: + Will return `None` if the file was not cached. Otherwise: + - The exact path to the cached file if it's found in the cache + - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was + cached. + + Example: + + ```python + from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST + + filepath = try_to_load_from_cache() + if isinstance(filepath, str): + # file exists and is cached + ... + elif filepath is _CACHED_NO_EXIST: + # non-existence of file is cached + ... + else: + # file is not cached + ... + ``` + """ + if revision is None: + revision = "main" + if repo_type is None: + repo_type = "model" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + + object_id = repo_id.replace("/", "--") + repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}") + if not os.path.isdir(repo_cache): + # No cache for this model + return None + + refs_dir = os.path.join(repo_cache, "refs") + snapshots_dir = os.path.join(repo_cache, "snapshots") + no_exist_dir = os.path.join(repo_cache, ".no_exist") + + # Resolve refs (for instance to convert main to the associated commit sha) + if os.path.isdir(refs_dir): + revision_file = os.path.join(refs_dir, revision) + if os.path.isfile(revision_file): + with open(revision_file) as f: + revision = f.read() + + # Check if file is cached as "no_exist" + if os.path.isfile(os.path.join(no_exist_dir, revision, filename)): + return _CACHED_NO_EXIST + + # Check if revision folder exists + if not os.path.exists(snapshots_dir): + return None + cached_shas = os.listdir(snapshots_dir) + if revision not in cached_shas: + # No cache for this revision and we won't try to return a random revision + return None + + # Check if file exists in cache + cached_file = os.path.join(snapshots_dir, revision, filename) + return cached_file if os.path.isfile(cached_file) else None + + +@validate_hf_hub_args +def get_hf_file_metadata( + url: str, + token: Union[bool, str, None] = None, + timeout: Optional[float] = constants.HF_HUB_ETAG_TIMEOUT, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[dict, str, None] = None, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, + retry_on_errors: bool = False, +) -> HfFileMetadata: + """Fetch metadata of a file versioned on the Hub for a given url. + + Args: + url (`str`): + File url, for example returned by [`hf_hub_url`]. + token (`str` or `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If `False` or `None`, no token is provided. + - If a string, it's used as the authentication token. + timeout (`float`, *optional*, defaults to 10): + How many seconds to wait for the server to send metadata before giving up. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`dict`, `str`, *optional*): + The user-agent info in the form of a dictionary or a string. + headers (`dict`, *optional*): + Additional headers to be sent with the request. + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + retry_on_errors (`bool`, *optional*, defaults to `False`): + Whether to retry on errors (429, 5xx, timeout, network errors). + If False, no retry for fast fallback to local cache. + + Returns: + A [`HfFileMetadata`] object containing metadata such as location, etag, size and + commit_hash. + """ + hf_headers = build_hf_headers( + token=token, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + headers=headers, + ) + hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file + + # Retrieve metadata + response = _httpx_follow_relative_redirects( + method="HEAD", url=url, headers=hf_headers, timeout=timeout, retry_on_errors=retry_on_errors + ) + hf_raise_for_status(response) + + # Return + return HfFileMetadata( + commit_hash=response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT), + # We favor a custom header indicating the etag of the linked resource, and we fall back to the regular etag header. + etag=_normalize_etag( + response.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag") + ), + # Either from response headers (if redirected) or defaults to request url + # Do not use directly `url` as we might have followed relative redirects. + location=response.headers.get("Location") or str(response.request.url), # type: ignore + size=_int_or_none( + response.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or response.headers.get("Content-Length") + ), + xet_file_data=parse_xet_file_data_from_response(response, endpoint=endpoint), # type: ignore + ) + + +def _get_metadata_or_catch_error( + *, + repo_id: str, + filename: str, + repo_type: str, + revision: str, + endpoint: Optional[str], + etag_timeout: Optional[float], + headers: dict[str, str], # mutated inplace! + token: Union[bool, str, None], + local_files_only: bool, + relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache + storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache + retry_on_errors: bool = False, +) -> Union[ + # Either an exception is caught and returned + tuple[None, None, None, None, None, Exception], + # Or the metadata is returned as + # `(url_to_download, etag, commit_hash, expected_size, xet_file_data, None)` + tuple[str, str, str, int, Optional[XetFileData], None], +]: + """Get metadata for a file on the Hub, safely handling network issues. + + Returns either the etag, commit_hash and expected size of the file, or the error + raised while fetching the metadata. + + NOTE: This function mutates `headers` inplace! It removes the `authorization` header + if the file is a LFS blob and the domain of the url is different from the + domain of the location (typically an S3 bucket). + """ + if local_files_only: + return ( + None, + None, + None, + None, + None, + OfflineModeIsEnabled( + f"Cannot access file since 'local_files_only=True' as been set. (repo_id: {repo_id}, repo_type: {repo_type}, revision: {revision}, filename: {filename})" + ), + ) + + url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint) + url_to_download: str = url + etag: Optional[str] = None + commit_hash: Optional[str] = None + expected_size: Optional[int] = None + head_error_call: Optional[Exception] = None + xet_file_data: Optional[XetFileData] = None + + # Try to get metadata from the server. + # Do not raise yet if the file is not found or not accessible. + if not local_files_only: + try: + try: + metadata = get_hf_file_metadata( + url=url, + timeout=etag_timeout, + headers=headers, + token=token, + endpoint=endpoint, + retry_on_errors=retry_on_errors, + ) + except RemoteEntryNotFoundError as http_error: + if storage_folder is not None and relative_filename is not None: + # Cache the non-existence of the file + commit_hash = http_error.response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT) + if commit_hash is not None: + no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename + try: + no_exist_file_path.parent.mkdir(parents=True, exist_ok=True) + no_exist_file_path.touch() + except OSError as e: + logger.error( + f"Could not cache non-existence of file. Will ignore error and continue. Error: {e}" + ) + _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash) + raise + + # Commit hash must exist + commit_hash = metadata.commit_hash + if commit_hash is None: + raise FileMetadataError( + "Distant resource does not seem to be on huggingface.co. It is possible that a configuration issue" + " prevents you from downloading resources from https://huggingface.co. Please check your firewall" + " and proxy settings and make sure your SSL certificates are updated." + ) + + # Etag must exist + # If we don't have any of those, raise an error. + etag = metadata.etag + if etag is None: + raise FileMetadataError( + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." + ) + + # Size must exist + expected_size = metadata.size + if expected_size is None: + raise FileMetadataError("Distant resource does not have a Content-Length.") + + xet_file_data = metadata.xet_file_data + + # In case of a redirect, save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + # + # If url domain is different => we are downloading from a CDN => url is signed => don't send auth + # If url domain is the same => redirect due to repo rename AND downloading a regular file => keep auth + if xet_file_data is None and url != metadata.location: + url_to_download = metadata.location + if urlparse(url).netloc != urlparse(metadata.location).netloc: + # Remove authorization header when downloading a LFS blob + headers.pop("authorization", None) + except httpx.ProxyError: + # Actually raise on proxy error + raise + except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled) as error: + # Otherwise, our Internet connection is down. + # etag is None + head_error_call = error + except (RevisionNotFoundError, RemoteEntryNotFoundError): + # The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted) + raise + except HfHubHTTPError as error: + # Multiple reasons for an http error: + # - Repository is private and invalid/missing token sent + # - Repository is gated and invalid/missing token sent + # - Hub is down (error 500 or 504) + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + head_error_call = error + except FileMetadataError as error: + # Multiple reasons for a FileMetadataError: + # - Wrong network configuration (proxy, firewall, SSL certificates) + # - Inconsistency on the Hub + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + head_error_call = error + + if not (local_files_only or etag is not None or head_error_call is not None): + raise RuntimeError("etag is empty due to uncovered problems") + + return (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_error_call) # type: ignore [return-value] + + +def _raise_on_head_call_error(head_call_error: Exception, force_download: bool, local_files_only: bool) -> NoReturn: + """Raise an appropriate error when the HEAD call failed and we cannot locate a local file.""" + # No head call => we cannot force download. + if force_download: + if local_files_only: + raise ValueError("Cannot pass 'force_download=True' and 'local_files_only=True' at the same time.") + elif isinstance(head_call_error, OfflineModeIsEnabled): + raise ValueError("Cannot pass 'force_download=True' when offline mode is enabled.") from head_call_error + else: + raise ValueError("Force download failed due to the above error.") from head_call_error + + # No head call + couldn't find an appropriate file on disk => raise an error. + if local_files_only: + raise LocalEntryNotFoundError( + "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable" + " hf.co look-ups and downloads online, set 'local_files_only' to False." + ) + elif isinstance(head_call_error, (RepositoryNotFoundError, GatedRepoError)) or ( + isinstance(head_call_error, HfHubHTTPError) and head_call_error.response.status_code == 401 + ): + # Repo not found or gated => let's raise the actual error + # Unauthorized => likely a token issue => let's raise the actual error + raise head_call_error + else: + # Otherwise: most likely a connection issue or Hub downtime => let's warn the user + raise LocalEntryNotFoundError( + "An error happened while trying to locate the file on the Hub and we cannot find the requested files" + " in the local cache. Please check your connection and try again or make sure your Internet connection" + " is on." + ) from head_call_error + + +def _download_to_tmp_and_move( + incomplete_path: Path, + destination_path: Path, + url_to_download: str, + headers: dict[str, str], + expected_size: Optional[int], + filename: str, + force_download: bool, + etag: Optional[str], + xet_file_data: Optional[XetFileData], + tqdm_class: Optional[type[base_tqdm]] = None, +) -> None: + """Download content from a URL to a destination path. + + Internal logic: + - return early if file is already downloaded + - resume download if possible (from incomplete file) + - do not resume download if `force_download=True` + - check disk space before downloading + - download content to a temporary file + - set correct permissions on temporary file + - move the temporary file to the destination path + + Both `incomplete_path` and `destination_path` must be on the same volume to avoid a local copy. + """ + if destination_path.exists() and not force_download: + # Do nothing if already exists (except if force_download=True) + return + + if incomplete_path.exists() and force_download: + # By default, we will try to resume the download if possible. + # However, if the user has set `force_download=True`, then we should + # not resume the download => delete the incomplete file. + logger.debug(f"Removing incomplete file '{incomplete_path}' (force_download=True)") + incomplete_path.unlink(missing_ok=True) + + with incomplete_path.open("ab") as f: + resume_size = f.tell() + message = f"Downloading '{filename}' to '{incomplete_path}'" + if resume_size > 0 and expected_size is not None: + message += f" (resume from {resume_size}/{expected_size})" + logger.debug(message) + + if expected_size is not None: # might be None if HTTP header not set correctly + # Check disk space in both tmp and destination path + _check_disk_space(expected_size, incomplete_path.parent) + _check_disk_space(expected_size, destination_path.parent) + + if xet_file_data is not None and is_xet_available(): + logger.debug("Xet Storage is enabled for this repo. Downloading file from Xet Storage..") + xet_get( + incomplete_path=incomplete_path, + xet_file_data=xet_file_data, + headers=headers, + expected_size=expected_size, + displayed_filename=filename, + tqdm_class=tqdm_class, + ) + else: + if xet_file_data is not None and not constants.HF_HUB_DISABLE_XET: + logger.warning( + "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. " + "Falling back to regular HTTP download. " + "For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`" + ) + + http_get( + url_to_download, + f, + resume_size=resume_size, + headers=headers, + expected_size=expected_size, + tqdm_class=tqdm_class, + ) + + logger.debug(f"Download complete. Moving file to {destination_path}") + _chmod_and_move(incomplete_path, destination_path) + + +def _int_or_none(value: Optional[str]) -> Optional[int]: + try: + return int(value) # type: ignore + except (TypeError, ValueError): + return None + + +def _chmod_and_move(src: Path, dst: Path) -> None: + """Set correct permission before moving a blob from tmp directory to cache dir. + + Do not take into account the `umask` from the process as there is no convenient way + to get it that is thread-safe. + + See: + - About umask: https://docs.python.org/3/library/os.html#os.umask + - Thread-safety: https://stackoverflow.com/a/70343066 + - About solution: https://github.com/huggingface/huggingface_hub/pull/1220#issuecomment-1326211591 + - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1141 + - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1215 + """ + # Get umask by creating a temporary file in the cached repo folder. + tmp_file = dst.parent.parent / f"tmp_{uuid.uuid4()}" + try: + tmp_file.touch() + cache_dir_mode = Path(tmp_file).stat().st_mode + os.chmod(str(src), stat.S_IMODE(cache_dir_mode)) + except OSError as e: + logger.warning( + f"Could not set the permissions on the file '{src}'. Error: {e}.\nContinuing without setting permissions." + ) + finally: + try: + tmp_file.unlink() + except OSError: + # fails if `tmp_file.touch()` failed => do nothing + # See https://github.com/huggingface/huggingface_hub/issues/2359 + pass + + shutil.move(str(src), str(dst), copy_function=_copy_no_matter_what) + + +def _copy_no_matter_what(src: str, dst: str) -> None: + """Copy file from src to dst. + + If `shutil.copy2` fails, fallback to `shutil.copyfile`. + """ + try: + # Copy file with metadata and permission + # Can fail e.g. if dst is an S3 mount + shutil.copy2(src, dst) + except OSError: + # Copy only file content + shutil.copyfile(src, dst) + + +def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str: + # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks + snapshot_path = os.path.join(storage_folder, "snapshots") + pointer_path = os.path.join(snapshot_path, revision, relative_filename) + if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents: + raise ValueError( + "Invalid pointer path: cannot create pointer path in snapshot folder if" + f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and" + f" `relative_filename='{relative_filename}'`." + ) + return pointer_path diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/hf_api.py b/venv/lib/python3.10/site-packages/huggingface_hub/hf_api.py new file mode 100644 index 0000000000000000000000000000000000000000..292114b6db16f0eacd3c2c8fc49d23a1f868b535 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/hf_api.py @@ -0,0 +1,11533 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import base64 +import inspect +import itertools +import json +import re +import struct +import time +import warnings +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import asdict, dataclass, field +from datetime import datetime +from functools import wraps +from itertools import islice +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + Callable, + Iterable, + Iterator, + Literal, + Optional, + Type, + TypeVar, + Union, + overload, +) +from urllib.parse import quote + +import httpcore +import httpx +from tqdm.auto import tqdm as base_tqdm +from tqdm.contrib.concurrent import thread_map + +from . import constants +from ._commit_api import ( + CommitOperation, + CommitOperationAdd, + CommitOperationCopy, + CommitOperationDelete, + _fetch_files_to_copy, + _fetch_upload_modes, + _prepare_commit_payload, + _upload_files, + _warn_on_overwriting_operations, +) +from ._eval_results import EvalResultEntry, parse_eval_result_entries +from ._inference_endpoints import InferenceEndpoint, InferenceEndpointScalingMetric, InferenceEndpointType +from ._jobs_api import JobHardware, JobInfo, JobSpec, ScheduledJobInfo, _create_job_spec +from ._space_api import SpaceHardware, SpaceRuntime, SpaceStorage, SpaceVariable +from ._upload_large_folder import upload_large_folder_internal +from .community import ( + Discussion, + DiscussionComment, + DiscussionStatusChange, + DiscussionTitleChange, + DiscussionWithDetails, + deserialize_event, +) +from .errors import ( + BadRequestError, + GatedRepoError, + HfHubHTTPError, + LocalTokenNotFoundError, + RemoteEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from .file_download import DryRunFileInfo, HfFileMetadata, get_hf_file_metadata, hf_hub_url +from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData +from .utils import ( + DEFAULT_IGNORE_PATTERNS, + NotASafetensorsRepoError, + SafetensorsFileMetadata, + SafetensorsParsingError, + SafetensorsRepoMetadata, + TensorInfo, + build_hf_headers, + chunk_iterable, + experimental, + filter_repo_objects, + fix_hf_endpoint_in_url, + get_session, + get_token, + hf_raise_for_status, + logging, + paginate, + parse_datetime, + validate_hf_hub_args, +) +from .utils import tqdm as hf_tqdm +from .utils._auth import _get_token_from_environment, _get_token_from_file, _get_token_from_google_colab +from .utils._deprecation import _deprecate_arguments +from .utils._typing import CallableT +from .utils._verification import collect_local_files, resolve_local_root, verify_maps +from .utils.endpoint_helpers import _is_emission_within_threshold + + +if TYPE_CHECKING: + from .inference._providers import PROVIDER_T + from .utils._verification import FolderVerification + +R = TypeVar("R") # Return type +CollectionItemType_T = Literal["model", "dataset", "space", "paper", "collection"] +CollectionSort_T = Literal["lastModified", "trending", "upvotes"] + +ExpandModelProperty_T = Literal[ + "author", + "baseModels", + "cardData", + "childrenModelCount", + "config", + "createdAt", + "disabled", + "downloads", + "downloadsAllTime", + "evalResults", + "gated", + "gguf", + "inference", + "inferenceProviderMapping", + "lastModified", + "library_name", + "likes", + "mask_token", + "model-index", + "pipeline_tag", + "private", + "resourceGroup", + "safetensors", + "sha", + "siblings", + "spaces", + "tags", + "transformersInfo", + "trendingScore", + "usedStorage", + "widgetData", +] + +ExpandDatasetProperty_T = Literal[ + "author", + "cardData", + "citation", + "createdAt", + "description", + "disabled", + "downloads", + "downloadsAllTime", + "gated", + "lastModified", + "likes", + "paperswithcode_id", + "private", + "resourceGroup", + "sha", + "siblings", + "tags", + "trendingScore", + "usedStorage", +] + +ExpandSpaceProperty_T = Literal[ + "author", + "cardData", + "createdAt", + "datasets", + "disabled", + "lastModified", + "likes", + "models", + "private", + "resourceGroup", + "runtime", + "sdk", + "sha", + "siblings", + "subdomain", + "tags", + "trendingScore", + "usedStorage", +] + +ModelSort_T = Literal["created_at", "downloads", "last_modified", "likes", "trending_score"] +DatasetSort_T = Literal["created_at", "downloads", "last_modified", "likes", "trending_score"] +SpaceSort_T = Literal["created_at", "last_modified", "likes", "trending_score"] +DailyPapersSort_T = Literal["publishedAt", "trending"] + +USERNAME_PLACEHOLDER = "hf_user" +_REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$") +_REGEX_HTTP_PROTOCOL = re.compile(r"https?://") + +_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = ( + "\nNote: Creating a commit assumes that the repo already exists on the" + " Huggingface Hub. Please use `create_repo` if it's not the case." +) +_AUTH_CHECK_NO_REPO_ERROR_MESSAGE = ( + "\nNote: The repository either does not exist or you do not have access rights." + " Please check the repository ID and your access permissions." + " If this is a private repository, ensure that your token is correct." +) +logger = logging.get_logger(__name__) + + +def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tuple[Optional[str], Optional[str], str]: + """ + Returns the repo type and ID from a huggingface.co URL linking to a + repository + + Args: + hf_id (`str`): + An URL or ID of a repository on the HF hub. Accepted values are: + + - https://huggingface.co/// + - https://huggingface.co// + - hf://// + - hf:/// + - // + - / + - + hub_url (`str`, *optional*): + The URL of the HuggingFace Hub, defaults to https://huggingface.co + + Returns: + A tuple with three items: repo_type (`str` or `None`), namespace (`str` or + `None`) and repo_id (`str`). + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If URL cannot be parsed. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `repo_type` is unknown. + """ + input_hf_id = hf_id + + # Get the hub_url (with or without protocol) + full_hub_url = hub_url if hub_url is not None else constants.ENDPOINT + hub_url_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", full_hub_url) + + # Check if hf_id is a URL containing the hub_url (check both with and without protocol) + hf_id_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", hf_id) + is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id + + HFFS_PREFIX = "hf://" + if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists + hf_id = hf_id[len(HFFS_PREFIX) :] + + # If it's a URL, strip the endpoint prefix to get the path + if is_hf_url: + # Remove protocol if present + hf_id_normalized = _REGEX_HTTP_PROTOCOL.sub("", hf_id) + + # Remove the hub_url prefix to get the relative path + if hf_id_normalized.startswith(hub_url_without_protocol): + # Strip the hub URL and any leading slashes + hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/") + + url_segments = hf_id.split("/") + is_hf_id = len(url_segments) <= 3 + + namespace: Optional[str] + if is_hf_url: + # For URLs, we need to extract repo_type, namespace, repo_id + # Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id + + if len(url_segments) >= 3: + # Check if first segment is a repo type + if url_segments[0] in constants.REPO_TYPES_MAPPING: + repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] + namespace = url_segments[1] + repo_id = url_segments[2] + else: + # First segment is namespace + namespace = url_segments[0] + repo_id = url_segments[1] + repo_type = None + elif len(url_segments) == 2: + namespace = url_segments[0] + repo_id = url_segments[1] + + # Check if namespace is actually a repo type mapping + if namespace in constants.REPO_TYPES_MAPPING: + # Mean canonical dataset or model + repo_type = constants.REPO_TYPES_MAPPING[namespace] + namespace = None + else: + repo_type = None + else: + # Single segment + repo_id = url_segments[0] + namespace = None + repo_type = None + elif is_hf_id: + if len(url_segments) == 3: + # Passed // or // + repo_type, namespace, repo_id = url_segments[-3:] + elif len(url_segments) == 2: + if url_segments[0] in constants.REPO_TYPES_MAPPING: + # Passed '' or 'datasets/' for a canonical model or dataset + repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] + namespace = None + repo_id = hf_id.split("/")[-1] + else: + # Passed / or / + namespace, repo_id = hf_id.split("/")[-2:] + repo_type = None + else: + # Passed + repo_id = url_segments[0] + namespace, repo_type = None, None + else: + raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}") + + # Check if repo type is known (mapping "spaces" => "space" + empty value => `None`) + if repo_type in constants.REPO_TYPES_MAPPING: + repo_type = constants.REPO_TYPES_MAPPING[repo_type] + if repo_type == "": + repo_type = None + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')") + + return repo_type, namespace, repo_id + + +@dataclass +class LastCommitInfo(dict): + oid: str + title: str + date: datetime + + def __post_init__(self): # hack to make LastCommitInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class BlobLfsInfo(dict): + size: int + sha256: str + pointer_size: int + + def __post_init__(self): # hack to make BlobLfsInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class BlobSecurityInfo(dict): + safe: bool # duplicate information with "status" field, keeping it for backward compatibility + status: str + av_scan: Optional[dict] + pickle_import_scan: Optional[dict] + + def __post_init__(self): # hack to make BlogSecurityInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class TransformersInfo(dict): + auto_model: str + custom_class: Optional[str] = None + # possible `pipeline_tag` values: https://github.com/huggingface/huggingface.js/blob/3ee32554b8620644a6287e786b2a83bf5caf559c/packages/tasks/src/pipelines.ts#L72 + pipeline_tag: Optional[str] = None + processor: Optional[str] = None + + def __post_init__(self): # hack to make TransformersInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class SafeTensorsInfo(dict): + parameters: dict[str, int] + total: int + + def __post_init__(self): # hack to make SafeTensorsInfo backward compatible + self.update(asdict(self)) + + +@dataclass +class CommitInfo(str): + """Data structure containing information about a newly created commit. + + Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`], + [`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific + to `str` is deprecated. + + Attributes: + commit_url (`str`): + Url where to find the commit. + + commit_message (`str`): + The summary (first line) of the commit that has been created. + + commit_description (`str`): + Description of the commit that has been created. Can be empty. + + oid (`str`): + Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. + + pr_url (`str`, *optional*): + Url to the PR that has been created, if any. Populated when `create_pr=True` + is passed. + + pr_revision (`str`, *optional*): + Revision of the PR that has been created, if any. Populated when + `create_pr=True` is passed. Example: `"refs/pr/1"`. + + pr_num (`int`, *optional*): + Number of the PR discussion that has been created, if any. Populated when + `create_pr=True` is passed. Can be passed as `discussion_num` in + [`get_discussion_details`]. Example: `1`. + + repo_url (`RepoUrl`): + Repo URL of the commit containing info like repo_id, repo_type, etc. + """ + + commit_url: str + commit_message: str + commit_description: str + oid: str + _endpoint: Optional[str] = field(default=None, repr=False) + pr_url: Optional[str] = None + + # Computed from `commit_url` in `__post_init__` + repo_url: RepoUrl = field(init=False) + + # Computed from `pr_url` in `__post_init__` + pr_revision: Optional[str] = field(init=False) + pr_num: Optional[int] = field(init=False) + + def __new__(cls, *args, commit_url: str, **kwargs): + return str.__new__(cls, commit_url) + + def __post_init__(self): + """Populate pr-related fields after initialization. + + See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing. + """ + # Repo info + self.repo_url = RepoUrl(self.commit_url.split("/commit/")[0], endpoint=self._endpoint) + + # PR info + if self.pr_url is not None: + self.pr_revision = _parse_revision_from_pr_url(self.pr_url) + self.pr_num = int(self.pr_revision.split("/")[-1]) + else: + self.pr_revision = None + self.pr_num = None + + +@dataclass +class AccessRequest: + """Data structure containing information about a user access request. + + Attributes: + username (`str`): + Username of the user who requested access. + fullname (`str`): + Fullname of the user who requested access. + email (`Optional[str]`): + Email of the user who requested access. + Can only be `None` in the /accepted list if the user was granted access manually. + timestamp (`datetime`): + Timestamp of the request. + status (`Literal["pending", "accepted", "rejected"]`): + Status of the request. Can be one of `["pending", "accepted", "rejected"]`. + fields (`dict[str, Any]`, *optional*): + Additional fields filled by the user in the gate form. + """ + + username: str + fullname: str + email: Optional[str] + timestamp: datetime + status: Literal["pending", "accepted", "rejected"] + + # Additional fields filled by the user in the gate form + fields: Optional[dict[str, Any]] = None + + +@dataclass +class WebhookWatchedItem: + """Data structure containing information about the items watched by a webhook. + + Attributes: + type (`Literal["dataset", "model", "org", "space", "user"]`): + Type of the item to be watched. Can be one of `["dataset", "model", "org", "space", "user"]`. + name (`str`): + Name of the item to be watched. Can be the username, organization name, model name, dataset name or space name. + """ + + type: Literal["dataset", "model", "org", "space", "user"] + name: str + + +@dataclass +class WebhookInfo: + """Data structure containing information about a webhook. + + One of `url` or `job` is specified, but not both. + + Attributes: + id (`str`): + ID of the webhook. + url (`str`, *optional*): + URL of the webhook. + job (`JobSpec`, *optional*): + Specifications of the Job to trigger. + watched (`list[WebhookWatchedItem]`): + List of items watched by the webhook, see [`WebhookWatchedItem`]. + domains (`list[WEBHOOK_DOMAIN_T]`): + List of domains the webhook is watching. Can be one of `["repo", "discussions"]`. + secret (`str`, *optional*): + Secret of the webhook. + disabled (`bool`): + Whether the webhook is disabled or not. + """ + + id: str + url: Optional[str] + job: Optional[JobSpec] + watched: list[WebhookWatchedItem] + domains: list[constants.WEBHOOK_DOMAIN_T] + secret: Optional[str] + disabled: bool + + +class RepoUrl(str): + """Subclass of `str` describing a repo URL on the Hub. + + `RepoUrl` is returned by `HfApi.create_repo`. It inherits from `str` for backward + compatibility. At initialization, the URL is parsed to populate properties: + - endpoint (`str`) + - namespace (`Optional[str]`) + - repo_name (`str`) + - repo_id (`str`) + - repo_type (`Literal["model", "dataset", "space"]`) + - url (`str`) + + Args: + url (`Any`): + String value of the repo url. + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + + Example: + ```py + >>> RepoUrl('https://huggingface.co/gpt2') + RepoUrl('https://huggingface.co/gpt2', endpoint='https://huggingface.co', repo_type='model', repo_id='gpt2') + + >>> RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co') + RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co', repo_type='dataset', repo_id='dummy_user/dummy_dataset') + + >>> RepoUrl('hf://datasets/my-user/my-dataset') + RepoUrl('hf://datasets/my-user/my-dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='user/dataset') + + >>> HfApi.create_repo("dummy_model") + RepoUrl('https://huggingface.co/Wauplin/dummy_model', endpoint='https://huggingface.co', repo_type='model', repo_id='Wauplin/dummy_model') + ``` + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If URL cannot be parsed. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `repo_type` is unknown. + """ + + def __new__(cls, url: Any, endpoint: Optional[str] = None): + url = fix_hf_endpoint_in_url(url, endpoint=endpoint) + return super(RepoUrl, cls).__new__(cls, url) + + def __init__(self, url: Any, endpoint: Optional[str] = None) -> None: + super().__init__() + # Parse URL + self.endpoint = endpoint or constants.ENDPOINT + repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(self, hub_url=self.endpoint) + + # Populate fields + self.namespace = namespace + self.repo_name = repo_name + self.repo_id = repo_name if namespace is None else f"{namespace}/{repo_name}" + self.repo_type = repo_type or constants.REPO_TYPE_MODEL + self.url = str(self) # just in case it's needed + + def __repr__(self) -> str: + return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')" + + +@dataclass +class RepoSibling: + """ + Contains basic information about a repo file inside a repo on the Hub. + + > [!TIP] + > All attributes of this class are optional except `rfilename`. This is because only the file names are returned when + > listing repositories on the Hub (with [`list_models`], [`list_datasets`] or [`list_spaces`]). If you need more + > information like file size, blob id or lfs details, you must request them specifically from one repo at a time + > (using [`model_info`], [`dataset_info`] or [`space_info`]) as it adds more constraints on the backend server to + > retrieve these. + + Attributes: + rfilename (str): + file name, relative to the repo root. + size (`int`, *optional*): + The file's size, in bytes. This attribute is defined when `files_metadata` argument of [`repo_info`] is set + to `True`. It's `None` otherwise. + blob_id (`str`, *optional*): + The file's git OID. This attribute is defined when `files_metadata` argument of [`repo_info`] is set to + `True`. It's `None` otherwise. + lfs (`BlobLfsInfo`, *optional*): + The file's LFS metadata. This attribute is defined when`files_metadata` argument of [`repo_info`] is set to + `True` and the file is stored with Git LFS. It's `None` otherwise. + """ + + rfilename: str + size: Optional[int] = None + blob_id: Optional[str] = None + lfs: Optional[BlobLfsInfo] = None + + +@dataclass +class RepoFile: + """ + Contains information about a file on the Hub. + + Attributes: + path (str): + file path relative to the repo root. + size (`int`): + The file's size, in bytes. + blob_id (`str`): + The file's git OID. + lfs (`BlobLfsInfo`, *optional*): + The file's LFS metadata. + last_commit (`LastCommitInfo`, *optional*): + The file's last commit metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + security (`BlobSecurityInfo`, *optional*): + The file's security scan metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + """ + + path: str + size: int + blob_id: str + lfs: Optional[BlobLfsInfo] = None + last_commit: Optional[LastCommitInfo] = None + security: Optional[BlobSecurityInfo] = None + + def __init__(self, **kwargs): + self.path = kwargs.pop("path") + self.size = kwargs.pop("size") + self.blob_id = kwargs.pop("oid") + lfs = kwargs.pop("lfs", None) + if lfs is not None: + lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"]) + self.lfs = lfs + last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) + if last_commit is not None: + last_commit = LastCommitInfo( + oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) + ) + self.last_commit = last_commit + security = kwargs.pop("securityFileStatus", None) + if security is not None: + safe = security["status"] == "safe" + security = BlobSecurityInfo( + safe=safe, + status=security["status"], + av_scan=security["avScan"], + pickle_import_scan=security["pickleImportScan"], + ) + self.security = security + + # backwards compatibility + self.rfilename = self.path + self.lastCommit = self.last_commit + + +@dataclass +class RepoFolder: + """ + Contains information about a folder on the Hub. + + Attributes: + path (str): + folder path relative to the repo root. + tree_id (`str`): + The folder's git OID. + last_commit (`LastCommitInfo`, *optional*): + The folder's last commit metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + """ + + path: str + tree_id: str + last_commit: Optional[LastCommitInfo] = None + + def __init__(self, **kwargs): + self.path = kwargs.pop("path") + self.tree_id = kwargs.pop("oid") + last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) + if last_commit is not None: + last_commit = LastCommitInfo( + oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) + ) + self.last_commit = last_commit + + +@dataclass +class InferenceProviderMapping: + provider: "PROVIDER_T" # Provider name + hf_model_id: str # ID of the model on the Hugging Face Hub + provider_id: str # ID of the model on the provider's side + status: Literal["error", "live", "staging"] + task: str + + adapter: Optional[str] = None + adapter_weights_path: Optional[str] = None + type: Optional[Literal["single-model", "tag-filter"]] = None + + def __init__(self, **kwargs): + self.provider = kwargs.pop("provider") + self.hf_model_id = kwargs.pop("hf_model_id") + self.provider_id = kwargs.pop("providerId") + self.status = kwargs.pop("status") + self.task = kwargs.pop("task") + + self.adapter = kwargs.pop("adapter", None) + self.adapter_weights_path = kwargs.pop("adapterWeightsPath", None) + self.type = kwargs.pop("type", None) + self.__dict__.update(**kwargs) + + +@dataclass +class ModelInfo: + """ + Contains information about a model on the Hub. This object is returned by [`model_info`] and [`list_models`]. + + > [!TIP] + > Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + > In general, the more specific the query, the more information is returned. On the contrary, when listing models + > using [`list_models`] only a subset of the attributes are returned. + + Attributes: + id (`str`): + ID of model. + author (`str`, *optional*): + Author of the model. + sha (`str`, *optional*): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + disabled (`bool`, *optional*): + Is the repo disabled. + downloads (`int`): + Number of downloads of the model over the last 30 days. + downloads_all_time (`int`): + Cumulated number of downloads of the model since its creation. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + gguf (`dict`, *optional*): + GGUF information of the model. + inference (`Literal["warm"]`, *optional*): + Status of the model on Inference Providers. Warm if the model is served by at least one provider. + inference_provider_mapping (`list[InferenceProviderMapping]`, *optional*): + A list of [`InferenceProviderMapping`] ordered after the user's provider order. + likes (`int`): + Number of likes of the model. + library_name (`str`, *optional*): + Library associated with the model. + tags (`list[str]`): + List of tags of the model. Compared to `card_data.tags`, contains extra tags computed by the Hub + (e.g. supported libraries, model's arXiv). + pipeline_tag (`str`, *optional*): + Pipeline tag associated with the model. + mask_token (`str`, *optional*): + Mask token used by the model. + widget_data (`Any`, *optional*): + Widget data associated with the model. + model_index (`dict`, *optional*): + Model index for evaluation. + config (`dict`, *optional*): + Model configuration. + transformers_info (`TransformersInfo`, *optional*): + Transformers-specific info (auto class, processor, etc.) associated with the model. + trending_score (`int`, *optional*): + Trending score of the model. + card_data (`ModelCardData`, *optional*): + Model Card Metadata as a [`huggingface_hub.repocard_data.ModelCardData`] object. + siblings (`list[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the model. + spaces (`list[str]`, *optional*): + List of spaces using the model. + safetensors (`SafeTensorsInfo`, *optional*): + Model's safetensors information. + security_repo_status (`dict`, *optional*): + Model's security scan status. + eval_results (`list[EvalResultEntry]`, *optional*): + Model's evaluation results. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: Optional[bool] + disabled: Optional[bool] + downloads: Optional[int] + downloads_all_time: Optional[int] + gated: Optional[Literal["auto", "manual", False]] + gguf: Optional[dict] + inference: Optional[Literal["warm"]] + inference_provider_mapping: Optional[list[InferenceProviderMapping]] + likes: Optional[int] + library_name: Optional[str] + tags: Optional[list[str]] + pipeline_tag: Optional[str] + mask_token: Optional[str] + card_data: Optional[ModelCardData] + widget_data: Optional[Any] + model_index: Optional[dict] + config: Optional[dict] + transformers_info: Optional[TransformersInfo] + trending_score: Optional[int] + siblings: Optional[list[RepoSibling]] + spaces: Optional[list[str]] + safetensors: Optional[SafeTensorsInfo] + security_repo_status: Optional[dict] + eval_results: Optional[list[EvalResultEntry]] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + self.private = kwargs.pop("private", None) + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.downloads = kwargs.pop("downloads", None) + self.downloads_all_time = kwargs.pop("downloadsAllTime", None) + self.likes = kwargs.pop("likes", None) + self.library_name = kwargs.pop("library_name", None) + self.gguf = kwargs.pop("gguf", None) + + self.inference = kwargs.pop("inference", None) + + # little hack to simplify Inference Providers logic and make it backward and forward compatible + # right now, API returns a dict on model_info and a list on list_models. Let's harmonize to list. + mapping = kwargs.pop("inferenceProviderMapping", None) + if isinstance(mapping, list): + self.inference_provider_mapping = [ + InferenceProviderMapping(**{**value, "hf_model_id": self.id}) for value in mapping + ] + elif isinstance(mapping, dict): + self.inference_provider_mapping = [ + InferenceProviderMapping(**{**value, "hf_model_id": self.id, "provider": provider}) + for provider, value in mapping.items() + ] + elif mapping is None: + self.inference_provider_mapping = None + else: + raise ValueError( + f"Unexpected type for `inferenceProviderMapping`. Expecting `dict` or `list`. Got {mapping}." + ) + + self.tags = kwargs.pop("tags", None) + self.pipeline_tag = kwargs.pop("pipeline_tag", None) + self.mask_token = kwargs.pop("mask_token", None) + self.trending_score = kwargs.pop("trendingScore", None) + + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + ModelCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + + self.widget_data = kwargs.pop("widgetData", None) + self.model_index = kwargs.pop("model-index", None) or kwargs.pop("model_index", None) + self.config = kwargs.pop("config", None) + transformers_info = kwargs.pop("transformersInfo", None) or kwargs.pop("transformers_info", None) + self.transformers_info = TransformersInfo(**transformers_info) if transformers_info else None + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings is not None + else None + ) + self.spaces = kwargs.pop("spaces", None) + safetensors = kwargs.pop("safetensors", None) + self.safetensors = ( + SafeTensorsInfo( + parameters=safetensors["parameters"], + total=safetensors["total"], + ) + if safetensors + else None + ) + self.security_repo_status = kwargs.pop("securityRepoStatus", None) + eval_results = kwargs.pop("evalResults", None) + self.eval_results = parse_eval_result_entries(eval_results) if eval_results else None + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.transformersInfo = self.transformers_info + self.__dict__.update(**kwargs) + + +@dataclass +class DatasetInfo: + """ + Contains information about a dataset on the Hub. This object is returned by [`dataset_info`] and [`list_datasets`]. + + > [!TIP] + > Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + > In general, the more specific the query, the more information is returned. On the contrary, when listing datasets + > using [`list_datasets`] only a subset of the attributes are returned. + + Attributes: + id (`str`): + ID of dataset. + author (`str`): + Author of the dataset. + sha (`str`): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + disabled (`bool`, *optional*): + Is the repo disabled. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + downloads (`int`): + Number of downloads of the dataset over the last 30 days. + downloads_all_time (`int`): + Cumulated number of downloads of the model since its creation. + likes (`int`): + Number of likes of the dataset. + tags (`list[str]`): + List of tags of the dataset. + card_data (`DatasetCardData`, *optional*): + Model Card Metadata as a [`huggingface_hub.repocard_data.DatasetCardData`] object. + siblings (`list[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the dataset. + paperswithcode_id (`str`, *optional*): + Papers with code ID of the dataset. + trending_score (`int`, *optional*): + Trending score of the dataset. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: Optional[bool] + gated: Optional[Literal["auto", "manual", False]] + disabled: Optional[bool] + downloads: Optional[int] + downloads_all_time: Optional[int] + likes: Optional[int] + paperswithcode_id: Optional[str] + tags: Optional[list[str]] + trending_score: Optional[int] + card_data: Optional[DatasetCardData] + siblings: Optional[list[RepoSibling]] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + self.private = kwargs.pop("private", None) + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.downloads = kwargs.pop("downloads", None) + self.downloads_all_time = kwargs.pop("downloadsAllTime", None) + self.likes = kwargs.pop("likes", None) + self.paperswithcode_id = kwargs.pop("paperswithcode_id", None) + self.tags = kwargs.pop("tags", None) + self.trending_score = kwargs.pop("trendingScore", None) + + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + DatasetCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings is not None + else None + ) + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.__dict__.update(**kwargs) + + +@dataclass +class SpaceInfo: + """ + Contains information about a Space on the Hub. This object is returned by [`space_info`] and [`list_spaces`]. + + > [!TIP] + > Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + > In general, the more specific the query, the more information is returned. On the contrary, when listing spaces + > using [`list_spaces`] only a subset of the attributes are returned. + + Attributes: + id (`str`): + ID of the Space. + author (`str`, *optional*): + Author of the Space. + sha (`str`, *optional*): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + disabled (`bool`, *optional*): + Is the Space disabled. + host (`str`, *optional*): + Host URL of the Space. + subdomain (`str`, *optional*): + Subdomain of the Space. + likes (`int`): + Number of likes of the Space. + tags (`list[str]`): + List of tags of the Space. + siblings (`list[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the Space. + card_data (`SpaceCardData`, *optional*): + Space Card Metadata as a [`huggingface_hub.repocard_data.SpaceCardData`] object. + runtime (`SpaceRuntime`, *optional*): + Space runtime information as a [`huggingface_hub.hf_api.SpaceRuntime`] object. + sdk (`str`, *optional*): + SDK used by the Space. + models (`list[str]`, *optional*): + List of models used by the Space. + datasets (`list[str]`, *optional*): + List of datasets used by the Space. + trending_score (`int`, *optional*): + Trending score of the Space. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: Optional[bool] + gated: Optional[Literal["auto", "manual", False]] + disabled: Optional[bool] + host: Optional[str] + subdomain: Optional[str] + likes: Optional[int] + sdk: Optional[str] + tags: Optional[list[str]] + siblings: Optional[list[RepoSibling]] + trending_score: Optional[int] + card_data: Optional[SpaceCardData] + runtime: Optional[SpaceRuntime] + models: Optional[list[str]] + datasets: Optional[list[str]] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + self.private = kwargs.pop("private", None) + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.host = kwargs.pop("host", None) + self.subdomain = kwargs.pop("subdomain", None) + self.likes = kwargs.pop("likes", None) + self.sdk = kwargs.pop("sdk", None) + self.tags = kwargs.pop("tags", None) + self.trending_score = kwargs.pop("trendingScore", None) + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + SpaceCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings is not None + else None + ) + runtime = kwargs.pop("runtime", None) + self.runtime = SpaceRuntime(runtime) if runtime else None + self.models = kwargs.pop("models", None) + self.datasets = kwargs.pop("datasets", None) + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.__dict__.update(**kwargs) + + +@dataclass +class CollectionItem: + """ + Contains information about an item of a Collection (model, dataset, Space, paper or collection). + + Attributes: + item_object_id (`str`): + Unique ID of the item in the collection. + item_id (`str`): + ID of the underlying object on the Hub. Can be either a repo_id, a paper id or a collection slug. + e.g. `"jbilcke-hf/ai-comic-factory"`, `"2307.09288"`, `"celinah/cerebras-function-calling-682607169c35fbfa98b30b9a"`. + item_type (`str`): + Type of the underlying object. Can be one of `"model"`, `"dataset"`, `"space"`, `"paper"` or `"collection"`. + position (`int`): + Position of the item in the collection. + note (`str`, *optional*): + Note associated with the item, as plain text. + """ + + item_object_id: str # id in database + item_id: str # repo_id or paper id + item_type: str + position: int + note: Optional[str] = None + + def __init__( + self, + _id: str, + id: str, + type: CollectionItemType_T, + position: int, + note: Optional[dict] = None, + **kwargs, + ) -> None: + self.item_object_id: str = _id # id in database + self.item_id: str = id # repo_id or paper id + # if the item is a collection, override item_id with the slug + slug = kwargs.get("slug") + if slug is not None: + self.item_id = slug # collection slug + self.item_type: CollectionItemType_T = type + self.position: int = position + self.note: str = note["text"] if note is not None else None + + +@dataclass +class Collection: + """ + Contains information about a Collection on the Hub. + + Attributes: + slug (`str`): + Slug of the collection. E.g. `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection. E.g. `"Recent models"`. + owner (`str`): + Owner of the collection. E.g. `"TheBloke"`. + items (`list[CollectionItem]`): + List of items in the collection. + last_updated (`datetime`): + Date of the last update of the collection. + position (`int`): + Position of the collection in the list of collections of the owner. + private (`bool`): + Whether the collection is private or not. + theme (`str`): + Theme of the collection. E.g. `"green"`. + upvotes (`int`): + Number of upvotes of the collection. + description (`str`, *optional*): + Description of the collection, as plain text. + url (`str`): + (property) URL of the collection on the Hub. + """ + + slug: str + title: str + owner: str + items: list[CollectionItem] + last_updated: datetime + position: int + private: bool + theme: str + upvotes: int + description: Optional[str] = None + + def __init__(self, **kwargs) -> None: + self.slug = kwargs.pop("slug") + self.title = kwargs.pop("title") + self.owner = kwargs.pop("owner") + self.items = [CollectionItem(**item) for item in kwargs.pop("items")] + self.last_updated = parse_datetime(kwargs.pop("lastUpdated")) + self.position = kwargs.pop("position") + self.private = kwargs.pop("private") + self.theme = kwargs.pop("theme") + self.upvotes = kwargs.pop("upvotes") + self.description = kwargs.pop("description", None) + endpoint = kwargs.pop("endpoint", None) + if endpoint is None: + endpoint = constants.ENDPOINT + self._url = f"{endpoint}/collections/{self.slug}" + + @property + def url(self) -> str: + """Returns the URL of the collection on the Hub.""" + return self._url + + +@dataclass +class GitRefInfo: + """ + Contains information about a git reference for a repo on the Hub. + + Attributes: + name (`str`): + Name of the reference (e.g. tag name or branch name). + ref (`str`): + Full git ref on the Hub (e.g. `"refs/heads/main"` or `"refs/tags/v1.0"`). + target_commit (`str`): + OID of the target commit for the ref (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) + """ + + name: str + ref: str + target_commit: str + + +@dataclass +class GitRefs: + """ + Contains information about all git references for a repo on the Hub. + + Object is returned by [`list_repo_refs`]. + + Attributes: + branches (`list[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about branches on the repo. + converts (`list[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about "convert" refs on the repo. + Converts are refs used (internally) to push preprocessed data in Dataset repos. + tags (`list[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about tags on the repo. + pull_requests (`list[GitRefInfo]`, *optional*): + A list of [`GitRefInfo`] containing information about pull requests on the repo. + Only returned if `include_prs=True` is set. + """ + + branches: list[GitRefInfo] + converts: list[GitRefInfo] + tags: list[GitRefInfo] + pull_requests: Optional[list[GitRefInfo]] = None + + +@dataclass +class GitCommitInfo: + """ + Contains information about a git commit for a repo on the Hub. Check out [`list_repo_commits`] for more details. + + Attributes: + commit_id (`str`): + OID of the commit (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) + authors (`list[str]`): + List of authors of the commit. + created_at (`datetime`): + Datetime when the commit was created. + title (`str`): + Title of the commit. This is a free-text value entered by the authors. + message (`str`): + Description of the commit. This is a free-text value entered by the authors. + formatted_title (`str`): + Title of the commit formatted as HTML. Only returned if `formatted=True` is set. + formatted_message (`str`): + Description of the commit formatted as HTML. Only returned if `formatted=True` is set. + """ + + commit_id: str + + authors: list[str] + created_at: datetime + title: str + message: str + + formatted_title: Optional[str] + formatted_message: Optional[str] + + +@dataclass +class UserLikes: + """ + Contains information about a user likes on the Hub. + + Attributes: + user (`str`): + Name of the user for which we fetched the likes. + total (`int`): + Total number of likes. + datasets (`list[str]`): + List of datasets liked by the user (as repo_ids). + models (`list[str]`): + List of models liked by the user (as repo_ids). + spaces (`list[str]`): + List of spaces liked by the user (as repo_ids). + """ + + # Metadata + user: str + total: int + + # User likes + datasets: list[str] + models: list[str] + spaces: list[str] + + +@dataclass +class Organization: + """ + Contains information about an organization on the Hub. + + Attributes: + avatar_url (`str`): + URL of the organization's avatar. + name (`str`): + Name of the organization on the Hub (unique). + fullname (`str`): + Organization's full name. + details (`str`, *optional*): + Organization's description. + is_verified (`bool`, *optional*): + Whether the organization is verified. + is_following (`bool`, *optional*): + Whether the authenticated user follows this organization. + num_users (`int`, *optional*): + Number of members in the organization. + num_models (`int`, *optional*): + Number of models owned by the organization. + num_spaces (`int`, *optional*): + Number of Spaces owned by the organization. + num_datasets (`int`, *optional*): + Number of datasets owned by the organization. + num_followers (`int`, *optional*): + Number of followers of the organization. + num_papers (`int`, *optional*): + Number of papers authored by the organization. + plan (`str`, *optional*): + The organization's plan (e.g., "enterprise", "team"). + """ + + avatar_url: str + name: str + fullname: str + details: Optional[str] = None + is_verified: Optional[bool] = None + is_following: Optional[bool] = None + num_users: Optional[int] = None + num_models: Optional[int] = None + num_spaces: Optional[int] = None + num_datasets: Optional[int] = None + num_followers: Optional[int] = None + num_papers: Optional[int] = None + plan: Optional[str] = None + + def __init__(self, **kwargs) -> None: + self.avatar_url = kwargs.pop("avatarUrl", "") + self.name = kwargs.pop("name", "") + self.fullname = kwargs.pop("fullname", "") + self.details = kwargs.pop("details", None) + self.is_verified = kwargs.pop("isVerified", None) + self.is_following = kwargs.pop("isFollowing", None) + self.num_users = kwargs.pop("numUsers", None) + self.num_models = kwargs.pop("numModels", None) + self.num_spaces = kwargs.pop("numSpaces", None) + self.num_datasets = kwargs.pop("numDatasets", None) + self.num_followers = kwargs.pop("numFollowers", None) + self.num_papers = kwargs.pop("numPapers", None) + self.plan = kwargs.pop("plan", None) + + # forward compatibility + self.__dict__.update(**kwargs) + + +@dataclass +class User: + """ + Contains information about a user on the Hub. + + Attributes: + username (`str`): + Name of the user on the Hub (unique). + fullname (`str`): + User's full name. + avatar_url (`str`): + URL of the user's avatar. + details (`str`, *optional*): + User's details. + is_following (`bool`, *optional*): + Whether the authenticated user is following this user. + is_pro (`bool`, *optional*): + Whether the user is a pro user. + num_models (`int`, *optional*): + Number of models created by the user. + num_datasets (`int`, *optional*): + Number of datasets created by the user. + num_spaces (`int`, *optional*): + Number of spaces created by the user. + num_discussions (`int`, *optional*): + Number of discussions initiated by the user. + num_papers (`int`, *optional*): + Number of papers authored by the user. + num_upvotes (`int`, *optional*): + Number of upvotes received by the user. + num_likes (`int`, *optional*): + Number of likes given by the user. + num_following (`int`, *optional*): + Number of users this user is following. + num_followers (`int`, *optional*): + Number of users following this user. + orgs (list of [`Organization`]): + List of organizations the user is part of. + """ + + # Metadata + username: str + fullname: str + avatar_url: str + details: Optional[str] = None + is_following: Optional[bool] = None + is_pro: Optional[bool] = None + num_models: Optional[int] = None + num_datasets: Optional[int] = None + num_spaces: Optional[int] = None + num_discussions: Optional[int] = None + num_papers: Optional[int] = None + num_upvotes: Optional[int] = None + num_likes: Optional[int] = None + num_following: Optional[int] = None + num_followers: Optional[int] = None + orgs: list[Organization] = field(default_factory=list) + + def __init__(self, **kwargs) -> None: + self.username = kwargs.pop("user", "") + self.fullname = kwargs.pop("fullname", "") + self.avatar_url = kwargs.pop("avatarUrl", "") + self.is_following = kwargs.pop("isFollowing", None) + self.is_pro = kwargs.pop("isPro", None) + self.details = kwargs.pop("details", None) + self.num_models = kwargs.pop("numModels", None) + self.num_datasets = kwargs.pop("numDatasets", None) + self.num_spaces = kwargs.pop("numSpaces", None) + self.num_discussions = kwargs.pop("numDiscussions", None) + self.num_papers = kwargs.pop("numPapers", None) + self.num_upvotes = kwargs.pop("numUpvotes", None) + self.num_likes = kwargs.pop("numLikes", None) + self.num_following = kwargs.pop("numFollowing", None) + self.num_followers = kwargs.pop("numFollowers", None) + self.user_type = kwargs.pop("type", None) + self.orgs = [Organization(**org) for org in kwargs.pop("orgs", [])] + + # forward compatibility + self.__dict__.update(**kwargs) + + +@dataclass +class PaperAuthor: + """ + Contains information about a paper author on the Hub. + + Attributes: + name (`str`): + Name of the author. + user (`User`, *optional*): + Information about the author as a [`User`] object. + status (`str`, *optional*): + Status of the author on the Hub. + status_last_changed_at (`datetime`, *optional*): + Date when the status of the author changed. + hidden (`bool`, *optional*): + Whether the author is hidden on the Hub. + """ + + name: str + user: Optional[User] + status: Optional[str] + status_last_changed_at: Optional[datetime] + hidden: Optional[bool] + + def __init__(self, **kwargs) -> None: + self.name = kwargs.pop("name", "") + user = kwargs.pop("user", None) + self.user = User(**user) if user else None + self.status = kwargs.pop("status", None) + status_last_changed_at = kwargs.pop("statusLastChangedAt", None) + self.status_last_changed_at = parse_datetime(status_last_changed_at) if status_last_changed_at else None + self.hidden = kwargs.pop("hidden", None) + + self.__dict__.update(**kwargs) + + +@dataclass +class PaperInfo: + """ + Contains information about a paper on the Hub. + + Attributes: + id (`str`): + arXiv paper ID. + authors (`list[PaperAuthor]`, *optional*): + Authors of the paper. + published_at (`datetime`, *optional*): + Date paper published. + title (`str`, *optional*): + Title of the paper. + summary (`str`, *optional*): + Summary of the paper. + upvotes (`int`, *optional*): + Number of upvotes for the paper on the Hub. + discussion_id (`str`, *optional*): + Discussion ID for the paper on the Hub. + source (`str`, *optional*): + Source of the paper. + comments (`int`, *optional*): + Number of comments for the paper on the Hub. + submitted_at (`datetime`, *optional*): + Date paper appeared in daily papers on the Hub. + submitted_by (`User`, *optional*): + Information about who submitted the daily paper. + ai_summary (`str`, *optional*): + AI summary of the paper. + ai_keywords (`list[str]`, *optional*): + AI keywords of the paper. + organization (`Organization`, *optional*): + Information about the organization associated with the paper. + project_page (`str`, *optional*): + URL of the project page for the paper. + github_repo (`str`, *optional*): + URL of the GitHub repository for the paper. + github_stars (`int`, *optional*): + Number of stars of the GitHub repository for the paper. + """ + + id: str + authors: Optional[list[PaperAuthor]] + published_at: Optional[datetime] + title: Optional[str] + summary: Optional[str] + upvotes: Optional[int] + discussion_id: Optional[str] + source: Optional[str] + comments: Optional[int] + submitted_at: Optional[datetime] + submitted_by: Optional[User] + ai_summary: Optional[str] + ai_keywords: Optional[list[str]] + organization: Optional[Organization] + project_page: Optional[str] + github_repo: Optional[str] + github_stars: Optional[int] + + def __init__(self, **kwargs) -> None: + paper = kwargs.pop("paper", {}) + self.id = kwargs.pop("id", None) or paper.pop("id", None) + authors = paper.pop("authors", None) or kwargs.pop("authors", None) + self.authors = [PaperAuthor(**author) for author in authors] if authors else None + published_at = paper.pop("publishedAt", None) or kwargs.pop("publishedAt", None) + self.published_at = parse_datetime(published_at) if published_at else None + self.title = kwargs.pop("title", None) + self.source = kwargs.pop("source", None) + self.summary = paper.pop("summary", None) or kwargs.pop("summary", None) + self.upvotes = paper.pop("upvotes", None) or kwargs.pop("upvotes", None) + self.discussion_id = paper.pop("discussionId", None) or kwargs.pop("discussionId", None) + self.comments = kwargs.pop("numComments", 0) + submitted_at = kwargs.pop("publishedAt", None) or kwargs.pop("submittedOnDailyAt", None) + self.submitted_at = parse_datetime(submitted_at) if submitted_at else None + submitted_by = kwargs.pop("submittedBy", None) or kwargs.pop("submittedOnDailyBy", None) + self.submitted_by = User(**submitted_by) if submitted_by else None + self.ai_summary = kwargs.pop("ai_summary", None) + self.ai_keywords = kwargs.pop("ai_keywords", None) + organization = kwargs.pop("organization", None) + self.organization = Organization(**organization) if organization else None + self.project_page = kwargs.pop("projectPage", None) + self.github_repo = kwargs.pop("githubRepo", None) + self.github_stars = kwargs.pop("githubStars", None) + + # forward compatibility + self.__dict__.update(**kwargs) + + +@dataclass +class LFSFileInfo: + """ + Contains information about a file stored as LFS on a repo on the Hub. + + Used in the context of listing and permanently deleting LFS files from a repo to free-up space. + See [`list_lfs_files`] and [`permanently_delete_lfs_files`] for more details. + + Git LFS files are tracked using SHA-256 object IDs, rather than file paths, to optimize performance + This approach is necessary because a single object can be referenced by multiple paths across different commits, + making it impractical to search and resolve these connections. Check out [our documentation](https://huggingface.co/docs/hub/storage-limits#advanced-track-lfs-file-references) + to learn how to know which filename(s) is(are) associated with each SHA. + + Attributes: + file_oid (`str`): + SHA-256 object ID of the file. This is the identifier to pass when permanently deleting the file. + filename (`str`): + Possible filename for the LFS object. See the note above for more information. + oid (`str`): + OID of the LFS object. + pushed_at (`datetime`): + Date the LFS object was pushed to the repo. + ref (`str`, *optional*): + Ref where the LFS object has been pushed (if any). + size (`int`): + Size of the LFS object. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> lfs_files = api.list_lfs_files("username/my-cool-repo") + + # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. + # e.g. select only LFS files in the "checkpoints" folder + >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) + + # Permanently delete LFS files + >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) + ``` + """ + + file_oid: str + filename: str + oid: str + pushed_at: datetime + ref: Optional[str] + size: int + + def __init__(self, **kwargs) -> None: + self.file_oid = kwargs.pop("fileOid") + self.filename = kwargs.pop("filename") + self.oid = kwargs.pop("oid") + self.pushed_at = parse_datetime(kwargs.pop("pushedAt")) + self.ref = kwargs.pop("ref", None) + self.size = kwargs.pop("size") + + # forward compatibility + self.__dict__.update(**kwargs) + + +def future_compatible(fn: CallableT) -> CallableT: + """Wrap a method of `HfApi` to handle `run_as_future=True`. + + A method flagged as "future_compatible" will be called in a thread if `run_as_future=True` and return a + `concurrent.futures.Future` instance. Otherwise, it will be called normally and return the result. + """ + sig = inspect.signature(fn) + args_params = list(sig.parameters)[1:] # remove "self" from list + + @wraps(fn) + def _inner(self, *args, **kwargs): + # Get `run_as_future` value if provided (default to False) + if "run_as_future" in kwargs: + run_as_future = kwargs["run_as_future"] + kwargs["run_as_future"] = False # avoid recursion error + else: + run_as_future = False + for param, value in zip(args_params, args): + if param == "run_as_future": + run_as_future = value + break + + # Call the function in a thread if `run_as_future=True` + if run_as_future: + return self.run_as_future(fn, self, *args, **kwargs) + + # Otherwise, call the function normally + return fn(self, *args, **kwargs) + + _inner.is_future_compatible = True # type: ignore + return _inner # type: ignore + + +def _get_safetensors_metadata_size(size_bytes: bytes, filename: str, context_msg: str) -> int: + """ + Parse and validate safetensors metadata size from the first 8 bytes. + + This is a shared helper function used by both remote and local safetensors parsing. + + Args: + size_bytes: First 8 bytes of the safetensors file. + filename: Filename for error messages. + context_msg: Additional context for error messages. + + Returns: + The metadata size as an integer. + + Raises: + SafetensorsParsingError: If size_bytes is too short or metadata size exceeds limit. + """ + if len(size_bytes) < 8: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' ({context_msg}): file is too small to be a valid " + "safetensors file." + ) + + metadata_size = struct.unpack(" constants.SAFETENSORS_MAX_HEADER_LENGTH: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' ({context_msg}): safetensors header is too big. " + f"Maximum supported size is {constants.SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size})." + ) + + return metadata_size + + +def _parse_safetensors_header(metadata_as_bytes: bytes, filename: str, context_msg: str) -> SafetensorsFileMetadata: + """ + Parse safetensors metadata from raw header bytes. + + This is a shared helper function used by both remote and local safetensors parsing. + + Args: + metadata_as_bytes: Raw bytes of the JSON metadata header (without the 8-byte size prefix). + filename: Filename for error messages. + context_msg: Additional context for error messages (e.g., repo info or local path). + + Returns: + SafetensorsFileMetadata object. + + Raises: + SafetensorsParsingError: If the header cannot be parsed. + """ + # Parse json header + try: + metadata_as_dict = json.loads(metadata_as_bytes.decode(errors="ignore")) + except json.JSONDecodeError as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' ({context_msg}): header is not json-encoded string. " + "Please make sure this is a correctly formatted safetensors file." + ) from e + + try: + return SafetensorsFileMetadata( + metadata=metadata_as_dict.get("__metadata__", {}), + tensors={ + key: TensorInfo( + dtype=tensor["dtype"], + shape=tensor["shape"], + data_offsets=tuple(tensor["data_offsets"]), # type: ignore + ) + for key, tensor in metadata_as_dict.items() + if key != "__metadata__" + }, + ) + except (KeyError, IndexError) as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' ({context_msg}): header format not recognized. " + "Please make sure this is a correctly formatted safetensors file." + ) from e + + +class HfApi: + """ + Client to interact with the Hugging Face Hub via HTTP. + + The client is initialized with some high-level settings used in all requests + made to the Hub (HF endpoint, authentication, user agents...). Using the `HfApi` + client is preferred but not mandatory as all of its public methods are exposed + directly at the root of `huggingface_hub`. + + Args: + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. Will be added to + the user-agent header. Example: `"transformers"`. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. Will be added + to the user-agent header. Example: `"4.24.0"`. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. It will + be completed with information about the installed packages. + headers (`dict`, *optional*): + Additional headers to be sent with each request. Example: `{"X-My-Header": "value"}`. + Headers passed here are taking precedence over the default headers. + """ + + def __init__( + self, + endpoint: Optional[str] = None, + token: Union[str, bool, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[dict, str, None] = None, + headers: Optional[dict[str, str]] = None, + ) -> None: + self.endpoint = endpoint if endpoint is not None else constants.ENDPOINT + self.token = token + self.library_name = library_name + self.library_version = library_version + self.user_agent = user_agent + self.headers = headers + self._thread_pool: Optional[ThreadPoolExecutor] = None + + # /whoami-v2 is the only endpoint for which we may want to cache results + self._whoami_cache: dict[str, dict] = {} + + def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: + """ + Run a method in the background and return a Future instance. + + The main goal is to run methods without blocking the main thread (e.g. to push data during a training). + Background jobs are queued to preserve order but are not ran in parallel. If you need to speed-up your scripts + by parallelizing lots of call to the API, you must setup and use your own [ThreadPoolExecutor](https://docs.python.org/3/library/concurrent.futures.html#threadpoolexecutor). + + Note: Most-used methods like [`upload_file`], [`upload_folder`] and [`create_commit`] have a `run_as_future: bool` + argument to directly call them in the background. This is equivalent to calling `api.run_as_future(...)` on them + but less verbose. + + Args: + fn (`Callable`): + The method to run in the background. + *args, **kwargs: + Arguments with which the method will be called. + + Return: + `Future`: a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) instance to + get the result of the task. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> future = api.run_as_future(api.whoami) # instant + >>> future.done() + False + >>> future.result() # wait until complete and return result + (...) + >>> future.done() + True + ``` + """ + if self._thread_pool is None: + self._thread_pool = ThreadPoolExecutor(max_workers=1) + self._thread_pool + return self._thread_pool.submit(fn, *args, **kwargs) + + @validate_hf_hub_args + def whoami(self, token: Union[bool, str, None] = None, *, cache: bool = False) -> dict: + """ + Call HF API to know "whoami". + + If passing `cache=True`, the result will be cached for subsequent calls for the duration of the Python process. This is useful if you plan to call + `whoami` multiple times as this endpoint is heavily rate-limited for security reasons. + + Args: + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + cache (`bool`, *optional*): + Whether to cache the result of the `whoami` call for subsequent calls. + If an error occurs during the first call, it won't be cached. + Defaults to `False`. + """ + # Get the effective token using the helper function get_token + token = self.token if token is None else token + if token is False: + raise ValueError("Cannot use `token=False` with `whoami` method as it requires authentication.") + if token is True or token is None: + token = get_token() + if token is None: + raise LocalTokenNotFoundError( + "Token is required to call the /whoami-v2 endpoint, but no token found. You must provide a token or be logged in to " + "Hugging Face with `hf auth login` or `huggingface_hub.login`. See https://huggingface.co/settings/tokens." + ) + + if cache and (cached_token := self._whoami_cache.get(token)): + return cached_token + + # Call Hub + output = self._inner_whoami(token=token) + + # Cache result and return + if cache: + self._whoami_cache[token] = output + return output + + def _inner_whoami(self, token: str) -> dict: + r = get_session().get( + f"{self.endpoint}/api/whoami-v2", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(r) + except HfHubHTTPError as e: + if e.response.status_code == 401: + error_message = "Invalid user token." + # Check which token is the effective one and generate the error message accordingly + if token == _get_token_from_google_colab(): + error_message += " The token from Google Colab vault is invalid. Please update it from the UI." + elif token == _get_token_from_environment(): + error_message += ( + " The token from HF_TOKEN environment variable is invalid. " + "Note that HF_TOKEN takes precedence over `hf auth login`." + ) + elif token == _get_token_from_file(): + error_message += " The token stored is invalid. Please run `hf auth login` to update it." + raise HfHubHTTPError(error_message, response=e.response) from e + if e.response.status_code == 429: + error_message = ( + "You've hit the rate limit for the /whoami-v2 endpoint, which is intentionally strict for security reasons." + " If you're calling it often, consider caching the response with `whoami(..., cache=True)`." + ) + raise HfHubHTTPError(error_message, response=e.response) from e + raise + return r.json() + + def get_model_tags(self) -> dict: + """ + List all valid model tags as a nested namespace object + """ + path = f"{self.endpoint}/api/models-tags-by-type" + r = get_session().get(path) + hf_raise_for_status(r) + return r.json() + + def get_dataset_tags(self) -> dict: + """ + List all valid dataset tags as a nested namespace object. + """ + path = f"{self.endpoint}/api/datasets-tags-by-type" + r = get_session().get(path) + hf_raise_for_status(r) + return r.json() + + @_deprecate_arguments(version="1.5", deprecated_args=["direction"], custom_message="Sorting is always descending.") + @validate_hf_hub_args + def list_models( + self, + *, + # Search-query parameter + filter: Union[str, Iterable[str], None] = None, + author: Optional[str] = None, + apps: Optional[Union[str, list[str]]] = None, + gated: Optional[bool] = None, + inference: Optional[Literal["warm"]] = None, + inference_provider: Optional[Union[Literal["all"], "PROVIDER_T", list["PROVIDER_T"]]] = None, + model_name: Optional[str] = None, + trained_dataset: Optional[Union[str, list[str]]] = None, + search: Optional[str] = None, + pipeline_tag: Optional[str] = None, + emissions_thresholds: Optional[tuple[float, float]] = None, + # Sorting and pagination parameters + sort: Optional[ModelSort_T] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + # Additional data to fetch + expand: Optional[list[ExpandModelProperty_T]] = None, + full: Optional[bool] = None, + cardData: bool = False, + fetch_config: bool = False, + token: Union[bool, str, None] = None, + ) -> Iterable[ModelInfo]: + """ + List models hosted on the Huggingface Hub, given some filters. + + Args: + filter (`str` or `Iterable[str]`, *optional*): + A string or list of string to filter models on the Hub. + Models can be filtered by library, language, task, tags, and more. + author (`str`, *optional*): + A string which identify the author (user or organization) of the + returned models. + apps (`str` or `List`, *optional*): + A string or list of strings to filter models on the Hub that + support the specified apps. Example values include `"ollama"` or `["ollama", "vllm"]`. + gated (`bool`, *optional*): + A boolean to filter models on the Hub that are gated or not. By default, all models are returned. + If `gated=True` is passed, only gated models are returned. + If `gated=False` is passed, only non-gated models are returned. + inference (`Literal["warm"]`, *optional*): + If "warm", filter models on the Hub currently served by at least one provider. + inference_provider (`Literal["all"]` or `str`, *optional*): + A string to filter models on the Hub that are served by a specific provider. + Pass `"all"` to get all models served by at least one provider. + model_name (`str`, *optional*): + A string that contain complete or partial names for models on the + Hub, such as "bert" or "bert-base-cased" + trained_dataset (`str` or `List`, *optional*): + A string tag or a list of string tags of the trained dataset for a + model on the Hub. + search (`str`, *optional*): + A string that will be contained in the returned model ids. + pipeline_tag (`str`, *optional*): + A string pipeline tag to filter models on the Hub by, such as `summarization`. + emissions_thresholds (`Tuple`, *optional*): + A tuple of two ints or floats representing a minimum and maximum + carbon footprint to filter the resulting models with in grams. + sort (`ModelSort_T`, *optional*): + The key with which to sort the resulting models. Possible values are "created_at", "downloads", + "last_modified", "likes" and "trending_score". + direction (`Literal[-1]` or `int`, *optional*): + Deprecated. This parameter is not used and will be removed in version 1.5. + limit (`int`, *optional*): + The limit on the number of models fetched. Leaving this option + to `None` fetches all models. + expand (`list[ExpandModelProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `full`, `cardData` or `fetch_config` are passed. + Possible values are `"author"`, `"cardData"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"evalResults"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, and `"resourceGroup"`. + full (`bool`, *optional*): + Whether to fetch all model data, including the `last_modified`, + the `sha`, the files and the `tags`. This is set to `True` by + default when using a filter. + cardData (`bool`, *optional*): + Whether to grab the metadata for the model as well. Can contain + useful information such as carbon emissions, metrics, and + datasets trained on. + fetch_config (`bool`, *optional*): + Whether to fetch the model configs as well. This is not included + in `full` due to its size. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + + Returns: + `Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects. + + Example: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all models + >>> api.list_models() + + # List text classification models + >>> api.list_models(filter="text-classification") + + # List models from the KerasHub library + >>> api.list_models(filter="keras-hub") + + # List models served by Cohere + >>> api.list_models(inference_provider="cohere") + + # List models with "bert" in their name + >>> api.list_models(search="bert") + + # List models with "bert" in their name and pushed by google + >>> api.list_models(search="bert", author="google") + ``` + """ + if expand and (full or cardData or fetch_config): + raise ValueError("`expand` cannot be used if `full`, `cardData` or `fetch_config` are passed.") + + if emissions_thresholds is not None and not cardData: + raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.") + + path = f"{self.endpoint}/api/models" + headers = self._build_hf_headers(token=token) + params: dict[str, Any] = {} + + # Build the filter list + filter_list: list[str] = [] + if filter: + filter_list.extend([filter] if isinstance(filter, str) else filter) + if trained_dataset: + datasets = [trained_dataset] if isinstance(trained_dataset, str) else trained_dataset + filter_list.extend(f"dataset:{d}" if not d.startswith("dataset:") else d for d in datasets) + if len(filter_list) > 0: + params["filter"] = filter_list + + # Handle other query params + if author: + params["author"] = author + if apps: + if isinstance(apps, str): + apps = [apps] + params["apps"] = apps + if gated is not None: + params["gated"] = gated + if inference is not None: + params["inference"] = inference + if inference_provider is not None: + params["inference_provider"] = inference_provider + if pipeline_tag: + params["pipeline_tag"] = pipeline_tag + search_list = [] + if model_name: + search_list.append(model_name) + if search: + search_list.append(search) + if len(search_list) > 0: + params["search"] = search_list + if sort is not None: + params["sort"] = ( + "lastModified" + if sort == "last_modified" + else "trendingScore" + if sort == "trending_score" + else "createdAt" + if sort == "created_at" + else sort + ) + if direction is not None: + params["direction"] = direction + if limit is not None: + params["limit"] = limit + + # Request additional data + if full: + params["full"] = True + if fetch_config: + params["config"] = True + if cardData: + params["cardData"] = True + if expand: + params["expand"] = expand + + # `items` is a generator + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + model_info = ModelInfo(**item) + if emissions_thresholds is None or _is_emission_within_threshold(model_info, *emissions_thresholds): + yield model_info + + @_deprecate_arguments(version="1.5", deprecated_args=["direction"], custom_message="Sorting is always descending.") + @validate_hf_hub_args + def list_datasets( + self, + *, + # Search-query parameter + filter: Union[str, Iterable[str], None] = None, + author: Optional[str] = None, + benchmark: Optional[Union[Literal[True], Literal["official"], str]] = None, + dataset_name: Optional[str] = None, + gated: Optional[bool] = None, + language_creators: Optional[Union[str, list[str]]] = None, + language: Optional[Union[str, list[str]]] = None, + multilinguality: Optional[Union[str, list[str]]] = None, + size_categories: Optional[Union[str, list[str]]] = None, + task_categories: Optional[Union[str, list[str]]] = None, + task_ids: Optional[Union[str, list[str]]] = None, + search: Optional[str] = None, + # Sorting and pagination parameters + sort: Optional[DatasetSort_T] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + # Additional data to fetch + expand: Optional[list[ExpandDatasetProperty_T]] = None, + full: Optional[bool] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[DatasetInfo]: + """ + List datasets hosted on the Huggingface Hub, given some filters. + + Args: + filter (`str` or `Iterable[str]`, *optional*): + A string or list of string to filter datasets on the hub. + author (`str`, *optional*): + A string which identify the author of the returned datasets. + benchmark (`True`, `"official"`, `str`, *optional*): + Filter datasets by benchmark. Can be `True` or `"official"` to return official benchmark datasets. + For future-compatibility, can also be a string representing the benchmark name (currently only "official" is supported). + dataset_name (`str`, *optional*): + A string or list of strings that can be used to identify datasets on + the Hub by its name, such as `SQAC` or `wikineural` + gated (`bool`, *optional*): + A boolean to filter datasets on the Hub that are gated or not. By default, all datasets are returned. + If `gated=True` is passed, only gated datasets are returned. + If `gated=False` is passed, only non-gated datasets are returned. + language_creators (`str` or `List`, *optional*): + A string or list of strings that can be used to identify datasets on + the Hub with how the data was curated, such as `crowdsourced` or + `machine_generated`. + language (`str` or `List`, *optional*): + A string or list of strings representing a two-character language to + filter datasets by on the Hub. + multilinguality (`str` or `List`, *optional*): + A string or list of strings representing a filter for datasets that + contain multiple languages. + size_categories (`str` or `List`, *optional*): + A string or list of strings that can be used to identify datasets on + the Hub by the size of the dataset such as `100K>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all datasets + >>> api.list_datasets() + + + # List only the text classification datasets + >>> api.list_datasets(filter="task_categories:text-classification") + + + # List only the datasets in russian for language modeling + >>> api.list_datasets( + ... filter=("language:ru", "task_ids:language-modeling") + ... ) + + # List FiftyOne datasets (identified by the tag "fiftyone" in dataset card) + >>> api.list_datasets(tags="fiftyone") + ``` + + Example usage with the `search` argument: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all datasets with "text" in their name + >>> api.list_datasets(search="text") + + # List all datasets with "text" in their name made by google + >>> api.list_datasets(search="text", author="google") + ``` + """ + if expand and full: + raise ValueError("`expand` cannot be used if `full` is passed.") + + path = f"{self.endpoint}/api/datasets" + headers = self._build_hf_headers(token=token) + params: dict[str, Any] = {} + + # Build `filter` list + filter_list = [] + if filter is not None: + if isinstance(filter, str): + filter_list.append(filter) + else: + filter_list.extend(filter) + for key, value in ( + ("language_creators", language_creators), + ("language", language), + ("multilinguality", multilinguality), + ("size_categories", size_categories), + ("task_categories", task_categories), + ("task_ids", task_ids), + ): + if value: + if isinstance(value, str): + value = [value] + for value_item in value: + if not value_item.startswith(f"{key}:"): + data = f"{key}:{value_item}" + else: + data = value_item + filter_list.append(data) + if benchmark is not None: + if benchmark is True: # alias for official benchmark + benchmark = "official" + filter_list.append(f"benchmark:{benchmark}") + if len(filter_list) > 0: + params["filter"] = filter_list + + # Handle other query params + if author: + params["author"] = author + if gated is not None: + params["gated"] = gated + search_list = [] + if dataset_name: + search_list.append(dataset_name) + if search: + search_list.append(search) + if len(search_list) > 0: + params["search"] = search_list + if sort is not None: + params["sort"] = ( + "lastModified" + if sort == "last_modified" + else "trendingScore" + if sort == "trending_score" + else "createdAt" + if sort == "created_at" + else sort + ) + if direction is not None: + params["direction"] = direction + if limit is not None: + params["limit"] = limit + + # Request additional data + if expand: + params["expand"] = expand + if full: + params["full"] = True + + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + yield DatasetInfo(**item) + + @_deprecate_arguments(version="1.5", deprecated_args=["direction"], custom_message="Sorting is always descending.") + @validate_hf_hub_args + def list_spaces( + self, + *, + # Search-query parameter + filter: Union[str, Iterable[str], None] = None, + author: Optional[str] = None, + search: Optional[str] = None, + datasets: Union[str, Iterable[str], None] = None, + models: Union[str, Iterable[str], None] = None, + linked: bool = False, + # Sorting and pagination parameters + sort: Optional[SpaceSort_T] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + # Additional data to fetch + expand: Optional[list[ExpandSpaceProperty_T]] = None, + full: Optional[bool] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[SpaceInfo]: + """ + List spaces hosted on the Huggingface Hub, given some filters. + + Args: + filter (`str` or `Iterable`, *optional*): + A string tag or list of tags that can be used to identify Spaces on the Hub. + author (`str`, *optional*): + A string which identify the author of the returned Spaces. + search (`str`, *optional*): + A string that will be contained in the returned Spaces. + datasets (`str` or `Iterable`, *optional*): + Whether to return Spaces that make use of a dataset. + The name of a specific dataset can be passed as a string. + models (`str` or `Iterable`, *optional*): + Whether to return Spaces that make use of a model. + The name of a specific model can be passed as a string. + linked (`bool`, *optional*): + Whether to return Spaces that make use of either a model or a dataset. + sort (`SpaceSort_T`, *optional*): + The key with which to sort the resulting spaces. Possible values are "created_at", "last_modified", + "likes" and "trending_score". + direction (`Literal[-1]` or `int`, *optional*): + Deprecated. This parameter is not used and will be removed in version 1.5. + limit (`int`, *optional*): + The limit on the number of Spaces fetched. Leaving this option + to `None` fetches all Spaces. + expand (`list[ExpandSpaceProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `full` is passed. + Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, and `"resourceGroup"`. + full (`bool`, *optional*): + Whether to fetch all Spaces data, including the `last_modified`, `siblings` + and `card_data` fields. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[SpaceInfo]`: an iterable of [`huggingface_hub.hf_api.SpaceInfo`] objects. + """ + if expand and full: + raise ValueError("`expand` cannot be used if `full` is passed.") + + path = f"{self.endpoint}/api/spaces" + headers = self._build_hf_headers(token=token) + params: dict[str, Any] = {} + if filter is not None: + params["filter"] = filter + if author is not None: + params["author"] = author + if search is not None: + params["search"] = search + if sort is not None: + params["sort"] = ( + "lastModified" + if sort == "last_modified" + else "trendingScore" + if sort == "trending_score" + else "createdAt" + if sort == "created_at" + else sort + ) + if direction is not None: + params["direction"] = direction + if limit is not None: + params["limit"] = limit + if linked: + params["linked"] = True + if datasets is not None: + params["datasets"] = datasets + if models is not None: + params["models"] = models + + # Request additional data + if expand: + params["expand"] = expand + if full: + params["full"] = True + + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + yield SpaceInfo(**item) + + @validate_hf_hub_args + def unlike( + self, + repo_id: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Unlike a given repo on the Hub (e.g. remove from favorite list). + + To prevent spam usage, it is not possible to `like` a repository from a script. + + See also [`list_liked_repos`]. + + Args: + repo_id (`str`): + The repository to unlike. Example: `"user/my-cool-model"`. + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if unliking a dataset or space, `None` or + `"model"` if unliking a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + + Example: + ```python + >>> from huggingface_hub import list_liked_repos, unlike + >>> "gpt2" in list_liked_repos().models # we assume you have already liked gpt2 + True + >>> unlike("gpt2") + >>> "gpt2" in list_liked_repos().models + False + ``` + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + response = get_session().delete( + url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def list_liked_repos( + self, + user: Optional[str] = None, + *, + token: Union[bool, str, None] = None, + ) -> UserLikes: + """ + List all public repos liked by a user on huggingface.co. + + This list is public so token is optional. If `user` is not passed, it defaults to + the logged in user. + + See also [`unlike`]. + + Args: + user (`str`, *optional*): + Name of the user for which you want to fetch the likes. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`UserLikes`]: object containing the user name and 3 lists of repo ids (1 for + models, 1 for datasets and 1 for Spaces). + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `user` is not passed and no token found (either from argument or from machine). + + Example: + ```python + >>> from huggingface_hub import list_liked_repos + + >>> likes = list_liked_repos("julien-c") + + >>> likes.user + "julien-c" + + >>> likes.models + ["osanseviero/streamlit_1.15", "Xhaheen/ChatGPT_HF", ...] + ``` + """ + # User is either provided explicitly or retrieved from current token. + if user is None: + me = self.whoami(token=token) + if me["type"] == "user": + user = me["name"] + else: + raise ValueError( + "Cannot list liked repos. You must provide a 'user' as input or be logged in as a user." + ) + + path = f"{self.endpoint}/api/users/{user}/likes" + headers = self._build_hf_headers(token=token) + + likes = list(paginate(path, params={}, headers=headers)) + # Looping over a list of items similar to: + # { + # 'createdAt': '2021-09-09T21:53:27.000Z', + # 'repo': { + # 'name': 'PaddlePaddle/PaddleOCR', + # 'type': 'space' + # } + # } + # Let's loop 3 times over the received list. Less efficient but more straightforward to read. + return UserLikes( + user=user, + total=len(likes), + models=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "model"], + datasets=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "dataset"], + spaces=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "space"], + ) + + @validate_hf_hub_args + def list_repo_likers( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[User]: + """ + List all users who liked a given repo on the hugging Face Hub. + + See also [`list_liked_repos`]. + + Args: + repo_id (`str`): + The repository to retrieve . Example: `"user/my-cool-model"`. + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: + `Iterable[User]`: an iterable of [`huggingface_hub.hf_api.User`] objects. + """ + + # Construct the API endpoint + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/likers" + for liker in paginate(path, params={}, headers=self._build_hf_headers(token=token)): + yield User(username=liker["user"], fullname=liker["fullname"], avatar_url=liker["avatarUrl"]) + + @validate_hf_hub_args + def model_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + securityStatus: Optional[bool] = None, + files_metadata: bool = False, + expand: Optional[list[ExpandModelProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> ModelInfo: + """ + Get info on one specific model on huggingface.co + + Model can be private if you pass an acceptable token or are logged in. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the model repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + securityStatus (`bool`, *optional*): + Whether to retrieve the security status from the model + repository as well. The security status will be returned in the `security_repo_status` field. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + expand (`list[ExpandModelProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `securityStatus` or `files_metadata` are passed. + Possible values are `"author"`, `"baseModels"`, `"cardData"`, `"childrenModelCount"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"evalResults"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, `"usedStorage"`, and `"resourceGroup"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`huggingface_hub.hf_api.ModelInfo`]: The model repository information. + + > [!TIP] + > Raises the following errors: + > + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + > - [`~utils.RevisionNotFoundError`] + > If the revision to download from cannot be found. + """ + if expand and (securityStatus or files_metadata): + raise ValueError("`expand` cannot be used if `securityStatus` or `files_metadata` are set.") + + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/models/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}") + ) + params: dict = {} + if securityStatus: + params["securityStatus"] = True + if files_metadata: + params["blobs"] = True + if expand: + params["expand"] = expand + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return ModelInfo(**data) + + @validate_hf_hub_args + def dataset_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + expand: Optional[list[ExpandDatasetProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> DatasetInfo: + """ + Get info on one specific dataset on huggingface.co. + + Dataset can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the dataset repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + expand (`list[ExpandDatasetProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `files_metadata` is passed. + Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"`, `"tags"`, `"trendingScore"`,`"usedStorage"`, and `"resourceGroup"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`hf_api.DatasetInfo`]: The dataset repository information. + + > [!TIP] + > Raises the following errors: + > + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + > - [`~utils.RevisionNotFoundError`] + > If the revision to download from cannot be found. + """ + if expand and files_metadata: + raise ValueError("`expand` cannot be used if `files_metadata` is set.") + + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/datasets/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") + ) + params: dict = {} + if files_metadata: + params["blobs"] = True + if expand: + params["expand"] = expand + + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return DatasetInfo(**data) + + @validate_hf_hub_args + def space_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + expand: Optional[list[ExpandSpaceProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> SpaceInfo: + """ + Get info on one specific Space on huggingface.co. + + Space can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the space repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + expand (`list[ExpandSpaceProperty_T]`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `full` is passed. + Possible values are `"author"`, `"cardData"`, `"createdAt"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, and `"resourceGroup"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`~hf_api.SpaceInfo`]: The space repository information. + + > [!TIP] + > Raises the following errors: + > + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + > - [`~utils.RevisionNotFoundError`] + > If the revision to download from cannot be found. + """ + if expand and files_metadata: + raise ValueError("`expand` cannot be used if `files_metadata` is set.") + + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/spaces/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") + ) + params: dict = {} + if files_metadata: + params["blobs"] = True + if expand: + params["expand"] = expand + + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return SpaceInfo(**data) + + @validate_hf_hub_args + def repo_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + expand: Optional[Union[ExpandModelProperty_T, ExpandDatasetProperty_T, ExpandSpaceProperty_T]] = None, + token: Union[bool, str, None] = None, + ) -> Union[ModelInfo, DatasetInfo, SpaceInfo]: + """ + Get the info object for a given repo of a given type. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the repository from which to get the + information. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + expand (`ExpandModelProperty_T` or `ExpandDatasetProperty_T` or `ExpandSpaceProperty_T`, *optional*): + List properties to return in the response. When used, only the properties in the list will be returned. + This parameter cannot be used if `files_metadata` is passed. + For an exhaustive list of available properties, check out [`model_info`], [`dataset_info`] or [`space_info`]. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Union[SpaceInfo, DatasetInfo, ModelInfo]`: The repository information, as a + [`huggingface_hub.hf_api.DatasetInfo`], [`huggingface_hub.hf_api.ModelInfo`] + or [`huggingface_hub.hf_api.SpaceInfo`] object. + + > [!TIP] + > Raises the following errors: + > + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + > - [`~utils.RevisionNotFoundError`] + > If the revision to download from cannot be found. + """ + if repo_type is None or repo_type == "model": + method = self.model_info + elif repo_type == "dataset": + method = self.dataset_info # type: ignore + elif repo_type == "space": + method = self.space_info # type: ignore + else: + raise ValueError("Unsupported repo type.") + return method( + repo_id, + revision=revision, + token=token, + timeout=timeout, + expand=expand, # type: ignore[arg-type] + files_metadata=files_metadata, + ) + + @validate_hf_hub_args + def repo_exists( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> bool: + """ + Checks if a repository exists on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + True if the repository exists, False otherwise. + + Examples: + ```py + >>> from huggingface_hub import repo_exists + >>> repo_exists("google/gemma-7b") + True + >>> repo_exists("google/not-a-repo") + False + ``` + """ + try: + self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) + return True + except GatedRepoError: + return True # we don't have access but it exists + except RepositoryNotFoundError: + return False + + @validate_hf_hub_args + def revision_exists( + self, + repo_id: str, + revision: str, + *, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> bool: + """ + Checks if a specific revision exists on a repo on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`): + The revision of the repository to check. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + True if the repository and the revision exists, False otherwise. + + Examples: + ```py + >>> from huggingface_hub import revision_exists + >>> revision_exists("google/gemma-7b", "float16") + True + >>> revision_exists("google/gemma-7b", "not-a-revision") + False + ``` + """ + try: + self.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) + return True + except RevisionNotFoundError: + return False + except RepositoryNotFoundError: + return False + + @validate_hf_hub_args + def file_exists( + self, + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> bool: + """ + Checks if a file exists in a repository on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + filename (`str`): + The name of the file to check, for example: + `"config.json"` + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + True if the file exists, False otherwise. + + Examples: + ```py + >>> from huggingface_hub import file_exists + >>> file_exists("bigcode/starcoder", "config.json") + True + >>> file_exists("bigcode/starcoder", "not-a-file") + False + >>> file_exists("bigcode/not-a-repo", "config.json") + False + ``` + """ + url = hf_hub_url( + repo_id=repo_id, repo_type=repo_type, revision=revision, filename=filename, endpoint=self.endpoint + ) + try: + if token is None: + token = self.token + get_hf_file_metadata(url, token=token) + return True + except GatedRepoError: # raise specifically on gated repo + raise + except (RepositoryNotFoundError, RemoteEntryNotFoundError, RevisionNotFoundError): + return False + + @validate_hf_hub_args + def list_repo_files( + self, + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> list[str]: + """ + Get the list of files in a given repo. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + revision (`str`, *optional*): + The revision of the repository from which to get the information. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to + a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `list[str]`: the list of files in a given repository. + """ + return [ + f.rfilename + for f in self.list_repo_tree( + repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type, token=token + ) + if isinstance(f, RepoFile) + ] + + @validate_hf_hub_args + def list_repo_tree( + self, + repo_id: str, + path_in_repo: Optional[str] = None, + *, + recursive: bool = False, + expand: bool = False, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> Iterable[Union[RepoFile, RepoFolder]]: + """ + List a repo tree's files and folders and get information about them. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + path_in_repo (`str`, *optional*): + Relative path of the tree (folder) in the repo, for example: + `"checkpoints/1fec34a/results"`. Will default to the root tree (folder) of the repository. + recursive (`bool`, *optional*, defaults to `False`): + Whether to list tree's files and folders recursively. + expand (`bool`, *optional*, defaults to `False`): + Whether to fetch more information about the tree's files and folders (e.g. last commit and files' security scan results). This + operation is more expensive for the server so only 50 results are returned per page (instead of 1000). + As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it + takes to get the results. + revision (`str`, *optional*): + The revision of the repository from which to get the tree. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the tree (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[Union[RepoFile, RepoFolder]]`: + The information about the tree's files and folders, as an iterable of [`RepoFile`] and [`RepoFolder`] objects. The order of the files and folders is + not guaranteed. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + [`~utils.RemoteEntryNotFoundError`]: + If the tree (folder) does not exist (error 404) on the repo. + + Examples: + + Get information about a repo's tree. + ```py + >>> from huggingface_hub import list_repo_tree + >>> repo_tree = list_repo_tree("lysandre/arxiv-nlp") + >>> repo_tree + + >>> list(repo_tree) + [ + RepoFile(path='.gitattributes', size=391, blob_id='ae8c63daedbd4206d7d40126955d4e6ab1c80f8f', lfs=None, last_commit=None, security=None), + RepoFile(path='README.md', size=391, blob_id='43bd404b159de6fba7c2f4d3264347668d43af25', lfs=None, last_commit=None, security=None), + RepoFile(path='config.json', size=554, blob_id='2f9618c3a19b9a61add74f70bfb121335aeef666', lfs=None, last_commit=None, security=None), + RepoFile( + path='flax_model.msgpack', size=497764107, blob_id='8095a62ccb4d806da7666fcda07467e2d150218e', + lfs={'size': 497764107, 'sha256': 'd88b0d6a6ff9c3f8151f9d3228f57092aaea997f09af009eefd7373a77b5abb9', 'pointer_size': 134}, last_commit=None, security=None + ), + RepoFile(path='merges.txt', size=456318, blob_id='226b0752cac7789c48f0cb3ec53eda48b7be36cc', lfs=None, last_commit=None, security=None), + RepoFile( + path='pytorch_model.bin', size=548123560, blob_id='64eaa9c526867e404b68f2c5d66fd78e27026523', + lfs={'size': 548123560, 'sha256': '9be78edb5b928eba33aa88f431551348f7466ba9f5ef3daf1d552398722a5436', 'pointer_size': 134}, last_commit=None, security=None + ), + RepoFile(path='vocab.json', size=898669, blob_id='b00361fece0387ca34b4b8b8539ed830d644dbeb', lfs=None, last_commit=None, security=None)] + ] + ``` + + Get even more information about a repo's tree (last commit and files' security scan results) + ```py + >>> from huggingface_hub import list_repo_tree + >>> repo_tree = list_repo_tree("prompthero/openjourney-v4", expand=True) + >>> list(repo_tree) + [ + RepoFolder( + path='feature_extractor', + tree_id='aa536c4ea18073388b5b0bc791057a7296a00398', + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) + } + ), + RepoFolder( + path='safety_checker', + tree_id='65aef9d787e5557373fdf714d6c34d4fcdd70440', + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) + } + ), + RepoFile( + path='model_index.json', + size=582, + blob_id='d3d7c1e8c3e78eeb1640b8e2041ee256e24c9ee1', + lfs=None, + last_commit={ + 'oid': 'b195ed2d503f3eb29637050a886d77bd81d35f0e', + 'title': 'Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`. (#54)', + 'date': datetime.datetime(2023, 5, 15, 21, 41, 59, tzinfo=datetime.timezone.utc) + }, + security={ + 'safe': True, + 'av_scan': {'virusFound': False, 'virusNames': None}, + 'pickle_import_scan': None + } + ) + ... + ] + ``` + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + encoded_path_in_repo = "/" + quote(path_in_repo, safe="") if path_in_repo else "" + tree_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tree/{revision}{encoded_path_in_repo}" + for path_info in paginate(path=tree_url, headers=headers, params={"recursive": recursive, "expand": expand}): + yield (RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info)) + + @validate_hf_hub_args + def verify_repo_checksums( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + local_dir: Optional[Union[str, Path]] = None, + cache_dir: Optional[Union[str, Path]] = None, + token: Union[str, bool, None] = None, + ) -> "FolderVerification": + """ + Verify local files for a repo against Hub checksums. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + repo_type (`str`, *optional*): + The type of the repository from which to get the tree (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + revision (`str`, *optional*): + The revision of the repository from which to get the tree. Defaults to `"main"` branch. + local_dir (`str` or `Path`, *optional*): + The local directory to verify. + cache_dir (`str` or `Path`, *optional*): + The cache directory to verify. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`FolderVerification`]: a structured result containing the verification details. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + + """ + + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + if local_dir is not None and cache_dir is not None: + raise ValueError("Pass either `local_dir` or `cache_dir`, not both.") + + root, remote_revision = resolve_local_root( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + cache_dir=Path(cache_dir) if cache_dir is not None else None, + local_dir=Path(local_dir) if local_dir is not None else None, + ) + local_by_path = collect_local_files(root) + + # get remote entries (only files, not folders) + remote_by_path: dict[str, RepoFile] = {} + for entry in self.list_repo_tree( + repo_id=repo_id, recursive=True, revision=remote_revision, repo_type=repo_type, token=token + ): + if isinstance(entry, RepoFile): + remote_by_path[entry.path] = entry + + return verify_maps( + remote_by_path=remote_by_path, + local_by_path=local_by_path, + revision=remote_revision, + verified_path=root, + ) + + @validate_hf_hub_args + def list_repo_refs( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + include_pull_requests: bool = False, + token: Union[str, bool, None] = None, + ) -> GitRefs: + """ + Get the list of refs of a given repo (both tags and branches). + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing refs from a dataset or a Space, + `None` or `"model"` if listing from a model. Default is `None`. + include_pull_requests (`bool`, *optional*): + Whether to include refs from pull requests in the list. Defaults to `False`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> api.list_repo_refs("gpt2") + GitRefs(branches=[GitRefInfo(name='main', ref='refs/heads/main', target_commit='e7da7f221d5bf496a48136c0cd264e630fe9fcc8')], converts=[], tags=[]) + + >>> api.list_repo_refs("bigcode/the-stack", repo_type='dataset') + GitRefs( + branches=[ + GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), + GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') + ], + converts=[], + tags=[ + GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') + ] + ) + ``` + + Returns: + [`GitRefs`]: object containing all information about branches and tags for a + repo on the Hub. + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + response = get_session().get( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs", + headers=self._build_hf_headers(token=token), + params={"include_prs": 1} if include_pull_requests else {}, + ) + hf_raise_for_status(response) + data = response.json() + + def _format_as_git_ref_info(item: dict) -> GitRefInfo: + return GitRefInfo(name=item["name"], ref=item["ref"], target_commit=item["targetCommit"]) + + return GitRefs( + branches=[_format_as_git_ref_info(item) for item in data["branches"]], + converts=[_format_as_git_ref_info(item) for item in data["converts"]], + tags=[_format_as_git_ref_info(item) for item in data["tags"]], + pull_requests=[_format_as_git_ref_info(item) for item in data["pullRequests"]] + if include_pull_requests + else None, + ) + + @validate_hf_hub_args + def list_repo_commits( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + revision: Optional[str] = None, + formatted: bool = False, + ) -> list[GitCommitInfo]: + """ + Get the list of commits of a given revision for a repo on the Hub. + + Commits are sorted by date (last commit first). + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if + listing from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + formatted (`bool`): + Whether to return the HTML-formatted title and description of the commits. Defaults to False. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + + # Commits are sorted by date (last commit first) + >>> initial_commit = api.list_repo_commits("gpt2")[-1] + + # Initial commit is always a system commit containing the `.gitattributes` file. + >>> initial_commit + GitCommitInfo( + commit_id='9b865efde13a30c13e0a33e536cf3e4a5a9d71d8', + authors=['system'], + created_at=datetime.datetime(2019, 2, 18, 10, 36, 15, tzinfo=datetime.timezone.utc), + title='initial commit', + message='', + formatted_title=None, + formatted_message=None + ) + + # Create an empty branch by deriving from initial commit + >>> api.create_branch("gpt2", "new_empty_branch", revision=initial_commit.commit_id) + ``` + + Returns: + list[[`GitCommitInfo`]]: list of objects containing information about the commits for a repo on the Hub. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + + # Paginate over results and return the list of commits. + return [ + GitCommitInfo( + commit_id=item["id"], + authors=[author["user"] for author in item["authors"]], + created_at=parse_datetime(item["date"]), + title=item["title"], + message=item["message"], + formatted_title=item.get("formatted", {}).get("title"), + formatted_message=item.get("formatted", {}).get("message"), + ) + for item in paginate( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/commits/{revision}", + headers=self._build_hf_headers(token=token), + params={"expand[]": "formatted"} if formatted else {}, + ) + ] + + @validate_hf_hub_args + def get_paths_info( + self, + repo_id: str, + paths: Union[list[str], str], + *, + expand: bool = False, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> list[Union[RepoFile, RepoFolder]]: + """ + Get information about a repo's paths. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + paths (`Union[list[str], str]`, *optional*): + The paths to get information about. If a path do not exist, it is ignored without raising + an exception. + expand (`bool`, *optional*, defaults to `False`): + Whether to fetch more information about the paths (e.g. last commit and files' security scan results). This + operation is more expensive for the server so only 50 results are returned per page (instead of 1000). + As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it + takes to get the results. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the information (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `list[Union[RepoFile, RepoFolder]]`: + The information about the paths, as a list of [`RepoFile`] and [`RepoFolder`] objects. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + + Example: + ```py + >>> from huggingface_hub import get_paths_info + >>> paths_info = get_paths_info("allenai/c4", ["README.md", "en"], repo_type="dataset") + >>> paths_info + [ + RepoFile(path='README.md', size=2379, blob_id='f84cb4c97182890fc1dbdeaf1a6a468fd27b4fff', lfs=None, last_commit=None, security=None), + RepoFolder(path='en', tree_id='dc943c4c40f53d02b31ced1defa7e5f438d5862e', last_commit=None) + ] + ``` + """ + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + response = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}", + data={ + "paths": paths if isinstance(paths, list) else [paths], + "expand": expand, + }, + headers=headers, + ) + hf_raise_for_status(response) + paths_info = response.json() + return [ + RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info) + for path_info in paths_info + ] + + @validate_hf_hub_args + def super_squash_history( + self, + repo_id: str, + *, + branch: Optional[str] = None, + commit_message: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ) -> None: + """Squash commit history on a branch for a repo on the Hub. + + Squashing the repo history is useful when you know you'll make hundreds of commits and you don't want to + clutter the history. Squashing commits can only be performed from the head of a branch. + + > [!WARNING] + > Once squashed, the commit history cannot be retrieved. This is a non-revertible operation. + + > [!WARNING] + > Once the history of a branch has been squashed, it is not possible to merge it back into another branch since + > their history will have diverged. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + branch (`str`, *optional*): + The branch to squash. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The commit message to use for the squashed commit. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if + listing from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If the branch to squash cannot be found. + [`~utils.BadRequestError`]: + If invalid reference for a branch. You cannot squash history on tags. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + + # Create repo + >>> repo_id = api.create_repo("test-squash").repo_id + + # Make a lot of commits. + >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"content") + >>> api.upload_file(repo_id=repo_id, path_in_repo="lfs.bin", path_or_fileobj=b"content") + >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"another_content") + + # Squash history + >>> api.super_squash_history(repo_id=repo_id) + ``` + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + if branch is None: + branch = constants.DEFAULT_REVISION + + # Prepare request + url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/super-squash/{quote(branch, safe='')}" + headers = self._build_hf_headers(token=token) + commit_message = commit_message or f"Super-squash branch '{branch}' using huggingface_hub" + + # Super-squash + response = get_session().post(url=url, headers=headers, json={"message": commit_message}) + hf_raise_for_status(response) + + @validate_hf_hub_args + def list_lfs_files( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[LFSFileInfo]: + """ + List all LFS files in a repo on the Hub. + + This is primarily useful to count how much storage a repo is using and to eventually clean up large files + with [`permanently_delete_lfs_files`]. Note that this would be a permanent action that will affect all commits + referencing this deleted files and that cannot be undone. + + Args: + repo_id (`str`): + The repository for which you are listing LFS files. + repo_type (`str`, *optional*): + Type of repository. Set to `"dataset"` or `"space"` if listing from a dataset or space, `None` or + `"model"` if listing from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[LFSFileInfo]`: An iterator of [`LFSFileInfo`] objects. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> lfs_files = api.list_lfs_files("username/my-cool-repo") + + # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. + # e.g. select only LFS files in the "checkpoints" folder + >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) + + # Permanently delete LFS files + >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) + ``` + """ + # Prepare request + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/lfs-files" + headers = self._build_hf_headers(token=token) + + # Paginate over LFS items + for item in paginate(url, params={}, headers=headers): + yield LFSFileInfo(**item) + + @validate_hf_hub_args + def permanently_delete_lfs_files( + self, + repo_id: str, + lfs_files: Iterable[LFSFileInfo], + *, + rewrite_history: bool = True, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Permanently delete LFS files from a repo on the Hub. + + > [!WARNING] + > This is a permanent action that will affect all commits referencing the deleted files and might corrupt your + > repository. This is a non-revertible operation. Use it only if you know what you are doing. + + Args: + repo_id (`str`): + The repository for which you are listing LFS files. + lfs_files (`Iterable[LFSFileInfo]`): + An iterable of [`LFSFileInfo`] items to permanently delete from the repo. Use [`list_lfs_files`] to list + all LFS files from a repo. + rewrite_history (`bool`, *optional*, default to `True`): + Whether to rewrite repository history to remove file pointers referencing the deleted LFS files (recommended). + repo_type (`str`, *optional*): + Type of repository. Set to `"dataset"` or `"space"` if listing from a dataset or space, `None` or + `"model"` if listing from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> lfs_files = api.list_lfs_files("username/my-cool-repo") + + # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. + # e.g. select only LFS files in the "checkpoints" folder + >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) + + # Permanently delete LFS files + >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) + ``` + """ + # Prepare request + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/lfs-files/batch" + headers = self._build_hf_headers(token=token) + + # Delete LFS items by batches of 1000 + for batch in chunk_iterable(lfs_files, 1000): + shas = [item.file_oid for item in batch] + if len(shas) == 0: + return + payload = { + "deletions": { + "sha": shas, + "rewriteHistory": rewrite_history, + } + } + response = get_session().post(url, headers=headers, json=payload) + hf_raise_for_status(response) + + @validate_hf_hub_args + def create_repo( + self, + repo_id: str, + *, + token: Union[str, bool, None] = None, + private: Optional[bool] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + resource_group_id: Optional[str] = None, + space_sdk: Optional[str] = None, + space_hardware: Optional[SpaceHardware] = None, + space_storage: Optional[SpaceStorage] = None, + space_sleep_time: Optional[int] = None, + space_secrets: Optional[list[dict[str, str]]] = None, + space_variables: Optional[list[dict[str, str]]] = None, + ) -> RepoUrl: + """Create an empty repo on the HuggingFace Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo already exists. + resource_group_id (`str`, *optional*): + Resource group in which to create the repo. Resource groups is only available for Enterprise Hub organizations and + allow to define which members of the organization can access the resource. The ID of a resource group + can be found in the URL of the resource's page on the Hub (e.g. `"66670e5163145ca562cb1988"`). + To learn more about resource groups, see https://huggingface.co/docs/hub/en/security-resource-groups. + space_sdk (`str`, *optional*): + Choice of SDK to use if repo_type is "space". Can be "streamlit", "gradio", "docker", or "static". + space_hardware (`SpaceHardware` or `str`, *optional*): + Choice of Hardware if repo_type is "space". See [`SpaceHardware`] for a complete list. + space_storage (`SpaceStorage` or `str`, *optional*): + Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. + space_sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + space_secrets (`list[dict[str, str]]`, *optional*): + A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + space_variables (`list[dict[str, str]]`, *optional*): + A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. + + Returns: + [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing + attributes like `endpoint`, `repo_type` and `repo_id`. + """ + organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) + + path = f"{self.endpoint}/api/repos/create" + + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + + json: dict[str, Any] = {"name": name, "organization": organization} + if private is not None: + json["private"] = private + if repo_type is not None: + json["type"] = repo_type + if repo_type == "space": + if space_sdk is None: + raise ValueError( + "No space_sdk provided. `create_repo` expects space_sdk to be one" + f" of {constants.SPACES_SDK_TYPES} when repo_type is 'space'`" + ) + if space_sdk not in constants.SPACES_SDK_TYPES: + raise ValueError(f"Invalid space_sdk. Please choose one of {constants.SPACES_SDK_TYPES}.") + json["sdk"] = space_sdk + + if space_sdk is not None and repo_type != "space": + warnings.warn("Ignoring provided space_sdk because repo_type is not 'space'.") + + function_args = [ + "space_hardware", + "space_storage", + "space_sleep_time", + "space_secrets", + "space_variables", + ] + json_keys = ["hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] + values = [space_hardware, space_storage, space_sleep_time, space_secrets, space_variables] + + if repo_type == "space": + json.update({k: v for k, v in zip(json_keys, values) if v is not None}) + else: + provided_space_args = [key for key, value in zip(function_args, values) if value is not None] + + if provided_space_args: + warnings.warn(f"Ignoring provided {', '.join(provided_space_args)} because repo_type is not 'space'.") + + if resource_group_id is not None: + json["resourceGroupId"] = resource_group_id + + headers = self._build_hf_headers(token=token) + while True: + r = get_session().post(path, headers=headers, json=json) + if r.status_code == 409 and "Cannot create repo: another conflicting operation is in progress" in r.text: + # Since https://github.com/huggingface/moon-landing/pull/7272 (private repo), it is not possible to + # concurrently create repos on the Hub for a same user. This is rarely an issue, except when running + # tests. To avoid any inconvenience, we retry to create the repo for this specific error. + # NOTE: This could have being fixed directly in the tests but adding it here should fixed CIs for all + # dependent libraries. + # NOTE: If a fix is implemented server-side, we should be able to remove this retry mechanism. + logger.debug("Create repo failed due to a concurrency issue. Retrying...") + continue + break + + try: + hf_raise_for_status(r) + except HfHubHTTPError as err: + if exist_ok and err.response.status_code == 409: + # Repo already exists and `exist_ok=True` + pass + elif exist_ok and err.response.status_code == 403: + # No write permission on the namespace but repo might already exist + try: + self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) + if repo_type is None or repo_type == constants.REPO_TYPE_MODEL: + return RepoUrl(f"{self.endpoint}/{repo_id}") + return RepoUrl(f"{self.endpoint}/{constants.REPO_TYPES_URL_PREFIXES[repo_type]}{repo_id}") + except HfHubHTTPError: + raise err + else: + raise + + d = r.json() + return RepoUrl(d["url"], endpoint=self.endpoint) + + @validate_hf_hub_args + def delete_repo( + self, + repo_id: str, + *, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + missing_ok: bool = False, + ) -> None: + """ + Delete a repo from the HuggingFace Hub. CAUTION: this is irreversible. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. + missing_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo does not exist. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to delete from cannot be found and `missing_ok` is set to False (default). + """ + organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) + + path = f"{self.endpoint}/api/repos/delete" + + if repo_type not in constants.REPO_TYPES: + raise ValueError("Invalid repo type") + + json = {"name": name, "organization": organization} + if repo_type is not None: + json["type"] = repo_type + + headers = self._build_hf_headers(token=token) + r = get_session().request("DELETE", path, headers=headers, json=json) + try: + hf_raise_for_status(r) + except RepositoryNotFoundError: + if not missing_ok: + raise + + @validate_hf_hub_args + def update_repo_settings( + self, + repo_id: str, + *, + gated: Optional[Literal["auto", "manual", False]] = None, + private: Optional[bool] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Update the settings of a repository, including gated access and visibility. + + To give more control over how repos are used, the Hub allows repo authors to enable + access requests for their repos, and also to set the visibility of the repo to private. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a /. + gated (`Literal["auto", "manual", False]`, *optional*): + The gated status for the repository. If set to `None` (default), the `gated` setting of the repository won't be updated. + * "auto": The repository is gated, and access requests are automatically approved or denied based on predefined criteria. + * "manual": The repository is gated, and access requests require manual approval. + * False : The repository is not gated, and anyone can access it. + private (`bool`, *optional*): + Whether the repository should be private. + token (`Union[str, bool, None]`, *optional*): + A valid user access token (string). Defaults to the locally saved token, + which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass False. + repo_type (`str`, *optional*): + The type of the repository to update settings from (`"model"`, `"dataset"` or `"space"`). + Defaults to `"model"`. + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If gated is not one of "auto", "manual", or False. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If repo_type is not one of the values in constants.REPO_TYPES. + [`~utils.HfHubHTTPError`]: + If the request to the Hugging Face Hub API fails. + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + """ + + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL # default repo type + + # Prepare the JSON payload for the PUT request + payload: dict = {} + + if gated is not None: + if gated not in ["auto", "manual", False]: + raise ValueError(f"Invalid gated status, must be one of 'auto', 'manual', or False. Got '{gated}'.") + payload["gated"] = gated + + if private is not None: + payload["private"] = private + + if len(payload) == 0: + raise ValueError("At least one setting must be updated.") + + # Build headers + headers = self._build_hf_headers(token=token) + + r = get_session().put( + url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", + headers=headers, + json=payload, + ) + hf_raise_for_status(r) + + def move_repo( + self, + from_id: str, + to_id: str, + *, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, + ): + """ + Moving a repository from namespace1/repo_name1 to namespace2/repo_name2 + + Note there are certain limitations. For more information about moving + repositories, please see + https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo. + + Args: + from_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. Original repository identifier. + to_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. Final repository identifier. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + > [!TIP] + > Raises the following errors: + > + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + if len(from_id.split("/")) != 2: + raise ValueError(f"Invalid repo_id: {from_id}. It should have a namespace (:namespace:/:repo_name:)") + + if len(to_id.split("/")) != 2: + raise ValueError(f"Invalid repo_id: {to_id}. It should have a namespace (:namespace:/:repo_name:)") + + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL # Hub won't accept `None`. + + json = {"fromRepo": from_id, "toRepo": to_id, "type": repo_type} + + path = f"{self.endpoint}/api/repos/move" + headers = self._build_hf_headers(token=token) + r = get_session().post(path, headers=headers, json=json) + try: + hf_raise_for_status(r) + except HfHubHTTPError as e: + e.append_to_message( + "\nFor additional documentation please see" + " https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo." + ) + raise + + @overload + def create_commit( # type: ignore + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: ... + + @overload + def create_commit( + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: ... + + @validate_hf_hub_args + @future_compatible + def create_commit( + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Creates a commit in the given repo, deleting & uploading files as needed. + + > [!WARNING] + > The input list of `CommitOperation` will be mutated during the commit process. Do not reuse the same objects + > for multiple commits. + + > [!WARNING] + > `create_commit` assumes that the repo already exists on the Hub. If you get a + > Client error 404, please make sure you are authenticated and that `repo_id` and + > `repo_type` are set correctly. If repo does not exist, create it first using + > [`~hf_api.create_repo`]. + + > [!WARNING] + > `create_commit` is limited to 25k LFS files and a 1GB payload for regular files. + + Args: + repo_id (`str`): + The repository in which the commit will be created, for example: + `"username/custom_transformers"` + + operations (`Iterable` of [`~hf_api.CommitOperation`]): + An iterable of operations to include in the commit, either: + + - [`~hf_api.CommitOperationAdd`] to upload a file + - [`~hf_api.CommitOperationDelete`] to delete a file + - [`~hf_api.CommitOperationCopy`] to copy a file + + Operation objects will be mutated to include information relative to the upload. Do not reuse the + same objects for multiple commits. + + commit_message (`str`): + The summary (first line) of the commit that will be created. + + commit_description (`str`, *optional*): + The description of the commit that will be created + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + + num_threads (`int`, *optional*): + Number of concurrent threads for uploading files. Defaults to 5. + Setting it to 2 means at most 2 files will be uploaded concurrently. + + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. + Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, + the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` + is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` + ensures the repo has not changed before committing the changes, and can be especially useful + if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If commit message is empty. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If parent commit is not a valid commit OID. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If a README.md file with an invalid metadata section is committed. In this case, the commit will fail + early, before trying to upload any file. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `create_pr` is `True` and revision is neither `None` nor `"main"`. + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + """ + if parent_commit is not None and not constants.REGEX_COMMIT_OID.fullmatch(parent_commit): + raise ValueError( + f"`parent_commit` is not a valid commit OID. It must match the following regex: {constants.REGEX_COMMIT_OID}" + ) + + if commit_message is None or len(commit_message) == 0: + raise ValueError("`commit_message` can't be empty, please pass a value.") + + commit_description = commit_description if commit_description is not None else "" + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + unquoted_revision = revision or constants.DEFAULT_REVISION + revision = quote(unquoted_revision, safe="") + create_pr = create_pr if create_pr is not None else False + + headers = self._build_hf_headers(token=token) + + operations = list(operations) + additions = [op for op in operations if isinstance(op, CommitOperationAdd)] + copies = [op for op in operations if isinstance(op, CommitOperationCopy)] + nb_additions = len(additions) + nb_copies = len(copies) + nb_deletions = len(operations) - nb_additions - nb_copies + + for addition in additions: + if addition._is_committed: + raise ValueError( + f"CommitOperationAdd {addition} has already being committed and cannot be reused. Please create a" + " new CommitOperationAdd object if you want to create a new commit." + ) + + if repo_type != "dataset": + for addition in additions: + if addition.path_in_repo.endswith((".arrow", ".parquet")): + warnings.warn( + f"It seems that you are about to commit a data file ({addition.path_in_repo}) to a {repo_type}" + " repository. You are sure this is intended? If you are trying to upload a dataset, please" + " set `repo_type='dataset'` or `--repo-type=dataset` in a CLI." + ) + + logger.debug( + f"About to commit to the hub: {len(additions)} addition(s), {len(copies)} copie(s) and" + f" {nb_deletions} deletion(s)." + ) + + # If updating a README.md file, make sure the metadata format is valid + # It's better to fail early than to fail after all the files have been uploaded. + for addition in additions: + if addition.path_in_repo == "README.md": + with addition.as_file() as file: + content = file.read().decode() + self._validate_yaml(content, repo_type=repo_type, token=token) + # Skip other additions after `README.md` has been processed + break + + # If updating twice the same file or update then delete a file in a single commit + _warn_on_overwriting_operations(operations) + + self.preupload_lfs_files( + repo_id=repo_id, + additions=additions, + token=token, + repo_type=repo_type, + revision=unquoted_revision, # first-class methods take unquoted revision + create_pr=create_pr, + num_threads=num_threads, + free_memory=False, # do not remove `CommitOperationAdd.path_or_fileobj` on LFS files for "normal" users + ) + + files_to_copy = _fetch_files_to_copy( + copies=copies, + repo_type=repo_type, + repo_id=repo_id, + headers=headers, + revision=unquoted_revision, + endpoint=self.endpoint, + ) + # Remove no-op operations (files that have not changed) + operations_without_no_op = [] + for operation in operations: + if ( + isinstance(operation, CommitOperationAdd) + and operation._remote_oid is not None + and operation._remote_oid == operation._local_oid + ): + # File already exists on the Hub and has not changed: we can skip it. + logger.debug(f"Skipping upload for '{operation.path_in_repo}' as the file has not changed.") + continue + if ( + isinstance(operation, CommitOperationCopy) + and operation._dest_oid is not None + and operation._dest_oid == operation._src_oid + ): + # Source and destination files are identical - skip + logger.debug( + f"Skipping copy for '{operation.src_path_in_repo}' -> '{operation.path_in_repo}' as the content of the source file is the same as the destination file." + ) + continue + operations_without_no_op.append(operation) + if len(operations) != len(operations_without_no_op): + logger.info( + f"Removing {len(operations) - len(operations_without_no_op)} file(s) from commit that have not changed." + ) + + # Return early if empty commit + if len(operations_without_no_op) == 0: + logger.warning("No files have been modified since last commit. Skipping to prevent empty commit.") + + # Get latest commit info + try: + info = self.repo_info(repo_id=repo_id, repo_type=repo_type, revision=unquoted_revision, token=token) + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + + # Return commit info based on latest commit + url_prefix = self.endpoint + if repo_type is not None and repo_type != constants.REPO_TYPE_MODEL: + url_prefix = f"{url_prefix}/{repo_type}s" + return CommitInfo( + commit_url=f"{url_prefix}/{repo_id}/commit/{info.sha}", + commit_message=commit_message, + commit_description=commit_description, + oid=info.sha, # type: ignore[arg-type] + _endpoint=self.endpoint, + ) + + commit_payload = _prepare_commit_payload( + operations=operations, + files_to_copy=files_to_copy, + commit_message=commit_message, + commit_description=commit_description, + parent_commit=parent_commit, + ) + commit_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/commit/{revision}" + + def _payload_as_ndjson() -> Iterable[bytes]: + for item in commit_payload: + yield json.dumps(item).encode() + yield b"\n" + + headers = { + # See https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073 + "Content-Type": "application/x-ndjson", + **headers, + } + data = b"".join(_payload_as_ndjson()) + params = {"create_pr": "1"} if create_pr else None + + try: + commit_resp = get_session().post(url=commit_url, headers=headers, content=data, params=params) + hf_raise_for_status(commit_resp, endpoint_name="commit") + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + except RemoteEntryNotFoundError as e: + if nb_deletions > 0 and "A file with this name doesn't exist" in str(e): + e.append_to_message( + "\nMake sure to differentiate file and folder paths in delete" + " operations with a trailing '/' or using `is_folder=True/False`." + ) + raise + + # Mark additions as committed (cannot be reused in another commit) + for addition in additions: + addition._is_committed = True + + commit_data = commit_resp.json() + return CommitInfo( + commit_url=commit_data["commitUrl"], + commit_message=commit_message, + commit_description=commit_description, + oid=commit_data["commitOid"], + pr_url=commit_data["pullRequestUrl"] if create_pr else None, + _endpoint=self.endpoint, + ) + + def preupload_lfs_files( + self, + repo_id: str, + additions: Iterable[CommitOperationAdd], + *, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + free_memory: bool = True, + gitignore_content: Optional[str] = None, + ): + """Pre-upload LFS files to S3 in preparation on a future commit. + + This method is useful if you are generating the files to upload on-the-fly and you don't want to store them + in memory before uploading them all at once. + + > [!WARNING] + > This is a power-user method. You shouldn't need to call it directly to make a normal commit. + > Use [`create_commit`] directly instead. + + > [!WARNING] + > Commit operations will be mutated during the process. In particular, the attached `path_or_fileobj` will be + > removed after the upload to save memory (and replaced by an empty `bytes` object). Do not reuse the same + > objects except to pass them to [`create_commit`]. If you don't want to remove the attached content from the + > commit operation object, pass `free_memory=False`. + + Args: + repo_id (`str`): + The repository in which you will commit the files, for example: `"username/custom_transformers"`. + + operations (`Iterable` of [`CommitOperationAdd`]): + The list of files to upload. Warning: the objects in this list will be mutated to include information + relative to the upload. Do not reuse the same objects for multiple commits. + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + The type of repository to upload to (e.g. `"model"` -default-, `"dataset"` or `"space"`). + + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + + create_pr (`boolean`, *optional*): + Whether or not you plan to create a Pull Request with that commit. Defaults to `False`. + + num_threads (`int`, *optional*): + Number of concurrent threads for uploading files. Defaults to 5. + Setting it to 2 means at most 2 files will be uploaded concurrently. + + gitignore_content (`str`, *optional*): + The content of the `.gitignore` file to know which files should be ignored. The order of priority + is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present + in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub + (if any). + + Example: + ```py + >>> from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit, create_repo + + >>> repo_id = create_repo("test_preupload").repo_id + + # Generate and preupload LFS files one by one + >>> operations = [] # List of all `CommitOperationAdd` objects that will be generated + >>> for i in range(5): + ... content = ... # generate binary content + ... addition = CommitOperationAdd(path_in_repo=f"shard_{i}_of_5.bin", path_or_fileobj=content) + ... preupload_lfs_files(repo_id, additions=[addition]) # upload + free memory + ... operations.append(addition) + + # Create commit + >>> create_commit(repo_id, operations=operations, commit_message="Commit all shards") + ``` + """ + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + create_pr = create_pr if create_pr is not None else False + headers = self._build_hf_headers(token=token) + + # Check if a `gitignore` file is being committed to the Hub. + additions = list(additions) + if gitignore_content is None: + for addition in additions: + if addition.path_in_repo == ".gitignore": + with addition.as_file() as f: + gitignore_content = f.read().decode() + break + + # Filter out already uploaded files + new_additions = [addition for addition in additions if not addition._is_uploaded] + + # Check which new files are LFS + # For some items, we might have already fetched the upload mode (in case of upload_large_folder) + additions_no_upload_mode = [addition for addition in new_additions if addition._upload_mode is None] + if len(additions_no_upload_mode) > 0: + try: + _fetch_upload_modes( + additions=additions_no_upload_mode, + repo_type=repo_type, + repo_id=repo_id, + headers=headers, + revision=revision, + endpoint=self.endpoint, + create_pr=create_pr or False, + gitignore_content=gitignore_content, + ) + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + + # Filter out regular files + new_lfs_additions = [addition for addition in new_additions if addition._upload_mode == "lfs"] + + # Filter out files listed in .gitignore + new_lfs_additions_to_upload = [] + for addition in new_lfs_additions: + if addition._should_ignore: + logger.debug(f"Skipping upload for LFS file '{addition.path_in_repo}' (ignored by gitignore file).") + else: + new_lfs_additions_to_upload.append(addition) + if len(new_lfs_additions) != len(new_lfs_additions_to_upload): + logger.info( + f"Skipped upload for {len(new_lfs_additions) - len(new_lfs_additions_to_upload)} LFS file(s) " + "(ignored by gitignore file)." + ) + # If no LFS files remain to upload, keep previous behavior and log explicitly + if len(new_lfs_additions_to_upload) == 0: + logger.debug("No LFS files to upload.") + return + # Prepare upload parameters + upload_kwargs = { + "additions": new_lfs_additions_to_upload, + "repo_type": repo_type, + "repo_id": repo_id, + "headers": headers, + "endpoint": self.endpoint, + # If `create_pr`, we don't want to check user permission on the revision as users with read permission + # should still be able to create PRs even if they don't have write permission on the target branch of the + # PR (i.e. `revision`). + "revision": revision if not create_pr else None, + } + _upload_files(**upload_kwargs, num_threads=num_threads, create_pr=create_pr) # type: ignore [arg-type] + for addition in new_lfs_additions_to_upload: + addition._is_uploaded = True + if free_memory: + addition.path_or_fileobj = b"" + + @overload + def upload_file( # type: ignore + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: ... + + @overload + def upload_file( + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: ... + + @validate_hf_hub_args + @future_compatible + def upload_file( + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Upload a local file (up to 50 GB) to the given repo. The upload is done + through a HTTP post request, and doesn't require git or git-lfs to be + installed. + + Args: + path_or_fileobj (`str`, `Path`, `bytes`, or `IO`): + Path to a file on the local machine or binary data stream / + fileobj / buffer. + path_in_repo (`str`): + Relative filepath in the repo, for example: + `"checkpoints/1fec34a/weights.bin"` + repo_id (`str`): + The repository to which the file will be uploaded, for example: + `"username/custom_transformers"` + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit + commit_description (`str` *optional*) + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + > - [`~utils.RevisionNotFoundError`] + > If the revision to download from cannot be found. + + > [!WARNING] + > `upload_file` assumes that the repo already exists on the Hub. If you get a + > Client error 404, please make sure you are authenticated and that `repo_id` and + > `repo_type` are set correctly. If repo does not exist, create it first using + > [`~hf_api.create_repo`]. + + Example: + + ```python + >>> from huggingface_hub import upload_file + + >>> with open("./local/filepath", "rb") as fobj: + ... upload_file( + ... path_or_fileobj=fileobj, + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-dataset", + ... repo_type="dataset", + ... token="my_token", + ... ) + + >>> upload_file( + ... path_or_fileobj=".\\\\local\\\\file\\\\path", + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-model", + ... token="my_token", + ... ) + + >>> upload_file( + ... path_or_fileobj=".\\\\local\\\\file\\\\path", + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-model", + ... token="my_token", + ... create_pr=True, + ... ) + ``` + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + + commit_message = ( + commit_message if commit_message is not None else f"Upload {path_in_repo} with huggingface_hub" + ) + operation = CommitOperationAdd( + path_or_fileobj=path_or_fileobj, + path_in_repo=path_in_repo, + ) + + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + operations=[operation], + commit_message=commit_message, + commit_description=commit_description, + token=token, + revision=revision, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: ... + + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: ... + + @validate_hf_hub_args + @future_compatible + def upload_folder( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Upload a local folder to the given repo. The upload is done through a HTTP requests, and doesn't require git or + git-lfs to be installed. + + The structure of the folder will be preserved. Files with the same name already present in the repository will + be overwritten. Others will be left untouched. + + Use the `allow_patterns` and `ignore_patterns` arguments to specify which files to upload. These parameters + accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as + documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). If both `allow_patterns` and + `ignore_patterns` are provided, both constraints apply. By default, all files from the folder are uploaded. + + Use the `delete_patterns` argument to specify remote files you want to delete. Input type is the same as for + `allow_patterns` (see above). If `path_in_repo` is also provided, the patterns are matched against paths + relative to this folder. For example, `upload_folder(..., path_in_repo="experiment", delete_patterns="logs/*")` + will delete any remote file under `./experiment/logs/`. Note that the `.gitattributes` file will not be deleted + even if it matches the patterns. + + Any `.git/` folder present in any subdirectory will be ignored. However, please be aware that the `.gitignore` + file is not taken into account. + + Uses `HfApi.create_commit` under the hood. + + Args: + repo_id (`str`): + The repository to which the file will be uploaded, for example: + `"username/custom_transformers"` + folder_path (`str` or `Path`): + Path to the folder to upload on the local file system + path_in_repo (`str`, *optional*): + Relative path of the directory in the repo, for example: + `"checkpoints/1fec34a/results"`. Will default to the root folder of the repository. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to: + `f"Upload {path_in_repo} with huggingface_hub"` + commit_description (`str` *optional*): + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not + set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened + against this branch. If `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + allow_patterns (`list[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`list[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + delete_patterns (`list[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo while committing + new files. This is useful if you don't know which files have already been uploaded. + Note: to avoid discrepancies the `.gitattributes` file is not deleted even if it matches the pattern. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + + > [!WARNING] + > `upload_folder` assumes that the repo already exists on the Hub. If you get a Client error 404, please make + > sure you are authenticated and that `repo_id` and `repo_type` are set correctly. If repo does not exist, create + > it first using [`~hf_api.create_repo`]. + + > [!TIP] + > When dealing with a large folder (thousands of files or hundreds of GB), we recommend using [`~hf_api.upload_large_folder`] instead. + + Example: + + ```python + # Upload checkpoints folder except the log files + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... ignore_patterns="**/logs/*.txt", + ... ) + + # Upload checkpoints folder including logs while deleting existing logs from the repo + # Useful if you don't know exactly which log files have already being pushed + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... delete_patterns="**/logs/*.txt", + ... ) + + # Upload checkpoints folder while creating a PR + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... create_pr=True, + ... ) + ``` + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + + # By default, upload folder to the root directory in repo. + if path_in_repo is None: + path_in_repo = "" + + # Do not upload .git folder + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + ignore_patterns += DEFAULT_IGNORE_PATTERNS + + delete_operations = self._prepare_folder_deletions( + repo_id=repo_id, + repo_type=repo_type, + revision=constants.DEFAULT_REVISION if create_pr else revision, + token=token, + path_in_repo=path_in_repo, + delete_patterns=delete_patterns, + ) + add_operations = self._prepare_upload_folder_additions( + folder_path, + path_in_repo, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + token=token, + repo_type=repo_type, + ) + + # Optimize operations: if some files will be overwritten, we don't need to delete them first + if len(add_operations) > 0: + added_paths = set(op.path_in_repo for op in add_operations) + delete_operations = [ + delete_op for delete_op in delete_operations if delete_op.path_in_repo not in added_paths + ] + commit_operations = delete_operations + add_operations + + commit_message = commit_message or "Upload folder using huggingface_hub" + + return self.create_commit( + repo_type=repo_type, + repo_id=repo_id, + operations=commit_operations, + commit_message=commit_message, + commit_description=commit_description, + token=token, + revision=revision, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @validate_hf_hub_args + def delete_file( + self, + path_in_repo: str, + repo_id: str, + *, + token: Union[str, bool, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Deletes a file in the given repo. + + Args: + path_in_repo (`str`): + Relative filepath in the repo, for example: + `"checkpoints/1fec34a/weights.bin"` + repo_id (`str`): + The repository from which the file will be deleted, for example: + `"username/custom_transformers"` + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or + space, `None` or `"model"` if in a model. Default is `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Delete {path_in_repo} with huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + > - [`~utils.RevisionNotFoundError`] + > If the revision to download from cannot be found. + > - [`~utils.EntryNotFoundError`] + > If the file to download cannot be found. + + """ + commit_message = ( + commit_message if commit_message is not None else f"Delete {path_in_repo} with huggingface_hub" + ) + + operations = [CommitOperationDelete(path_in_repo=path_in_repo)] + + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=operations, + revision=revision, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @validate_hf_hub_args + def delete_files( + self, + repo_id: str, + delete_patterns: list[str], + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Delete files from a repository on the Hub. + + If a folder path is provided, the entire folder is deleted as well as + all files it contained. + + Args: + repo_id (`str`): + The repository from which the folder will be deleted, for example: + `"username/custom_transformers"` + delete_patterns (`list[str]`): + List of files or folders to delete. Each string can either be + a file path, a folder path, or a wildcard pattern. Patterns are Standard + Wildcards (globbing patterns) as documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). + The pattern matching is based on [`fnmatch`](https://docs.python.org/3/library/fnmatch.html). + Note that `fnmatch` matches `*` across path boundaries, unlike traditional Unix shell globbing. + E.g. `["file.txt", "folder/", "data/*.parquet"]` + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + to the stored token. + repo_type (`str`, *optional*): + Type of the repo to delete files from. Can be `"model"`, + `"dataset"` or `"space"`. Defaults to `"model"`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary (first line) of the generated commit. Defaults to + `f"Delete files using huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + """ + operations = self._prepare_folder_deletions( + repo_id=repo_id, repo_type=repo_type, delete_patterns=delete_patterns, path_in_repo="", revision=revision + ) + + if commit_message is None: + commit_message = f"Delete files {' '.join(delete_patterns)} with huggingface_hub" + + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=operations, + revision=revision, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @validate_hf_hub_args + def delete_folder( + self, + path_in_repo: str, + repo_id: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Deletes a folder in the given repo. + + Simple wrapper around [`create_commit`] method. + + Args: + path_in_repo (`str`): + Relative folder path in the repo, for example: `"checkpoints/1fec34a"`. + repo_id (`str`): + The repository from which the folder will be deleted, for example: + `"username/custom_transformers"` + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + to the stored token. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the folder is in a dataset or + space, `None` or `"model"` if in a model. Default is `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Delete folder {path_in_repo} with huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + """ + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=[CommitOperationDelete(path_in_repo=path_in_repo, is_folder=True)], + revision=revision, + commit_message=( + commit_message if commit_message is not None else f"Delete folder {path_in_repo} with huggingface_hub" + ), + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + def upload_large_folder( + self, + repo_id: str, + folder_path: Union[str, Path], + *, + repo_type: str, # Repo type is required! + revision: Optional[str] = None, + private: Optional[bool] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + num_workers: Optional[int] = None, + print_report: bool = True, + print_report_every: int = 60, + ) -> None: + """Upload a large folder to the Hub in the most resilient way possible. + + Several workers are started to upload files in an optimized way. Before being committed to a repo, files must be + hashed and be pre-uploaded if they are LFS files. Workers will perform these tasks for each file in the folder. + At each step, some metadata information about the upload process is saved in the folder under `.cache/.huggingface/` + to be able to resume the process if interrupted. The whole process might result in several commits. + + Args: + repo_id (`str`): + The repository to which the file will be uploaded. + E.g. `"HuggingFaceTB/smollm-corpus"`. + folder_path (`str` or `Path`): + Path to the folder to upload on the local file system. + repo_type (`str`): + Type of the repository. Must be one of `"model"`, `"dataset"` or `"space"`. + Unlike in all other `HfApi` methods, `repo_type` is explicitly required here. This is to avoid + any mistake when uploading a large folder to the Hub, and therefore prevent from having to re-upload + everything. + revision (`str`, `optional`): + The branch to commit to. If not provided, the `main` branch will be used. + private (`bool`, `optional`): + Whether the repository should be private. + If `None` (default), the repo will be public unless the organization's default is private. + allow_patterns (`list[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`list[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + num_workers (`int`, *optional*): + Number of workers to start. Defaults to half of CPU cores (minimum 1). + A higher number of workers may speed up the process if your machine allows it. However, on machines with a + slower connection, it is recommended to keep the number of workers low to ensure better resumability. + Indeed, partially uploaded files will have to be completely re-uploaded if the process is interrupted. + print_report (`bool`, *optional*): + Whether to print a report of the upload progress. Defaults to True. + Report is printed to `sys.stdout` every X seconds (60 by defaults) and overwrites the previous report. + print_report_every (`int`, *optional*): + Frequency at which the report is printed. Defaults to 60 seconds. + + > [!TIP] + > A few things to keep in mind: + > - Repository limits still apply: https://huggingface.co/docs/hub/repositories-recommendations + > - Do not start several processes in parallel. + > - You can interrupt and resume the process at any time. + > - Do not upload the same folder to several repositories. If you need to do so, you must delete the local `.cache/.huggingface/` folder first. + + > [!WARNING] + > While being much more robust to upload large folders, `upload_large_folder` is more limited than [`upload_folder`] feature-wise. In practice: + > - you cannot set a custom `path_in_repo`. If you want to upload to a subfolder, you need to set the proper structure locally. + > - you cannot set a custom `commit_message` and `commit_description` since multiple commits are created. + > - you cannot delete from the repo while uploading. Please make a separate commit first. + > - you cannot create a PR directly. Please create a PR first (from the UI or using [`create_pull_request`]) and then commit to it by passing `revision`. + + **Technical details:** + + `upload_large_folder` process is as follow: + 1. (Check parameters and setup.) + 2. Create repo if missing. + 3. List local files to upload. + 4. Run validation checks and display warnings if repository limits might be exceeded: + - Warns if the total number of files exceeds 100k (recommended limit). + - Warns if any folder contains more than 10k files (recommended limit). + - Warns about files larger than 20GB (recommended) or 50GB (hard limit). + 5. Start workers. Workers can perform the following tasks: + - Hash a file. + - Get upload mode (regular or LFS) for a list of files. + - Pre-upload an LFS file. + - Commit a bunch of files. + Once a worker finishes a task, it will move on to the next task based on the priority list (see below) until + all files are uploaded and committed. + 6. While workers are up, regularly print a report to sys.stdout. + + Order of priority: + 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file). + 2. Commit if at least 150 files are ready to commit. + 3. Get upload mode if at least 10 files have been hashed. + 4. Pre-upload LFS file if at least 1 file and no worker is pre-uploading. + 5. Hash file if at least 1 file and no worker is hashing. + 6. Get upload mode if at least 1 file and no worker is getting upload mode. + 7. Pre-upload LFS file if at least 1 file. + 8. Hash file if at least 1 file to hash. + 9. Get upload mode if at least 1 file to get upload mode. + 10. Commit if at least 1 file to commit and at least 1 min since last commit attempt. + 11. Commit if at least 1 file to commit and all other queues are empty. + + Special rules: + - Only one worker can commit at a time. + - If no tasks are available, the worker waits for 10 seconds before checking again. + """ + return upload_large_folder_internal( + self, + repo_id=repo_id, + folder_path=folder_path, + repo_type=repo_type, + revision=revision, + private=private, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + num_workers=num_workers, + print_report=print_report, + print_report_every=print_report_every, + ) + + @validate_hf_hub_args + def get_hf_file_metadata( + self, + *, + url: str, + token: Union[bool, str, None] = None, + timeout: Optional[float] = constants.HF_HUB_ETAG_TIMEOUT, + ) -> HfFileMetadata: + """Fetch metadata of a file versioned on the Hub for a given url. + + Args: + url (`str`): + File url, for example returned by [`hf_hub_url`]. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + timeout (`float`, *optional*, defaults to 10): + How many seconds to wait for the server to send metadata before giving up. + + Returns: + A [`HfFileMetadata`] object containing metadata such as location, etag, size and commit_hash. + """ + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return get_hf_file_metadata( + url=url, + token=token, + timeout=timeout, + library_name=self.library_name, + library_version=self.library_version, + user_agent=self.user_agent, + endpoint=self.endpoint, + ) + + @overload + def hf_hub_download( + self, + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + force_download: bool = False, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: Literal[False] = False, + ) -> str: ... + + @overload + def hf_hub_download( + self, + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + force_download: bool = False, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: Literal[True], + ) -> DryRunFileInfo: ... + + @validate_hf_hub_args + def hf_hub_download( + self, + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + force_download: bool = False, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: bool = False, + ) -> Union[str, DryRunFileInfo]: + """Download a given file if it's not already present in the local cache. + + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + ``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + ``` + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files. While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the repository. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded file will be placed under this directory. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in + the local cache. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `httpx.request`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + dry_run (`bool`, *optional*, defaults to `False`): + If `True`, perform a dry run without actually downloading the file. Returns a + [`DryRunFileInfo`] object containing information about what would be downloaded. + + Returns: + `str` or [`DryRunFileInfo`]: + - If `dry_run=False`: Local path of file or if networking is off, last version of file cached on disk. + - If `dry_run=True`: A [`DryRunFileInfo`] object containing download information. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`~utils.RemoteEntryNotFoundError`] + If the file to download cannot be found. + [`~utils.LocalEntryNotFoundError`] + If network is disabled or unavailable and file is not found in cache. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` but the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + If ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If some parameter value is invalid. + """ + from .file_download import hf_hub_download + + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + library_name=self.library_name, + library_version=self.library_version, + cache_dir=cache_dir, + local_dir=local_dir, + user_agent=self.user_agent, + force_download=force_download, + etag_timeout=etag_timeout, + token=token, + headers=self.headers, + local_files_only=local_files_only, + tqdm_class=tqdm_class, + dry_run=dry_run, + ) + + @validate_hf_hub_args + def snapshot_download( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Union[bool, str, None] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[type[base_tqdm]] = None, + dry_run: bool = False, + ) -> Union[str, list[DryRunFileInfo]]: + """Download repo files. + + Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from + a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. You can also filter which files to download using + `allow_patterns` and `ignore_patterns`. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files.While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly + configured. It is also not possible to filter which files to download when cloning a repository using git. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded files will be placed under this directory. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `httpx.request`. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in the local cache. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_patterns (`list[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are downloaded. + ignore_patterns (`list[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not downloaded. + max_workers (`int`, *optional*): + Number of concurrent threads to download files (1 thread = 1 file download). + Defaults to 8. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Note that the `tqdm_class` is not passed to each individual download. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + dry_run (`bool`, *optional*, defaults to `False`): + If `True`, perform a dry run without actually downloading the files. Returns a list of + [`DryRunFileInfo`] objects containing information about what would be downloaded. + + Returns: + `str` or list of [`DryRunFileInfo`]: + - If `dry_run=False`: Folder path of the repo snapshot. + - If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` and the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid. + """ + from ._snapshot_download import snapshot_download + + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return snapshot_download( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + cache_dir=cache_dir, + local_dir=local_dir, + library_name=self.library_name, + library_version=self.library_version, + user_agent=self.user_agent, + etag_timeout=etag_timeout, + force_download=force_download, + token=token, + local_files_only=local_files_only, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + max_workers=max_workers, + tqdm_class=tqdm_class, + headers=self.headers, + dry_run=dry_run, + ) + + def get_safetensors_metadata( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> SafetensorsRepoMetadata: + """ + Parse metadata for a safetensors repo on the Hub. + + We first check if the repo has a single safetensors file or a sharded safetensors repo. If it's a single + safetensors file, we parse the metadata from this file. If it's a sharded safetensors repo, we parse the + metadata from the index file and then parse the metadata from each shard. + + To parse metadata from a single safetensors file, use [`parse_safetensors_file_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a + model. Default is `None`. + revision (`str`, *optional*): + The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the + head of the `"main"` branch. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`SafetensorsRepoMetadata`]: information related to safetensors repo. + + Raises: + [`NotASafetensorsRepoError`] + If the repo is not a safetensors repo i.e. doesn't have either a + `model.safetensors` or a `model.safetensors.index.json` file. + [`SafetensorsParsingError`] + If a safetensors file header couldn't be parsed correctly. + + Example: + ```py + # Parse repo with single weights file + >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m") + >>> metadata + SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...}, + files_metadata={'model.safetensors': SafetensorsFileMetadata(...)} + ) + >>> metadata.files_metadata["model.safetensors"].metadata + {'format': 'pt'} + + # Parse repo with sharded model + >>> metadata = get_safetensors_metadata("bigscience/bloom") + Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s] + >>> metadata + SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...}) + >>> len(metadata.files_metadata) + 72 # All safetensors files have been fetched + + # Parse repo with sharded model + >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5") + NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files. + ``` + """ + if self.file_exists( # Single safetensors file => non-sharded model + repo_id=repo_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ): + file_metadata = self.parse_safetensors_file_metadata( + repo_id=repo_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ) + return SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={ + tensor_name: constants.SAFETENSORS_SINGLE_FILE for tensor_name in file_metadata.tensors.keys() + }, + files_metadata={constants.SAFETENSORS_SINGLE_FILE: file_metadata}, + ) + elif self.file_exists( # Multiple safetensors files => sharded with index + repo_id=repo_id, + filename=constants.SAFETENSORS_INDEX_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ): + # Fetch index + index_file = self.hf_hub_download( + repo_id=repo_id, + filename=constants.SAFETENSORS_INDEX_FILE, + repo_type=repo_type, + revision=revision, + token=token, + ) + with open(index_file) as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + + # Fetch metadata per shard + files_metadata = {} + + def _parse(filename: str) -> None: + files_metadata[filename] = self.parse_safetensors_file_metadata( + repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, token=token + ) + + thread_map( + _parse, + set(weight_map.values()), + desc="Parse safetensors files", + tqdm_class=hf_tqdm, + ) + + return SafetensorsRepoMetadata( + metadata=index.get("metadata", None), + sharded=True, + weight_map=weight_map, + files_metadata=files_metadata, + ) + else: + # Not a safetensors repo + raise NotASafetensorsRepoError( + f"'{repo_id}' is not a safetensors repo. Couldn't find '{constants.SAFETENSORS_INDEX_FILE}' or '{constants.SAFETENSORS_SINGLE_FILE}' files." + ) + + def parse_safetensors_file_metadata( + self, + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> SafetensorsFileMetadata: + """ + Parse metadata from a safetensors file on the Hub. + + To parse metadata from all safetensors files in a repo at once, use [`get_safetensors_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a + model. Default is `None`. + revision (`str`, *optional*): + The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the + head of the `"main"` branch. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`SafetensorsFileMetadata`]: information related to a safetensors file. + + Raises: + [`NotASafetensorsRepoError`]: + If the repo is not a safetensors repo i.e. doesn't have either a + `model.safetensors` or a `model.safetensors.index.json` file. + [`SafetensorsParsingError`]: + If a safetensors file header couldn't be parsed correctly. + """ + url = hf_hub_url( + repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, endpoint=self.endpoint + ) + _headers = self._build_hf_headers(token=token) + + context_msg = f"repo '{repo_id}', revision '{revision or constants.DEFAULT_REVISION}'" + + # 1. Fetch first 100kb + # Empirically, 97% of safetensors files have a metadata size < 100kb (over the top 1000 models on the Hub). + # We assume fetching 100kb is faster than making 2 GET requests. Therefore we always fetch the first 100kb to + # avoid the 2nd GET in most cases. + # See https://github.com/huggingface/huggingface_hub/pull/1855#discussion_r1404286419. + response = get_session().get(url, headers={**_headers, "range": "bytes=0-100000"}) + hf_raise_for_status(response) + + # 2. Parse and validate metadata size using shared helper + metadata_size = _get_safetensors_metadata_size(response.content[:8], filename, context_msg) + + # 3.a. Get metadata from payload + if metadata_size <= 100000: + metadata_as_bytes = response.content[8 : 8 + metadata_size] + else: # 3.b. Request full metadata + response = get_session().get(url, headers={**_headers, "range": f"bytes=8-{metadata_size + 7}"}) + hf_raise_for_status(response) + metadata_as_bytes = response.content + + # 4. Parse json header using shared helper + return _parse_safetensors_header(metadata_as_bytes, filename, context_msg) + + @validate_hf_hub_args + def create_branch( + self, + repo_id: str, + *, + branch: str, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Create a new branch for a repo on the Hub, starting from the specified revision (defaults to `main`). + To find a revision suiting your needs, you can use [`list_repo_refs`] or [`list_repo_commits`]. + + Args: + repo_id (`str`): + The repository in which the branch will be created. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to create. + + revision (`str`, *optional*): + The git revision to create the branch from. It can be a branch name or + the OID/SHA of a commit, as a hexadecimal string. Defaults to the head + of the `"main"` branch. + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if branch already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.BadRequestError`]: + If invalid reference for a branch. Ex: `refs/pr/5` or 'refs/foo/bar'. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token) + payload = {} + if revision is not None: + payload["startingPoint"] = revision + + # Create branch + response = get_session().post(url=branch_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if exist_ok and e.response.status_code == 409: + return + elif exist_ok and e.response.status_code == 403: + # No write permission on the namespace but branch might already exist + try: + refs = self.list_repo_refs(repo_id=repo_id, repo_type=repo_type, token=token) + for branch_ref in refs.branches: + if branch_ref.name == branch: + return # Branch already exists => do not raise + except HfHubHTTPError: + pass # We raise the original error if the branch does not exist + raise + + @validate_hf_hub_args + def delete_branch( + self, + repo_id: str, + *, + branch: str, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a branch from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a branch will be deleted. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to delete. + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.HfHubHTTPError`]: + If trying to delete a protected branch. Ex: `main` cannot be deleted. + [`~utils.HfHubHTTPError`]: + If trying to delete a branch that does not exist. + + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token) + + # Delete branch + response = get_session().delete(url=branch_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def create_tag( + self, + repo_id: str, + *, + tag: str, + tag_message: Optional[str] = None, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Tag a given commit of a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a commit will be tagged. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to create. + + tag_message (`str`, *optional*): + The description of the tag to create. + + revision (`str`, *optional*): + The git revision to tag. It can be a branch name or the OID/SHA of a + commit, as a hexadecimal string. Shorthands (7 first characters) are + also supported. Defaults to the head of the `"main"` branch. + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or + space, `None` or `"model"` if tagging a model. Default is + `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if tag already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{revision}" + headers = self._build_hf_headers(token=token) + payload = {"tag": tag} + if tag_message is not None: + payload["message"] = tag_message + + # Tag + response = get_session().post(url=tag_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if not (e.response.status_code == 409 and exist_ok): + raise + + @validate_hf_hub_args + def delete_tag( + self, + repo_id: str, + *, + tag: str, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a tag from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a tag will be deleted. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to delete. + + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or space, `None` or + `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If tag is not found. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + tag = quote(tag, safe="") + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{tag}" + headers = self._build_hf_headers(token=token) + + # Un-tag + response = get_session().delete(url=tag_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def get_full_repo_name( + self, + model_id: str, + *, + organization: Optional[str] = None, + token: Union[bool, str, None] = None, + ): + """ + Returns the repository name for a given model ID and optional + organization. + + Args: + model_id (`str`): + The name of the model. + organization (`str`, *optional*): + If passed, the repository name will be in the organization + namespace instead of the user namespace. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `str`: The repository name in the user's namespace + ({username}/{model_id}) if no organization is passed, and under the + organization namespace ({organization}/{model_id}) otherwise. + """ + if organization is None: + if "/" in model_id: + username = model_id.split("/")[0] + else: + username = self.whoami(token=token)["name"] # type: ignore + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + @validate_hf_hub_args + def get_repo_discussions( + self, + repo_id: str, + *, + author: Optional[str] = None, + discussion_type: Optional[constants.DiscussionTypeFilter] = None, + discussion_status: Optional[constants.DiscussionStatusFilter] = None, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterator[Discussion]: + """ + Fetches Discussions and Pull Requests for the given repo. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + author (`str`, *optional*): + Pass a value to filter by discussion author. `None` means no filter. + Default is `None`. + discussion_type (`str`, *optional*): + Set to `"pull_request"` to fetch only pull requests, `"discussion"` + to fetch only discussions. Set to `"all"` or `None` to fetch both. + Default is `None`. + discussion_status (`str`, *optional*): + Set to `"open"` (respectively `"closed"`) to fetch only open + (respectively closed) discussions. Set to `"all"` or `None` + to fetch both. + Default is `None`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if fetching from a dataset or + space, `None` or `"model"` if fetching from a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterator[Discussion]`: An iterator of [`Discussion`] objects. + + Example: + Collecting all discussions of a repo in a list: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) + ``` + + Iterating over discussions of a repo: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> for discussion in get_repo_discussions(repo_id="bert-base-uncased"): + ... print(discussion.num, discussion.title) + ``` + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if discussion_type is not None and discussion_type not in constants.DISCUSSION_TYPES: + raise ValueError(f"Invalid discussion_type, must be one of {constants.DISCUSSION_TYPES}") + if discussion_status is not None and discussion_status not in constants.DISCUSSION_STATUS: + raise ValueError(f"Invalid discussion_status, must be one of {constants.DISCUSSION_STATUS}") + + headers = self._build_hf_headers(token=token) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions" + + params: dict[str, Union[str, int]] = {} + if discussion_type is not None: + params["type"] = discussion_type + if discussion_status is not None: + params["status"] = discussion_status + if author is not None: + params["author"] = author + + def _fetch_discussion_page(page_index: int): + params["p"] = page_index + resp = get_session().get(path, headers=headers, params=params) + hf_raise_for_status(resp) + paginated_discussions = resp.json() + total = paginated_discussions["count"] + start = paginated_discussions["start"] + discussions = paginated_discussions["discussions"] + has_next = (start + len(discussions)) < total + return discussions, has_next + + has_next, page_index = True, 0 + + while has_next: + discussions, has_next = _fetch_discussion_page(page_index=page_index) + for discussion in discussions: + yield Discussion( + title=discussion["title"], + num=discussion["num"], + author=discussion.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion["createdAt"]), + status=discussion["status"], + repo_id=discussion["repo"]["name"], + repo_type=discussion["repo"]["type"], + is_pull_request=discussion["isPullRequest"], + endpoint=self.endpoint, + ) + page_index = page_index + 1 + + @validate_hf_hub_args + def get_discussion_details( + self, + repo_id: str, + discussion_num: int, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> DiscussionWithDetails: + """Fetches a Discussion's / Pull Request 's details from the Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`DiscussionWithDetails`] + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" + headers = self._build_hf_headers(token=token) + resp = get_session().get(path, params={"diff": "1"}, headers=headers) + hf_raise_for_status(resp) + + discussion_details = resp.json() + is_pull_request = discussion_details["isPullRequest"] + + target_branch = discussion_details["changes"]["base"] if is_pull_request else None + conflicting_files = discussion_details["filesWithConflicts"] if is_pull_request else None + merge_commit_oid = discussion_details["changes"].get("mergeCommitId", None) if is_pull_request else None + + return DiscussionWithDetails( + title=discussion_details["title"], + num=discussion_details["num"], + author=discussion_details.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion_details["createdAt"]), + status=discussion_details["status"], + repo_id=discussion_details["repo"]["name"], + repo_type=discussion_details["repo"]["type"], + is_pull_request=discussion_details["isPullRequest"], + events=[deserialize_event(evt) for evt in discussion_details["events"]], + conflicting_files=conflicting_files, + target_branch=target_branch, + merge_commit_oid=merge_commit_oid, + diff=discussion_details.get("diff"), + endpoint=self.endpoint, + ) + + @validate_hf_hub_args + def create_discussion( + self, + repo_id: str, + title: str, + *, + token: Union[bool, str, None] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + pull_request: bool = False, + ) -> DiscussionWithDetails: + """Creates a Discussion or Pull Request. + + Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + pull_request (`bool`, *optional*): + Whether to create a Pull Request or discussion. If `True`, creates a Pull Request. + If `False`, creates a discussion. Defaults to `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access.""" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + if description is not None: + description = description.strip() + description = ( + description + if description + else ( + f"{'Pull Request' if pull_request else 'Discussion'} opened with the" + " [huggingface_hub Python" + " library](https://huggingface.co/docs/huggingface_hub)" + ) + ) + + headers = self._build_hf_headers(token=token) + resp = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions", + json={ + "title": title.strip(), + "description": description, + "pullRequest": pull_request, + }, + headers=headers, + ) + hf_raise_for_status(resp) + num = resp.json()["num"] + return self.get_discussion_details( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=num, + token=token, + ) + + @validate_hf_hub_args + def create_pull_request( + self, + repo_id: str, + title: str, + *, + token: Union[bool, str, None] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionWithDetails: + """Creates a Pull Request . Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]; + + This is a wrapper around [`HfApi.create_discussion`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access.""" + return self.create_discussion( + repo_id=repo_id, + title=title, + token=token, + description=description, + repo_type=repo_type, + pull_request=True, + ) + + def _post_discussion_changes( + self, + *, + repo_id: str, + discussion_num: int, + resource: str, + body: Optional[dict] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> httpx.Response: + """Internal utility to POST changes to a Discussion or Pull Request""" + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + repo_id = f"{repo_type}s/{repo_id}" + + path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" + + headers = self._build_hf_headers(token=token) + resp = get_session().post(path, headers=headers, json=body) + hf_raise_for_status(resp) + return resp + + @validate_hf_hub_args + def comment_discussion( + self, + repo_id: str, + discussion_num: int, + comment: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Creates a new comment on the given Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`): + The content of the comment to create. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the newly created comment + + + Examples: + ```python + + >>> comment = \"\"\" + ... Hello @otheruser! + ... + ... # This is a title + ... + ... **This is bold**, *this is italic* and ~this is strikethrough~ + ... And [this](http://url) is a link + ... \"\"\" + + >>> HfApi().comment_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... comment=comment + ... ) + # DiscussionComment(id='deadbeef0000000', type='comment', ...) + + ``` + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="comment", + body={"comment": comment}, + ) + return deserialize_event(resp.json()["newMessage"]) # type: ignore + + @validate_hf_hub_args + def rename_discussion( + self, + repo_id: str, + discussion_num: int, + new_title: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionTitleChange: + """Renames a Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_title (`str`): + The new title for the discussion + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionTitleChange`]: the title change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionTitleChange(id='deadbeef0000000', type='title-change', ...) + + ``` + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="title", + body={"title": new_title}, + ) + return deserialize_event(resp.json()["newTitle"]) # type: ignore + + @validate_hf_hub_args + def change_discussion_status( + self, + repo_id: str, + discussion_num: int, + new_status: Literal["open", "closed"], + *, + token: Union[bool, str, None] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionStatusChange: + """Closes or re-opens a Discussion or Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_status (`str`): + The new status for the discussion, either `"open"` or `"closed"`. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionStatusChange`]: the status change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionStatusChange(id='deadbeef0000000', type='status-change', ...) + + ``` + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + if new_status not in ["open", "closed"]: + raise ValueError("Invalid status, valid statuses are: 'open' and 'closed'") + body: dict[str, str] = {"status": new_status} + if comment and comment.strip(): + body["comment"] = comment.strip() + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="status", + body=body, + ) + return deserialize_event(resp.json()["newStatus"]) # type: ignore + + @validate_hf_hub_args + def merge_pull_request( + self, + repo_id: str, + discussion_num: int, + *, + token: Union[bool, str, None] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ): + """Merges a Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionStatusChange`]: the status change event + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="merge", + body={"comment": comment.strip()} if comment and comment.strip() else None, + ) + + @validate_hf_hub_args + def edit_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + new_content: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Edits a comment on a Discussion / Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + new_content (`str`): + The new content of the comment. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the edited comment + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/edit", + body={"content": new_content}, + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def hide_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Hides a comment on a Discussion / Pull Request. + + > [!WARNING] + > Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the hidden comment + + > [!TIP] + > Raises the following errors: + > + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the HuggingFace API returned an error + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if some parameter value is invalid + > - [`~utils.RepositoryNotFoundError`] + > If the repository to download from cannot be found. This may be because it doesn't exist, + > or because it is set to `private` and you do not have access. + """ + warnings.warn( + "Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible.", + UserWarning, + ) + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/hide", + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def add_space_secret( + self, + repo_id: str, + key: str, + value: str, + *, + description: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """Adds or updates a secret in a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"` + value (`str`): + Secret value. Example: `"your_github_api_key"`. + description (`str`, *optional*): + Secret description. Example: `"Github API key to access the Github API"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def delete_space_secret(self, repo_id: str, key: str, *, token: Union[bool, str, None] = None) -> None: + """Deletes a secret from a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().request( + "DELETE", + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def get_space_variables(self, repo_id: str, *, token: Union[bool, str, None] = None) -> dict[str, SpaceVariable]: + """Gets all variables from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to query. Example: `"bigcode/in-the-stack"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def add_space_variable( + self, + repo_id: str, + key: str, + value: str, + *, + description: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> dict[str, SpaceVariable]: + """Adds or updates a variable in a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + value (`str`): + Variable value. Example: `"the_model_repo_id"`. + description (`str`): + Description of the variable. Example: `"Model Repo ID of the implemented model"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def delete_space_variable( + self, repo_id: str, key: str, *, token: Union[bool, str, None] = None + ) -> dict[str, SpaceVariable]: + """Deletes a variable from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().request( + "DELETE", + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def get_space_runtime(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: + """Gets runtime information about a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/runtime", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def request_space_hardware( + self, + repo_id: str, + hardware: SpaceHardware, + *, + token: Union[bool, str, None] = None, + sleep_time: Optional[int] = None, + ) -> SpaceRuntime: + """Request new hardware for a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + hardware (`str` or [`SpaceHardware`]): + Hardware on which to run the Space. Example: `"t4-medium"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + > [!TIP] + > It is also possible to request hardware directly when creating the Space repo! See [`create_repo`] for details. + """ + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + payload: dict[str, Any] = {"flavor": hardware} + if sleep_time is not None: + payload["sleepTimeSeconds"] = sleep_time + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/hardware", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def set_space_sleep_time( + self, repo_id: str, sleep_time: int, *, token: Union[bool, str, None] = None + ) -> SpaceRuntime: + """Set a custom sleep time for a Space running on upgraded hardware.. + + Your Space will go to sleep after X seconds of inactivity. You are not billed when your Space is in "sleep" + mode. If a new visitor lands on your Space, it will "wake it up". Only upgraded hardware can have a + configurable sleep time. To know more about the sleep stage, please refer to + https://huggingface.co/docs/hub/spaces-gpus#sleep-time. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to pause (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + > [!TIP] + > It is also possible to set a custom sleep time when requesting hardware with [`request_space_hardware`]. + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/sleeptime", + headers=self._build_hf_headers(token=token), + json={"seconds": sleep_time}, + ) + hf_raise_for_status(r) + runtime = SpaceRuntime(r.json()) + + hardware = runtime.requested_hardware or runtime.hardware + if hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + return runtime + + @validate_hf_hub_args + def pause_space(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: + """Pause your Space. + + A paused Space stops executing until manually restarted by its owner. This is different from the sleeping + state in which free Spaces go after 48h of inactivity. Paused time is not billed to your account, no matter the + hardware you've selected. To restart your Space, use [`restart_space`] and go to your Space settings page. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to pause. Example: `"Salesforce/BLIP2"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space including `stage=PAUSED` and requested hardware. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can pause it. If you want to manage a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/pause", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def restart_space( + self, repo_id: str, *, token: Union[bool, str, None] = None, factory_reboot: bool = False + ) -> SpaceRuntime: + """Restart your Space. + + This is the only way to programmatically restart a Space if you've put it on Pause (see [`pause_space`]). You + must be the owner of the Space to restart it. If you are using an upgraded hardware, your account will be + billed as soon as the Space is restarted. You can trigger a restart no matter the current state of a Space. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to restart. Example: `"Salesforce/BLIP2"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + factory_reboot (`bool`, *optional*): + If `True`, the Space will be rebuilt from scratch without caching any requirements. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can restart it. If you want to restart a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + params = {} + if factory_reboot: + params["factory"] = "true" + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/restart", headers=self._build_hf_headers(token=token), params=params + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def duplicate_space( + self, + from_id: str, + to_id: Optional[str] = None, + *, + private: Optional[bool] = None, + token: Union[bool, str, None] = None, + exist_ok: bool = False, + hardware: Optional[SpaceHardware] = None, + storage: Optional[SpaceStorage] = None, + sleep_time: Optional[int] = None, + secrets: Optional[list[dict[str, str]]] = None, + variables: Optional[list[dict[str, str]]] = None, + ) -> RepoUrl: + """Duplicate a Space. + + Programmatically duplicate a Space. The new Space will be created in your account and will be in the same state + as the original Space (running or paused). You can duplicate a Space no matter the current state of a Space. + + Args: + from_id (`str`): + ID of the Space to duplicate. Example: `"pharma/CLIP-Interrogator"`. + to_id (`str`, *optional*): + ID of the new Space. Example: `"dog/CLIP-Interrogator"`. If not provided, the new Space will have the same + name as the original Space, but in your account. + private (`bool`, *optional*): + Whether the new Space should be private or not. Defaults to the same privacy as the original Space. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo already exists. + hardware (`SpaceHardware` or `str`, *optional*): + Choice of Hardware. Example: `"t4-medium"`. See [`SpaceHardware`] for a complete list. + storage (`SpaceStorage` or `str`, *optional*): + Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + secrets (`list[dict[str, str]]`, *optional*): + A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + variables (`list[dict[str, str]]`, *optional*): + A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. + + Returns: + [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing + attributes like `endpoint`, `repo_type` and `repo_id`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If one of `from_id` or `to_id` cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`HfHubHTTPError`]: + If the HuggingFace API returned an error + + Example: + ```python + >>> from huggingface_hub import duplicate_space + + # Duplicate a Space to your account + >>> duplicate_space("multimodalart/dreambooth-training") + RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) + + # Can set custom destination id and visibility flag. + >>> duplicate_space("multimodalart/dreambooth-training", to_id="my-dreambooth", private=True) + RepoUrl('https://huggingface.co/spaces/nateraw/my-dreambooth',...) + ``` + """ + # Parse to_id if provided + parsed_to_id = RepoUrl(to_id) if to_id is not None else None + + # Infer target repo_id + to_namespace = ( # set namespace manually or default to username + parsed_to_id.namespace + if parsed_to_id is not None and parsed_to_id.namespace is not None + else self.whoami(token)["name"] + ) + to_repo_name = parsed_to_id.repo_name if to_id is not None else RepoUrl(from_id).repo_name # type: ignore + + # repository must be a valid repo_id (namespace/repo_name). + payload: dict[str, Any] = {"repository": f"{to_namespace}/{to_repo_name}"} + + keys = ["private", "hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] + values = [private, hardware, storage, sleep_time, secrets, variables] + payload.update({k: v for k, v in zip(keys, values) if v is not None}) + + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + + r = get_session().post( + f"{self.endpoint}/api/spaces/{from_id}/duplicate", + headers=self._build_hf_headers(token=token), + json=payload, + ) + + try: + hf_raise_for_status(r) + except HfHubHTTPError as err: + if exist_ok and err.response.status_code == 409: + # Repo already exists and `exist_ok=True` + pass + else: + raise + + return RepoUrl(r.json()["url"], endpoint=self.endpoint) + + @validate_hf_hub_args + def request_space_storage( + self, + repo_id: str, + storage: SpaceStorage, + *, + token: Union[bool, str, None] = None, + ) -> SpaceRuntime: + """Request persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. + storage (`str` or [`SpaceStorage`]): + Storage tier. Either 'small', 'medium', or 'large'. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + > [!TIP] + > It is not possible to decrease persistent storage after its granted. To do so, you must delete it + > via [`delete_space_storage`]. + """ + payload: dict[str, SpaceStorage] = {"tier": storage} + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def delete_space_storage( + self, + repo_id: str, + *, + token: Union[bool, str, None] = None, + ) -> SpaceRuntime: + """Delete persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + Raises: + [`BadRequestError`] + If space has no persistent storage. + + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + ####################### + # Inference Endpoints # + ####################### + + def list_inference_endpoints( + self, namespace: Optional[str] = None, *, token: Union[bool, str, None] = None + ) -> list[InferenceEndpoint]: + """Lists all inference endpoints for the given namespace. + + Args: + namespace (`str`, *optional*): + The namespace to list endpoints for. Defaults to the current user. Set to `"*"` to list all endpoints + from all namespaces (i.e. personal namespace and all orgs the user belongs to). + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + list[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> api.list_inference_endpoints() + [InferenceEndpoint(name='my-endpoint', ...), ...] + ``` + """ + # Special case: list all endpoints for all namespaces the user has access to + if namespace == "*": + user = self.whoami(token=token) + + # List personal endpoints first + endpoints: list[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token)) + + # Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access) + for org in user.get("orgs", []): + try: + endpoints += list_inference_endpoints(namespace=org["name"], token=token) + except HfHubHTTPError as error: + if error.response.status_code == 401: # Either no billing or user don't have access) + logger.debug("Cannot list Inference Endpoints for org '%s': %s", org["name"], error) + pass + + return endpoints + + # Normal case: list endpoints for a specific namespace + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return [ + InferenceEndpoint.from_raw(endpoint, namespace=namespace, token=token) + for endpoint in response.json()["items"] + ] + + def create_inference_endpoint( + self, + name: str, + *, + repository: str, + framework: str, + accelerator: str, + instance_size: str, + instance_type: str, + region: str, + vendor: str, + account_id: Optional[str] = None, + min_replica: int = 1, + max_replica: int = 1, + scaling_metric: Optional[InferenceEndpointScalingMetric] = None, + scaling_threshold: Optional[float] = None, + scale_to_zero_timeout: Optional[int] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[dict] = None, + env: Optional[dict[str, str]] = None, + secrets: Optional[dict[str, str]] = None, + type: InferenceEndpointType = InferenceEndpointType.PROTECTED, + domain: Optional[str] = None, + path: Optional[str] = None, + cache_http_responses: Optional[bool] = None, + tags: Optional[list[str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Create a new Inference Endpoint. + + Args: + name (`str`): + The unique name for the new Inference Endpoint. + repository (`str`): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`): + The machine learning framework used for the model (e.g. `"custom"`). + accelerator (`str`): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + region (`str`): + The cloud region in which the Inference Endpoint will be created (e.g. `"us-east-1"`). + vendor (`str`): + The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. `"aws"`). + account_id (`str`, *optional*): + The account ID used to link a VPC to a private Inference Endpoint (if applicable). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. To enable + scaling to zero, set this value to 0 and adjust `scale_to_zero_timeout` accordingly. Defaults to 1. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1. + scaling_metric (`str` or [`InferenceEndpointScalingMetric `], *optional*): + The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. Defaults to + None (meaning: let the HF Endpoints service specify the metric). + scaling_threshold (`float`, *optional*): + The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided. + Defaults to None (meaning: let the HF Endpoints service specify the threshold). + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero, or no scaling to zero if + set to None and `min_replica` is not 0. Defaults to None. + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + env (`dict[str, str]`, *optional*): + Non-secret environment variables to inject in the container environment. + secrets (`dict[str, str]`, *optional*): + Secret values to inject in the container environment. + type ([`InferenceEndpointType]`, *optional*): + The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`. + domain (`str`, *optional*): + The custom domain for the Inference Endpoint deployment, if setup the inference endpoint will be available at this domain (e.g. `"my-new-domain.cool-website.woof"`). + path (`str`, *optional*): + The custom path to the deployed model, should start with a `/` (e.g. `"/models/google-bert/bert-base-uncased"`). + cache_http_responses (`bool`, *optional*): + Whether to cache HTTP responses from the Inference Endpoint. Defaults to `False`. + tags (`list[str]`, *optional*): + A list of tags to associate with the Inference Endpoint. + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "my-endpoint-name", + ... repository="gpt2", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="cpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x2", + ... instance_type="intel-icl", + ... ) + >>> endpoint + InferenceEndpoint(name='my-endpoint-name', status="pending",...) + + # Run inference on the endpoint + >>> endpoint.client.text_generation(...) + "..." + ``` + + ```python + # Start an Inference Endpoint running Zephyr-7b-beta on TGI + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "aws-zephyr-7b-beta-0486", + ... repository="HuggingFaceH4/zephyr-7b-beta", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="gpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x1", + ... instance_type="nvidia-a10g", + ... env={ + ... "MAX_BATCH_PREFILL_TOKENS": "2048", + ... "MAX_INPUT_LENGTH": "1024", + ... "MAX_TOTAL_TOKENS": "1512", + ... "MODEL_ID": "/repository" + ... }, + ... custom_image={ + ... "health_route": "/health", + ... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", + ... }, + ... secrets={"MY_SECRET_KEY": "secret_value"}, + ... tags=["dev", "text-generation"], + ... ) + ``` + + ```python + # Start an Inference Endpoint running ProsusAI/finbert while scaling to zero in 15 minutes + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "finbert-classifier", + ... repository="ProsusAI/finbert", + ... framework="pytorch", + ... task="text-classification", + ... min_replica=0, + ... scale_to_zero_timeout=15, + ... accelerator="cpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x2", + ... instance_type="intel-icl", + ... ) + >>> endpoint.wait(timeout=300) + # Run inference on the endpoint + >>> endpoint.client.text_generation(...) + TextClassificationOutputElement(label='positive', score=0.8983615040779114) + ``` + + """ + namespace = namespace or self._get_namespace(token=token) + + if custom_image is not None: + image = ( + custom_image + if next(iter(custom_image)) in constants.INFERENCE_ENDPOINT_IMAGE_KEYS + else {"custom": custom_image} + ) + else: + image = {"huggingface": {}} + + payload: dict = { + "accountId": account_id, + "compute": { + "accelerator": accelerator, + "instanceSize": instance_size, + "instanceType": instance_type, + "scaling": { + "maxReplica": max_replica, + "minReplica": min_replica, + "scaleToZeroTimeout": scale_to_zero_timeout, + }, + }, + "model": { + "framework": framework, + "repository": repository, + "revision": revision, + "task": task, + "image": image, + }, + "name": name, + "provider": { + "region": region, + "vendor": vendor, + }, + "type": type, + } + if scaling_metric: + payload["compute"]["scaling"]["measure"] = {scaling_metric: scaling_threshold} # type: ignore + if env: + payload["model"]["env"] = env + if secrets: + payload["model"]["secrets"] = secrets + if domain is not None or path is not None: + payload["route"] = {} + if domain is not None: + payload["route"]["domain"] = domain + if path is not None: + payload["route"]["path"] = path + if cache_http_responses is not None: + payload["cacheHttpResponses"] = cache_http_responses + if tags is not None: + payload["tags"] = tags + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + @experimental + @validate_hf_hub_args + def create_inference_endpoint_from_catalog( + self, + repo_id: str, + *, + name: Optional[str] = None, + token: Union[bool, str, None] = None, + namespace: Optional[str] = None, + ) -> InferenceEndpoint: + """Create a new Inference Endpoint from a model in the Hugging Face Inference Catalog. + + The goal of the Inference Catalog is to provide a curated list of models that are optimized for inference + and for which default configurations have been tested. See https://endpoints.huggingface.co/catalog for a list + of available models in the catalog. + + Args: + repo_id (`str`): + The ID of the model in the catalog to deploy as an Inference Endpoint. + name (`str`, *optional*): + The unique name for the new Inference Endpoint. If not provided, a random name will be generated. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. + + Returns: + [`InferenceEndpoint`]: information about the new Inference Endpoint. + + > [!WARNING] + > `create_inference_endpoint_from_catalog` is experimental. Its API is subject to change in the future. Please provide feedback + > if you have any suggestions or requests. + """ + token = token or self.token or get_token() + payload: dict = { + "namespace": namespace or self._get_namespace(token=token), + "repoId": repo_id, + } + if name is not None: + payload["endpointName"] = name + + response = get_session().post( + f"{constants.INFERENCE_CATALOG_ENDPOINT}/deploy", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + data = response.json()["endpoint"] + return InferenceEndpoint.from_raw(data, namespace=data["name"], token=token) + + @experimental + @validate_hf_hub_args + def list_inference_catalog(self, *, token: Union[bool, str, None] = None) -> list[str]: + """List models available in the Hugging Face Inference Catalog. + + The goal of the Inference Catalog is to provide a curated list of models that are optimized for inference + and for which default configurations have been tested. See https://endpoints.huggingface.co/catalog for a list + of available models in the catalog. + + Use [`create_inference_endpoint_from_catalog`] to deploy a model from the catalog. + + Args: + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + + Returns: + List[`str`]: A list of model IDs available in the catalog. + > [!WARNING] + > `list_inference_catalog` is experimental. Its API is subject to change in the future. Please provide feedback + > if you have any suggestions or requests. + """ + response = get_session().get( + f"{constants.INFERENCE_CATALOG_ENDPOINT}/repo-list", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + return response.json()["models"] + + def get_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Get information about an Inference Endpoint. + + Args: + name (`str`): + The name of the Inference Endpoint to retrieve information about. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the requested Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.get_inference_endpoint("my-text-to-image") + >>> endpoint + InferenceEndpoint(name='my-text-to-image', ...) + + # Get status + >>> endpoint.status + 'running' + >>> endpoint.url + 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' + + # Run inference + >>> endpoint.client.text_to_image(...) + ``` + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def update_inference_endpoint( + self, + name: str, + *, + # Compute update + accelerator: Optional[str] = None, + instance_size: Optional[str] = None, + instance_type: Optional[str] = None, + min_replica: Optional[int] = None, + max_replica: Optional[int] = None, + scale_to_zero_timeout: Optional[int] = None, + scaling_metric: Optional[InferenceEndpointScalingMetric] = None, + scaling_threshold: Optional[float] = None, + # Model update + repository: Optional[str] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[dict] = None, + env: Optional[dict[str, str]] = None, + secrets: Optional[dict[str, str]] = None, + # Route update + domain: Optional[str] = None, + path: Optional[str] = None, + # Other + cache_http_responses: Optional[bool] = None, + tags: Optional[list[str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Update an Inference Endpoint. + + This method allows the update of either the compute configuration, the deployed model, the route, or any combination. + All arguments are optional but at least one must be provided. + + For convenience, you can also update an Inference Endpoint using [`InferenceEndpoint.update`]. + + Args: + name (`str`): + The name of the Inference Endpoint to update. + + accelerator (`str`, *optional*): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`, *optional*): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`, *optional*): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero. + scaling_metric (`str` or [`InferenceEndpointScalingMetric `], *optional*): + The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. + Defaults to None. + scaling_threshold (`float`, *optional*): + The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided. + Defaults to None. + repository (`str`, *optional*): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`, *optional*): + The machine learning framework used for the model (e.g. `"custom"`). + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + env (`dict[str, str]`, *optional*): + Non-secret environment variables to inject in the container environment + secrets (`dict[str, str]`, *optional*): + Secret values to inject in the container environment. + + domain (`str`, *optional*): + The custom domain for the Inference Endpoint deployment, if setup the inference endpoint will be available at this domain (e.g. `"my-new-domain.cool-website.woof"`). + path (`str`, *optional*): + The custom path to the deployed model, should start with a `/` (e.g. `"/models/google-bert/bert-base-uncased"`). + + cache_http_responses (`bool`, *optional*): + Whether to cache HTTP responses from the Inference Endpoint. + tags (`list[str]`, *optional*): + A list of tags to associate with the Inference Endpoint. + + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + # Populate only the fields that are not None + payload: dict = defaultdict(lambda: defaultdict(dict)) + if accelerator is not None: + payload["compute"]["accelerator"] = accelerator + if instance_size is not None: + payload["compute"]["instanceSize"] = instance_size + if instance_type is not None: + payload["compute"]["instanceType"] = instance_type + if max_replica is not None: + payload["compute"]["scaling"]["maxReplica"] = max_replica + if min_replica is not None: + payload["compute"]["scaling"]["minReplica"] = min_replica + if scale_to_zero_timeout is not None: + payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout + if scaling_metric: + payload["compute"]["scaling"]["measure"] = {scaling_metric: scaling_threshold} + if repository is not None: + payload["model"]["repository"] = repository + if framework is not None: + payload["model"]["framework"] = framework + if revision is not None: + payload["model"]["revision"] = revision + if task is not None: + payload["model"]["task"] = task + if custom_image is not None: + payload["model"]["image"] = {"custom": custom_image} + if env is not None: + payload["model"]["env"] = env + if secrets is not None: + payload["model"]["secrets"] = secrets + if domain is not None: + payload["route"]["domain"] = domain + if path is not None: + payload["route"]["path"] = path + if cache_http_responses is not None: + payload["cacheHttpResponses"] = cache_http_responses + if tags is not None: + payload["tags"] = tags + + response = get_session().put( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def delete_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """Delete an Inference Endpoint. + + This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable + to pause it with [`pause_inference_endpoint`] or scale it to zero with [`scale_to_zero_inference_endpoint`]. + + For convenience, you can also delete an Inference Endpoint using [`InferenceEndpoint.delete`]. + + Args: + name (`str`): + The name of the Inference Endpoint to delete. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + namespace = namespace or self._get_namespace(token=token) + response = get_session().delete( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + def pause_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Pause an Inference Endpoint. + + A paused Inference Endpoint will not be charged. It can be resumed at any time using [`resume_inference_endpoint`]. + This is different than scaling the Inference Endpoint to zero with [`scale_to_zero_inference_endpoint`], which + would be automatically restarted when a request is made to it. + + For convenience, you can also pause an Inference Endpoint using [`pause_inference_endpoint`]. + + Args: + name (`str`): + The name of the Inference Endpoint to pause. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the paused Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def resume_inference_endpoint( + self, + name: str, + *, + namespace: Optional[str] = None, + running_ok: bool = True, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Resume an Inference Endpoint. + + For convenience, you can also resume an Inference Endpoint using [`InferenceEndpoint.resume`]. + + Args: + name (`str`): + The name of the Inference Endpoint to resume. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the resumed Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(response) + except HfHubHTTPError as error: + # If already running (and it's ok), then fetch current status and return + if running_ok and error.response.status_code == 400 and "already running" in error.response.text: + return self.get_inference_endpoint(name, namespace=namespace, token=token) + # Otherwise, raise the error + raise + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def scale_to_zero_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Scale Inference Endpoint to zero. + + An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a + cold start delay. This is different than pausing the Inference Endpoint with [`pause_inference_endpoint`], which + would require a manual resume with [`resume_inference_endpoint`]. + + For convenience, you can also scale an Inference Endpoint to zero using [`InferenceEndpoint.scale_to_zero`]. + + Args: + name (`str`): + The name of the Inference Endpoint to scale to zero. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the scaled-to-zero Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def _get_namespace(self, token: Union[bool, str, None] = None) -> str: + """Get the default namespace for the current user.""" + me = self.whoami(token=token) + if me["type"] == "user": + return me["name"] + else: + raise ValueError( + "Cannot determine default namespace. You must provide a 'namespace' as input or be logged in as a" + " user." + ) + + ######################## + # Collection Endpoints # + ######################## + @validate_hf_hub_args + def list_collections( + self, + *, + owner: Union[list[str], str, None] = None, + item: Union[list[str], str, None] = None, + sort: Optional[CollectionSort_T] = None, + limit: Optional[int] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[Collection]: + """List collections on the Huggingface Hub, given some filters. + + > [!WARNING] + > When listing collections, the item list per collection is truncated to 4 items maximum. To retrieve all items + > from a collection, you must use [`get_collection`]. + + Args: + owner (`list[str]` or `str`, *optional*): + Filter by owner's username. + item (`list[str]` or `str`, *optional*): + Filter collections containing a particular items. Example: `"models/teknium/OpenHermes-2.5-Mistral-7B"`, `"datasets/squad"` or `"papers/2311.12983"`. + sort (`Literal["lastModified", "trending", "upvotes"]`, *optional*): + Sort collections by last modified, trending or upvotes. + limit (`int`, *optional*): + Maximum number of collections to be returned. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[Collection]`: an iterable of [`Collection`] objects. + """ + # Construct the API endpoint + path = f"{self.endpoint}/api/collections" + headers = self._build_hf_headers(token=token) + params: dict = {} + if owner is not None: + params.update({"owner": owner}) + if item is not None: + params.update({"item": item}) + if sort is not None: + params.update({"sort": sort}) + if limit is not None: + params.update({"limit": limit}) + + # Paginate over the results until limit is reached + items = paginate(path, headers=headers, params=params) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + + # Parse as Collection and return + for position, collection_data in enumerate(items): + yield Collection(position=position, **collection_data) + + def get_collection(self, collection_slug: str, *, token: Union[bool, str, None] = None) -> Collection: + """Gets information about a Collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection of the Hub. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import get_collection + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + >>> collection.title + 'Recent models' + >>> len(collection.items) + 37 + >>> collection.items[0] + CollectionItem( + item_object_id='651446103cd773a050bf64c2', + item_id='TheBloke/U-Amethyst-20B-AWQ', + item_type='model', + position=88, + note=None + ) + ``` + """ + r = get_session().get( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def create_collection( + self, + title: str, + *, + namespace: Optional[str] = None, + description: Optional[str] = None, + private: bool = False, + exists_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> Collection: + """Create a new Collection on the Hub. + + Args: + title (`str`): + Title of the collection to create. Example: `"Recent models"`. + namespace (`str`, *optional*): + Namespace of the collection to create (username or org). Will default to the owner name. + description (`str`, *optional*): + Description of the collection to create. + private (`bool`, *optional*): + Whether the collection should be private or not. Defaults to `False` (i.e. public collection). + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if collection already exists. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import create_collection + >>> collection = create_collection( + ... title="ICCV 2023", + ... description="Portfolio of models, papers and demos I presented at ICCV 2023", + ... ) + >>> collection.slug + "username/iccv-2023-64f9a55bb3115b4f513ec026" + ``` + """ + if namespace is None: + namespace = self.whoami(token)["name"] + + payload = { + "title": title, + "namespace": namespace, + "private": private, + } + if description is not None: + payload["description"] = description + + r = get_session().post( + f"{self.endpoint}/api/collections", headers=self._build_hf_headers(token=token), json=payload + ) + try: + hf_raise_for_status(r) + except HfHubHTTPError as err: + if exists_ok and err.response.status_code == 409: + # Collection already exists and `exists_ok=True` + slug = r.json()["slug"] + return self.get_collection(slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_metadata( + self, + collection_slug: str, + *, + title: Optional[str] = None, + description: Optional[str] = None, + position: Optional[int] = None, + private: Optional[bool] = None, + theme: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Collection: + """Update metadata of a collection on the Hub. + + All arguments are optional. Only provided metadata will be updated. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection to update. + description (`str`, *optional*): + Description of the collection to update. + position (`int`, *optional*): + New position of the collection in the list of collections of the user. + private (`bool`, *optional*): + Whether the collection should be private or not. + theme (`str`, *optional*): + Theme of the collection on the Hub. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import update_collection_metadata + >>> collection = update_collection_metadata( + ... collection_slug="username/iccv-2023-64f9a55bb3115b4f513ec026", + ... title="ICCV Oct. 2023" + ... description="Portfolio of models, datasets, papers and demos I presented at ICCV Oct. 2023", + ... private=False, + ... theme="pink", + ... ) + >>> collection.slug + "username/iccv-oct-2023-64f9a55bb3115b4f513ec026" + # ^collection slug got updated but not the trailing ID + ``` + """ + payload = { + "position": position, + "private": private, + "theme": theme, + "title": title, + "description": description, + } + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + return Collection(**{**r.json()["data"], "endpoint": self.endpoint}) + + def delete_collection( + self, collection_slug: str, *, missing_ok: bool = False, token: Union[bool, str, None] = None + ) -> None: + """Delete a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to delete. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if collection doesn't exists. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import delete_collection + >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) + ``` + + > [!WARNING] + > This is a non-revertible action. A deleted collection cannot be restored. + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + try: + hf_raise_for_status(r) + except HfHubHTTPError as err: + if missing_ok and err.response.status_code == 404: + # Collection doesn't exists and `missing_ok=True` + return + else: + raise + + def add_collection_item( + self, + collection_slug: str, + item_id: str, + item_type: CollectionItemType_T, + *, + note: Optional[str] = None, + exists_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> Collection: + """Add an item to a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_id (`str`): + Id of the item to add to the collection. Use the repo_id for repos/spaces/datasets, + the paper id for papers, or the slug of another collection (e.g. `"moonshotai/kimi-k2"`). + item_type (`str`): + Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"`, `"paper"` or `"collection"`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if item already exists. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Raises: + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HfHubHTTPError`]: + HTTP 404 if the item you try to add to the collection does not exist on the Hub. + [`HfHubHTTPError`]: + HTTP 409 if the item you try to add to the collection is already in the collection (and exists_ok=False) + + Example: + + ```py + >>> from huggingface_hub import add_collection_item + >>> collection = add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="pierre-loic/climate-news-articles", + ... item_type="dataset" + ... ) + >>> collection.items[-1].item_id + "pierre-loic/climate-news-articles" + # ^item got added to the collection on last position + + # Add item with a note + >>> add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="datasets/climate_fever", + ... item_type="dataset" + ... note="This dataset adopts the FEVER methodology that consists of 1,535 real-world claims regarding climate-change collected on the internet." + ... ) + (...) + ``` + """ + payload: dict[str, Any] = {"item": {"id": item_id, "type": item_type}} + if note is not None: + payload["note"] = note + r = get_session().post( + f"{self.endpoint}/api/collections/{collection_slug}/items", + headers=self._build_hf_headers(token=token), + json=payload, + ) + try: + hf_raise_for_status(r) + except HfHubHTTPError as err: + if exists_ok and err.response.status_code == 409: + # Item already exists and `exists_ok=True` + return self.get_collection(collection_slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + note: Optional[str] = None, + position: Optional[int] = None, + token: Union[bool, str, None] = None, + ) -> None: + """Update an item in a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + position (`int`, *optional*): + New position of the item in the collection. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import get_collection, update_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Update item based on its ID (add note + update position) + >>> update_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... note="Newly updated model!" + ... position=0, + ... ) + ``` + """ + payload = {"position": position, "note": note} + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + + def delete_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + missing_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> None: + """Delete an item from a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if item doesn't exists. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import get_collection, delete_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Delete item based on its ID + >>> delete_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... ) + ``` + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(r) + except HfHubHTTPError as err: + if missing_ok and err.response.status_code == 404: + # Item already deleted and `missing_ok=True` + return + else: + raise + + ########################## + # Manage access requests # + ########################## + + @validate_hf_hub_args + def list_pending_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> Iterable[AccessRequest]: + """ + Get pending access requests for a given gated repo. + + A pending request means the user has requested access to the repo but the request has not been processed yet. + If the approval mode is automatic, this list should be empty. Pending requests can be accepted or rejected + using [`accept_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HfHubHTTPError`]: + HTTP 400 if the repo is not gated. + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_pending_access_requests, accept_access_request + + # List pending requests + >>> requests = list(list_pending_access_requests("meta-llama/Llama-2-7b")) + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='pending', + fields=None, + ), + ... + ] + + # Accept Clem's request + >>> accept_access_request("meta-llama/Llama-2-7b", "clem") + ``` + """ + yield from self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_accepted_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> Iterable[AccessRequest]: + """ + Get accepted access requests for a given gated repo. + + An accepted request means the user has requested access to the repo and the request has been accepted. The user + can download any file of the repo. If the approval mode is automatic, this list should contains by default all + requests. Accepted requests can be cancelled or rejected at any time using [`cancel_access_request`] and + [`reject_access_request`]. A cancelled request will go back to the pending list while a rejected request will + go to the rejected list. In both cases, the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HfHubHTTPError`]: + HTTP 400 if the repo is not gated. + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_accepted_access_requests + + >>> requests = list(list_accepted_access_requests("meta-llama/Llama-2-7b")) + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='accepted', + fields=None, + ), + ... + ] + ``` + """ + yield from self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_rejected_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> Iterable[AccessRequest]: + """ + Get rejected access requests for a given gated repo. + + A rejected request means the user has requested access to the repo and the request has been explicitly rejected + by a repo owner (either you or another user from your organization). The user cannot download any file of the + repo. Rejected requests can be accepted or cancelled at any time using [`accept_access_request`] and + [`cancel_access_request`]. A cancelled request will go back to the pending list while an accepted request will + go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HfHubHTTPError`]: + HTTP 400 if the repo is not gated. + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_rejected_access_requests + + >>> requests = list(list_rejected_access_requests("meta-llama/Llama-2-7b")) + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='rejected', + fields=None, + ), + ... + ] + ``` + """ + yield from self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token) + + def _list_access_requests( + self, + repo_id: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[AccessRequest]: + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + for request in paginate( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield AccessRequest( + username=request["user"]["user"], + fullname=request["user"]["fullname"], + email=request["user"].get("email"), + status=request["status"], + timestamp=parse_datetime(request["timestamp"]), + fields=request.get("fields"), # only if custom fields in form + ) + + @validate_hf_hub_args + def cancel_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Cancel an access request from a user for a given gated repo. + + A cancelled request will go back to the pending list and the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to cancel access request for. + user (`str`): + The username of the user which access request should be cancelled. + repo_type (`str`, *optional*): + The type of the repo to cancel access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HfHubHTTPError`]: + HTTP 400 if the repo is not gated. + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HfHubHTTPError`]: + HTTP 404 if the user does not exist on the Hub. + [`HfHubHTTPError`]: + HTTP 404 if the user access request cannot be found. + [`HfHubHTTPError`]: + HTTP 404 if the user access request is already in the pending list. + """ + self._handle_access_request(repo_id, user, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def accept_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Accept an access request from a user for a given gated repo. + + Once the request is accepted, the user will be able to download any file of the repo and access the community + tab. If the approval mode is automatic, you don't have to accept requests manually. An accepted request can be + cancelled or rejected at any time using [`cancel_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to accept access request for. + user (`str`): + The username of the user which access request should be accepted. + repo_type (`str`, *optional*): + The type of the repo to accept access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HfHubHTTPError`]: + HTTP 400 if the repo is not gated. + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HfHubHTTPError`]: + HTTP 404 if the user does not exist on the Hub. + [`HfHubHTTPError`]: + HTTP 404 if the user access request cannot be found. + [`HfHubHTTPError`]: + HTTP 404 if the user access request is already in the accepted list. + """ + self._handle_access_request(repo_id, user, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def reject_access_request( + self, + repo_id: str, + user: str, + *, + repo_type: Optional[str] = None, + rejection_reason: Optional[str], + token: Union[bool, str, None] = None, + ) -> None: + """ + Reject an access request from a user for a given gated repo. + + A rejected request will go to the rejected list. The user cannot download any file of the repo. Rejected + requests can be accepted or cancelled at any time using [`accept_access_request`] and [`cancel_access_request`]. + A cancelled request will go back to the pending list while an accepted request will go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to reject access request for. + user (`str`): + The username of the user which access request should be rejected. + repo_type (`str`, *optional*): + The type of the repo to reject access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + rejection_reason (`str`, *optional*): + Optional rejection reason that will be visible to the user (max 200 characters). + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HfHubHTTPError`]: + HTTP 400 if the repo is not gated. + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HfHubHTTPError`]: + HTTP 404 if the user does not exist on the Hub. + [`HfHubHTTPError`]: + HTTP 404 if the user access request cannot be found. + [`HfHubHTTPError`]: + HTTP 404 if the user access request is already in the rejected list. + """ + self._handle_access_request( + repo_id, user, "rejected", repo_type=repo_type, rejection_reason=rejection_reason, token=token + ) + + @validate_hf_hub_args + def _handle_access_request( + self, + repo_id: str, + user: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + rejection_reason: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + payload = {"user": user, "status": status} + + if rejection_reason is not None: + if status != "rejected": + raise ValueError("`rejection_reason` can only be passed when rejecting an access request.") + payload["rejectionReason"] = rejection_reason + + response = get_session().post( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/handle", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def grant_access( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Grant access to a user for a given gated repo. + + Granting access don't require for the user to send an access request by themselves. The user is automatically + added to the accepted list meaning they can download the files You can revoke the granted access at any time + using [`cancel_access_request`] or [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to grant access to. + user (`str`): + The username of the user to grant access. + repo_type (`str`, *optional*): + The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HfHubHTTPError`]: + HTTP 400 if the repo is not gated. + [`HfHubHTTPError`]: + HTTP 400 if the user already has access to the repo. + [`HfHubHTTPError`]: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HfHubHTTPError`]: + HTTP 404 if the user does not exist on the Hub. + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + response = get_session().post( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/grant", + headers=self._build_hf_headers(token=token), + json={"user": user}, + ) + hf_raise_for_status(response) + return response.json() + + ################### + # Manage webhooks # + ################### + + @validate_hf_hub_args + def get_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Get a webhook by its id. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to get. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the webhook. + + Example: + ```python + >>> from huggingface_hub import get_webhook + >>> webhook = get_webhook("654bbbc16f2ec14d77f109cc") + >>> print(webhook) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + job=None, + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + secret="my-secret", + domains=["repo", "discussion"], + disabled=False, + ) + ``` + """ + response = get_session().get( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data.get("url"), + job=JobSpec(**webhook_data["job"]) if webhook_data.get("job") else None, + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def list_webhooks(self, *, token: Union[bool, str, None] = None) -> list[WebhookInfo]: + """List all configured webhooks. + + Args: + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `list[WebhookInfo]`: + List of webhook info objects. + + Example: + ```python + >>> from huggingface_hub import list_webhooks + >>> webhooks = list_webhooks() + >>> len(webhooks) + 2 + >>> webhooks[0] + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + secret="my-secret", + domains=["repo", "discussion"], + disabled=False, + ) + ``` + """ + response = get_session().get( + f"{constants.ENDPOINT}/api/settings/webhooks", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhooks_data = response.json() + + return [ + WebhookInfo( + id=webhook["id"], + url=webhook.get("url"), + job=JobSpec(**webhook["job"]) if webhook.get("job") else None, + watched=[WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook["watched"]], + domains=webhook["domains"], + secret=webhook.get("secret"), + disabled=webhook["disabled"], + ) + for webhook in webhooks_data + ] + + @validate_hf_hub_args + def create_webhook( + self, + *, + url: Optional[str] = None, + job_id: Optional[str] = None, + watched: list[Union[dict, WebhookWatchedItem]], + domains: Optional[list[constants.WEBHOOK_DOMAIN_T]] = None, + secret: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> WebhookInfo: + """Create a new webhook. + + The webhook can either send a payload to a URL, or trigger a Job to run on Hugging Face infrastructure. + This function should be called with one of `url` or `job_id`, but not both. + + Args: + url (`str`): + URL to send the payload to. + job_id (`str`): + ID of the source Job to trigger with the webhook payload in the environment variable WEBHOOK_PAYLOAD. + Additional environment variables are available for convenience: WEBHOOK_REPO_ID, WEBHOOK_REPO_TYPE and WEBHOOK_SECRET. + watched (`list[WebhookWatchedItem]`): + List of [`WebhookWatchedItem`] to be watched by the webhook. It can be users, orgs, models, datasets or spaces. + Watched items can also be provided as plain dictionaries. + domains (`list[Literal["repo", "discussion"]]`, optional): + List of domains to watch. It can be "repo", "discussion" or both. + secret (`str`, optional): + A secret to sign the payload with. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the newly created webhook. + + Example: + + Create a webhook that sends a payload to a URL + ```python + >>> from huggingface_hub import create_webhook + >>> payload = create_webhook( + ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], + ... url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + ... domains=["repo", "discussion"], + ... secret="my-secret", + ... ) + >>> print(payload) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + job=None, + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=False, + ) + ``` + + Run a Job and then create a webhook that triggers this Job + ```python + >>> from huggingface_hub import create_webhook, run_job + >>> job = run_job( + ... image="ubuntu", + ... command=["bash", "-c", r"echo An event occured in $WEBHOOK_REPO_ID: $WEBHOOK_PAYLOAD"], + ... ) + >>> payload = create_webhook( + ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], + ... job_id=job.id, + ... domains=["repo", "discussion"], + ... secret="my-secret", + ... ) + >>> print(payload) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url=None, + job=JobSpec( + docker_image='ubuntu', + space_id=None, + command=['bash', '-c', 'echo An event occured in $WEBHOOK_REPO_ID: $WEBHOOK_PAYLOAD'], + arguments=[], + environment={}, + secrets=[], + flavor='cpu-basic', + timeout=None, + tags=None, + arch=None + ), + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=False, + ) + ``` + """ + watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] + + post_webhooks_json = {"watched": watched_dicts, "domains": domains, "secret": secret} + if url is not None and job_id is not None: + raise ValueError("Set `url` or `job_id` but not both.") + elif url is not None: + post_webhooks_json["url"] = url + elif job_id is not None: + post_webhooks_json["jobSourceId"] = job_id + else: + raise ValueError("Missing argument for webhook: `url` or `job_id`.") + + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks", + json=post_webhooks_json, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data.get("url"), + job=JobSpec(**webhook_data["job"]) if webhook_data.get("job") else None, + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def update_webhook( + self, + webhook_id: str, + *, + url: Optional[str] = None, + watched: Optional[list[Union[dict, WebhookWatchedItem]]] = None, + domains: Optional[list[constants.WEBHOOK_DOMAIN_T]] = None, + secret: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> WebhookInfo: + """Update an existing webhook. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to be updated. + url (`str`, optional): + The URL to which the payload will be sent. + watched (`list[WebhookWatchedItem]`, optional): + List of items to watch. It can be users, orgs, models, datasets, or spaces. + Refer to [`WebhookWatchedItem`] for more details. Watched items can also be provided as plain dictionaries. + domains (`list[Literal["repo", "discussion"]]`, optional): + The domains to watch. This can include "repo", "discussion", or both. + secret (`str`, optional): + A secret to sign the payload with, providing an additional layer of security. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the updated webhook. + + Example: + ```python + >>> from huggingface_hub import update_webhook + >>> updated_payload = update_webhook( + ... webhook_id="654bbbc16f2ec14d77f109cc", + ... url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], + ... domains=["repo"], + ... secret="my-secret", + ... ) + >>> print(updated_payload) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + job=None, + url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo"], + secret="my-secret", + disabled=False, + ``` + """ + if watched is None: + watched = [] + watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] + + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data.get("url"), + job=JobSpec(**webhook_data["job"]) if webhook_data.get("job") else None, + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def enable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Enable a webhook (makes it "active"). + + Args: + webhook_id (`str`): + The unique identifier of the webhook to enable. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the enabled webhook. + + Example: + ```python + >>> from huggingface_hub import enable_webhook + >>> enabled_webhook = enable_webhook("654bbbc16f2ec14d77f109cc") + >>> enabled_webhook + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + job=None, + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=False, + ) + ``` + """ + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/enable", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data.get("url"), + job=JobSpec(**webhook_data["job"]) if webhook_data.get("job") else None, + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def disable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Disable a webhook (makes it "disabled"). + + Args: + webhook_id (`str`): + The unique identifier of the webhook to disable. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the disabled webhook. + + Example: + ```python + >>> from huggingface_hub import disable_webhook + >>> disabled_webhook = disable_webhook("654bbbc16f2ec14d77f109cc") + >>> disabled_webhook + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + jon=None, + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=True, + ) + ``` + """ + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/disable", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data.get("url"), + job=JobSpec(**webhook_data["job"]) if webhook_data.get("job") else None, + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def delete_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> None: + """Delete a webhook. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to delete. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `None` + + Example: + ```python + >>> from huggingface_hub import delete_webhook + >>> delete_webhook("654bbbc16f2ec14d77f109cc") + ``` + """ + response = get_session().delete( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + ############# + # Internals # + ############# + + def _build_hf_headers( + self, + token: Union[bool, str, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[dict, str, None] = None, + ) -> dict[str, str]: + """ + Alias for [`build_hf_headers`] that uses the token from [`HfApi`] client + when `token` is not provided. + """ + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + return build_hf_headers( + token=token, + library_name=library_name or self.library_name, + library_version=library_version or self.library_version, + user_agent=user_agent or self.user_agent, + headers=self.headers, + ) + + def _prepare_folder_deletions( + self, + repo_id: str, + repo_type: Optional[str], + revision: Optional[str], + path_in_repo: str, + delete_patterns: Optional[Union[list[str], str]], + token: Union[bool, str, None] = None, + ) -> list[CommitOperationDelete]: + """Generate the list of Delete operations for a commit to delete files from a repo. + + List remote files and match them against the `delete_patterns` constraints. Returns a list of [`CommitOperationDelete`] + with the matching items. + + Note: `.gitattributes` file is essential to make a repo work properly on the Hub. This file will always be + kept even if it matches the `delete_patterns` constraints. + """ + if delete_patterns is None: + # If no delete patterns, no need to list and filter remote files + return [] + + # List remote files + filenames = self.list_repo_files(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) + + # Compute relative path in repo + if path_in_repo and path_in_repo not in (".", "./"): + path_in_repo = path_in_repo.strip("/") + "/" # harmonize + relpath_to_abspath = { + file[len(path_in_repo) :]: file for file in filenames if file.startswith(path_in_repo) + } + else: + relpath_to_abspath = {file: file for file in filenames} + + # Apply filter on relative paths and return + return [ + CommitOperationDelete(path_in_repo=relpath_to_abspath[relpath], is_folder=False) + for relpath in filter_repo_objects(relpath_to_abspath.keys(), allow_patterns=delete_patterns) + if relpath_to_abspath[relpath] != ".gitattributes" + ] + + def _prepare_upload_folder_additions( + self, + folder_path: Union[str, Path], + path_in_repo: str, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> list[CommitOperationAdd]: + """Generate the list of Add operations for a commit to upload a folder. + + Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist) + constraints are discarded. + """ + + folder_path = Path(folder_path).expanduser().resolve() + if not folder_path.is_dir(): + raise ValueError(f"Provided path: '{folder_path}' is not a directory") + + # List files from folder + relpath_to_abspath = { + path.relative_to(folder_path).as_posix(): path + for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic + if path.is_file() + } + + # Filter files + # Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering. + filtered_repo_objects = list( + filter_repo_objects( + relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns + ) + ) + + prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else "" + + # If updating a README.md file, make sure the metadata format is valid + # It's better to fail early than to fail after all the files have been hashed. + if "README.md" in filtered_repo_objects: + self._validate_yaml( + content=relpath_to_abspath["README.md"].read_text(encoding="utf8"), + repo_type=repo_type, + token=token, + ) + if len(filtered_repo_objects) > 30: + log = logger.warning if len(filtered_repo_objects) > 200 else logger.info + log( + "It seems you are trying to upload a large folder at once. This might take some time and then fail if " + "the folder is too large. For such cases, it is recommended to upload in smaller batches or to use " + "`HfApi().upload_large_folder(...)`/`hf upload-large-folder` instead. For more details, " + "check out https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#upload-a-large-folder." + ) + + logger.info(f"Start hashing {len(filtered_repo_objects)} files.") + operations = [ + CommitOperationAdd( + path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk + path_in_repo=prefix + relpath, # "absolute" path in repo + ) + for relpath in filtered_repo_objects + ] + logger.info(f"Finished hashing {len(filtered_repo_objects)} files.") + return operations + + def _validate_yaml(self, content: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None): + """ + Validate YAML from `README.md`, used before file hashing and upload. + + Args: + content (`str`): + Content of `README.md` to validate. + repo_type (`str`, *optional*): + The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if YAML is invalid + """ + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + headers = self._build_hf_headers(token=token) + + response = get_session().post( + f"{self.endpoint}/api/validate-yaml", + json={"content": content, "repoType": repo_type}, + headers=headers, + ) + # Handle warnings (example: empty metadata) + response_content = response.json() + message = "\n".join([f"- {warning.get('message')}" for warning in response_content.get("warnings", [])]) + if message: + warnings.warn(f"Warnings while validating metadata in README.md:\n{message}") + + # Raise on errors + try: + hf_raise_for_status(response) + except BadRequestError as e: + errors = response_content.get("errors", []) + message = "\n".join([f"- {error.get('message')}" for error in errors]) + raise ValueError(f"Invalid metadata in README.md.\n{message}") from e + + def get_user_overview(self, username: str, token: Union[bool, str, None] = None) -> User: + """ + Get an overview of a user on the Hub. + + Args: + username (`str`): + Username of the user to get an overview of. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `User`: A [`User`] object with the user's overview. + + Raises: + [`HfHubHTTPError`]: + HTTP 404 If the user does not exist on the Hub. + """ + r = get_session().get( + f"{constants.ENDPOINT}/api/users/{username}/overview", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return User(**r.json()) + + @validate_hf_hub_args + def get_organization_overview(self, organization: str, token: Union[bool, str, None] = None) -> Organization: + """ + Get an overview of an organization on the Hub. + + Args: + organization (`str`): + Name of the organization to get an overview of. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved token, which is the recommended method + for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Organization`: An [`Organization`] object with the organization's overview. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the organization does not exist on the Hub. + """ + r = get_session().get( + f"{constants.ENDPOINT}/api/organizations/{organization}/overview", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return Organization(**r.json()) + + @validate_hf_hub_args + def list_organization_followers(self, organization: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + List followers of an organization on the Hub. + + Args: + organization (`str`): + Name of the organization to get the followers of. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the followers of the organization. + + Raises: + [`HfHubHTTPError`]: + HTTP 404 If the organization does not exist on the Hub. + + """ + for follower in paginate( + path=f"{constants.ENDPOINT}/api/organizations/{organization}/followers", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**follower) + + def list_organization_members(self, organization: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + List of members of an organization on the Hub. + + Args: + organization (`str`): + Name of the organization to get the members of. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the members of the organization. + + Raises: + [`HfHubHTTPError`]: + HTTP 404 If the organization does not exist on the Hub. + + """ + for member in paginate( + path=f"{constants.ENDPOINT}/api/organizations/{organization}/members", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**member) + + def list_user_followers(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + Get the list of followers of a user on the Hub. + + Args: + username (`str`): + Username of the user to get the followers of. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the followers of the user. + + Raises: + [`HfHubHTTPError`]: + HTTP 404 If the user does not exist on the Hub. + + """ + for follower in paginate( + path=f"{constants.ENDPOINT}/api/users/{username}/followers", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**follower) + + def list_user_following(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + Get the list of users followed by a user on the Hub. + + Args: + username (`str`): + Username of the user to get the users followed by. + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the users followed by the user. + + Raises: + [`HfHubHTTPError`]: + HTTP 404 If the user does not exist on the Hub. + + """ + for followed_user in paginate( + path=f"{constants.ENDPOINT}/api/users/{username}/following", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**followed_user) + + def list_papers( + self, + *, + query: Optional[str] = None, + limit: Optional[int] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[PaperInfo]: + """ + List daily papers on the Hugging Face Hub given a search query. + + Args: + query (`str`, *optional*): + A search query string to find papers. + If provided, returns papers that match the query. + limit (`int`, *optional*): + The maximum number of papers to return. + token (Union[bool, str, None], *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects. + + Example: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + # List all papers with "attention" in their title + >>> api.list_papers(query="attention") + ``` + """ + path = f"{self.endpoint}/api/papers/search" + params: dict[str, Any] = {} + if query: + params["q"] = query + if limit is not None: + params["limit"] = limit + r = get_session().get( + path, + params=params, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + for paper in r.json(): + yield PaperInfo(**paper) + + def paper_info(self, id: str) -> PaperInfo: + """ + Get information for a paper on the Hub. + + Args: + id (`str`, **optional**): + ArXiv id of the paper. + + Returns: + `PaperInfo`: A `PaperInfo` object. + + Raises: + [`HfHubHTTPError`]: + HTTP 404 If the paper does not exist on the Hub. + """ + path = f"{self.endpoint}/api/papers/{id}" + r = get_session().get(path) + hf_raise_for_status(r) + return PaperInfo(**r.json()) + + def list_daily_papers( + self, + *, + date: Optional[str] = None, + token: Union[bool, str, None] = None, + week: Optional[str] = None, + month: Optional[str] = None, + submitter: Optional[str] = None, + sort: Optional[DailyPapersSort_T] = None, + p: Optional[int] = None, + limit: Optional[int] = None, + ) -> Iterable[PaperInfo]: + """ + List the daily papers published on a given date on the Hugging Face Hub. + + Args: + date (`str`, *optional*): + Date in ISO format (YYYY-MM-DD) for which to fetch daily papers. + Defaults to most recent ones. + token (Union[bool, str, None], *optional*): + A valid user access token (string). Defaults to the locally saved + token. To disable authentication, pass `False`. + week (`str`, *optional*): + Week in ISO format (YYYY-Www) for which to fetch daily papers. Example, `2025-W09`. + month (`str`, *optional*): + Month in ISO format (YYYY-MM) for which to fetch daily papers. Example, `2025-02`. + submitter (`str`, *optional*): + Username of the submitter to filter daily papers. + sort (`Literal["publishedAt", "trending"]`, *optional*): + Sort order for the daily papers. Can be either by `publishedAt` or by `trending`. + Defaults to `"publishedAt"` + p (`int`, *optional*): + Page number for pagination. Defaults to 0. + limit (`int`, *optional*): + Limit of papers to fetch. Defaults to 50. + + Returns: + `Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects. + + Example: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + >>> list(api.list_daily_papers(date="2025-10-29")) + ``` + """ + path = f"{self.endpoint}/api/daily_papers" + + params = { + k: v + for k, v in { + "p": p, + "limit": limit, + "sort": sort, + "date": date, + "week": week, + "month": month, + "submitter": submitter, + }.items() + if v is not None + } + + r = get_session().get(path, params=params, headers=self._build_hf_headers(token=token)) + hf_raise_for_status(r) + for paper in r.json(): + yield PaperInfo(**paper) + + def auth_check( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Check if the provided user token has access to a specific repository on the Hugging Face Hub. + + This method verifies whether the user, authenticated via the provided token, has access to the specified + repository. If the repository is not found or if the user lacks the required permissions to access it, + the method raises an appropriate exception. + + Args: + repo_id (`str`): + The repository to check for access. Format should be `"user/repo_name"`. + Example: `"user/my-cool-model"`. + + repo_type (`str`, *optional*): + The type of the repository. Should be one of `"model"`, `"dataset"`, or `"space"`. + If not specified, the default is `"model"`. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Raises: + [`~utils.RepositoryNotFoundError`]: + Raised if the repository does not exist, is private, or the user does not have access. This can + occur if the `repo_id` or `repo_type` is incorrect or if the repository is private but the user + is not authenticated. + + [`~utils.GatedRepoError`]: + Raised if the repository exists but is gated and the user is not authorized to access it. + + Example: + Check if the user has access to a repository: + + ```python + >>> from huggingface_hub import auth_check + >>> from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + + try: + auth_check("user/my-cool-model") + except GatedRepoError: + # Handle gated repository error + print("You do not have permission to access this gated repository.") + except RepositoryNotFoundError: + # Handle repository not found error + print("The repository was not found or you do not have access.") + ``` + + In this example: + - If the user has access, the method completes successfully. + - If the repository is gated or does not exist, appropriate exceptions are raised, allowing the user + to handle them accordingly. + """ + headers = self._build_hf_headers(token=token) + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/auth-check" + r = get_session().get(path, headers=headers) + hf_raise_for_status(r) + + def run_job( + self, + *, + image: str, + command: list[str], + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + labels: Optional[dict[str, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> JobInfo: + """ + Run compute Jobs on Hugging Face infrastructure. + + Args: + image (`str`): + The Docker image to use. + Examples: `"ubuntu"`, `"python:3.12"`, `"pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"`. + Example with an image from a Space: `"hf.co/spaces/lhoestq/duckdb"`. + + command (`list[str]`): + The command to run. Example: `["echo", "hello"]`. + + env (`dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + labels (`dict[str, str]`, *optional*): + Labels to attach to the job (key-value pairs). + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + Run your first Job: + + ```python + >>> from huggingface_hub import run_job + >>> run_job(image="python:3.12", command=["python", "-c" ,"print('Hello from HF compute!')"]) + ``` + + Run a GPU Job: + + ```python + >>> from huggingface_hub import run_job + >>> image = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel" + >>> command = ["python", "-c", "import torch; print(f"This code ran with the following GPU: {torch.cuda.get_device_name()}")"] + >>> run_job(image=image, command=command, flavor="a10g-small") + ``` + + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + job_spec = _create_job_spec( + image=image, + command=command, + env=env, + secrets=secrets, + flavor=flavor, + timeout=timeout, + labels=labels, + ) + response = get_session().post( + f"{self.endpoint}/api/jobs/{namespace}", + json=job_spec, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + job_info = response.json() + return JobInfo(**job_info, endpoint=self.endpoint) + + def _fetch_running_job_sse( + self, + *, + job_id: str, + route: str, + timeout: int, + skip_previous_events_on_retry: bool, + double_check_job_has_finished_on_status_code_or_error: tuple[Union[int, Type[Exception]], ...], + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[dict[str, Any]]: + if namespace is None: + namespace = self.whoami(token=token)["name"] + # We don't use http_backoff since we need to check ourselves if the job is still running + nb_tries = 0 + max_retries = 5 + min_wait_time = 1 + max_wait_time = 10 + sleep_time = 0 + start_event_idx = 0 + error_to_retry = None + while True: + if error_to_retry is not None: + logger.warning(f"'{error_to_retry}' thrown while requesting jobs /{route} for {job_id=}") + logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") + error_to_retry = None + time.sleep(sleep_time) + try: + with get_session().stream( + "GET", + f"{self.endpoint}/api/jobs/{namespace}/{job_id}/{route}", + headers=self._build_hf_headers(token=token), + timeout=timeout, + ) as response: + if response.status_code == 200: + event_idx = -1 + for line in response.iter_lines(): + if line and line.startswith("data: {"): + event_idx += 1 + if event_idx >= start_event_idx: + if skip_previous_events_on_retry: + start_event_idx += 1 + yield json.loads(line[len("data: ") :]) + break + elif response.status_code not in double_check_job_has_finished_on_status_code_or_error: + hf_raise_for_status(response) + except httpx.HTTPStatusError: + raise + except httpx.DecodingError: + # Response ended prematurely + break + except KeyboardInterrupt: + break + except (httpx.HTTPError, httpcore.TimeoutException) as err: + is_no_new_line_timeout = ( + isinstance(err, httpx.NetworkError) + and err.__context__ + and isinstance(getattr(err.__context__, "__cause__", None), TimeoutError) + ) + if is_no_new_line_timeout: + # job is likely finished + pass + elif type(err) in double_check_job_has_finished_on_status_code_or_error: + pass + elif nb_tries >= max_retries: + raise + else: + nb_tries += 1 + sleep_time = min(max_wait_time, max(min_wait_time, sleep_time * 2)) + error_to_retry = err + job_status_response = get_session().get( + f"{self.endpoint}/api/jobs/{namespace}/{job_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(job_status_response) + job_status = job_status_response.json() + if "status" in job_status and job_status["status"]["stage"] not in ("RUNNING", "UPDATING"): + break + + def fetch_job_logs( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[str]: + """ + Fetch all the logs from a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import fetch_job_logs, run_job + >>> job = run_job(image="python:3.12", command=["python", "-c" ,"print('Hello from HF compute!')"]) + >>> for log in fetch_job_logs(job_id=job.id): + ... print(log) + Hello from HF compute! + ``` + """ + # - We need to retry because sometimes the /logs doesn't return logs when the job just started. + # (for example it can return only two lines: one for "Job started" and one empty line) + # - Timeouts can happen in case of build errors + # - ChunkedEncodingError can happen in case of stopped logging in the middle of streaming + # - Infinite empty log stream can happen in case of build error + # (the logs stream is infinite and empty except for the Job started message) + # - there is a ": keep-alive" every 30 seconds + + seconds_between_keep_alive = 30 + for event in self._fetch_running_job_sse( + job_id=job_id, + route="logs", + timeout=4 * seconds_between_keep_alive, + skip_previous_events_on_retry=True, + double_check_job_has_finished_on_status_code_or_error=tuple(), + namespace=namespace, + token=token, + ): + # timestamp = event["timestamp"] + if not event["data"].startswith("===== Job started"): + log = event["data"] + yield log + + def fetch_job_metrics( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[dict[str, Any]]: + """ + Fetch all the live metrics from a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import fetch_job_metrics, run_job + >>> job = run_job(image="python:3.12", command=["python", "-c" ,"print('Hello from HF compute!')"], flavor="a10g-small") + >>> for metrics in fetch_job_metrics(job_id=job.id): + ... print(metrics) + { + "cpu_usage_pct": 0, + "cpu_millicores": 3500, + "memory_used_bytes": 1306624, + "memory_total_bytes": 15032385536, + "rx_bps": 0, + "tx_bps": 0, + "gpus": { + "882fa930": { + "utilization": 0, + "memory_used_bytes": 0, + "memory_total_bytes": 22836000000 + } + }, + "replica": "57vr7" + } + ``` + """ + # - there is one "metric" event every second, like this: + # event: metric + # data: {"cpu_usage_pct":0,"cpu_millicores":3500,"memory_used_bytes":1417216,"memory_total_bytes":15032385536,"rx_bps":0,"tx_bps":0,"gpus":{"d901cd7f":{"utilization":0,"memory_used_bytes":0,"memory_total_bytes":22836000000}},"replica":"j6qz9"} + # - the stream doesn't end when the job finishes, so we rely on timeouts (httpx.NetworkError with Timeout as cause) + # - httpx.ReadTimeout can happen if the job is marked as running but the hardware is not available yet, that we can ignore + # - it returns an internal error 500 if the job has already finished, we simply ignore it + # - ChunkedEncodingError can happen in case of stopped logging in the middle of streaming + # - there is a ": keep-alive" every 30 seconds + seconds_between_events = 1 + yield from self._fetch_running_job_sse( + job_id=job_id, + route="metrics", + timeout=10 * seconds_between_events, + skip_previous_events_on_retry=False, + double_check_job_has_finished_on_status_code_or_error=(500, httpx.ReadTimeout), + namespace=namespace, + token=token, + ) + + def list_jobs( + self, + *, + timeout: Optional[int] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> list[JobInfo]: + """ + List compute Jobs on Hugging Face infrastructure. + + Args: + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + + namespace (`str`, *optional*): + The namespace from where it lists the jobs. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/jobs/{namespace}", + headers=self._build_hf_headers(token=token), + timeout=timeout, + ) + response.raise_for_status() + return [JobInfo(**job_info, endpoint=self.endpoint) for job_info in response.json()] + + def list_jobs_hardware(self, token: Union[bool, str, None] = None) -> list[JobHardware]: + """ + List available hardware options for Jobs on Hugging Face infrastructure. + + Returns: + `list[JobHardware]`: A list of available hardware configurations. + + Example: + + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> hardware_list = api.list_jobs_hardware() + >>> hardware_list[0] + JobHardware(name='cpu-basic', pretty_name='CPU Basic', cpu='2 vCPU', ram='16 GB', accelerator=None, unit_cost_micro_usd=167, unit_cost_usd=0.000167, unit_label='minute') + >>> hardware_list[0].name + 'cpu-basic' + + # Filter GPU options + >>> gpu_hardware = [hw for hw in hardware_list if hw.accelerator is not None] + >>> gpu_hardware[0].accelerator.model + 'T4' + ``` + """ + response = get_session().get(f"{self.endpoint}/api/jobs/hardware", headers=self._build_hf_headers(token=token)) + hf_raise_for_status(response) + return [JobHardware(**hardware) for hardware in response.json()] + + def inspect_job( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> JobInfo: + """ + Inspect a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import inspect_job, run_job + >>> job = run_job(image="python:3.12", command=["python", "-c" ,"print('Hello from HF compute!')"]) + >>> inspect_job(job.id) + JobInfo( + id='68780d00bbe36d38803f645f', + created_at=datetime.datetime(2025, 7, 16, 20, 35, 12, 808000, tzinfo=datetime.timezone.utc), + docker_image='python:3.12', + space_id=None, + command=['python', '-c', "print('Hello from HF compute!')"], + arguments=[], + environment={}, + secrets={}, + flavor='cpu-basic', + status=JobStatus(stage='RUNNING', message=None) + ) + ``` + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/jobs/{namespace}/{job_id}", + headers=self._build_hf_headers(token=token), + ) + response.raise_for_status() + return JobInfo(**response.json(), endpoint=self.endpoint) + + def cancel_job( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Cancel a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + get_session().post( + f"{self.endpoint}/api/jobs/{namespace}/{job_id}/cancel", + headers=self._build_hf_headers(token=token), + ).raise_for_status() + + @experimental + def run_uv_job( + self, + script: str, + *, + script_args: Optional[list[str]] = None, + dependencies: Optional[list[str]] = None, + python: Optional[str] = None, + image: Optional[str] = None, + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + labels: Optional[dict[str, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> JobInfo: + """ + Run a UV script Job on Hugging Face infrastructure. + + Args: + script (`str`): + Path or URL of the UV script, or a command. + + script_args (`list[str]`, *optional*) + Arguments to pass to the script or command. + + dependencies (`list[str]`, *optional*) + Dependencies to use to run the UV script. + + python (`str`, *optional*) + Use a specific Python version. Default is 3.12. + + image (`str`, *optional*, defaults to "ghcr.io/astral-sh/uv:python3.12-bookworm"): + Use a custom Docker image with `uv` installed. + + env (`dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + labels (`dict[str, str]`, *optional*): + Labels to attach to the job (key-value pairs). + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + Run a script from a URL: + + ```python + >>> from huggingface_hub import run_uv_job + >>> script = "https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/trl/scripts/sft.py" + >>> script_args = ["--model_name_or_path", "Qwen/Qwen2-0.5B", "--dataset_name", "trl-lib/Capybara", "--push_to_hub"] + >>> run_uv_job(script, script_args=script_args, dependencies=["trl"], flavor="a10g-small") + ``` + + Run a local script: + + ```python + >>> from huggingface_hub import run_uv_job + >>> script = "my_sft.py" + >>> script_args = ["--model_name_or_path", "Qwen/Qwen2-0.5B", "--dataset_name", "trl-lib/Capybara", "--push_to_hub"] + >>> run_uv_job(script, script_args=script_args, dependencies=["trl"], flavor="a10g-small") + ``` + + Run a command: + + ```python + >>> from huggingface_hub import run_uv_job + >>> script = "lighteval" + >>> script_args= ["endpoint", "inference-providers", "model_name=openai/gpt-oss-20b,provider=auto", "lighteval|gsm8k|0|0"] + >>> run_uv_job(script, script_args=script_args, dependencies=["lighteval"], flavor="a10g-small") + ``` + """ + image = image or "ghcr.io/astral-sh/uv:python3.12-bookworm" + env = env or {} + secrets = secrets or {} + + # Build command + command, env, secrets = self._create_uv_command_env_and_secrets( + script=script, + script_args=script_args, + dependencies=dependencies, + python=python, + env=env, + secrets=secrets, + namespace=namespace, + token=token, + ) + # Create RunCommand args + return self.run_job( + image=image, + command=command, + env=env, + secrets=secrets, + flavor=flavor, + timeout=timeout, + labels=labels, + namespace=namespace, + token=token, + ) + + def create_scheduled_job( + self, + *, + image: str, + command: list[str], + schedule: str, + suspend: Optional[bool] = None, + concurrency: Optional[bool] = None, + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + labels: Optional[dict[str, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> ScheduledJobInfo: + """ + Create scheduled compute Jobs on Hugging Face infrastructure. + + Args: + image (`str`): + The Docker image to use. + Examples: `"ubuntu"`, `"python:3.12"`, `"pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"`. + Example with an image from a Space: `"hf.co/spaces/lhoestq/duckdb"`. + + command (`list[str]`): + The command to run. Example: `["echo", "hello"]`. + + schedule (`str`): + One of "@annually", "@yearly", "@monthly", "@weekly", "@daily", "@hourly", or a + CRON schedule expression (e.g., '0 9 * * 1' for 9 AM every Monday). + + suspend (`bool`, *optional*): + If True, the scheduled Job is suspended (paused). Defaults to False. + + concurrency (`bool`, *optional*): + If True, multiple instances of this Job can run concurrently. Defaults to False. + + env (`dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + labels (`dict[str, str]`, *optional*): + Labels to attach to the job (key-value pairs). + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + Create your first scheduled Job: + + ```python + >>> from huggingface_hub import create_scheduled_job + >>> create_scheduled_job(image="python:3.12", command=["python", "-c" ,"print('Hello from HF compute!')"], schedule="@hourly") + ``` + + Use a CRON schedule expression: + + ```python + >>> from huggingface_hub import create_scheduled_job + >>> create_scheduled_job(image="python:3.12", command=["python", "-c" ,"print('this runs every 5min')"], schedule="*/5 * * * *") + ``` + + Create a scheduled GPU Job: + + ```python + >>> from huggingface_hub import create_scheduled_job + >>> image = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel" + >>> command = ["python", "-c", "import torch; print(f"This code ran with the following GPU: {torch.cuda.get_device_name()}")"] + >>> create_scheduled_job(image, command, flavor="a10g-small", schedule="@hourly") + ``` + + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + + # prepare payload to send to HF Jobs API + job_spec = _create_job_spec( + image=image, + command=command, + env=env, + secrets=secrets, + flavor=flavor, + timeout=timeout, + labels=labels, + ) + input_json: dict[str, Any] = { + "jobSpec": job_spec, + "schedule": schedule, + } + if concurrency is not None: + input_json["concurrency"] = concurrency + if suspend is not None: + input_json["suspend"] = suspend + response = get_session().post( + f"{self.endpoint}/api/scheduled-jobs/{namespace}", + json=input_json, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + scheduled_job_info = response.json() + return ScheduledJobInfo(**scheduled_job_info) + + def list_scheduled_jobs( + self, + *, + timeout: Optional[int] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> list[ScheduledJobInfo]: + """ + List scheduled compute Jobs on Hugging Face infrastructure. + + Args: + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + + namespace (`str`, *optional*): + The namespace from where it lists the jobs. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/scheduled-jobs/{namespace}", + headers=self._build_hf_headers(token=token), + timeout=timeout, + ) + hf_raise_for_status(response) + return [ScheduledJobInfo(**scheduled_job_info) for scheduled_job_info in response.json()] + + def inspect_scheduled_job( + self, + *, + scheduled_job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> ScheduledJobInfo: + """ + Inspect a scheduled compute Job on Hugging Face infrastructure. + + Args: + scheduled_job_id (`str`): + ID of the scheduled Job. + + namespace (`str`, *optional*): + The namespace where the scheduled Job is. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import inspect_job, create_scheduled_job + >>> scheduled_job = create_scheduled_job(image="python:3.12", command=["python", "-c" ,"print('Hello from HF compute!')"], schedule="@hourly") + >>> inspect_scheduled_job(scheduled_job.id) + ``` + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/scheduled-jobs/{namespace}/{scheduled_job_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + return ScheduledJobInfo(**response.json()) + + def delete_scheduled_job( + self, + *, + scheduled_job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Delete a scheduled compute Job on Hugging Face infrastructure. + + Args: + scheduled_job_id (`str`): + ID of the scheduled Job. + + namespace (`str`, *optional*): + The namespace where the scheduled Job is. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().delete( + f"{self.endpoint}/api/scheduled-jobs/{namespace}/{scheduled_job_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + def suspend_scheduled_job( + self, + *, + scheduled_job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Suspend (pause) a scheduled compute Job on Hugging Face infrastructure. + + Args: + scheduled_job_id (`str`): + ID of the scheduled Job. + + namespace (`str`, *optional*): + The namespace where the scheduled Job is. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + get_session().post( + f"{self.endpoint}/api/scheduled-jobs/{namespace}/{scheduled_job_id}/suspend", + headers=self._build_hf_headers(token=token), + ).raise_for_status() + + def resume_scheduled_job( + self, + *, + scheduled_job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Resume (unpause) a scheduled compute Job on Hugging Face infrastructure. + + Args: + scheduled_job_id (`str`): + ID of the scheduled Job. + + namespace (`str`, *optional*): + The namespace where the scheduled Job is. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + get_session().post( + f"{self.endpoint}/api/scheduled-jobs/{namespace}/{scheduled_job_id}/resume", + headers=self._build_hf_headers(token=token), + ).raise_for_status() + + @experimental + def create_scheduled_uv_job( + self, + script: str, + *, + script_args: Optional[list[str]] = None, + schedule: str, + suspend: Optional[bool] = None, + concurrency: Optional[bool] = None, + dependencies: Optional[list[str]] = None, + python: Optional[str] = None, + image: Optional[str] = None, + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + labels: Optional[dict[str, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> ScheduledJobInfo: + """ + Run a UV script Job on Hugging Face infrastructure. + + Args: + script (`str`): + Path or URL of the UV script, or a command. + + script_args (`list[str]`, *optional*) + Arguments to pass to the script, or a command. + + schedule (`str`): + One of "@annually", "@yearly", "@monthly", "@weekly", "@daily", "@hourly", or a + CRON schedule expression (e.g., '0 9 * * 1' for 9 AM every Monday). + + suspend (`bool`, *optional*): + If True, the scheduled Job is suspended (paused). Defaults to False. + + concurrency (`bool`, *optional*): + If True, multiple instances of this Job can run concurrently. Defaults to False. + + dependencies (`list[str]`, *optional*) + Dependencies to use to run the UV script. + + python (`str`, *optional*) + Use a specific Python version. Default is 3.12. + + image (`str`, *optional*, defaults to "ghcr.io/astral-sh/uv:python3.12-bookworm"): + Use a custom Docker image with `uv` installed. + + env (`dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + labels (`dict[str, str]`, *optional*): + Labels to attach to the job (key-value pairs). + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + Schedule a script from a URL: + + ```python + >>> from huggingface_hub import create_scheduled_uv_job + >>> script = "https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/trl/scripts/sft.py" + >>> script_args = ["--model_name_or_path", "Qwen/Qwen2-0.5B", "--dataset_name", "trl-lib/Capybara", "--push_to_hub"] + >>> create_scheduled_uv_job(script, script_args=script_args, dependencies=["trl"], flavor="a10g-small", schedule="@weekly") + ``` + + Schedule a local script: + + ```python + >>> from huggingface_hub import create_scheduled_uv_job + >>> script = "my_sft.py" + >>> script_args = ["--model_name_or_path", "Qwen/Qwen2-0.5B", "--dataset_name", "trl-lib/Capybara", "--push_to_hub"] + >>> create_scheduled_uv_job(script, script_args=script_args, dependencies=["trl"], flavor="a10g-small", schedule="@weekly") + ``` + + Schedule a command: + + ```python + >>> from huggingface_hub import create_scheduled_uv_job + >>> script = "lighteval" + >>> script_args= ["endpoint", "inference-providers", "model_name=openai/gpt-oss-20b,provider=auto", "lighteval|gsm8k|0|0"] + >>> create_scheduled_uv_job(script, script_args=script_args, dependencies=["lighteval"], flavor="a10g-small", schedule="@weekly") + ``` + """ + image = image or "ghcr.io/astral-sh/uv:python3.12-bookworm" + # Build command + command, env, secrets = self._create_uv_command_env_and_secrets( + script=script, + script_args=script_args, + dependencies=dependencies, + python=python, + env=env, + secrets=secrets, + namespace=namespace, + token=token, + ) + # Create RunCommand args + return self.create_scheduled_job( + image=image, + command=command, + schedule=schedule, + suspend=suspend, + concurrency=concurrency, + env=env, + secrets=secrets, + flavor=flavor, + timeout=timeout, + labels=labels, + namespace=namespace, + token=token, + ) + + def _create_uv_command_env_and_secrets( + self, + *, + script: str, + script_args: Optional[list[str]], + dependencies: Optional[list[str]], + python: Optional[str], + env: Optional[dict[str, Any]], + secrets: Optional[dict[str, Any]], + namespace: Optional[str], + token: Union[bool, str, None], + ) -> tuple[list[str], dict[str, Any], dict[str, Any]]: + env = env or {} + secrets = secrets or {} + + # Build command + uv_args = [] + if dependencies: + for dependency in dependencies: + uv_args += ["--with", dependency] + if python: + uv_args += ["--python", python] + script_args = script_args or [] + + if namespace is None: + namespace = self.whoami(token=token)["name"] + + # Find the local files to pass to the job + local_files_to_include = {candidate for candidate in [script] + script_args if Path(candidate).is_file()} + # Fail early for missing scripts or config files + missing_local_files = { + candidate + for candidate in [script] + script_args + if not Path(candidate).is_file() + and Path(candidate).suffix in [".py", ".sh", ".yaml", ".yml", ".toml"] + and not candidate.startswith("https://") + and not candidate.startswith("http://") + } + if missing_local_files: + raise FileNotFoundError(", ".join(missing_local_files)) + + if len(local_files_to_include) == 0: + # Direct URL execution or command - no upload needed + command = ["uv", "run"] + uv_args + [script] + script_args + else: + # Find appropriate remote file names + remote_to_local_file_names: dict[str, str] = {} + for local_file_to_include in local_files_to_include: + local_file_path = Path(local_file_to_include) + # remove spaces for proper xargs parsing + remote_file_path = Path(local_file_path.name.replace(" ", "_")) + if remote_file_path.name in remote_to_local_file_names: + for i in itertools.count(): + remote_file_name = remote_file_path.with_stem(remote_file_path.stem + f"({i})").name + if remote_file_name not in remote_to_local_file_names: + remote_to_local_file_names[remote_file_name] = local_file_to_include + break + else: + remote_to_local_file_names[remote_file_path.name] = local_file_to_include + local_to_remote_file_names = dict( + (local_file_to_include, remote_file_name) + for remote_file_name, local_file_to_include in remote_to_local_file_names.items() + ) + + # Replace local paths with remote paths in command + if script in local_to_remote_file_names: + script = local_to_remote_file_names[script] + script_args = [ + local_to_remote_file_names[arg] if arg in local_to_remote_file_names else arg for arg in script_args + ] + + # Load content to pass as environment variable with format + # file1 base64content1 + # file2 base64content2 + # ... + env["LOCAL_FILES_ENCODED"] = "\n".join( + remote_file_name + " " + base64.b64encode(Path(local_file_to_include).read_bytes()).decode() + for remote_file_name, local_file_to_include in remote_to_local_file_names.items() + ) + command = [ + "bash", + "-c", + """echo $LOCAL_FILES_ENCODED | xargs -n 2 bash -c 'echo "$1" | base64 -d > "$0"' && """ + + f"uv run {' '.join(uv_args)} {script} {' '.join(script_args)}", + ] + return command, env, secrets + + +def _parse_revision_from_pr_url(pr_url: str) -> str: + """Safely parse revision number from a PR url. + + Example: + ```py + >>> _parse_revision_from_pr_url("https://huggingface.co/bigscience/bloom/discussions/2") + "refs/pr/2" + ``` + """ + re_match = re.match(_REGEX_DISCUSSION_URL, pr_url) + if re_match is None: + raise RuntimeError(f"Unexpected response from the hub, expected a Pull Request URL but got: '{pr_url}'") + return f"refs/pr/{re_match[1]}" + + +def parse_local_safetensors_file_metadata(path: Union[str, Path]) -> SafetensorsFileMetadata: + """ + Parse metadata from a local safetensors file. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + path (`str` or `Path`): + Path to the safetensors file. + + Returns: + [`SafetensorsFileMetadata`]: information related to the safetensors file. + + Raises: + [`SafetensorsParsingError`]: + If the safetensors file header couldn't be parsed correctly. + `FileNotFoundError`: + If the file does not exist. + + Example: + ```py + >>> metadata = parse_local_safetensors_file_metadata("path/to/model.safetensors") + >>> metadata + SafetensorsFileMetadata( + metadata={'format': 'pt'}, + tensors={'layer.weight': TensorInfo(dtype='F32', shape=[512, 512], ...}, ...} + ) + >>> metadata.parameter_count + {'F32': 262144} + ``` + """ + path = Path(path) + filename = path.name + context_msg = f"path '{path}'" + + with open(path, "rb") as f: + # 1. Read first 8 bytes and parse/validate metadata size using shared helper + size_bytes = f.read(8) + metadata_size = _get_safetensors_metadata_size(size_bytes, filename, context_msg) + + # 2. Read metadata bytes + metadata_as_bytes = f.read(metadata_size) + if len(metadata_as_bytes) < metadata_size: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' ({context_msg}): file is truncated. Expected " + f"{metadata_size} bytes of metadata but got {len(metadata_as_bytes)}." + ) + + # 3. Parse using shared helper + return _parse_safetensors_header(metadata_as_bytes, filename, context_msg) + + +def get_local_safetensors_metadata(path: Union[str, Path]) -> SafetensorsRepoMetadata: + """ + Parse metadata for a local safetensors file or folder. + + Supports: + - Single safetensors file (e.g., `model.safetensors`) + - Directory with non-sharded model (contains `model.safetensors`) + - Directory with sharded model (contains `model.safetensors.index.json`) + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + path (`str` or `Path`): + Path to a safetensors file or directory containing safetensors files. + + Returns: + [`SafetensorsRepoMetadata`]: information related to the safetensors repo. + + Raises: + [`NotASafetensorsRepoError`]: + If the path is not a valid safetensors file or folder (i.e., doesn't have either a + `model.safetensors` or a `model.safetensors.index.json` file). + [`SafetensorsParsingError`]: + If a safetensors file header couldn't be parsed correctly. + `FileNotFoundError`: + If the path does not exist. + + Example: + ```py + # Parse single safetensors file + >>> metadata = get_local_safetensors_metadata("path/to/model.safetensors") + >>> metadata + SafetensorsRepoMetadata(metadata=None, sharded=False, weight_map={...}, files_metadata={...}) + + # Parse directory with sharded model + >>> metadata = get_local_safetensors_metadata("path/to/model_folder") + >>> metadata + SafetensorsRepoMetadata(metadata={'total_size': ...}, sharded=True, weight_map={...}, files_metadata={...}) + >>> len(metadata.files_metadata) + 3 # Number of safetensors shards + ``` + """ + path = Path(path) + + # Case 1: Direct path to a safetensors file + if path.is_file(): + file_metadata = parse_local_safetensors_file_metadata(path) + return SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={tensor_name: path.name for tensor_name in file_metadata.tensors.keys()}, + files_metadata={path.name: file_metadata}, + ) + + # Case 2: Directory + if not path.is_dir(): + raise FileNotFoundError(f"Path '{path}' does not exist.") + + single_file_path = path / constants.SAFETENSORS_SINGLE_FILE + index_file_path = path / constants.SAFETENSORS_INDEX_FILE + + # Case 2a: Non-sharded model (single model.safetensors file) + if single_file_path.exists(): + file_metadata = parse_local_safetensors_file_metadata(single_file_path) + return SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={ + tensor_name: constants.SAFETENSORS_SINGLE_FILE for tensor_name in file_metadata.tensors.keys() + }, + files_metadata={constants.SAFETENSORS_SINGLE_FILE: file_metadata}, + ) + + # Case 2b: Sharded model (model.safetensors.index.json) + if index_file_path.exists(): + with open(index_file_path) as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + + # Parse metadata from each shard + files_metadata = {} + for shard_filename in set(weight_map.values()): + shard_path = path / shard_filename + files_metadata[shard_filename] = parse_local_safetensors_file_metadata(shard_path) + + return SafetensorsRepoMetadata( + metadata=index.get("metadata", None), + sharded=True, + weight_map=weight_map, + files_metadata=files_metadata, + ) + + # Not a valid safetensors folder + raise NotASafetensorsRepoError( + f"'{path}' is not a valid safetensors folder. Couldn't find '{constants.SAFETENSORS_INDEX_FILE}' or " + f"'{constants.SAFETENSORS_SINGLE_FILE}' files." + ) + + +api = HfApi() + +whoami = api.whoami +auth_check = api.auth_check + +list_models = api.list_models +model_info = api.model_info + +list_datasets = api.list_datasets +dataset_info = api.dataset_info + +list_spaces = api.list_spaces +space_info = api.space_info + +list_papers = api.list_papers +paper_info = api.paper_info +list_daily_papers = api.list_daily_papers + +repo_exists = api.repo_exists +revision_exists = api.revision_exists +file_exists = api.file_exists +repo_info = api.repo_info +list_repo_files = api.list_repo_files +list_repo_refs = api.list_repo_refs +list_repo_commits = api.list_repo_commits +list_repo_tree = api.list_repo_tree +get_paths_info = api.get_paths_info +verify_repo_checksums = api.verify_repo_checksums + +get_model_tags = api.get_model_tags +get_dataset_tags = api.get_dataset_tags + +create_commit = api.create_commit +create_repo = api.create_repo +delete_repo = api.delete_repo +update_repo_settings = api.update_repo_settings +move_repo = api.move_repo +upload_file = api.upload_file +upload_folder = api.upload_folder +delete_file = api.delete_file +delete_folder = api.delete_folder +delete_files = api.delete_files +upload_large_folder = api.upload_large_folder +preupload_lfs_files = api.preupload_lfs_files +create_branch = api.create_branch +delete_branch = api.delete_branch +create_tag = api.create_tag +delete_tag = api.delete_tag +get_full_repo_name = api.get_full_repo_name + +# Danger-zone API +super_squash_history = api.super_squash_history +list_lfs_files = api.list_lfs_files +permanently_delete_lfs_files = api.permanently_delete_lfs_files + +# Safetensors helpers +get_safetensors_metadata = api.get_safetensors_metadata +parse_safetensors_file_metadata = api.parse_safetensors_file_metadata + +# Background jobs +run_as_future = api.run_as_future + +# Activity API +list_liked_repos = api.list_liked_repos +list_repo_likers = api.list_repo_likers +unlike = api.unlike + +# Community API +get_discussion_details = api.get_discussion_details +get_repo_discussions = api.get_repo_discussions +create_discussion = api.create_discussion +create_pull_request = api.create_pull_request +change_discussion_status = api.change_discussion_status +comment_discussion = api.comment_discussion +edit_discussion_comment = api.edit_discussion_comment +rename_discussion = api.rename_discussion +merge_pull_request = api.merge_pull_request + +# Space API +add_space_secret = api.add_space_secret +delete_space_secret = api.delete_space_secret +get_space_variables = api.get_space_variables +add_space_variable = api.add_space_variable +delete_space_variable = api.delete_space_variable +get_space_runtime = api.get_space_runtime +request_space_hardware = api.request_space_hardware +set_space_sleep_time = api.set_space_sleep_time +pause_space = api.pause_space +restart_space = api.restart_space +duplicate_space = api.duplicate_space +request_space_storage = api.request_space_storage +delete_space_storage = api.delete_space_storage + +# Inference Endpoint API +list_inference_endpoints = api.list_inference_endpoints +create_inference_endpoint = api.create_inference_endpoint +get_inference_endpoint = api.get_inference_endpoint +update_inference_endpoint = api.update_inference_endpoint +delete_inference_endpoint = api.delete_inference_endpoint +pause_inference_endpoint = api.pause_inference_endpoint +resume_inference_endpoint = api.resume_inference_endpoint +scale_to_zero_inference_endpoint = api.scale_to_zero_inference_endpoint +create_inference_endpoint_from_catalog = api.create_inference_endpoint_from_catalog +list_inference_catalog = api.list_inference_catalog + +# Collections API +get_collection = api.get_collection +list_collections = api.list_collections +create_collection = api.create_collection +update_collection_metadata = api.update_collection_metadata +delete_collection = api.delete_collection +add_collection_item = api.add_collection_item +update_collection_item = api.update_collection_item +delete_collection_item = api.delete_collection_item +delete_collection_item = api.delete_collection_item + +# Access requests API +list_pending_access_requests = api.list_pending_access_requests +list_accepted_access_requests = api.list_accepted_access_requests +list_rejected_access_requests = api.list_rejected_access_requests +cancel_access_request = api.cancel_access_request +accept_access_request = api.accept_access_request +reject_access_request = api.reject_access_request +grant_access = api.grant_access + +# Webhooks API +create_webhook = api.create_webhook +disable_webhook = api.disable_webhook +delete_webhook = api.delete_webhook +enable_webhook = api.enable_webhook +get_webhook = api.get_webhook +list_webhooks = api.list_webhooks +update_webhook = api.update_webhook + + +# User API +get_user_overview = api.get_user_overview +get_organization_overview = api.get_organization_overview +list_organization_followers = api.list_organization_followers +list_organization_members = api.list_organization_members +list_user_followers = api.list_user_followers +list_user_following = api.list_user_following + +# Jobs API +run_job = api.run_job +fetch_job_logs = api.fetch_job_logs +fetch_job_metrics = api.fetch_job_metrics +list_jobs = api.list_jobs +list_jobs_hardware = api.list_jobs_hardware +inspect_job = api.inspect_job +cancel_job = api.cancel_job +run_uv_job = api.run_uv_job +create_scheduled_job = api.create_scheduled_job +list_scheduled_jobs = api.list_scheduled_jobs +inspect_scheduled_job = api.inspect_scheduled_job +delete_scheduled_job = api.delete_scheduled_job +suspend_scheduled_job = api.suspend_scheduled_job +resume_scheduled_job = api.resume_scheduled_job +create_scheduled_uv_job = api.create_scheduled_uv_job diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/hf_file_system.py b/venv/lib/python3.10/site-packages/huggingface_hub/hf_file_system.py new file mode 100644 index 0000000000000000000000000000000000000000..f96b1fee1f89c54991bc11fc8ebf1818d32f0ad5 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/hf_file_system.py @@ -0,0 +1,1295 @@ +import os +import re +import tempfile +import threading +from collections import deque +from contextlib import ExitStack +from copy import deepcopy +from dataclasses import dataclass, field +from datetime import datetime +from itertools import chain +from pathlib import Path +from typing import Any, Iterator, NoReturn, Optional, Union +from urllib.parse import quote, unquote + +import fsspec +import httpx +from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback +from fsspec.utils import isfilelike + +from . import constants +from ._commit_api import CommitOperationCopy, CommitOperationDelete +from .errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError +from .file_download import hf_hub_url, http_get +from .hf_api import HfApi, LastCommitInfo, RepoFile +from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff +from .utils.insecure_hashlib import md5 + + +# Regex used to match special revisions with "/" in them (see #1710) +SPECIAL_REFS_REVISION_REGEX = re.compile( + r""" + (^refs\/convert\/\w+) # `refs/convert/parquet` revisions + | + (^refs\/pr\/\d+) # PR revisions + """, + re.VERBOSE, +) + + +@dataclass +class HfFileSystemResolvedPath: + """Data structure containing information about a resolved Hugging Face file system path.""" + + repo_type: str + repo_id: str + revision: str + path_in_repo: str + # The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision. + # Used to reconstruct the unresolved path to return to the user. + _raw_revision: Optional[str] = field(default=None, repr=False) + + def unresolve(self) -> str: + repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id + if self._raw_revision: + return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/") + elif self.revision != constants.DEFAULT_REVISION: + return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/") + else: + return f"{repo_path}/{self.path_in_repo}".rstrip("/") + + +# We need to improve fsspec.spec._Cached which is AbstractFileSystem's metaclass +_cached_base: Any = type(fsspec.AbstractFileSystem) + + +class _Cached(_cached_base): + """ + Metaclass for caching HfFileSystem instances according to the args. + + This creates an additional reference to the filesystem, which prevents the + filesystem from being garbage collected when all *user* references go away. + A call to the :meth:`AbstractFileSystem.clear_instance_cache` must *also* + be made for a filesystem instance to be garbage collected. + + This is a slightly modified version of `fsspec.spec._Cached` to improve it. + In particular in `_tokenize` the pid isn't taken into account for the + `fs_token` used to identify cached instances. The `fs_token` logic is also + robust to defaults values and the order of the args. Finally new instances + reuse the states from sister instances in the main thread. + """ + + def __init__(cls, *args, **kwargs): + # Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L53 + super().__init__(*args, **kwargs) + # Note: we intentionally create a reference here, to avoid garbage + # collecting instances when all other references are gone. To really + # delete a FileSystem, the cache must be cleared. + cls._cache = {} + + def __call__(cls, *args, **kwargs): + # Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L65 + skip = kwargs.pop("skip_instance_cache", False) + fs_token = cls._tokenize(cls, threading.get_ident(), *args, **kwargs) + fs_token_main_thread = cls._tokenize(cls, threading.main_thread().ident, *args, **kwargs) + if not skip and cls.cachable and fs_token in cls._cache: + # reuse cached instance + cls._latest = fs_token + return cls._cache[fs_token] + else: + # create new instance + obj = type.__call__(cls, *args, **kwargs) + if not skip and cls.cachable and fs_token_main_thread in cls._cache: + # reuse the cache from the main thread instance in the new instance + instance_state = cls._cache[fs_token_main_thread]._get_instance_state() + for attr, state_value in instance_state.items(): + setattr(obj, attr, state_value) + obj._fs_token_ = fs_token + obj.storage_args = args + obj.storage_options = kwargs + if cls.cachable and not skip: + cls._latest = fs_token + cls._cache[fs_token] = obj + return obj + + +class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached): + """ + Access a remote Hugging Face Hub repository as if were a local file system. + + > [!WARNING] + > [`HfFileSystem`] provides fsspec compatibility, which is useful for libraries that require it (e.g., reading + > Hugging Face datasets directly with `pandas`). However, it introduces additional overhead due to this compatibility + > layer. For better performance and reliability, it's recommended to use `HfApi` methods when possible. + + Args: + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + token (`bool` or `str`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + block_size (`int`, *optional*): + Block size for reading and writing files. + expand_info (`bool`, *optional*): + Whether to expand the information of the files. + **storage_options (`dict`, *optional*): + Additional options for the filesystem. See [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.__init__). + + Usage: + + ```python + >>> from huggingface_hub import hffs + + >>> # List files + >>> hffs.glob("my-username/my-model/*.bin") + ['my-username/my-model/pytorch_model.bin'] + >>> hffs.ls("datasets/my-username/my-dataset", detail=False) + ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] + + >>> # Read/write files + >>> with hffs.open("my-username/my-model/pytorch_model.bin") as f: + ... data = f.read() + >>> with hffs.open("my-username/my-model/pytorch_model.bin", "wb") as f: + ... f.write(data) + ``` + + Specify a token for authentication: + ```python + >>> from huggingface_hub import HfFileSystem + >>> hffs = HfFileSystem(token=token) + ``` + """ + + root_marker = "" + protocol = "hf" + + def __init__( + self, + *args, + endpoint: Optional[str] = None, + token: Union[bool, str, None] = None, + block_size: Optional[int] = None, + expand_info: Optional[bool] = None, + **storage_options, + ): + super().__init__(*args, **storage_options) + self.endpoint = endpoint or constants.ENDPOINT + self.token = token + self._api = HfApi(endpoint=endpoint, token=token) + self.block_size = block_size + self.expand_info = expand_info + # Maps (repo_type, repo_id, revision) to a 2-tuple with: + # * the 1st element indicating whether the repositoy and the revision exist + # * the 2nd element being the exception raised if the repository or revision doesn't exist + self._repo_and_revision_exists_cache: dict[ + tuple[str, str, Optional[str]], tuple[bool, Optional[Exception]] + ] = {} + # Maps parent directory path to path infos + self.dircache: dict[str, list[dict[str, Any]]] = {} + + @classmethod + def _tokenize(cls, threading_ident: int, *args, **kwargs) -> str: + """Deterministic token for caching""" + # make fs_token robust to default values and to kwargs order + kwargs["endpoint"] = kwargs.get("endpoint") or constants.ENDPOINT + kwargs["token"] = kwargs.get("token") + kwargs = {key: kwargs[key] for key in sorted(kwargs)} + # contrary to fsspec, we don't include pid here + tokenize_args = (cls, threading_ident, args, kwargs) + h = md5(str(tokenize_args).encode()) + return h.hexdigest() + + def _repo_and_revision_exist( + self, repo_type: str, repo_id: str, revision: Optional[str] + ) -> tuple[bool, Optional[Exception]]: + if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: + try: + self._api.repo_info( + repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT + ) + except (RepositoryNotFoundError, HFValidationError) as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e + except RevisionNotFoundError as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + else: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] + + def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: + """ + Resolve a Hugging Face file system path into its components. + + Args: + path (`str`): + Path to resolve. + revision (`str`, *optional*): + The revision of the repo to resolve. Defaults to the revision specified in the path. + + Returns: + [`HfFileSystemResolvedPath`]: Resolved path information containing `repo_type`, `repo_id`, `revision` and `path_in_repo`. + + Raises: + `ValueError`: + If path contains conflicting revision information. + `NotImplementedError`: + If trying to list repositories. + """ + + def _align_revision_in_path_with_revision( + revision_in_path: Optional[str], revision: Optional[str] + ) -> Optional[str]: + if revision is not None: + if revision_in_path is not None and revision_in_path != revision: + raise ValueError( + f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' + " are not the same." + ) + else: + revision = revision_in_path + return revision + + path = self._strip_protocol(path) + if not path: + # can't list repositories at root + raise NotImplementedError("Access to repositories lists is not implemented.") + elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values(): + if "/" not in path: + # can't list repositories at the repository type level + raise NotImplementedError("Access to repositories lists is not implemented.") + repo_type, path = path.split("/", 1) + repo_type = constants.REPO_TYPES_MAPPING[repo_type] + else: + repo_type = constants.REPO_TYPE_MODEL + if path.count("/") > 0: + if "@" in "/".join(path.split("/")[:2]): + repo_id, revision_in_path = path.split("@", 1) + if "/" in revision_in_path: + match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path) + if match is not None and revision in (None, match.group()): + # Handle `refs/convert/parquet` and PR revisions separately + path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/") + revision_in_path = match.group() + else: + revision_in_path, path_in_repo = revision_in_path.split("/", 1) + else: + path_in_repo = "" + revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + _raise_file_not_found(path, err) + else: + revision_in_path = None + repo_id_with_namespace = "/".join(path.split("/")[:2]) + path_in_repo_with_namespace = "/".join(path.split("/")[2:]) + repo_id_without_namespace = path.split("/")[0] + path_in_repo_without_namespace = "/".join(path.split("/")[1:]) + repo_id = repo_id_with_namespace + path_in_repo = path_in_repo_with_namespace + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + if isinstance(err, (RepositoryNotFoundError, HFValidationError)): + repo_id = repo_id_without_namespace + path_in_repo = path_in_repo_without_namespace + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + _raise_file_not_found(path, err) + else: + _raise_file_not_found(path, err) + else: + repo_id = path + path_in_repo = "" + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) + else: + revision_in_path = None + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise NotImplementedError("Access to repositories lists is not implemented.") + + revision = revision if revision is not None else constants.DEFAULT_REVISION + return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path) + + def invalidate_cache(self, path: Optional[str] = None) -> None: + """ + Clear the cache for a given path. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.invalidate_cache). + + Args: + path (`str`, *optional*): + Path to clear from cache. If not provided, clear the entire cache. + + """ + if not path: + self.dircache.clear() + self._repo_and_revision_exists_cache.clear() + else: + resolved_path = self.resolve_path(path) + path = resolved_path.unresolve() + while path: + self.dircache.pop(path, None) + path = self._parent(path) + + # Only clear repo cache if path is to repo root + if not resolved_path.path_in_repo: + self._repo_and_revision_exists_cache.pop((resolved_path.repo_type, resolved_path.repo_id, None), None) + self._repo_and_revision_exists_cache.pop( + (resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision), None + ) + + def _open( # type: ignore[override] + self, + path: str, + mode: str = "rb", + block_size: Optional[int] = None, + revision: Optional[str] = None, + **kwargs, + ) -> Union["HfFileSystemFile", "HfFileSystemStreamFile"]: + block_size = block_size if block_size is not None else self.block_size + if block_size is not None: + kwargs["block_size"] = block_size + if "a" in mode: + raise NotImplementedError("Appending to remote files is not yet supported.") + if block_size == 0: + return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, **kwargs) + else: + return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs) + + def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path = self.resolve_path(path, revision=revision) + self._api.delete_file( + path_in_repo=resolved_path.path_in_repo, + repo_id=resolved_path.repo_id, + token=self.token, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message"), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def rm( + self, + path: str, + recursive: bool = False, + maxdepth: Optional[int] = None, + revision: Optional[str] = None, + **kwargs, + ) -> None: + """ + Delete files from a repository. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.rm). + + > [!WARNING] + > Note: When possible, use `HfApi.delete_file()` for better performance. + + Args: + path (`str`): + Path to delete. + recursive (`bool`, *optional*): + If True, delete directory and all its contents. Defaults to False. + maxdepth (`int`, *optional*): + Maximum number of subdirectories to visit when deleting recursively. + revision (`str`, *optional*): + The git revision to delete from. + + """ + resolved_path = self.resolve_path(path, revision=revision) + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision) + paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)] + operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] + commit_message = f"Delete {path} " + commit_message += "recursively " if recursive else "" + commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" + # TODO: use `commit_description` to list all the deleted paths? + self._api.create_commit( + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + token=self.token, + operations=operations, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def ls( + self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs + ) -> list[Union[str, dict[str, Any]]]: + """ + List the contents of a directory. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.ls). + + > [!WARNING] + > Note: When possible, use `HfApi.list_repo_tree()` for better performance. + + Args: + path (`str`): + Path to the directory. + detail (`bool`, *optional*): + If True, returns a list of dictionaries containing file information. If False, + returns a list of file paths. Defaults to True. + refresh (`bool`, *optional*): + If True, bypass the cache and fetch the latest data. Defaults to False. + revision (`str`, *optional*): + The git revision to list from. + + Returns: + `list[Union[str, dict[str, Any]]]`: List of file paths (if detail=False) or list of file information + dictionaries (if detail=True). + """ + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + try: + out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs) + except EntryNotFoundError: + # Path could be a file + if not resolved_path.path_in_repo: + _raise_file_not_found(path, None) + out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs) + out = [o for o in out if o["name"] == path] + if len(out) == 0: + _raise_file_not_found(path, None) + return out if detail else [o["name"] for o in out] + + def _ls_tree( + self, + path: str, + recursive: bool = False, + refresh: bool = False, + revision: Optional[str] = None, + expand_info: Optional[bool] = None, + maxdepth: Optional[int] = None, + ): + expand_info = ( + expand_info if expand_info is not None else (self.expand_info if self.expand_info is not None else False) + ) + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + root_path = HfFileSystemResolvedPath( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + path_in_repo="", + _raw_revision=resolved_path._raw_revision, + ).unresolve() + + out = [] + if path in self.dircache and not refresh: + cached_path_infos = self.dircache[path] + out.extend(cached_path_infos) + dirs_not_in_dircache = [] + if recursive: + # Use BFS to traverse the cache and build the "recursive "output + # (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same) + depth = 2 + dirs_to_visit = deque( + [(depth, path_info) for path_info in cached_path_infos if path_info["type"] == "directory"] + ) + while dirs_to_visit: + depth, dir_info = dirs_to_visit.popleft() + if maxdepth is None or depth <= maxdepth: + if dir_info["name"] not in self.dircache: + dirs_not_in_dircache.append(dir_info["name"]) + else: + cached_path_infos = self.dircache[dir_info["name"]] + out.extend(cached_path_infos) + dirs_to_visit.extend( + [ + (depth + 1, path_info) + for path_info in cached_path_infos + if path_info["type"] == "directory" + ] + ) + + dirs_not_expanded = [] + if expand_info: + # Check if there are directories with non-expanded entries + dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None] + + if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded): + # If the dircache is incomplete, find the common path of the missing and non-expanded entries + # and extend the output with the result of `_ls_tree(common_path, recursive=True)` + common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded) + # Get the parent directory if the common prefix itself is not a directory + common_path = ( + common_prefix.rstrip("/") + if common_prefix.endswith("/") + or common_prefix == root_path + or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded) + else self._parent(common_prefix) + ) + if maxdepth is not None: + common_path_depth = common_path[len(path) :].count("/") + maxdepth -= common_path_depth + out = [o for o in out if not o["name"].startswith(common_path + "/")] + for cached_path in list(self.dircache): + if cached_path.startswith(common_path + "/"): + self.dircache.pop(cached_path, None) + self.dircache.pop(common_path, None) + out.extend( + self._ls_tree( + common_path, + recursive=recursive, + refresh=True, + revision=revision, + expand_info=expand_info, + maxdepth=maxdepth, + ) + ) + else: + tree = self._api.list_repo_tree( + resolved_path.repo_id, + resolved_path.path_in_repo, + recursive=recursive, + expand=expand_info, + revision=resolved_path.revision, + repo_type=resolved_path.repo_type, + ) + for path_info in tree: + cache_path = root_path + "/" + path_info.path + if isinstance(path_info, RepoFile): + cache_path_info = { + "name": cache_path, + "size": path_info.size, + "type": "file", + "blob_id": path_info.blob_id, + "lfs": path_info.lfs, + "last_commit": path_info.last_commit, + "security": path_info.security, + } + else: + cache_path_info = { + "name": cache_path, + "size": 0, + "type": "directory", + "tree_id": path_info.tree_id, + "last_commit": path_info.last_commit, + } + parent_path = self._parent(cache_path_info["name"]) + self.dircache.setdefault(parent_path, []).append(cache_path_info) + depth = cache_path[len(path) :].count("/") + if maxdepth is None or depth <= maxdepth: + out.append(cache_path_info) + return out + + def walk(self, path: str, *args, **kwargs) -> Iterator[tuple[str, list[str], list[str]]]: + """ + Return all files below the given path. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.walk). + + Args: + path (`str`): + Root path to list files from. + + Returns: + `Iterator[tuple[str, list[str], list[str]]]`: An iterator of (path, list of directory names, list of file names) tuples. + """ + path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() + yield from super().walk(path, *args, **kwargs) + + def glob(self, path: str, maxdepth: Optional[int] = None, **kwargs) -> list[str]: + """ + Find files by glob-matching. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob). + + Args: + path (`str`): + Path pattern to match. + + Returns: + `list[str]`: List of paths matching the pattern. + """ + path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() + return super().glob(path, maxdepth=maxdepth, **kwargs) + + def find( + self, + path: str, + maxdepth: Optional[int] = None, + withdirs: bool = False, + detail: bool = False, + refresh: bool = False, + revision: Optional[str] = None, + **kwargs, + ) -> Union[list[str], dict[str, dict[str, Any]]]: + """ + List all files below path. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.find). + + Args: + path (`str`): + Root path to list files from. + maxdepth (`int`, *optional*): + Maximum depth to descend into subdirectories. + withdirs (`bool`, *optional*): + Include directory paths in the output. Defaults to False. + detail (`bool`, *optional*): + If True, returns a dict mapping paths to file information. Defaults to False. + refresh (`bool`, *optional*): + If True, bypass the cache and fetch the latest data. Defaults to False. + revision (`str`, *optional*): + The git revision to list from. + + Returns: + `Union[list[str], dict[str, dict[str, Any]]]`: List of paths or dict of file information. + """ + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + try: + out = self._ls_tree( + path, recursive=True, refresh=refresh, revision=resolved_path.revision, maxdepth=maxdepth, **kwargs + ) + except EntryNotFoundError: + # Path could be a file + try: + if self.info(path, revision=revision, **kwargs)["type"] == "file": + out = {path: {}} + else: + out = {} + except FileNotFoundError: + out = {} + else: + if not withdirs: + out = [o for o in out if o["type"] != "directory"] + else: + # If `withdirs=True`, include the directory itself to be consistent with the spec + path_info = self.info(path, revision=resolved_path.revision, **kwargs) + out = [path_info] + out if path_info["type"] == "directory" else out + out = {o["name"]: o for o in out} + names = sorted(out) + if not detail: + return names + else: + return {name: out[name] for name in names} + + def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: + """ + Copy a file within or between repositories. + + > [!WARNING] + > Note: When possible, use `HfApi.upload_file()` for better performance. + + Args: + path1 (`str`): + Source path to copy from. + path2 (`str`): + Destination path to copy to. + revision (`str`, *optional*): + The git revision to copy from. + + """ + resolved_path1 = self.resolve_path(path1, revision=revision) + resolved_path2 = self.resolve_path(path2, revision=revision) + + same_repo = ( + resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id + ) + + if same_repo: + commit_message = f"Copy {path1} to {path2}" + self._api.create_commit( + repo_id=resolved_path1.repo_id, + repo_type=resolved_path1.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description", ""), + operations=[ + CommitOperationCopy( + src_path_in_repo=resolved_path1.path_in_repo, + path_in_repo=resolved_path2.path_in_repo, + src_revision=resolved_path1.revision, + ) + ], + ) + else: + with self.open(path1, "rb", revision=resolved_path1.revision) as f: + content = f.read() + commit_message = f"Copy {path1} to {path2}" + self._api.upload_file( + path_or_fileobj=content, + path_in_repo=resolved_path2.path_in_repo, + repo_id=resolved_path2.repo_id, + token=self.token, + repo_type=resolved_path2.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path1.unresolve()) + self.invalidate_cache(path=resolved_path2.unresolve()) + + def modified(self, path: str, **kwargs) -> datetime: + """ + Get the last modified time of a file. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.modified). + + Args: + path (`str`): + Path to the file. + + Returns: + `datetime`: Last commit date of the file. + """ + info = self.info(path, **{**kwargs, "expand_info": True}) # type: ignore + return info["last_commit"]["date"] + + def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> dict[str, Any]: + """ + Get information about a file or directory. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.info). + + > [!WARNING] + > Note: When possible, use `HfApi.get_paths_info()` or `HfApi.repo_info()` for better performance. + + Args: + path (`str`): + Path to get info for. + refresh (`bool`, *optional*): + If True, bypass the cache and fetch the latest data. Defaults to False. + revision (`str`, *optional*): + The git revision to get info from. + + Returns: + `dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.). + + """ + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + expand_info = kwargs.get( + "expand_info", self.expand_info if self.expand_info is not None else False + ) # don't expose it as a parameter in the public API to follow the spec + if not resolved_path.path_in_repo: + # Path is the root directory + out = { + "name": path, + "size": 0, + "type": "directory", + "last_commit": None, + } + if expand_info: + last_commit = self._api.list_repo_commits( + resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision + )[-1] + out = { + **out, + "tree_id": None, # TODO: tree_id of the root directory? + "last_commit": LastCommitInfo( + oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at + ), + } + else: + out = None + parent_path = self._parent(path) + if not expand_info and parent_path not in self.dircache: + # Fill the cache with cheap call + self.ls(parent_path) + if parent_path in self.dircache: + # Check if the path is in the cache + out1 = [o for o in self.dircache[parent_path] if o["name"] == path] + if not out1: + _raise_file_not_found(path, None) + out = out1[0] + if refresh or out is None or (expand_info and out and out["last_commit"] is None): + paths_info = self._api.get_paths_info( + resolved_path.repo_id, + resolved_path.path_in_repo, + expand=expand_info, + revision=resolved_path.revision, + repo_type=resolved_path.repo_type, + ) + if not paths_info: + _raise_file_not_found(path, None) + path_info = paths_info[0] + root_path = HfFileSystemResolvedPath( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + path_in_repo="", + _raw_revision=resolved_path._raw_revision, + ).unresolve() + if isinstance(path_info, RepoFile): + out = { + "name": root_path + "/" + path_info.path, + "size": path_info.size, + "type": "file", + "blob_id": path_info.blob_id, + "lfs": path_info.lfs, + "last_commit": path_info.last_commit, + "security": path_info.security, + } + else: + out = { + "name": root_path + "/" + path_info.path, + "size": 0, + "type": "directory", + "tree_id": path_info.tree_id, + "last_commit": path_info.last_commit, + } + if not expand_info: + out = {k: out[k] for k in ["name", "size", "type"]} + assert out is not None + return out + + def exists(self, path, **kwargs): + """ + Check if a file exists. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists). + + > [!WARNING] + > Note: When possible, use `HfApi.file_exists()` for better performance. + + Args: + path (`str`): + Path to check. + + Returns: + `bool`: True if file exists, False otherwise. + """ + try: + if kwargs.get("refresh", False): + self.invalidate_cache(path) + + self.info(path, **kwargs) + return True + except: # noqa: E722 + return False + + def isdir(self, path): + """ + Check if a path is a directory. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isdir). + + Args: + path (`str`): + Path to check. + + Returns: + `bool`: True if path is a directory, False otherwise. + """ + try: + return self.info(path)["type"] == "directory" + except OSError: + return False + + def isfile(self, path): + """ + Check if a path is a file. + + For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isfile). + + Args: + path (`str`): + Path to check. + + Returns: + `bool`: True if path is a file, False otherwise. + """ + try: + return self.info(path)["type"] == "file" + except: # noqa: E722 + return False + + def url(self, path: str) -> str: + """ + Get the HTTP URL of the given path. + + Args: + path (`str`): + Path to get URL for. + + Returns: + `str`: HTTP URL to access the file or directory on the Hub. + """ + resolved_path = self.resolve_path(path) + url = hf_hub_url( + resolved_path.repo_id, + resolved_path.path_in_repo, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + endpoint=self.endpoint, + ) + if self.isdir(path): + url = url.replace("/resolve/", "/tree/", 1) + return url + + def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None: + """ + Copy single remote file to local. + + > [!WARNING] + > Note: When possible, use `HfApi.hf_hub_download()` for better performance. + + Args: + rpath (`str`): + Remote path to download from. + lpath (`str`): + Local path to download to. + callback (`Callback`, *optional*): + Optional callback to track download progress. Defaults to no callback. + outfile (`IO`, *optional*): + Optional file-like object to write to. If provided, `lpath` is ignored. + + """ + revision = kwargs.get("revision") + unhandled_kwargs = set(kwargs.keys()) - {"revision"} + if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0: + # for now, let's not handle custom callbacks + # and let's not handle custom kwargs + return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs) + + # Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883 + if isfilelike(lpath): + outfile = lpath + elif self.isdir(rpath): + os.makedirs(lpath, exist_ok=True) + return None + + if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object + os.makedirs(os.path.dirname(lpath), exist_ok=True) + + # Open file if not already open + close_file = False + if outfile is None: + outfile = open(lpath, "wb") + close_file = True + initial_pos = outfile.tell() + + # Custom implementation of `get_file` to use `http_get`. + resolve_remote_path = self.resolve_path(rpath, revision=revision) + expected_size = self.info(rpath, revision=revision)["size"] + callback.set_size(expected_size) + try: + http_get( + url=hf_hub_url( + repo_id=resolve_remote_path.repo_id, + revision=resolve_remote_path.revision, + filename=resolve_remote_path.path_in_repo, + repo_type=resolve_remote_path.repo_type, + endpoint=self.endpoint, + ), + temp_file=outfile, # type: ignore[arg-type] + displayed_filename=rpath, + expected_size=expected_size, + resume_size=0, + headers=self._api._build_hf_headers(), + _tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None, + ) + outfile.seek(initial_pos) + finally: + # Close file only if we opened it ourselves + if close_file: + outfile.close() + + @property + def transaction(self): + """A context within which files are committed together upon exit + + Requires the file class to implement `.commit()` and `.discard()` + for the normal and exception cases. + """ + # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231 + # See https://github.com/huggingface/huggingface_hub/issues/1733 + raise NotImplementedError("Transactional commits are not supported.") + + def start_transaction(self): + """Begin write transaction for deferring files, non-context version""" + # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241 + # See https://github.com/huggingface/huggingface_hub/issues/1733 + raise NotImplementedError("Transactional commits are not supported.") + + def __reduce__(self): + # re-populate the instance cache at HfFileSystem._cache and re-populate the state of every instance + return make_instance, ( + type(self), + self.storage_args, + self.storage_options, + self._get_instance_state(), + ) + + def _get_instance_state(self): + return { + "dircache": deepcopy(self.dircache), + "_repo_and_revision_exists_cache": deepcopy(self._repo_and_revision_exists_cache), + } + + +class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): + def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): + try: + self.resolved_path = fs.resolve_path(path, revision=revision) + except FileNotFoundError as e: + if "w" in kwargs.get("mode", ""): + raise FileNotFoundError( + f"{e}.\nMake sure the repository and revision exist before writing data." + ) from e + raise + super().__init__(fs, self.resolved_path.unresolve(), **kwargs) + self.fs: HfFileSystem + + def __del__(self): + if not hasattr(self, "resolved_path"): + # Means that the constructor failed. Nothing to do. + return + return super().__del__() + + def _fetch_range(self, start: int, end: int) -> bytes: + headers = { + "range": f"bytes={start}-{end - 1}", + **self.fs._api._build_hf_headers(), + } + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + r = http_backoff("GET", url, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT) + hf_raise_for_status(r) + return r.content + + def _initiate_upload(self) -> None: + self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) + + def _upload_chunk(self, final: bool = False) -> None: + self.buffer.seek(0) + block = self.buffer.read() + self.temp_file.write(block) + if final: + self.temp_file.close() + self.fs._api.upload_file( + path_or_fileobj=self.temp_file.name, + path_in_repo=self.resolved_path.path_in_repo, + repo_id=self.resolved_path.repo_id, + token=self.fs.token, + repo_type=self.resolved_path.repo_type, + revision=self.resolved_path.revision, + commit_message=self.kwargs.get("commit_message"), + commit_description=self.kwargs.get("commit_description"), + ) + os.remove(self.temp_file.name) + self.fs.invalidate_cache( + path=self.resolved_path.unresolve(), + ) + + def read(self, length=-1): + """Read remote file. + + If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems the file is + loaded in memory directly. Otherwise, the file is downloaded to a temporary file and read from there. + """ + if self.mode == "rb" and (length is None or length == -1) and self.loc == 0: + with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming + out = f.read() + self.loc += len(out) + return out + return super().read(length) + + def url(self) -> str: + return self.fs.url(self.path) + + +class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile): + def __init__( + self, + fs: HfFileSystem, + path: str, + mode: str = "rb", + revision: Optional[str] = None, + block_size: int = 0, + cache_type: str = "none", + **kwargs, + ): + if block_size != 0: + raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}") + if cache_type != "none": + raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}") + if "w" in mode: + raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'") + try: + self.resolved_path = fs.resolve_path(path, revision=revision) + except FileNotFoundError as e: + if "w" in kwargs.get("mode", ""): + raise FileNotFoundError( + f"{e}.\nMake sure the repository and revision exist before writing data." + ) from e + # avoid an unnecessary .info() call to instantiate .details + self.details = {"name": self.resolved_path.unresolve(), "size": None} + super().__init__( + fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs + ) + self.response: Optional[httpx.Response] = None + self.fs: HfFileSystem + self._exit_stack = ExitStack() + # streaming state + self._stream_iterator: Optional[Iterator[bytes]] = None + self._stream_buffer = bytearray() + + def seek(self, loc: int, whence: int = 0): + if loc == 0 and whence == 1: + return + if loc == self.loc and whence == 0: + return + raise ValueError("Cannot seek streaming HF file") + + def read(self, length: int = -1): + """Read the remote file. + + If the file is already open, we reuse the connection. + Otherwise, open a new connection and read from it. + + If reading the stream fails, we retry with a new connection. + """ + if self.response is None: + self._open_connection() + + retried_once = False + while True: + try: + if self.response is None or self._stream_iterator is None: + return b"" # Already read the entire file + out = self._read_from_stream(self._stream_iterator, length) + self.loc += len(out) + return out + except Exception: + if self.response is not None: + self.response.close() + if retried_once: # Already retried once, give up + raise + # First failure, retry with range header + self._open_connection() + retried_once = True + + def _read_from_stream(self, iterator: Iterator[bytes], length: int = -1) -> bytes: + """Read up to `length` bytes from stream buffer and stream. + + If length < 0, read until EOF. + + If EOF is reached before length, fewer bytes may be returned. + """ + if length == 0: + return b"" + + if length < 0: + buf = bytearray(self._stream_buffer) + self._stream_buffer.clear() + for chunk in iterator: + buf.extend(chunk) + return bytes(buf) + + if length <= len(self._stream_buffer): + result = bytes(self._stream_buffer[:length]) + del self._stream_buffer[:length] + return result + + buf = bytearray(self._stream_buffer) + self._stream_buffer.clear() + for chunk in iterator: + need = length - len(buf) + if need > len(chunk): + buf.extend(chunk) + else: + buf.extend(chunk[:need]) + self._stream_buffer.extend(chunk[need:]) + break + return bytes(buf) + + def url(self) -> str: + return self.fs.url(self.path) + + def __del__(self): + if not hasattr(self, "resolved_path"): + # Means that the constructor failed. Nothing to do. + return + self._exit_stack.close() + return super().__del__() + + def __reduce__(self): + return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name) + + def _open_connection(self): + """Open a connection to the remote file.""" + # reset streaming state + self._stream_buffer.clear() + self._stream_iterator = None + + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + headers = self.fs._api._build_hf_headers() + if self.loc > 0: + headers["Range"] = f"bytes={self.loc}-" + self.response = self._exit_stack.enter_context( + http_stream_backoff( + "GET", + url, + headers=headers, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + ) + + try: + hf_raise_for_status(self.response) + except HfHubHTTPError as e: + if e.response.status_code == 416: + # Range not satisfiable => means that we have already read the entire file + self.response = None + return + raise + + self._stream_iterator = self.response.iter_bytes() + + +def safe_revision(revision: str) -> str: + return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision) + + +def safe_quote(s: str) -> str: + return quote(s, safe="") + + +def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn: + msg = path + if isinstance(err, RepositoryNotFoundError): + msg = f"{path} (repository not found)" + elif isinstance(err, RevisionNotFoundError): + msg = f"{path} (revision not found)" + elif isinstance(err, HFValidationError): + msg = f"{path} (invalid repository id)" + raise FileNotFoundError(msg) from err + + +def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str): + return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type) + + +def make_instance(cls, args, kwargs, instance_state): + fs = cls(*args, **kwargs) + for attr, state_value in instance_state.items(): + setattr(fs, attr, state_value) + return fs + + +hffs = HfFileSystem() diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/hub_mixin.py b/venv/lib/python3.10/site-packages/huggingface_hub/hub_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f42d522381adc21124099d0c67b5d597d7babc --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/hub_mixin.py @@ -0,0 +1,831 @@ +import inspect +import json +import os +from dataclasses import Field, asdict, dataclass, is_dataclass +from pathlib import Path +from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union + +import packaging.version + +from . import constants +from .errors import EntryNotFoundError, HfHubHTTPError +from .file_download import hf_hub_download +from .hf_api import HfApi +from .repocard import ModelCard, ModelCardData +from .utils import ( + SoftTemporaryDirectory, + is_jsonable, + is_safetensors_available, + is_simple_optional_type, + is_torch_available, + logging, + unwrap_simple_optional_type, + validate_hf_hub_args, +) + + +if is_torch_available(): + import torch # type: ignore + +if is_safetensors_available(): + import safetensors + from safetensors.torch import load_model as load_model_as_safetensor + from safetensors.torch import save_model as save_model_as_safetensor + + +logger = logging.get_logger(__name__) + + +# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349 +class DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[dict[str, Field]] + + +# Generic variable that is either ModelHubMixin or a subclass thereof +T = TypeVar("T", bound="ModelHubMixin") +# Generic variable to represent an args type +ARGS_T = TypeVar("ARGS_T") +ENCODER_T = Callable[[ARGS_T], Any] +DECODER_T = Callable[[Any], ARGS_T] +CODER_T = tuple[ENCODER_T, DECODER_T] + + +DEFAULT_MODEL_CARD = """ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + +This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration: +- Code: {{ repo_url | default("[More Information Needed]", true) }} +- Paper: {{ paper_url | default("[More Information Needed]", true) }} +- Docs: {{ docs_url | default("[More Information Needed]", true) }} +""" + + +@dataclass +class MixinInfo: + model_card_template: str + model_card_data: ModelCardData + docs_url: Optional[str] = None + paper_url: Optional[str] = None + repo_url: Optional[str] = None + + +class ModelHubMixin: + """ + A generic mixin to integrate ANY machine learning framework with the Hub. + + To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models + have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example + of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. + + When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to + `__init__` but to the class definition itself. This is useful to define metadata about the library integrating + [`ModelHubMixin`]. + + For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations). + + Args: + repo_url (`str`, *optional*): + URL of the library repository. Used to generate model card. + paper_url (`str`, *optional*): + URL of the library paper. Used to generate model card. + docs_url (`str`, *optional*): + URL of the library documentation. Used to generate model card. + model_card_template (`str`, *optional*): + Template of the model card. Used to generate model card. Defaults to a generic template. + language (`str` or `list[str]`, *optional*): + Language supported by the library. Used to generate model card. + library_name (`str`, *optional*): + Name of the library integrating ModelHubMixin. Used to generate model card. + license (`str`, *optional*): + License of the library integrating ModelHubMixin. Used to generate model card. + E.g: "apache-2.0" + license_name (`str`, *optional*): + Name of the library integrating ModelHubMixin. Used to generate model card. + Only used if `license` is set to `other`. + E.g: "coqui-public-model-license". + license_link (`str`, *optional*): + URL to the license of the library integrating ModelHubMixin. Used to generate model card. + Only used if `license` is set to `other` and `license_name` is set. + E.g: "https://coqui.ai/cpml". + pipeline_tag (`str`, *optional*): + Tag of the pipeline. Used to generate model card. E.g. "text-classification". + tags (`list[str]`, *optional*): + Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"] + coders (`dict[Type, tuple[Callable, Callable]]`, *optional*): + Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not + jsonable by default. E.g. dataclasses, argparse.Namespace, OmegaConf, etc. + + Example: + + ```python + >>> from huggingface_hub import ModelHubMixin + + # Inherit from ModelHubMixin + >>> class MyCustomModel( + ... ModelHubMixin, + ... library_name="my-library", + ... tags=["computer-vision"], + ... repo_url="https://github.com/huggingface/my-cool-library", + ... paper_url="https://arxiv.org/abs/2304.12244", + ... docs_url="https://huggingface.co/docs/my-cool-library", + ... # ^ optional metadata to generate model card + ... ): + ... def __init__(self, size: int = 512, device: str = "cpu"): + ... # define how to initialize your model + ... super().__init__() + ... ... + ... + ... def _save_pretrained(self, save_directory: Path) -> None: + ... # define how to serialize your model + ... ... + ... + ... @classmethod + ... def from_pretrained( + ... cls: type[T], + ... pretrained_model_name_or_path: Union[str, Path], + ... *, + ... force_download: bool = False, + ... token: Optional[Union[str, bool]] = None, + ... cache_dir: Optional[Union[str, Path]] = None, + ... local_files_only: bool = False, + ... revision: Optional[str] = None, + ... **model_kwargs, + ... ) -> T: + ... # define how to deserialize your model + ... ... + + >>> model = MyCustomModel(size=256, device="gpu") + + # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + + # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + + # Download and initialize weights from the Hub + >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model") + >>> reloaded_model.size + 256 + + # Model card has been correctly populated + >>> from huggingface_hub import ModelCard + >>> card = ModelCard.load("username/my-awesome-model") + >>> card.data.tags + ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"] + >>> card.data.library_name + "my-library" + ``` + """ + + _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None + # ^ optional config attribute automatically set in `from_pretrained` + _hub_mixin_info: MixinInfo + # ^ information about the library integrating ModelHubMixin (used to generate model card) + _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not + _hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters + _hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters + _hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded + _hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types + # ^ internal values to handle config + + def __init_subclass__( + cls, + *, + # Generic info for model card + repo_url: Optional[str] = None, + paper_url: Optional[str] = None, + docs_url: Optional[str] = None, + # Model card template + model_card_template: str = DEFAULT_MODEL_CARD, + # Model card metadata + language: Optional[list[str]] = None, + library_name: Optional[str] = None, + license: Optional[str] = None, + license_name: Optional[str] = None, + license_link: Optional[str] = None, + pipeline_tag: Optional[str] = None, + tags: Optional[list[str]] = None, + # How to encode/decode arguments with custom type into a JSON config? + coders: Optional[ + dict[Type, CODER_T] + # Key is a type. + # Value is a tuple (encoder, decoder). + # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))} + ] = None, + ) -> None: + """Inspect __init__ signature only once when subclassing + handle modelcard.""" + super().__init_subclass__() + + # Will be reused when creating modelcard + tags = tags or [] + tags.append("model_hub_mixin") + + # Initialize MixinInfo if not existent + info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData()) + + # If parent class has a MixinInfo, inherit from it as a copy + if hasattr(cls, "_hub_mixin_info"): + # Inherit model card template from parent class if not explicitly set + if model_card_template == DEFAULT_MODEL_CARD: + info.model_card_template = cls._hub_mixin_info.model_card_template + + # Inherit from parent model card data + info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict()) + + # Inherit other info + info.docs_url = cls._hub_mixin_info.docs_url + info.paper_url = cls._hub_mixin_info.paper_url + info.repo_url = cls._hub_mixin_info.repo_url + cls._hub_mixin_info = info + + # Update MixinInfo with metadata + if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD: + info.model_card_template = model_card_template + if repo_url is not None: + info.repo_url = repo_url + if paper_url is not None: + info.paper_url = paper_url + if docs_url is not None: + info.docs_url = docs_url + if language is not None: + info.model_card_data.language = language + if library_name is not None: + info.model_card_data.library_name = library_name + if license is not None: + info.model_card_data.license = license + if license_name is not None: + info.model_card_data.license_name = license_name + if license_link is not None: + info.model_card_data.license_link = license_link + if pipeline_tag is not None: + info.model_card_data.pipeline_tag = pipeline_tag + if tags is not None: + normalized_tags = list(tags) + if info.model_card_data.tags is not None: + info.model_card_data.tags.extend(normalized_tags) + else: + info.model_card_data.tags = normalized_tags + + if info.model_card_data.tags is not None: + info.model_card_data.tags = sorted(set(info.model_card_data.tags)) + + # Handle encoders/decoders for args + cls._hub_mixin_coders = coders or {} + cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys()) + + # Inspect __init__ signature to handle config + cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters) + cls._hub_mixin_jsonable_default_values = { + param.name: cls._encode_arg(param.default) + for param in cls._hub_mixin_init_parameters.values() + if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default) + } + cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters + + def __new__(cls: type[T], *args, **kwargs) -> T: + """Create a new instance of the class and handle config. + + 3 cases: + - If `self._hub_mixin_config` is already set, do nothing. + - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`. + - Otherwise, build `self._hub_mixin_config` from default values and passed values. + """ + instance = super().__new__(cls) + + # If `config` is already set, return early + if instance._hub_mixin_config is not None: + return instance + + # Infer passed values + passed_values = { + **{ + key: value + for key, value in zip( + # [1:] to skip `self` parameter + list(cls._hub_mixin_init_parameters)[1:], + args, + ) + }, + **kwargs, + } + + # If config passed as dataclass => set it and return early + if is_dataclass(passed_values.get("config")): + instance._hub_mixin_config = passed_values["config"] + return instance + + # Otherwise, build config from default + passed values + init_config = { + # default values + **cls._hub_mixin_jsonable_default_values, + # passed values + **{ + key: cls._encode_arg(value) # Encode custom types as jsonable value + for key, value in passed_values.items() + if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder + }, + } + passed_config = init_config.pop("config", {}) + + # Populate `init_config` with provided config + if isinstance(passed_config, dict): + init_config.update(passed_config) + + # Set `config` attribute and return + if init_config != {}: + instance._hub_mixin_config = init_config + return instance + + @classmethod + def _is_jsonable(cls, value: Any) -> bool: + """Check if a value is JSON serializable.""" + if is_dataclass(value): + return True + if isinstance(value, cls._hub_mixin_jsonable_custom_types): + return True + return is_jsonable(value) + + @classmethod + def _encode_arg(cls, arg: Any) -> Any: + """Encode an argument into a JSON serializable format.""" + if is_dataclass(arg): + return asdict(arg) # type: ignore[arg-type] + for type_, (encoder, _) in cls._hub_mixin_coders.items(): + if isinstance(arg, type_): + if arg is None: + return None + return encoder(arg) + return arg + + @classmethod + def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]: + """Decode a JSON serializable value into an argument.""" + if is_simple_optional_type(expected_type): + if value is None: + return None + expected_type = unwrap_simple_optional_type(expected_type) # type: ignore[assignment] + # Dataclass => handle it + if is_dataclass(expected_type): + return _load_dataclass(expected_type, value) # type: ignore[return-value] + # Otherwise => check custom decoders + for type_, (_, decoder) in cls._hub_mixin_coders.items(): + if inspect.isclass(expected_type) and issubclass(expected_type, type_): + return decoder(value) + # Otherwise => don't decode + return value + + def save_pretrained( + self, + save_directory: Union[str, Path], + *, + config: Optional[Union[dict, DataclassInstance]] = None, + repo_id: Optional[str] = None, + push_to_hub: bool = False, + model_card_kwargs: Optional[dict[str, Any]] = None, + **push_to_hub_kwargs, + ) -> Optional[str]: + """ + Save weights in local directory. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + config (`dict` or `DataclassInstance`, *optional*): + Model configuration specified as a key/value dictionary or a dataclass instance. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Huggingface Hub after saving it. + repo_id (`str`, *optional*): + ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if + not provided. + model_card_kwargs (`dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + push_to_hub_kwargs: + Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. + Returns: + `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. + """ + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json + # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite + # an existing config.json if it was not saved by `_save_pretrained`. + config_path = save_directory / constants.CONFIG_NAME + config_path.unlink(missing_ok=True) + + # save model weights/files (framework-specific) + self._save_pretrained(save_directory) + + # save config (if provided and if not serialized yet in `_save_pretrained`) + if config is None: + config = self._hub_mixin_config + if config is not None: + if is_dataclass(config): + config = asdict(config) # type: ignore[arg-type] + if not config_path.exists(): + config_str = json.dumps(config, sort_keys=True, indent=2) + config_path.write_text(config_str) + + # save model card + model_card_path = save_directory / "README.md" + model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {} + if not model_card_path.exists(): # do not overwrite if already exists + self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md") + + # push to the Hub if required + if push_to_hub: + kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input + if config is not None: # kwarg for `push_to_hub` + kwargs["config"] = config + if repo_id is None: + repo_id = save_directory.name # Defaults to `save_directory` name + return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs) + return None + + def _save_pretrained(self, save_directory: Path) -> None: + """ + Overwrite this method in subclass to define how to save your model. + Check out our [integration guide](../guides/integrations) for instructions. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + """ + raise NotImplementedError + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls: type[T], + pretrained_model_name_or_path: Union[str, Path], + *, + force_download: bool = False, + token: Optional[Union[str, bool]] = None, + cache_dir: Optional[Union[str, Path]] = None, + local_files_only: bool = False, + revision: Optional[str] = None, + **model_kwargs, + ) -> T: + """ + Download a model from the Huggingface Hub and instantiate it. + + Args: + pretrained_model_name_or_path (`str`, `Path`): + - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. + - Or a path to a `directory` containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. + Defaults to the latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `hf auth login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs (`dict`, *optional*): + Additional kwargs to pass to the model during initialization. + """ + model_id = str(pretrained_model_name_or_path) + config_file: Optional[str] = None + if os.path.isdir(model_id): + if constants.CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, constants.CONFIG_NAME) + else: + logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=constants.CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") + + # Read config + config = None + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + + # Decode custom types in config + for key, value in config.items(): + if key in cls._hub_mixin_init_parameters: + expected_type = cls._hub_mixin_init_parameters[key].annotation + if expected_type is not inspect.Parameter.empty: + config[key] = cls._decode_arg(expected_type, value) + + # Populate model_kwargs from config + for param in cls._hub_mixin_init_parameters.values(): + if param.name not in model_kwargs and param.name in config: + model_kwargs[param.name] = config[param.name] + + # Check if `config` argument was passed at init + if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs: + # Decode `config` argument if it was passed + config_annotation = cls._hub_mixin_init_parameters["config"].annotation + config = cls._decode_arg(config_annotation, config) + + # Forward config to model initialization + model_kwargs["config"] = config + + # Inject config if `**kwargs` are expected + if is_dataclass(cls): + for key in cls.__dataclass_fields__: + if key not in model_kwargs and key in config: + model_kwargs[key] = config[key] + elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): + for key, value in config.items(): # type: ignore[union-attr] + if key not in model_kwargs: + model_kwargs[key] = value + + # Finally, also inject if `_from_pretrained` expects it + if cls._hub_mixin_inject_config and "config" not in model_kwargs: + model_kwargs["config"] = config + + instance = cls._from_pretrained( + model_id=str(model_id), + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + **model_kwargs, + ) + + # Implicitly set the config as instance attribute if not already set by the class + # This way `config` will be available when calling `save_pretrained` or `push_to_hub`. + if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})): + instance._hub_mixin_config = config + + return instance + + @classmethod + def _from_pretrained( + cls: type[T], + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + local_files_only: bool, + token: Optional[Union[str, bool]], + **model_kwargs, + ) -> T: + """Overwrite this method in subclass to define how to load your model from pretrained. + + Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most + args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this + method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location` + parameter to set on which device the model should be loaded. + + Check out our [integration guide](../guides/integrations) for more instructions. + + Args: + model_id (`str`): + ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the + latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `hf auth login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs: + Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. + """ + raise NotImplementedError + + @validate_hf_hub_args + def push_to_hub( + self, + repo_id: str, + *, + config: Optional[Union[dict, DataclassInstance]] = None, + commit_message: str = "Push model using huggingface_hub.", + private: Optional[bool] = None, + token: Optional[str] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, + model_card_kwargs: Optional[dict[str, Any]] = None, + ) -> str: + """ + Upload model checkpoint to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + repo_id (`str`): + ID of the repository to push to (example: `"username/my-model"`). + config (`dict` or `DataclassInstance`, *optional*): + Model configuration specified as a key/value dictionary or a dataclass instance. + commit_message (`str`, *optional*): + Message to commit while pushing. + private (`bool`, *optional*): + Whether the repository created should be private. + If `None` (default), the repo will be public unless the organization's default is private. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `hf auth login`. + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. + allow_patterns (`list[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`list[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`list[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + model_card_kwargs (`dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + + Returns: + The url of the commit of your model in the given repository. + """ + api = HfApi(token=token) + repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs) + return api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) + + def generate_model_card(self, *args, **kwargs) -> ModelCard: + card = ModelCard.from_template( + card_data=self._hub_mixin_info.model_card_data, + template_str=self._hub_mixin_info.model_card_template, + repo_url=self._hub_mixin_info.repo_url, + paper_url=self._hub_mixin_info.paper_url, + docs_url=self._hub_mixin_info.docs_url, + **kwargs, + ) + return card + + +class PyTorchModelHubMixin(ModelHubMixin): + """ + Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model + is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model, + you should first set it back in training mode with `model.train()`. + + See [`ModelHubMixin`] for more details on how to use the mixin. + + Example: + + ```python + >>> import torch + >>> import torch.nn as nn + >>> from huggingface_hub import PyTorchModelHubMixin + + >>> class MyModel( + ... nn.Module, + ... PyTorchModelHubMixin, + ... library_name="keras-nlp", + ... repo_url="https://github.com/keras-team/keras-nlp", + ... paper_url="https://arxiv.org/abs/2304.12244", + ... docs_url="https://keras.io/keras_nlp/", + ... # ^ optional metadata to generate model card + ... ): + ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4): + ... super().__init__() + ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size)) + ... self.linear = nn.Linear(output_size, vocab_size) + + ... def forward(self, x): + ... return self.linear(x + self.param) + >>> model = MyModel(hidden_size=256) + + # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + + # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + + # Download and initialize weights from the Hub + >>> model = MyModel.from_pretrained("username/my-awesome-model") + >>> model.hidden_size + 256 + ``` + """ + + def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None: + tags = tags or [] + tags.append("pytorch_model_hub_mixin") + kwargs["tags"] = tags + return super().__init_subclass__(*args, **kwargs) + + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights from a Pytorch model to a local directory.""" + model_to_save = self.module if hasattr(self, "module") else self # type: ignore + save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) # type: ignore [arg-type] + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + model = cls(**model_kwargs) + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE) + return cls._load_as_safetensor(model, model_file, map_location, strict) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_safetensor(model, model_file, map_location, strict) + except EntryNotFoundError: + model_file = hf_hub_download( + repo_id=model_id, + filename=constants.PYTORCH_WEIGHTS_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_pickle(model, model_file, map_location, strict) + + @classmethod + def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True) + model.load_state_dict(state_dict, strict=strict) # type: ignore + model.eval() # type: ignore + return model + + @classmethod + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined] + load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] + if map_location != "cpu": + logger.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) # type: ignore [attr-defined] + else: + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type] + return model + + +def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance: + """Load a dataclass instance from a dictionary. + + Fields not expected by the dataclass are ignored. + """ + return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__}) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/lfs.py b/venv/lib/python3.10/site-packages/huggingface_hub/lfs.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5df3a9f56eae3589f734885d3fa696d6145248 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/lfs.py @@ -0,0 +1,395 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Git LFS related type definitions and utilities""" + +import io +import re +from dataclasses import dataclass +from math import ceil +from os.path import getsize +from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, TypedDict +from urllib.parse import unquote + +from huggingface_hub import constants + +from .utils import ( + build_hf_headers, + fix_hf_endpoint_in_url, + hf_raise_for_status, + http_backoff, + logging, + validate_hf_hub_args, +) +from .utils._lfs import SliceFileObj +from .utils.sha import sha256, sha_fileobj + + +if TYPE_CHECKING: + from ._commit_api import CommitOperationAdd + +logger = logging.get_logger(__name__) + +OID_REGEX = re.compile(r"^[0-9a-f]{40}$") + +LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload" + +LFS_HEADERS = { + "Accept": "application/vnd.git-lfs+json", + "Content-Type": "application/vnd.git-lfs+json", +} + + +@dataclass +class UploadInfo: + """ + Dataclass holding required information to determine whether a blob + should be uploaded to the hub using the LFS protocol or the regular protocol + + Args: + sha256 (`bytes`): + SHA256 hash of the blob + size (`int`): + Size in bytes of the blob + sample (`bytes`): + First 512 bytes of the blob + """ + + sha256: bytes + size: int + sample: bytes + + @classmethod + def from_path(cls, path: str): + size = getsize(path) + with io.open(path, "rb") as file: + sample = file.peek(512)[:512] + sha = sha_fileobj(file) + return cls(size=size, sha256=sha, sample=sample) + + @classmethod + def from_bytes(cls, data: bytes): + sha = sha256(data).digest() + return cls(size=len(data), sample=data[:512], sha256=sha) + + @classmethod + def from_fileobj(cls, fileobj: BinaryIO): + sample = fileobj.read(512) + fileobj.seek(0, io.SEEK_SET) + sha = sha_fileobj(fileobj) + size = fileobj.tell() + fileobj.seek(0, io.SEEK_SET) + return cls(size=size, sha256=sha, sample=sample) + + +@validate_hf_hub_args +def post_lfs_batch_info( + upload_infos: Iterable[UploadInfo], + token: Optional[str], + repo_type: str, + repo_id: str, + revision: Optional[str] = None, + endpoint: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + transfers: Optional[list[str]] = None, +) -> tuple[list[dict], list[dict], Optional[str]]: + """ + Requests the LFS batch endpoint to retrieve upload instructions + + Learn more: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md + + Args: + upload_infos (`Iterable` of `UploadInfo`): + `UploadInfo` for the files that are being uploaded, typically obtained + from `CommitOperationAdd.upload_info` + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The git revision to upload to. + headers (`dict`, *optional*): + Additional headers to include in the request + transfers (`list`, *optional*): + List of transfer methods to use. Defaults to ["basic", "multipart"]. + + Returns: + `LfsBatchInfo`: 3-tuple: + - First element is the list of upload instructions from the server + - Second element is a list of errors, if any + - Third element is the chosen transfer adapter if provided by the server (e.g. "basic", "multipart", "xet") + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If an argument is invalid or the server response is malformed. + [`HfHubHTTPError`] + If the server returned an error. + """ + endpoint = endpoint if endpoint is not None else constants.ENDPOINT + url_prefix = "" + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + url_prefix = constants.REPO_TYPES_URL_PREFIXES[repo_type] + batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch" + payload: dict = { + "operation": "upload", + "transfers": transfers if transfers is not None else ["basic", "multipart"], + "objects": [ + { + "oid": upload.sha256.hex(), + "size": upload.size, + } + for upload in upload_infos + ], + "hash_algo": "sha256", + } + if revision is not None: + payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted' + + headers = { + **LFS_HEADERS, + **build_hf_headers(token=token), + **(headers or {}), + } + resp = http_backoff("POST", batch_url, headers=headers, json=payload) + hf_raise_for_status(resp) + batch_info = resp.json() + + objects = batch_info.get("objects", None) + if not isinstance(objects, list): + raise ValueError("Malformed response from server") + + chosen_transfer = batch_info.get("transfer") + chosen_transfer = chosen_transfer if isinstance(chosen_transfer, str) else None + + return ( + [_validate_batch_actions(obj) for obj in objects if "error" not in obj], + [_validate_batch_error(obj) for obj in objects if "error" in obj], + chosen_transfer, + ) + + +class PayloadPartT(TypedDict): + partNumber: int + etag: str + + +class CompletionPayloadT(TypedDict): + """Payload that will be sent to the Hub when uploading multi-part.""" + + oid: str + parts: list[PayloadPartT] + + +def lfs_upload( + operation: "CommitOperationAdd", + lfs_batch_action: dict, + token: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + endpoint: Optional[str] = None, +) -> None: + """ + Handles uploading a given object to the Hub with the LFS protocol. + + Can be a No-op if the content of the file is already present on the hub large file storage. + + Args: + operation (`CommitOperationAdd`): + The add operation triggering this upload. + lfs_batch_action (`dict`): + Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for + more details. + headers (`dict`, *optional*): + Headers to include in the request, including authentication and user agent headers. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `lfs_batch_action` is improperly formatted + [`HfHubHTTPError`] + If the upload resulted in an error + """ + # 0. If LFS file is already present, skip upload + _validate_batch_actions(lfs_batch_action) + actions = lfs_batch_action.get("actions") + if actions is None: + # The file was already uploaded + logger.debug(f"Content of file {operation.path_in_repo} is already present upstream - skipping upload") + return + + # 1. Validate server response (check required keys in dict) + upload_action = lfs_batch_action["actions"]["upload"] + _validate_lfs_action(upload_action) + verify_action = lfs_batch_action["actions"].get("verify") + if verify_action is not None: + _validate_lfs_action(verify_action) + + # 2. Upload file (either single part or multi-part) + header = upload_action.get("header", {}) + chunk_size = header.get("chunk_size") + upload_url = fix_hf_endpoint_in_url(upload_action["href"], endpoint=endpoint) + if chunk_size is not None: + try: + chunk_size = int(chunk_size) + except (ValueError, TypeError): + raise ValueError( + f"Malformed response from LFS batch endpoint: `chunk_size` should be an integer. Got '{chunk_size}'." + ) + _upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_url) + else: + _upload_single_part(operation=operation, upload_url=upload_url) + + # 3. Verify upload went well + if verify_action is not None: + _validate_lfs_action(verify_action) + verify_url = fix_hf_endpoint_in_url(verify_action["href"], endpoint) + verify_resp = http_backoff( + "POST", + verify_url, + headers=build_hf_headers(token=token, headers=headers), + json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size}, + ) + hf_raise_for_status(verify_resp) + logger.debug(f"{operation.path_in_repo}: Upload successful") + + +def _validate_lfs_action(lfs_action: dict): + """validates response from the LFS batch endpoint""" + if not ( + isinstance(lfs_action.get("href"), str) + and (lfs_action.get("header") is None or isinstance(lfs_action.get("header"), dict)) + ): + raise ValueError("lfs_action is improperly formatted") + return lfs_action + + +def _validate_batch_actions(lfs_batch_actions: dict): + """validates response from the LFS batch endpoint""" + if not (isinstance(lfs_batch_actions.get("oid"), str) and isinstance(lfs_batch_actions.get("size"), int)): + raise ValueError("lfs_batch_actions is improperly formatted") + + upload_action = lfs_batch_actions.get("actions", {}).get("upload") + verify_action = lfs_batch_actions.get("actions", {}).get("verify") + if upload_action is not None: + _validate_lfs_action(upload_action) + if verify_action is not None: + _validate_lfs_action(verify_action) + return lfs_batch_actions + + +def _validate_batch_error(lfs_batch_error: dict): + """validates response from the LFS batch endpoint""" + if not (isinstance(lfs_batch_error.get("oid"), str) and isinstance(lfs_batch_error.get("size"), int)): + raise ValueError("lfs_batch_error is improperly formatted") + error_info = lfs_batch_error.get("error") + if not ( + isinstance(error_info, dict) + and isinstance(error_info.get("message"), str) + and isinstance(error_info.get("code"), int) + ): + raise ValueError("lfs_batch_error is improperly formatted") + return lfs_batch_error + + +def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> None: + """ + Uploads `fileobj` as a single PUT HTTP request (basic LFS transfer protocol) + + Args: + upload_url (`str`): + The URL to PUT the file to. + fileobj: + The file-like object holding the data to upload. + + Raises: + [`HfHubHTTPError`] + If the upload resulted in an error. + """ + with operation.as_file(with_tqdm=True) as fileobj: + # S3 might raise a transient 500 error -> let's retry if that happens + response = http_backoff("PUT", upload_url, data=fileobj) + hf_raise_for_status(response) + + +def _upload_multi_part(operation: "CommitOperationAdd", header: dict, chunk_size: int, upload_url: str) -> None: + """ + Uploads file using HF multipart LFS transfer protocol. + """ + # 1. Get upload URLs for each part + sorted_parts_urls = _get_sorted_parts_urls(header=header, upload_info=operation.upload_info, chunk_size=chunk_size) + + # 2. Upload parts (pure Python) + response_headers = _upload_parts_iteratively( + operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size + ) + + # 3. Send completion request + # NOTE: `upload_url` is the Hub completion endpoint (not the S3 upload URLs). + completion_res = http_backoff( + "POST", + upload_url, + json=_get_completion_payload(response_headers, operation.upload_info.sha256.hex()), + headers=LFS_HEADERS, + ) + hf_raise_for_status(completion_res) + + +def _get_sorted_parts_urls(header: dict, upload_info: UploadInfo, chunk_size: int) -> list[str]: + sorted_part_upload_urls = [ + upload_url + for _, upload_url in sorted( + [ + (int(part_num, 10), upload_url) + for part_num, upload_url in header.items() + if part_num.isdigit() and len(part_num) > 0 + ], + key=lambda t: t[0], + ) + ] + num_parts = len(sorted_part_upload_urls) + if num_parts != ceil(upload_info.size / chunk_size): + raise ValueError("Invalid server response to upload large LFS file") + return sorted_part_upload_urls + + +def _get_completion_payload(response_headers: list[dict], oid: str) -> CompletionPayloadT: + parts: list[PayloadPartT] = [] + for part_number, header in enumerate(response_headers): + etag = header.get("etag") + if etag is None or etag == "": + raise ValueError(f"Invalid etag (`{etag}`) returned for part {part_number + 1}") + parts.append( + { + "partNumber": part_number + 1, + "etag": etag, + } + ) + return {"oid": oid, "parts": parts} + + +def _upload_parts_iteratively( + operation: "CommitOperationAdd", sorted_parts_urls: list[str], chunk_size: int +) -> list[dict]: + headers = [] + with operation.as_file(with_tqdm=True) as fileobj: + for part_idx, part_upload_url in enumerate(sorted_parts_urls): + with SliceFileObj( + fileobj, + seek_from=chunk_size * part_idx, + read_limit=chunk_size, + ) as fileobj_slice: + # S3 might raise a transient 500 error -> let's retry if that happens + part_upload_res = http_backoff("PUT", part_upload_url, data=fileobj_slice) + hf_raise_for_status(part_upload_res) + headers.append(part_upload_res.headers) + return headers # type: ignore diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/py.typed b/venv/lib/python3.10/site-packages/huggingface_hub/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/repocard.py b/venv/lib/python3.10/site-packages/huggingface_hub/repocard.py new file mode 100644 index 0000000000000000000000000000000000000000..683162c9a666659e8c35c7a9e3c57c824c0e3e83 --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/repocard.py @@ -0,0 +1,826 @@ +import os +import re +from pathlib import Path +from typing import Any, Literal, Optional, Union + +import yaml + +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.hf_api import upload_file +from huggingface_hub.repocard_data import ( + CardData, + DatasetCardData, + EvalResult, + ModelCardData, + SpaceCardData, + eval_results_to_model_index, + model_index_to_eval_results, +) +from huggingface_hub.utils import HfHubHTTPError, get_session, hf_raise_for_status, is_jinja_available, yaml_dump + +from . import constants +from .errors import EntryNotFoundError +from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args + + +logger = logging.get_logger(__name__) + + +TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md" +TEMPLATE_DATASETCARD_PATH = Path(__file__).parent / "templates" / "datasetcard_template.md" + +# exact same regex as in the Hub server. Please keep in sync. +# See https://github.com/huggingface/moon-landing/blob/main/server/lib/ViewMarkdown.ts#L18 +REGEX_YAML_BLOCK = re.compile(r"^(\s*---[\r\n]+)([\S\s]*?)([\r\n]+---(\r\n|\n|$))") + + +class RepoCard: + card_data_class = CardData + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "model" + + def __init__(self, content: str, ignore_metadata_errors: bool = False): + """Initialize a RepoCard from string content. The content should be a + Markdown file with a YAML block at the beginning and a Markdown body. + + Args: + content (`str`): The content of the Markdown file. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> text = ''' + ... --- + ... language: en + ... license: mit + ... --- + ... + ... # My repo + ... ''' + >>> card = RepoCard(text) + >>> card.data.to_dict() + {'language': 'en', 'license': 'mit'} + >>> card.text + '\\n# My repo\\n' + + ``` + > [!TIP] + > Raises the following error: + > + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > when the content of the repo card metadata is not a dictionary. + """ + + # Set the content of the RepoCard, as well as underlying .data and .text attributes. + # See the `content` property setter for more details. + self.ignore_metadata_errors = ignore_metadata_errors + self.content = content + + @property + def content(self): + """The content of the RepoCard, including the YAML block and the Markdown body.""" + line_break = _detect_line_ending(self._content) or "\n" + return f"---{line_break}{self.data.to_yaml(line_break=line_break, original_order=self._original_order)}{line_break}---{line_break}{self.text}" + + @content.setter + def content(self, content: str): + """Set the content of the RepoCard.""" + self._content = content + + match = REGEX_YAML_BLOCK.search(content) + if match: + # Metadata found in the YAML block + yaml_block = match.group(2) + self.text = content[match.end() :] + data_dict = yaml.safe_load(yaml_block) + + if data_dict is None: + data_dict = {} + + # The YAML block's data should be a dictionary + if not isinstance(data_dict, dict): + raise ValueError("repo card metadata block should be a dict") + else: + # Model card without metadata... create empty metadata + logger.warning("Repo card metadata block was not found. Setting CardData to empty.") + data_dict = {} + self.text = content + + self.data = self.card_data_class(**data_dict, ignore_metadata_errors=self.ignore_metadata_errors) + self._original_order = list(data_dict.keys()) + + def __str__(self): + return self.content + + def save(self, filepath: Union[Path, str]): + r"""Save a RepoCard to a file. + + Args: + filepath (`Union[Path, str]`): Filepath to the markdown file to save. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> card = RepoCard("---\nlanguage: en\n---\n# This is a test repo card") + >>> card.save("/tmp/test.md") + + ``` + """ + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + # Preserve newlines as in the existing file. + with open(filepath, mode="w", newline="", encoding="utf-8") as f: + f.write(str(self)) + + @classmethod + def load( + cls, + repo_id_or_path: Union[str, Path], + repo_type: Optional[str] = None, + token: Optional[str] = None, + ignore_metadata_errors: bool = False, + ): + """Initialize a RepoCard from a Hugging Face Hub repo's README.md or a local filepath. + + Args: + repo_id_or_path (`Union[str, Path]`): + The repo ID associated with a Hugging Face Hub repo or a local filepath. + repo_type (`str`, *optional*): + The type of Hugging Face repo to push to. Defaults to None, which will use "model". Other options + are "dataset" and "space". Not used when loading from a local filepath. If this is called from a child + class, the default value will be the child class's `repo_type`. + token (`str`, *optional*): + Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + + Returns: + [`huggingface_hub.repocard.RepoCard`]: The RepoCard (or subclass) initialized from the repo's + README.md file or filepath. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> card = RepoCard.load("nateraw/food") + >>> assert card.data.tags == ["generated_from_trainer", "image-classification", "pytorch"] + + ``` + """ + + if Path(repo_id_or_path).is_file(): + card_path = Path(repo_id_or_path) + elif isinstance(repo_id_or_path, str): + card_path = Path( + hf_hub_download( + repo_id_or_path, + constants.REPOCARD_NAME, + repo_type=repo_type or cls.repo_type, + token=token, + ) + ) + else: + raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).") + + # Preserve newlines in the existing file. + with card_path.open(mode="r", newline="", encoding="utf-8") as f: + return cls(f.read(), ignore_metadata_errors=ignore_metadata_errors) + + def validate(self, repo_type: Optional[str] = None): + """Validates card against Hugging Face Hub's card validation logic. + Using this function requires access to the internet, so it is only called + internally by [`huggingface_hub.repocard.RepoCard.push_to_hub`]. + + Args: + repo_type (`str`, *optional*, defaults to "model"): + The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". + If this function is called from a child class, the default will be the child class's `repo_type`. + + > [!TIP] + > Raises the following errors: + > + > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + > if the card fails validation checks. + > - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + > if the request to the Hub API fails for any other reason. + """ + + # If repo type is provided, otherwise, use the repo type of the card. + repo_type = repo_type or self.repo_type + + body = { + "repoType": repo_type, + "content": str(self), + } + headers = {"Accept": "text/plain"} + + try: + response = get_session().post("https://huggingface.co/api/validate-yaml", json=body, headers=headers) + hf_raise_for_status(response) + except HfHubHTTPError as exc: + if response.status_code == 400: + raise ValueError(response.text) + else: + raise exc + + def push_to_hub( + self, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ): + """Push a RepoCard to a Hugging Face Hub repo. + + Args: + repo_id (`str`): + The repo ID of the Hugging Face Hub repo to push to. Example: "nateraw/food". + token (`str`, *optional*): + Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to + the stored token. + repo_type (`str`, *optional*, defaults to "model"): + The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this + function is called by a child class, it will default to the child class's `repo_type`. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. + commit_description (`str`, *optional*) + The description of the generated commit. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + create_pr (`bool`, *optional*): + Whether or not to create a Pull Request with this commit. Defaults to `False`. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed too concurrently. + Returns: + `str`: URL of the commit which updated the card metadata. + """ + + # If repo type is provided, otherwise, use the repo type of the card. + repo_type = repo_type or self.repo_type + + # Validate card before pushing to hub + self.validate(repo_type=repo_type) + + with SoftTemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) / constants.REPOCARD_NAME + tmp_path.write_text(str(self), encoding="utf-8") + url = upload_file( + path_or_fileobj=str(tmp_path), + path_in_repo=constants.REPOCARD_NAME, + repo_id=repo_id, + token=token, + repo_type=repo_type, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + revision=revision, + parent_commit=parent_commit, + ) + return url + + @classmethod + def from_template( + cls, + card_data: CardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a RepoCard from a template. By default, it uses the default template. + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.CardData`): + A huggingface_hub.CardData instance containing the metadata you want to include in the YAML + header of the repo card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.repocard.RepoCard`]: A RepoCard instance with the specified card data and content from the + template. + """ + if is_jinja_available(): + import jinja2 + else: + raise ImportError( + "Using RepoCard.from_template requires Jinja2 to be installed. Please" + " install it with `pip install Jinja2`." + ) + + kwargs = card_data.to_dict().copy() + kwargs.update(template_kwargs) # Template_kwargs have priority + + if template_path is not None: + template_str = Path(template_path).read_text() + if template_str is None: + template_str = Path(cls.default_template_path).read_text() + template = jinja2.Template(template_str) + content = template.render(card_data=card_data.to_yaml(), **kwargs) + return cls(content) + + +class ModelCard(RepoCard): + card_data_class = ModelCardData # type: ignore[assignment] + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "model" + + @classmethod + def from_template( # type: ignore # violates Liskov property but easier to use + cls, + card_data: ModelCardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a ModelCard from a template. By default, it uses the default template, which can be found here: + https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.ModelCardData`): + A huggingface_hub.ModelCardData instance containing the metadata you want to include in the YAML + header of the model card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.ModelCard`]: A ModelCard instance with the specified card data and content from the + template. + + Example: + ```python + >>> from huggingface_hub import ModelCard, ModelCardData, EvalResult + + >>> # Using the Default Template + >>> card_data = ModelCardData( + ... language='en', + ... license='mit', + ... library_name='timm', + ... tags=['image-classification', 'resnet'], + ... datasets=['beans'], + ... metrics=['accuracy'], + ... ) + >>> card = ModelCard.from_template( + ... card_data, + ... model_description='This model does x + y...' + ... ) + + >>> # Including Evaluation Results + >>> card_data = ModelCardData( + ... language='en', + ... tags=['image-classification', 'resnet'], + ... eval_results=[ + ... EvalResult( + ... task_type='image-classification', + ... dataset_type='beans', + ... dataset_name='Beans', + ... metric_type='accuracy', + ... metric_value=0.9, + ... ), + ... ], + ... model_name='my-cool-model', + ... ) + >>> card = ModelCard.from_template(card_data) + + >>> # Using a Custom Template + >>> card_data = ModelCardData( + ... language='en', + ... tags=['image-classification', 'resnet'] + ... ) + >>> card = ModelCard.from_template( + ... card_data=card_data, + ... template_path='./src/huggingface_hub/templates/modelcard_template.md', + ... custom_template_var='custom value', # will be replaced in template if it exists + ... ) + + ``` + """ + return super().from_template(card_data, template_path, template_str, **template_kwargs) + + +class DatasetCard(RepoCard): + card_data_class = DatasetCardData # type: ignore[assignment] + default_template_path = TEMPLATE_DATASETCARD_PATH + repo_type = "dataset" + + @classmethod + def from_template( # type: ignore # violates Liskov property but easier to use + cls, + card_data: DatasetCardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here: + https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.DatasetCardData`): + A huggingface_hub.DatasetCardData instance containing the metadata you want to include in the YAML + header of the dataset card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.DatasetCard`]: A DatasetCard instance with the specified card data and content from the + template. + + Example: + ```python + >>> from huggingface_hub import DatasetCard, DatasetCardData + + >>> # Using the Default Template + >>> card_data = DatasetCardData( + ... language='en', + ... license='mit', + ... annotations_creators='crowdsourced', + ... task_categories=['text-classification'], + ... task_ids=['sentiment-classification', 'text-scoring'], + ... multilinguality='monolingual', + ... pretty_name='My Text Classification Dataset', + ... ) + >>> card = DatasetCard.from_template( + ... card_data, + ... pretty_name=card_data.pretty_name, + ... ) + + >>> # Using a Custom Template + >>> card_data = DatasetCardData( + ... language='en', + ... license='mit', + ... ) + >>> card = DatasetCard.from_template( + ... card_data=card_data, + ... template_path='./src/huggingface_hub/templates/datasetcard_template.md', + ... custom_template_var='custom value', # will be replaced in template if it exists + ... ) + + ``` + """ + return super().from_template(card_data, template_path, template_str, **template_kwargs) + + +class SpaceCard(RepoCard): + card_data_class = SpaceCardData # type: ignore[assignment] + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "space" + + +def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722 + """Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines. + + Uses same implementation as in Hub server, keep it in sync. + + Returns: + str: The detected line ending of the string. + """ + cr = content.count("\r") + lf = content.count("\n") + crlf = content.count("\r\n") + if cr + lf == 0: + return None + if crlf == cr and crlf == lf: + return "\r\n" + if cr > lf: + return "\r" + else: + return "\n" + + +def metadata_load(local_path: Union[str, Path]) -> Optional[dict]: + content = Path(local_path).read_text() + match = REGEX_YAML_BLOCK.search(content) + if match: + yaml_block = match.group(2) + data = yaml.safe_load(yaml_block) + if data is None or isinstance(data, dict): + return data + raise ValueError("repo card metadata block should be a dict") + else: + return None + + +def metadata_save(local_path: Union[str, Path], data: dict) -> None: + """ + Save the metadata dict in the upper YAML part Trying to preserve newlines as + in the existing file. Docs about open() with newline="" parameter: + https://docs.python.org/3/library/functions.html?highlight=open#open Does + not work with "^M" linebreaks, which are replaced by \n + """ + line_break = "\n" + content = "" + # try to detect existing newline character + if os.path.exists(local_path): + with open(local_path, "r", newline="", encoding="utf8") as readme: + content = readme.read() + if isinstance(readme.newlines, tuple): + line_break = readme.newlines[0] + elif isinstance(readme.newlines, str): + line_break = readme.newlines + + # creates a new file if it not + with open(local_path, "w", newline="", encoding="utf8") as readme: + data_yaml = yaml_dump(data, sort_keys=False, line_break=line_break) + # sort_keys: keep dict order + match = REGEX_YAML_BLOCK.search(content) + if match: + output = content[: match.start()] + f"---{line_break}{data_yaml}---{line_break}" + content[match.end() :] + else: + output = f"---{line_break}{data_yaml}---{line_break}{content}" + + readme.write(output) + readme.close() + + +def metadata_eval_result( + *, + model_pretty_name: str, + task_pretty_name: str, + task_id: str, + metrics_pretty_name: str, + metrics_id: str, + metrics_value: Any, + dataset_pretty_name: str, + dataset_id: str, + metrics_config: Optional[str] = None, + metrics_verified: bool = False, + dataset_config: Optional[str] = None, + dataset_split: Optional[str] = None, + dataset_revision: Optional[str] = None, + metrics_verification_token: Optional[str] = None, +) -> dict: + """ + Creates a metadata dict with the result from a model evaluated on a dataset. + + Args: + model_pretty_name (`str`): + The name of the model in natural language. + task_pretty_name (`str`): + The name of a task in natural language. + task_id (`str`): + Example: automatic-speech-recognition. A task id. + metrics_pretty_name (`str`): + A name for the metric in natural language. Example: Test WER. + metrics_id (`str`): + Example: wer. A metric id from https://hf.co/metrics. + metrics_value (`Any`): + The value from the metric. Example: 20.0 or "20.0 ± 1.2". + dataset_pretty_name (`str`): + The name of the dataset in natural language. + dataset_id (`str`): + Example: common_voice. A dataset id from https://hf.co/datasets. + metrics_config (`str`, *optional*): + The name of the metric configuration used in `load_metric()`. + Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + metrics_verified (`bool`, *optional*, defaults to `False`): + Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + dataset_config (`str`, *optional*): + Example: fr. The name of the dataset configuration used in `load_dataset()`. + dataset_split (`str`, *optional*): + Example: test. The name of the dataset split used in `load_dataset()`. + dataset_revision (`str`, *optional*): + Example: 5503434ddd753f426f4b38109466949a1217c2bb. The name of the dataset dataset revision + used in `load_dataset()`. + metrics_verification_token (`bool`, *optional*): + A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + + Returns: + `dict`: a metadata dict with the result from a model evaluated on a dataset. + + Example: + ```python + >>> from huggingface_hub import metadata_eval_result + >>> results = metadata_eval_result( + ... model_pretty_name="RoBERTa fine-tuned on ReactionGIF", + ... task_pretty_name="Text Classification", + ... task_id="text-classification", + ... metrics_pretty_name="Accuracy", + ... metrics_id="accuracy", + ... metrics_value=0.2662102282047272, + ... dataset_pretty_name="ReactionJPEG", + ... dataset_id="julien-c/reactionjpeg", + ... dataset_config="default", + ... dataset_split="test", + ... ) + >>> results == { + ... 'model-index': [ + ... { + ... 'name': 'RoBERTa fine-tuned on ReactionGIF', + ... 'results': [ + ... { + ... 'task': { + ... 'type': 'text-classification', + ... 'name': 'Text Classification' + ... }, + ... 'dataset': { + ... 'name': 'ReactionJPEG', + ... 'type': 'julien-c/reactionjpeg', + ... 'config': 'default', + ... 'split': 'test' + ... }, + ... 'metrics': [ + ... { + ... 'type': 'accuracy', + ... 'value': 0.2662102282047272, + ... 'name': 'Accuracy', + ... 'verified': False + ... } + ... ] + ... } + ... ] + ... } + ... ] + ... } + True + + ``` + """ + + return { + "model-index": eval_results_to_model_index( + model_name=model_pretty_name, + eval_results=[ + EvalResult( + task_name=task_pretty_name, + task_type=task_id, + metric_name=metrics_pretty_name, + metric_type=metrics_id, + metric_value=metrics_value, + dataset_name=dataset_pretty_name, + dataset_type=dataset_id, + metric_config=metrics_config, + verified=metrics_verified, + verify_token=metrics_verification_token, + dataset_config=dataset_config, + dataset_split=dataset_split, + dataset_revision=dataset_revision, + ) + ], + ) + } + + +@validate_hf_hub_args +def metadata_update( + repo_id: str, + metadata: dict, + *, + repo_type: Optional[str] = None, + overwrite: bool = False, + token: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + revision: Optional[str] = None, + create_pr: bool = False, + parent_commit: Optional[str] = None, +) -> str: + """ + Updates the metadata in the README.md of a repository on the Hugging Face Hub. + If the README.md file doesn't exist yet, a new one is created with metadata and + the default ModelCard or DatasetCard template. For `space` repo, an error is thrown + as a Space cannot exist without a `README.md` file. + + Args: + repo_id (`str`): + The name of the repository. + metadata (`dict`): + A dictionary containing the metadata to be updated. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if updating to a dataset or space, + `None` or `"model"` if updating to a model. Default is `None`. + overwrite (`bool`, *optional*, defaults to `False`): + If set to `True` an existing field can be overwritten, otherwise + attempting to overwrite an existing field will cause an error. + token (`str`, *optional*): + The Hugging Face authentication token. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Update metadata with huggingface_hub"` + commit_description (`str` *optional*) + The description of the generated commit + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the + `"main"` branch. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `revision` with that commit. + Defaults to `False`. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed too concurrently. + Returns: + `str`: URL of the commit which updated the card metadata. + + Example: + ```python + >>> from huggingface_hub import metadata_update + >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF', + ... 'results': [{'dataset': {'name': 'ReactionGIF', + ... 'type': 'julien-c/reactiongif'}, + ... 'metrics': [{'name': 'Recall', + ... 'type': 'recall', + ... 'value': 0.7762102282047272}], + ... 'task': {'name': 'Text Classification', + ... 'type': 'text-classification'}}]}]} + >>> url = metadata_update("hf-internal-testing/reactiongif-roberta-card", metadata) + + ``` + """ + commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub" + + # Card class given repo_type + card_class: type[RepoCard] + if repo_type is None or repo_type == "model": + card_class = ModelCard + elif repo_type == "dataset": + card_class = DatasetCard + elif repo_type == "space": + card_class = RepoCard + else: + raise ValueError(f"Unknown repo_type: {repo_type}") + + # Either load repo_card from the Hub or create an empty one. + # NOTE: Will not create the repo if it doesn't exist. + try: + card = card_class.load(repo_id, token=token, repo_type=repo_type) + except EntryNotFoundError: + if repo_type == "space": + raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.") + + # Initialize a ModelCard or DatasetCard from default template and no data. + # Cast to the concrete expected card type to satisfy type checkers. + card = card_class.from_template(CardData()) # type: ignore[return-value] + + for key, value in metadata.items(): + if key == "model-index": + # if the new metadata doesn't include a name, either use existing one or repo name + if "name" not in value[0]: + value[0]["name"] = getattr(card, "model_name", repo_id) + model_name, new_results = model_index_to_eval_results(value) + if card.data.eval_results is None: + card.data.eval_results = new_results + card.data.model_name = model_name + else: + existing_results = card.data.eval_results + + # Iterate over new results + # Iterate over existing results + # If both results describe the same metric but value is different: + # If overwrite=True: overwrite the metric value + # Else: raise ValueError + # Else: append new result to existing ones. + for new_result in new_results: + result_found = False + for existing_result in existing_results: + if new_result.is_equal_except_value(existing_result): + if new_result != existing_result and not overwrite: + raise ValueError( + "You passed a new value for the existing metric" + f" 'name: {new_result.metric_name}, type: " + f"{new_result.metric_type}'. Set `overwrite=True`" + " to overwrite existing metrics." + ) + result_found = True + existing_result.metric_value = new_result.metric_value + if existing_result.verified is True: + existing_result.verify_token = new_result.verify_token + if not result_found: + card.data.eval_results.append(new_result) + else: + # Any metadata that is not a result metric + if card.data.get(key) is not None and not overwrite and card.data.get(key) != value: + raise ValueError( + f"You passed a new value for the existing meta data field '{key}'." + " Set `overwrite=True` to overwrite existing metadata." + ) + else: + card.data[key] = value + + return card.push_to_hub( + repo_id, + token=token, + repo_type=repo_type, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + revision=revision, + parent_commit=parent_commit, + ) diff --git a/venv/lib/python3.10/site-packages/huggingface_hub/repocard_data.py b/venv/lib/python3.10/site-packages/huggingface_hub/repocard_data.py new file mode 100644 index 0000000000000000000000000000000000000000..002ed0b4224477b4ab890ea59a4cf089300bf03c --- /dev/null +++ b/venv/lib/python3.10/site-packages/huggingface_hub/repocard_data.py @@ -0,0 +1,770 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Optional, Union + +from huggingface_hub.utils import logging, yaml_dump + + +logger = logging.get_logger(__name__) + + +@dataclass +class EvalResult: + """ + Flattened representation of individual evaluation results found in model-index of Model Cards. + + For more information on the model-index spec, see https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1. + + Args: + task_type (`str`): + The task identifier. Example: "image-classification". + dataset_type (`str`): + The dataset identifier. Example: "common_voice". Use dataset id from https://hf.co/datasets. + dataset_name (`str`): + A pretty name for the dataset. Example: "Common Voice (French)". + metric_type (`str`): + The metric identifier. Example: "wer". Use metric id from https://hf.co/metrics. + metric_value (`Any`): + The metric value. Example: 0.9 or "20.0 ± 1.2". + task_name (`str`, *optional*): + A pretty name for the task. Example: "Speech Recognition". + dataset_config (`str`, *optional*): + The name of the dataset configuration used in `load_dataset()`. + Example: fr in `load_dataset("common_voice", "fr")`. See the `datasets` docs for more info: + https://hf.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name + dataset_split (`str`, *optional*): + The split used in `load_dataset()`. Example: "test". + dataset_revision (`str`, *optional*): + The revision (AKA Git Sha) of the dataset used in `load_dataset()`. + Example: 5503434ddd753f426f4b38109466949a1217c2bb + dataset_args (`dict[str, Any]`, *optional*): + The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}` + metric_name (`str`, *optional*): + A pretty name for the metric. Example: "Test WER". + metric_config (`str`, *optional*): + The name of the metric configuration used in `load_metric()`. + Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations + metric_args (`dict[str, Any]`, *optional*): + The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4 + verified (`bool`, *optional*): + Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + verify_token (`str`, *optional*): + A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + source_name (`str`, *optional*): + The name of the source of the evaluation result. Example: "Open LLM Leaderboard". + source_url (`str`, *optional*): + The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard". + """ + + # Required + + # The task identifier + # Example: automatic-speech-recognition + task_type: str + + # The dataset identifier + # Example: common_voice. Use dataset id from https://hf.co/datasets + dataset_type: str + + # A pretty name for the dataset. + # Example: Common Voice (French) + dataset_name: str + + # The metric identifier + # Example: wer. Use metric id from https://hf.co/metrics + metric_type: str + + # Value of the metric. + # Example: 20.0 or "20.0 ± 1.2" + metric_value: Any + + # Optional + + # A pretty name for the task. + # Example: Speech Recognition + task_name: Optional[str] = None + + # The name of the dataset configuration used in `load_dataset()`. + # Example: fr in `load_dataset("common_voice", "fr")`. + # See the `datasets` docs for more info: + # https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name + dataset_config: Optional[str] = None + + # The split used in `load_dataset()`. + # Example: test + dataset_split: Optional[str] = None + + # The revision (AKA Git Sha) of the dataset used in `load_dataset()`. + # Example: 5503434ddd753f426f4b38109466949a1217c2bb + dataset_revision: Optional[str] = None + + # The arguments passed during `Metric.compute()`. + # Example for `bleu`: max_order: 4 + dataset_args: Optional[dict[str, Any]] = None + + # A pretty name for the metric. + # Example: Test WER + metric_name: Optional[str] = None + + # The name of the metric configuration used in `load_metric()`. + # Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + # See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations + metric_config: Optional[str] = None + + # The arguments passed during `Metric.compute()`. + # Example for `bleu`: max_order: 4 + metric_args: Optional[dict[str, Any]] = None + + # Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + verified: Optional[bool] = None + + # A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + verify_token: Optional[str] = None + + # The name of the source of the evaluation result. + # Example: Open LLM Leaderboard + source_name: Optional[str] = None + + # The URL of the source of the evaluation result. + # Example: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard + source_url: Optional[str] = None + + @property + def unique_identifier(self) -> tuple: + """Returns a tuple that uniquely identifies this evaluation.""" + return ( + self.task_type, + self.dataset_type, + self.dataset_config, + self.dataset_split, + self.dataset_revision, + ) + + def is_equal_except_value(self, other: "EvalResult") -> bool: + """ + Return True if `self` and `other` describe exactly the same metric but with a + different value. + """ + for key, _ in self.__dict__.items(): + if key == "metric_value": + continue + # For metrics computed by Hugging Face's evaluation service, `verify_token` is derived from `metric_value`, + # so we exclude it here in the comparison. + if key != "verify_token" and getattr(self, key) != getattr(other, key): + return False + return True + + def __post_init__(self) -> None: + if self.source_name is not None and self.source_url is None: + raise ValueError("If `source_name` is provided, `source_url` must also be provided.") + + +@dataclass +class CardData: + """Structure containing metadata from a RepoCard. + + [`CardData`] is the parent class of [`ModelCardData`] and [`DatasetCardData`]. + + Metadata can be exported as a dictionary or YAML. Export can be customized to alter the representation of the data + (example: flatten evaluation results). `CardData` behaves as a dictionary (can get, pop, set values) but do not + inherit from `dict` to allow this export step. + """ + + def __init__(self, ignore_metadata_errors: bool = False, **kwargs): + self.__dict__.update(kwargs) + + def to_dict(self): + """Converts CardData to a dict. + + Returns: + `dict`: CardData represented as a dictionary ready to be dumped to a YAML + block for inclusion in a README.md file. + """ + + data_dict = copy.deepcopy(self.__dict__) + self._to_dict(data_dict) + return {key: value for key, value in data_dict.items() if value is not None} + + def _to_dict(self, data_dict): + """Use this method in child classes to alter the dict representation of the data. Alter the dict in-place. + + Args: + data_dict (`dict`): The raw dict representation of the card data. + """ + pass + + def to_yaml(self, line_break=None, original_order: Optional[list[str]] = None) -> str: + """Dumps CardData to a YAML block for inclusion in a README.md file. + + Args: + line_break (str, *optional*): + The line break to use when dumping to yaml. + + Returns: + `str`: CardData represented as a YAML block. + """ + if original_order: + self.__dict__ = { + k: self.__dict__[k] + for k in original_order + list(set(self.__dict__.keys()) - set(original_order)) + if k in self.__dict__ + } + return yaml_dump(self.to_dict(), sort_keys=False, line_break=line_break).strip() + + def __repr__(self): + return repr(self.__dict__) + + def __str__(self): + return self.to_yaml() + + def get(self, key: str, default: Any = None) -> Any: + """Get value for a given metadata key.""" + value = self.__dict__.get(key) + return default if value is None else value + + def pop(self, key: str, default: Any = None) -> Any: + """Pop value for a given metadata key.""" + return self.__dict__.pop(key, default) + + def __getitem__(self, key: str) -> Any: + """Get value for a given metadata key.""" + return self.__dict__[key] + + def __setitem__(self, key: str, value: Any) -> None: + """Set value for a given metadata key.""" + self.__dict__[key] = value + + def __contains__(self, key: str) -> bool: + """Check if a given metadata key is set.""" + return key in self.__dict__ + + def __len__(self) -> int: + """Return the number of metadata keys set.""" + return len(self.__dict__) + + +def _validate_eval_results( + eval_results: Optional[Union[EvalResult, list[EvalResult]]], + model_name: Optional[str], +) -> list[EvalResult]: + if eval_results is None: + return [] + if isinstance(eval_results, EvalResult): + eval_results = [eval_results] + if not isinstance(eval_results, list) or not all(isinstance(r, EvalResult) for r in eval_results): + raise ValueError( + f"`eval_results` should be of type `EvalResult` or a list of `EvalResult`, got {type(eval_results)}." + ) + if model_name is None: + raise ValueError("Passing `eval_results` requires `model_name` to be set.") + return eval_results + + +class ModelCardData(CardData): + """Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + Args: + base_model (`str` or `list[str]`, *optional*): + The identifier of the base model from which the model derives. This is applicable for example if your model is a + fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs + if your model derives from multiple models). Defaults to None. + datasets (`Union[str, list[str]]`, *optional*): + Dataset or list of datasets that were used to train this model. Should be a dataset ID + found on https://hf.co/datasets. Defaults to None. + eval_results (`Union[list[EvalResult], EvalResult]`, *optional*): + List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided, + `model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`. + language (`Union[str, list[str]]`, *optional*): + Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or + 639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`. + library_name (`str`, *optional*): + Name of library used by this model. Example: keras or any library from + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts. + Defaults to None. + license (`str`, *optional*): + License of this model. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. Defaults to None. + license_name (`str`, *optional*): + Name of the license of this model. Defaults to None. To be used in conjunction with `license_link`. + Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a name. In that case, use `license` instead. + license_link (`str`, *optional*): + Link to the license of this model. Defaults to None. To be used in conjunction with `license_name`. + Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a link. In that case, use `license` instead. + metrics (`list[str]`, *optional*): + List of metrics used to evaluate this model. Should be a metric name that can be found + at https://hf.co/metrics. Example: 'accuracy'. Defaults to None. + model_name (`str`, *optional*): + A name for this model. It is used along with + `eval_results` to construct the `model-index` within the card's metadata. The name + you supply here is what will be used on PapersWithCode's leaderboards. If None is provided + then the repo name is used as a default. Defaults to None. + pipeline_tag (`str`, *optional*): + The pipeline tag associated with the model. Example: "text-classification". + tags (`list[str]`, *optional*): + List of tags to add to your model that can be used when filtering on the Hugging + Face Hub. Defaults to None. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + kwargs (`dict`, *optional*): + Additional metadata that will be added to the model card. Defaults to None. + + Example: + ```python + >>> from huggingface_hub import ModelCardData + >>> card_data = ModelCardData( + ... language="en", + ... license="mit", + ... library_name="timm", + ... tags=['image-classification', 'resnet'], + ... ) + >>> card_data.to_dict() + {'language': 'en', 'license': 'mit', 'library_name': 'timm', 'tags': ['image-classification', 'resnet']} + + ``` + """ + + def __init__( + self, + *, + base_model: Optional[Union[str, list[str]]] = None, + datasets: Optional[Union[str, list[str]]] = None, + eval_results: Optional[list[EvalResult]] = None, + language: Optional[Union[str, list[str]]] = None, + library_name: Optional[str] = None, + license: Optional[str] = None, + license_name: Optional[str] = None, + license_link: Optional[str] = None, + metrics: Optional[list[str]] = None, + model_name: Optional[str] = None, + pipeline_tag: Optional[str] = None, + tags: Optional[list[str]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.base_model = base_model + self.datasets = datasets + self.eval_results = eval_results + self.language = language + self.library_name = library_name + self.license = license + self.license_name = license_name + self.license_link = license_link + self.metrics = metrics + self.model_name = model_name + self.pipeline_tag = pipeline_tag + self.tags = _to_unique_list(tags) + + model_index = kwargs.pop("model-index", None) + if model_index: + try: + model_name, eval_results = model_index_to_eval_results(model_index) + self.model_name = model_name + self.eval_results = eval_results + except (KeyError, TypeError) as error: + if ignore_metadata_errors: + logger.warning("Invalid model-index. Not loading eval results into CardData.") + else: + raise ValueError( + f"Invalid `model_index` in metadata cannot be parsed: {error.__class__} {error}. Pass" + " `ignore_metadata_errors=True` to ignore this error while loading a Model Card. Warning:" + " some information will be lost. Use it at your own risk." + ) + + super().__init__(**kwargs) + + if self.eval_results: + try: + self.eval_results = _validate_eval_results(self.eval_results, self.model_name) + except Exception as e: + if ignore_metadata_errors: + logger.warning(f"Failed to validate eval_results: {e}. Not loading eval results into CardData.") + else: + raise ValueError(f"Failed to validate eval_results: {e}") from e + + def _to_dict(self, data_dict): + """Format the internal data dict. In this case, we convert eval results to a valid model index""" + if self.eval_results is not None: + data_dict["model-index"] = eval_results_to_model_index(self.model_name, self.eval_results) # type: ignore + del data_dict["eval_results"], data_dict["model_name"] + + +class DatasetCardData(CardData): + """Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + Args: + language (`list[str]`, *optional*): + Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or + 639-3 code (two/three letters), or a special value like "code", "multilingual". + license (`Union[str, list[str]]`, *optional*): + License(s) of this dataset. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. + annotations_creators (`Union[str, list[str]]`, *optional*): + How the annotations for the dataset were created. + Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'. + language_creators (`Union[str, list[str]]`, *optional*): + How the text-based data in the dataset was created. + Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other' + multilinguality (`Union[str, list[str]]`, *optional*): + Whether the dataset is multilingual. + Options are: 'monolingual', 'multilingual', 'translation', 'other'. + size_categories (`Union[str, list[str]]`, *optional*): + The number of examples in the dataset. Options are: 'n<1K', '1K1T', and 'other'. + source_datasets (`list[str]]`, *optional*): + Indicates whether the dataset is an original dataset or extended from another existing dataset. + Options are: 'original' and 'extended'. + task_categories (`Union[str, list[str]]`, *optional*): + What categories of task does the dataset support? + task_ids (`Union[str, list[str]]`, *optional*): + What specific tasks does the dataset support? + paperswithcode_id (`str`, *optional*): + ID of the dataset on PapersWithCode. + pretty_name (`str`, *optional*): + A more human-readable name for the dataset. (ex. "Cats vs. Dogs") + train_eval_index (`dict`, *optional*): + A dictionary that describes the necessary spec for doing evaluation on the Hub. + If not provided, it will be gathered from the 'train-eval-index' key of the kwargs. + config_names (`Union[str, list[str]]`, *optional*): + A list of the available dataset configs for the dataset. + """ + + def __init__( + self, + *, + language: Optional[Union[str, list[str]]] = None, + license: Optional[Union[str, list[str]]] = None, + annotations_creators: Optional[Union[str, list[str]]] = None, + language_creators: Optional[Union[str, list[str]]] = None, + multilinguality: Optional[Union[str, list[str]]] = None, + size_categories: Optional[Union[str, list[str]]] = None, + source_datasets: Optional[list[str]] = None, + task_categories: Optional[Union[str, list[str]]] = None, + task_ids: Optional[Union[str, list[str]]] = None, + paperswithcode_id: Optional[str] = None, + pretty_name: Optional[str] = None, + train_eval_index: Optional[dict] = None, + config_names: Optional[Union[str, list[str]]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.annotations_creators = annotations_creators + self.language_creators = language_creators + self.language = language + self.license = license + self.multilinguality = multilinguality + self.size_categories = size_categories + self.source_datasets = source_datasets + self.task_categories = task_categories + self.task_ids = task_ids + self.paperswithcode_id = paperswithcode_id + self.pretty_name = pretty_name + self.config_names = config_names + + # TODO - maybe handle this similarly to EvalResult? + self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None) + super().__init__(**kwargs) + + def _to_dict(self, data_dict): + data_dict["train-eval-index"] = data_dict.pop("train_eval_index") + + +class SpaceCardData(CardData): + """Space Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + To get an exhaustive reference of Spaces configuration, please visit https://huggingface.co/docs/hub/spaces-config-reference#spaces-configuration-reference. + + Args: + title (`str`, *optional*) + Title of the Space. + sdk (`str`, *optional*) + SDK of the Space (one of `gradio`, `streamlit`, `docker`, or `static`). + sdk_version (`str`, *optional*) + Version of the used SDK (if Gradio/Streamlit sdk). + python_version (`str`, *optional*) + Python version used in the Space (if Gradio/Streamlit sdk). + app_file (`str`, *optional*) + Path to your main application file (which contains either gradio or streamlit Python code, or static html code). + Path is relative to the root of the repository. + app_port (`str`, *optional*) + Port on which your application is running. Used only if sdk is `docker`. + license (`str`, *optional*) + License of this model. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. + duplicated_from (`str`, *optional*) + ID of the original Space if this is a duplicated Space. + models (list[`str`], *optional*) + List of models related to this Space. Should be a dataset ID found on https://hf.co/models. + datasets (`list[str]`, *optional*) + List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets. + tags (`list[str]`, *optional*) + List of tags to add to your Space that can be used when filtering on the Hub. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + kwargs (`dict`, *optional*): + Additional metadata that will be added to the space card. + + Example: + ```python + >>> from huggingface_hub import SpaceCardData + >>> card_data = SpaceCardData( + ... title="Dreambooth Training", + ... license="mit", + ... sdk="gradio", + ... duplicated_from="multimodalart/dreambooth-training" + ... ) + >>> card_data.to_dict() + {'title': 'Dreambooth Training', 'sdk': 'gradio', 'license': 'mit', 'duplicated_from': 'multimodalart/dreambooth-training'} + ``` + """ + + def __init__( + self, + *, + title: Optional[str] = None, + sdk: Optional[str] = None, + sdk_version: Optional[str] = None, + python_version: Optional[str] = None, + app_file: Optional[str] = None, + app_port: Optional[int] = None, + license: Optional[str] = None, + duplicated_from: Optional[str] = None, + models: Optional[list[str]] = None, + datasets: Optional[list[str]] = None, + tags: Optional[list[str]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.title = title + self.sdk = sdk + self.sdk_version = sdk_version + self.python_version = python_version + self.app_file = app_file + self.app_port = app_port + self.license = license + self.duplicated_from = duplicated_from + self.models = models + self.datasets = datasets + self.tags = _to_unique_list(tags) + super().__init__(**kwargs) + + +def model_index_to_eval_results(model_index: list[dict[str, Any]]) -> tuple[str, list[EvalResult]]: + """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects. + + A detailed spec of the model index can be found here: + https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 + + Args: + model_index (`list[dict[str, Any]]`): + A model index data structure, likely coming from a README.md file on the + Hugging Face Hub. + + Returns: + model_name (`str`): + The name of the model as found in the model index. This is used as the + identifier for the model on leaderboards like PapersWithCode. + eval_results (`list[EvalResult]`): + A list of `huggingface_hub.EvalResult` objects containing the metrics + reported in the provided model_index. + + Example: + ```python + >>> from huggingface_hub.repocard_data import model_index_to_eval_results + >>> # Define a minimal model index + >>> model_index = [ + ... { + ... "name": "my-cool-model", + ... "results": [ + ... { + ... "task": { + ... "type": "image-classification" + ... }, + ... "dataset": { + ... "type": "beans", + ... "name": "Beans" + ... }, + ... "metrics": [ + ... { + ... "type": "accuracy", + ... "value": 0.9 + ... } + ... ] + ... } + ... ] + ... } + ... ] + >>> model_name, eval_results = model_index_to_eval_results(model_index) + >>> model_name + 'my-cool-model' + >>> eval_results[0].task_type + 'image-classification' + >>> eval_results[0].metric_type + 'accuracy' + + ``` + """ + + eval_results = [] + for elem in model_index: + name = elem["name"] + results = elem["results"] + for result in results: + task_type = result["task"]["type"] + task_name = result["task"].get("name") + dataset_type = result["dataset"]["type"] + dataset_name = result["dataset"]["name"] + dataset_config = result["dataset"].get("config") + dataset_split = result["dataset"].get("split") + dataset_revision = result["dataset"].get("revision") + dataset_args = result["dataset"].get("args") + source_name = result.get("source", {}).get("name") + source_url = result.get("source", {}).get("url") + + for metric in result["metrics"]: + metric_type = metric["type"] + metric_value = metric["value"] + metric_name = metric.get("name") + metric_args = metric.get("args") + metric_config = metric.get("config") + verified = metric.get("verified") + verify_token = metric.get("verifyToken") + + eval_result = EvalResult( + task_type=task_type, # Required + dataset_type=dataset_type, # Required + dataset_name=dataset_name, # Required + metric_type=metric_type, # Required + metric_value=metric_value, # Required + task_name=task_name, + dataset_config=dataset_config, + dataset_split=dataset_split, + dataset_revision=dataset_revision, + dataset_args=dataset_args, + metric_name=metric_name, + metric_args=metric_args, + metric_config=metric_config, + verified=verified, + verify_token=verify_token, + source_name=source_name, + source_url=source_url, + ) + eval_results.append(eval_result) + return name, eval_results + + +def _remove_none(obj): + """ + Recursively remove `None` values from a dict. Borrowed from: https://stackoverflow.com/a/20558778 + """ + if isinstance(obj, (list, tuple, set)): + return type(obj)(_remove_none(x) for x in obj if x is not None) + elif isinstance(obj, dict): + return type(obj)((_remove_none(k), _remove_none(v)) for k, v in obj.items() if k is not None and v is not None) + else: + return obj + + +def eval_results_to_model_index(model_name: str, eval_results: list[EvalResult]) -> list[dict[str, Any]]: + """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a + valid model-index that will be compatible with the format expected by the + Hugging Face Hub. + + Args: + model_name (`str`): + Name of the model (ex. "my-cool-model"). This is used as the identifier + for the model on leaderboards like PapersWithCode. + eval_results (`list[EvalResult]`): + List of `huggingface_hub.EvalResult` objects containing the metrics to be + reported in the model-index. + + Returns: + model_index (`list[dict[str, Any]]`): The eval_results converted to a model-index. + + Example: + ```python + >>> from huggingface_hub.repocard_data import eval_results_to_model_index, EvalResult + >>> # Define minimal eval_results + >>> eval_results = [ + ... EvalResult( + ... task_type="image-classification", # Required + ... dataset_type="beans", # Required + ... dataset_name="Beans", # Required + ... metric_type="accuracy", # Required + ... metric_value=0.9, # Required + ... ) + ... ] + >>> eval_results_to_model_index("my-cool-model", eval_results) + [{'name': 'my-cool-model', 'results': [{'task': {'type': 'image-classification'}, 'dataset': {'name': 'Beans', 'type': 'beans'}, 'metrics': [{'type': 'accuracy', 'value': 0.9}]}]}] + + ``` + """ + + # Metrics are reported on a unique task-and-dataset basis. + # Here, we make a map of those pairs and the associated EvalResults. + task_and_ds_types_map: dict[Any, list[EvalResult]] = defaultdict(list) + for eval_result in eval_results: + task_and_ds_types_map[eval_result.unique_identifier].append(eval_result) + + # Use the map from above to generate the model index data. + model_index_data = [] + for results in task_and_ds_types_map.values(): + # All items from `results` share same metadata + sample_result = results[0] + data = { + "task": { + "type": sample_result.task_type, + "name": sample_result.task_name, + }, + "dataset": { + "name": sample_result.dataset_name, + "type": sample_result.dataset_type, + "config": sample_result.dataset_config, + "split": sample_result.dataset_split, + "revision": sample_result.dataset_revision, + "args": sample_result.dataset_args, + }, + "metrics": [ + { + "type": result.metric_type, + "value": result.metric_value, + "name": result.metric_name, + "config": result.metric_config, + "args": result.metric_args, + "verified": result.verified, + "verifyToken": result.verify_token, + } + for result in results + ], + } + if sample_result.source_url is not None: + source = { + "url": sample_result.source_url, + } + if sample_result.source_name is not None: + source["name"] = sample_result.source_name + data["source"] = source + model_index_data.append(data) + + # TODO - Check if there cases where this list is longer than one? + # Finally, the model index itself is list of dicts. + model_index = [ + { + "name": model_name, + "results": model_index_data, + } + ] + return _remove_none(model_index) + + +def _to_unique_list(tags: Optional[list[str]]) -> Optional[list[str]]: + if tags is None: + return tags + unique_tags = [] # make tags unique + keep order explicitly + for tag in tags: + if tag not in unique_tags: + unique_tags.append(tag) + return unique_tags