| """ |
| Utilities for working with the local dataset cache. |
| This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp |
| Copyright by the AllenNLP authors. |
| """ |
|
|
| import asyncio |
| import glob |
| import io |
| import json |
| import multiprocessing |
| import os |
| import posixpath |
| import re |
| import shutil |
| import tarfile |
| import time |
| import xml.dom.minidom |
| import zipfile |
| from collections.abc import Generator |
| from io import BytesIO |
| from itertools import chain |
| from pathlib import Path, PurePosixPath |
| from typing import Any, Optional, TypeVar, Union |
| from unittest.mock import patch |
| from urllib.parse import urlparse |
| from xml.etree import ElementTree as ET |
|
|
| import fsspec |
| import httpx |
| import huggingface_hub |
| import huggingface_hub.errors |
| import requests |
| from fsspec.core import strip_protocol, url_to_fs |
| from fsspec.utils import can_be_local |
| from huggingface_hub.utils import get_session, insecure_hashlib |
| from packaging import version |
|
|
| from .. import __version__, config |
| from ..download.download_config import DownloadConfig |
| from ..filesystems import COMPRESSION_FILESYSTEMS |
| from . import _tqdm, logging |
| from ._filelock import FileLock |
| from .extract import ExtractManager |
| from .track import TrackedIterableFromGenerator |
|
|
|
|
| try: |
| from aiohttp.client_exceptions import ClientError as _AiohttpClientError |
| except ImportError: |
| |
| |
| |
| class _AiohttpClientError(Exception): |
| pass |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| INCOMPLETE_SUFFIX = ".incomplete" |
|
|
| T = TypeVar("T", str, Path) |
|
|
| CONNECTION_ERRORS_TO_RETRY = ( |
| _AiohttpClientError, |
| asyncio.TimeoutError, |
| requests.exceptions.ConnectionError, |
| requests.exceptions.Timeout, |
| httpx.RequestError, |
| ) |
|
|
|
|
| def is_remote_url(url_or_filename: str) -> bool: |
| return urlparse(url_or_filename).scheme != "" and not os.path.ismount(urlparse(url_or_filename).scheme + ":/") |
|
|
|
|
| def is_local_path(url_or_filename: str) -> bool: |
| |
| |
| |
| return urlparse(url_or_filename).scheme == "" or os.path.ismount(urlparse(url_or_filename).scheme + ":/") |
|
|
|
|
| def is_relative_path(url_or_filename: str) -> bool: |
| return urlparse(url_or_filename).scheme == "" and not os.path.isabs(url_or_filename) |
|
|
|
|
| def relative_to_absolute_path(path: T) -> T: |
| """Convert relative path to absolute path.""" |
| abs_path_str = os.path.abspath(os.path.expanduser(os.path.expandvars(str(path)))) |
| return Path(abs_path_str) if isinstance(path, Path) else abs_path_str |
|
|
|
|
| def url_or_path_join(base_name: str, *pathnames: str) -> str: |
| if is_remote_url(base_name): |
| return posixpath.join(base_name, *(str(pathname).replace(os.sep, "/").lstrip("/") for pathname in pathnames)) |
| else: |
| return Path(base_name, *pathnames).as_posix() |
|
|
|
|
| def url_or_path_parent(url_or_path: str) -> str: |
| if is_remote_url(url_or_path): |
| return url_or_path[: url_or_path.rindex("/")] |
| else: |
| return os.path.dirname(url_or_path) |
|
|
|
|
| def hash_url_to_filename(url, etag=None): |
| """ |
| Convert `url` into a hashed filename in a repeatable way. |
| If `etag` is specified, append its hash to the url's, delimited |
| by a period. |
| If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name |
| so that TF 2.0 can identify it as a HDF5 file |
| (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) |
| """ |
| url_bytes = url.encode("utf-8") |
| url_hash = insecure_hashlib.sha256(url_bytes) |
| filename = url_hash.hexdigest() |
|
|
| if etag: |
| etag_bytes = etag.encode("utf-8") |
| etag_hash = insecure_hashlib.sha256(etag_bytes) |
| filename += "." + etag_hash.hexdigest() |
|
|
| if url.endswith(".py"): |
| filename += ".py" |
|
|
| return filename |
|
|
|
|
| def cached_path( |
| url_or_filename, |
| download_config=None, |
| **download_kwargs, |
| ) -> str: |
| """ |
| Given something that might be a URL (or might be a local path), |
| determine which. If it's a URL, download the file and cache it, and |
| return the path to the cached file. If it's already a local path, |
| make sure the file exists and then return the path. |
| |
| Return: |
| Local path (string) |
| |
| Raises: |
| FileNotFoundError: in case of non-recoverable file |
| (non-existent or no cache on disk) |
| ConnectionError: in case of unreachable url |
| and no cache on disk |
| ValueError: if it couldn't parse the url or filename correctly |
| httpx.NetworkError or requests.exceptions.ConnectionError: in case of internet connection issue |
| """ |
| if download_config is None: |
| download_config = DownloadConfig(**download_kwargs) |
|
|
| cache_dir = download_config.cache_dir or config.DOWNLOADED_DATASETS_PATH |
| if isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
| if isinstance(url_or_filename, Path): |
| url_or_filename = str(url_or_filename) |
|
|
| |
| if can_be_local(url_or_filename): |
| url_or_filename = strip_protocol(url_or_filename) |
|
|
| if is_remote_url(url_or_filename): |
| |
| url_or_filename, storage_options = _prepare_path_and_storage_options( |
| url_or_filename, download_config=download_config |
| ) |
| |
| |
| |
| if url_or_filename.startswith("hf://"): |
| resolved_path = huggingface_hub.HfFileSystem( |
| endpoint=config.HF_ENDPOINT, token=download_config.token |
| ).resolve_path(url_or_filename) |
| try: |
| output_path = huggingface_hub.HfApi( |
| endpoint=config.HF_ENDPOINT, |
| token=download_config.token, |
| library_name="datasets", |
| library_version=__version__, |
| user_agent=get_datasets_user_agent(download_config.user_agent), |
| ).hf_hub_download( |
| repo_id=resolved_path.repo_id, |
| repo_type=resolved_path.repo_type, |
| revision=resolved_path.revision, |
| filename=resolved_path.path_in_repo, |
| force_download=download_config.force_download, |
| proxies=download_config.proxies, |
| ) |
| except ( |
| huggingface_hub.utils.RepositoryNotFoundError, |
| huggingface_hub.utils.EntryNotFoundError, |
| huggingface_hub.utils.RevisionNotFoundError, |
| huggingface_hub.utils.GatedRepoError, |
| ) as e: |
| raise FileNotFoundError(str(e)) from e |
| |
| else: |
| output_path = get_from_cache( |
| url_or_filename, |
| cache_dir=cache_dir, |
| force_download=download_config.force_download, |
| user_agent=download_config.user_agent, |
| use_etag=download_config.use_etag, |
| token=download_config.token, |
| storage_options=storage_options, |
| download_desc=download_config.download_desc, |
| disable_tqdm=download_config.disable_tqdm, |
| ) |
| elif os.path.exists(url_or_filename): |
| |
| output_path = url_or_filename |
| elif is_local_path(url_or_filename): |
| |
| raise FileNotFoundError(f"Local file {url_or_filename} doesn't exist") |
| else: |
| |
| raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path") |
|
|
| if output_path is None: |
| return output_path |
|
|
| if download_config.extract_compressed_file: |
| if download_config.extract_on_the_fly: |
| |
| |
| protocol = _get_extraction_protocol(output_path, download_config=download_config) |
| extension = _get_path_extension(url_or_filename.split("::")[0]) |
| if ( |
| protocol |
| and extension not in ["tgz", "tar"] |
| and not url_or_filename.split("::")[0].endswith((".tar.gz", ".tar.bz2", ".tar.xz")) |
| ): |
| output_path = relative_to_absolute_path(output_path) |
| if protocol in SINGLE_FILE_COMPRESSION_PROTOCOLS: |
| |
| inner_file = os.path.basename(output_path) |
| inner_file = inner_file[: inner_file.rindex(".")] if "." in inner_file else inner_file |
| output_path = f"{protocol}://{inner_file}::{output_path}" |
| else: |
| output_path = f"{protocol}://::{output_path}" |
| return output_path |
|
|
| |
| output_path = ExtractManager(cache_dir=download_config.cache_dir).extract( |
| output_path, force_extract=download_config.force_extract |
| ) |
| return relative_to_absolute_path(output_path) |
|
|
|
|
| def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str: |
| ua = f"datasets/{__version__}" |
| ua += f"; python/{config.PY_VERSION}" |
| ua += f"; hf_hub/{huggingface_hub.__version__}" |
| ua += f"; pyarrow/{config.PYARROW_VERSION}" |
| if config.TORCH_AVAILABLE: |
| ua += f"; torch/{config.TORCH_VERSION}" |
| if config.TF_AVAILABLE: |
| ua += f"; tensorflow/{config.TF_VERSION}" |
| if config.JAX_AVAILABLE: |
| ua += f"; jax/{config.JAX_VERSION}" |
| if isinstance(user_agent, dict): |
| ua += f"; {'; '.join(f'{k}/{v}' for k, v in user_agent.items())}" |
| elif isinstance(user_agent, str): |
| ua += "; " + user_agent |
| return ua |
|
|
|
|
| def get_authentication_headers_for_url(url: str, token: Optional[Union[str, bool]] = None) -> dict: |
| """Handle the HF authentication""" |
| if url.startswith(config.HF_ENDPOINT): |
| return huggingface_hub.utils.build_hf_headers( |
| token=token, library_name="datasets", library_version=__version__ |
| ) |
| else: |
| return {} |
|
|
|
|
| def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None): |
| """Raise an OfflineModeIsEnabled error (subclass of ConnectionError) if HF_HUB_OFFLINE is True.""" |
| if config.HF_HUB_OFFLINE: |
| raise huggingface_hub.errors.OfflineModeIsEnabled( |
| "Offline mode is enabled." if msg is None else "Offline mode is enabled. " + str(msg) |
| ) |
|
|
|
|
| def fsspec_head(url, storage_options=None): |
| _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") |
| fs, path = url_to_fs(url, **(storage_options or {})) |
| return fs.info(path) |
|
|
|
|
| def stack_multiprocessing_download_progress_bars(): |
| |
| |
| return patch.dict(os.environ, {"HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS": "1"}) |
|
|
|
|
| class TqdmCallback(fsspec.callbacks.TqdmCallback): |
| def __init__(self, tqdm_kwargs=None, *args, **kwargs): |
| if config.FSSPEC_VERSION < version.parse("2024.2.0"): |
| super().__init__(tqdm_kwargs, *args, **kwargs) |
| self._tqdm = _tqdm |
| else: |
| kwargs["tqdm_cls"] = _tqdm.tqdm |
| super().__init__(tqdm_kwargs, *args, **kwargs) |
|
|
|
|
| def fsspec_get(url, temp_file, storage_options=None, desc=None, disable_tqdm=False): |
| _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") |
| fs, path = url_to_fs(url, **(storage_options or {})) |
| callback = TqdmCallback( |
| tqdm_kwargs={ |
| "desc": desc or "Downloading", |
| "unit": "B", |
| "unit_scale": True, |
| "position": multiprocessing.current_process()._identity[-1] |
| if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1" |
| and multiprocessing.current_process()._identity |
| else None, |
| "disable": disable_tqdm, |
| } |
| ) |
| fs.get_file(path, temp_file.name, callback=callback) |
|
|
|
|
| def get_from_cache( |
| url, |
| cache_dir=None, |
| force_download=False, |
| user_agent=None, |
| use_etag=True, |
| token=None, |
| storage_options=None, |
| download_desc=None, |
| disable_tqdm=False, |
| ) -> str: |
| """ |
| Given a URL, look for the corresponding file in the local cache. |
| If it's not there, download it. Then return the path to the cached file. |
| |
| Return: |
| Local path (string) |
| |
| Raises: |
| FileNotFoundError: in case of non-recoverable file |
| (non-existent or no cache on disk) |
| ConnectionError: in case of unreachable url |
| and no cache on disk |
| """ |
| if storage_options is None: |
| storage_options = {} |
| if cache_dir is None: |
| cache_dir = config.HF_DATASETS_CACHE |
| if isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| response = None |
| etag = None |
|
|
| |
| |
| filename = hash_url_to_filename(url, etag=None) |
| cache_path = os.path.join(cache_dir, filename) |
|
|
| if os.path.exists(cache_path) and not force_download and not use_etag: |
| return cache_path |
|
|
| |
| headers = get_authentication_headers_for_url(url, token=token) |
| if user_agent is not None: |
| headers["user-agent"] = user_agent |
|
|
| response = fsspec_head(url, storage_options=storage_options) |
| etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None |
|
|
| |
| filename = hash_url_to_filename(url, etag) |
| cache_path = os.path.join(cache_dir, filename) |
|
|
| if os.path.exists(cache_path) and not force_download: |
| return cache_path |
|
|
| |
| lock_path = cache_path + ".lock" |
| with FileLock(lock_path): |
| |
| if os.path.exists(cache_path) and not force_download: |
| return cache_path |
|
|
| incomplete_path = cache_path + ".incomplete" |
|
|
| |
| |
| with open(incomplete_path, "w+b") as temp_file: |
| logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}") |
| |
| fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm) |
|
|
| logger.info(f"storing {url} in cache at {cache_path}") |
| shutil.move(temp_file.name, cache_path) |
|
|
| logger.info(f"creating metadata file for {cache_path}") |
| meta = {"url": url, "etag": etag} |
| meta_path = cache_path + ".json" |
| with open(meta_path, "w", encoding="utf-8") as meta_file: |
| json.dump(meta, meta_file) |
|
|
| return cache_path |
|
|
|
|
| def add_start_docstrings(*docstr): |
| def docstring_decorator(fn): |
| fn.__doc__ = "".join(docstr) + "\n\n" + (fn.__doc__ if fn.__doc__ is not None else "") |
| return fn |
|
|
| return docstring_decorator |
|
|
|
|
| def add_end_docstrings(*docstr): |
| def docstring_decorator(fn): |
| fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "\n\n" + "".join(docstr) |
| return fn |
|
|
| return docstring_decorator |
|
|
|
|
| def estimate_dataset_size(paths): |
| return sum(path.stat().st_size for path in paths) |
|
|
|
|
| def readline(f: io.RawIOBase): |
| |
| res = bytearray() |
| while True: |
| b = f.read(1) |
| if not b: |
| break |
| res += b |
| if res.endswith(b"\n"): |
| break |
| return bytes(res) |
|
|
|
|
| |
| |
| |
|
|
| BASE_KNOWN_EXTENSIONS = [ |
| "txt", |
| "csv", |
| "json", |
| "jsonl", |
| "tsv", |
| "conll", |
| "conllu", |
| "orig", |
| "parquet", |
| "pkl", |
| "pickle", |
| "rel", |
| "xml", |
| "arrow", |
| ] |
| COMPRESSION_EXTENSION_TO_PROTOCOL = { |
| |
| **{ |
| extension.lstrip("."): fs_class.protocol |
| for fs_class in COMPRESSION_FILESYSTEMS |
| for extension in fs_class.extensions |
| }, |
| |
| "zip": "zip", |
| } |
| SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL = { |
| extension.lstrip("."): fs_class.protocol |
| for fs_class in COMPRESSION_FILESYSTEMS |
| for extension in fs_class.extensions |
| } |
| SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS} |
| SINGLE_SLASH_AFTER_PROTOCOL_PATTERN = re.compile(r"(?<!:):/") |
|
|
|
|
| MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL = { |
| bytes.fromhex("504B0304"): "zip", |
| bytes.fromhex("504B0506"): "zip", |
| bytes.fromhex("504B0708"): "zip", |
| bytes.fromhex("425A68"): "bz2", |
| bytes.fromhex("1F8B"): "gzip", |
| bytes.fromhex("FD377A585A00"): "xz", |
| bytes.fromhex("04224D18"): "lz4", |
| bytes.fromhex("28B52FFD"): "zstd", |
| } |
| MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL = { |
| b"Rar!": "rar", |
| } |
| MAGIC_NUMBER_MAX_LENGTH = max( |
| len(magic_number) |
| for magic_number in chain(MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL, MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL) |
| ) |
|
|
|
|
| class NonStreamableDatasetError(Exception): |
| pass |
|
|
|
|
| def _get_path_extension(path: str) -> str: |
| |
| extension = path.split(".")[-1] |
| |
| |
| for symb in "?-_": |
| extension = extension.split(symb)[0] |
| return extension |
|
|
|
|
| def _get_extraction_protocol_with_magic_number(f) -> Optional[str]: |
| """read the magic number from a file-like object and return the compression protocol""" |
| |
| try: |
| f.seek(0) |
| except (AttributeError, io.UnsupportedOperation): |
| return None |
| magic_number = f.read(MAGIC_NUMBER_MAX_LENGTH) |
| f.seek(0) |
| for i in range(MAGIC_NUMBER_MAX_LENGTH): |
| compression = MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i]) |
| if compression is not None: |
| return compression |
| compression = MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i]) |
| if compression is not None: |
| raise NotImplementedError(f"Compression protocol '{compression}' not implemented.") |
|
|
|
|
| def _get_extraction_protocol(urlpath: str, download_config: Optional[DownloadConfig] = None) -> Optional[str]: |
| |
| urlpath = str(urlpath) |
| path = urlpath.split("::")[0] |
| extension = _get_path_extension(path) |
| if ( |
| extension in BASE_KNOWN_EXTENSIONS |
| or extension in ["tgz", "tar"] |
| or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz")) |
| ): |
| return None |
| elif extension in COMPRESSION_EXTENSION_TO_PROTOCOL: |
| return COMPRESSION_EXTENSION_TO_PROTOCOL[extension] |
| urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config) |
| try: |
| with fsspec.open(urlpath, **(storage_options or {})) as f: |
| return _get_extraction_protocol_with_magic_number(f) |
| except FileNotFoundError: |
| if urlpath.startswith(config.HF_ENDPOINT): |
| raise FileNotFoundError( |
| urlpath + "\nIf the repo is private or gated, make sure to log in with `huggingface-cli login`." |
| ) from None |
| else: |
| raise |
|
|
|
|
| def xjoin(a, *p): |
| """ |
| This function extends os.path.join to support the "::" hop separator. It supports both paths and urls. |
| |
| A shorthand, particularly useful where you have multiple hops, is to “chain” the URLs with the special separator "::". |
| This is used to access files inside a zip file over http for example. |
| |
| Let's say you have a zip file at https://host.com/archive.zip, and you want to access the file inside the zip file at /folder1/file.txt. |
| Then you can just chain the url this way: |
| |
| zip://folder1/file.txt::https://host.com/archive.zip |
| |
| The xjoin function allows you to apply the join on the first path of the chain. |
| |
| Example:: |
| |
| >>> xjoin("zip://folder1::https://host.com/archive.zip", "file.txt") |
| zip://folder1/file.txt::https://host.com/archive.zip |
| """ |
| a, *b = str(a).split("::") |
| if is_local_path(a): |
| return os.path.join(a, *p) |
| else: |
| a = posixpath.join(a, *p) |
| return "::".join([a] + b) |
|
|
|
|
| def xdirname(a): |
| """ |
| This function extends os.path.dirname to support the "::" hop separator. It supports both paths and urls. |
| |
| A shorthand, particularly useful where you have multiple hops, is to “chain” the URLs with the special separator "::". |
| This is used to access files inside a zip file over http for example. |
| |
| Let's say you have a zip file at https://host.com/archive.zip, and you want to access the file inside the zip file at /folder1/file.txt. |
| Then you can just chain the url this way: |
| |
| zip://folder1/file.txt::https://host.com/archive.zip |
| |
| The xdirname function allows you to apply the dirname on the first path of the chain. |
| |
| Example:: |
| |
| >>> xdirname("zip://folder1/file.txt::https://host.com/archive.zip") |
| zip://folder1::https://host.com/archive.zip |
| """ |
| a, *b = str(a).split("::") |
| if is_local_path(a): |
| a = os.path.dirname(Path(a).as_posix()) |
| else: |
| a = posixpath.dirname(a) |
| |
| |
| if a.endswith(":"): |
| a += "//" |
| return "::".join([a] + b) |
|
|
|
|
| def xexists(urlpath: str, download_config: Optional[DownloadConfig] = None): |
| """Extend `os.path.exists` function to support both local and remote files. |
| |
| Args: |
| urlpath (`str`): URL path. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `bool` |
| """ |
|
|
| main_hop, *rest_hops = _as_str(urlpath).split("::") |
| if is_local_path(main_hop): |
| return os.path.exists(main_hop) |
| else: |
| urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config) |
| main_hop, *rest_hops = urlpath.split("::") |
| fs, *_ = url_to_fs(urlpath, **storage_options) |
| return fs.exists(main_hop) |
|
|
|
|
| def xbasename(a): |
| """ |
| This function extends os.path.basename to support the "::" hop separator. It supports both paths and urls. |
| |
| A shorthand, particularly useful where you have multiple hops, is to “chain” the URLs with the special separator "::". |
| This is used to access files inside a zip file over http for example. |
| |
| Let's say you have a zip file at https://host.com/archive.zip, and you want to access the file inside the zip file at /folder1/file.txt. |
| Then you can just chain the url this way: |
| |
| zip://folder1/file.txt::https://host.com/archive.zip |
| |
| The xbasename function allows you to apply the basename on the first path of the chain. |
| |
| Example:: |
| |
| >>> xbasename("zip://folder1/file.txt::https://host.com/archive.zip") |
| file.txt |
| """ |
| a, *b = str(a).split("::") |
| if is_local_path(a): |
| return os.path.basename(Path(a).as_posix()) |
| else: |
| return posixpath.basename(a) |
|
|
|
|
| def xsplit(a): |
| """ |
| This function extends os.path.split to support the "::" hop separator. It supports both paths and urls. |
| |
| A shorthand, particularly useful where you have multiple hops, is to “chain” the URLs with the special separator "::". |
| This is used to access files inside a zip file over http for example. |
| |
| Let's say you have a zip file at https://host.com/archive.zip, and you want to access the file inside the zip file at /folder1/file.txt. |
| Then you can just chain the url this way: |
| |
| zip://folder1/file.txt::https://host.com/archive.zip |
| |
| The xsplit function allows you to apply the xsplit on the first path of the chain. |
| |
| Example:: |
| |
| >>> xsplit("zip://folder1/file.txt::https://host.com/archive.zip") |
| ('zip://folder1::https://host.com/archive.zip', 'file.txt') |
| """ |
| a, *b = str(a).split("::") |
| if is_local_path(a): |
| return os.path.split(Path(a).as_posix()) |
| else: |
| a, tail = posixpath.split(a) |
| return "::".join([a + "//" if a.endswith(":") else a] + b), tail |
|
|
|
|
| def xsplitext(a): |
| """ |
| This function extends os.path.splitext to support the "::" hop separator. It supports both paths and urls. |
| |
| A shorthand, particularly useful where you have multiple hops, is to “chain” the URLs with the special separator "::". |
| This is used to access files inside a zip file over http for example. |
| |
| Let's say you have a zip file at https://host.com/archive.zip, and you want to access the file inside the zip file at /folder1/file.txt. |
| Then you can just chain the url this way: |
| |
| zip://folder1/file.txt::https://host.com/archive.zip |
| |
| The xsplitext function allows you to apply the splitext on the first path of the chain. |
| |
| Example:: |
| |
| >>> xsplitext("zip://folder1/file.txt::https://host.com/archive.zip") |
| ('zip://folder1/file::https://host.com/archive.zip', '.txt') |
| """ |
| a, *b = str(a).split("::") |
| if is_local_path(a): |
| return os.path.splitext(Path(a).as_posix()) |
| else: |
| a, ext = posixpath.splitext(a) |
| return "::".join([a] + b), ext |
|
|
|
|
| def xisfile(path, download_config: Optional[DownloadConfig] = None) -> bool: |
| """Extend `os.path.isfile` function to support remote files. |
| |
| Args: |
| path (`str`): URL path. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `bool` |
| """ |
| main_hop, *rest_hops = str(path).split("::") |
| if is_local_path(main_hop): |
| return os.path.isfile(path) |
| else: |
| path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config) |
| main_hop, *rest_hops = path.split("::") |
| fs, *_ = url_to_fs(path, **storage_options) |
| return fs.isfile(main_hop) |
|
|
|
|
| def xgetsize(path, download_config: Optional[DownloadConfig] = None) -> int: |
| """Extend `os.path.getsize` function to support remote files. |
| |
| Args: |
| path (`str`): URL path. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `int`: optional |
| """ |
| main_hop, *rest_hops = str(path).split("::") |
| if is_local_path(main_hop): |
| return os.path.getsize(path) |
| else: |
| path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config) |
| main_hop, *rest_hops = path.split("::") |
| fs, *_ = fs, *_ = url_to_fs(path, **storage_options) |
| try: |
| size = fs.size(main_hop) |
| except huggingface_hub.utils.EntryNotFoundError: |
| raise FileNotFoundError(f"No such file: {path}") |
| if size is None: |
| |
| with xopen(path, download_config=download_config) as f: |
| size = len(f.read()) |
| return size |
|
|
|
|
| def xisdir(path, download_config: Optional[DownloadConfig] = None) -> bool: |
| """Extend `os.path.isdir` function to support remote files. |
| |
| Args: |
| path (`str`): URL path. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `bool` |
| """ |
| main_hop, *rest_hops = str(path).split("::") |
| if is_local_path(main_hop): |
| return os.path.isdir(path) |
| else: |
| path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config) |
| main_hop, *rest_hops = path.split("::") |
| fs, *_ = fs, *_ = url_to_fs(path, **storage_options) |
| inner_path = main_hop.split("://")[-1] |
| if not inner_path.strip("/"): |
| return True |
| return fs.isdir(inner_path) |
|
|
|
|
| def xrelpath(path, start=None): |
| """Extend `os.path.relpath` function to support remote files. |
| |
| Args: |
| path (`str`): URL path. |
| start (`str`): Start URL directory path. |
| |
| Returns: |
| `str` |
| """ |
| main_hop, *rest_hops = str(path).split("::") |
| if is_local_path(main_hop): |
| return os.path.relpath(main_hop, start=start) if start else os.path.relpath(main_hop) |
| else: |
| return posixpath.relpath(main_hop, start=str(start).split("::")[0]) if start else os.path.relpath(main_hop) |
|
|
|
|
| def _add_retries_to_file_obj_read_method(file_obj): |
| read = file_obj.read |
| max_retries = config.STREAMING_READ_MAX_RETRIES |
|
|
| def read_with_retries(*args, **kwargs): |
| disconnect_err = None |
| for retry in range(1, max_retries + 1): |
| try: |
| out = read(*args, **kwargs) |
| break |
| except CONNECTION_ERRORS_TO_RETRY as err: |
| disconnect_err = err |
| logger.warning( |
| f"Got disconnected from remote data host. Retrying in {config.STREAMING_READ_RETRY_INTERVAL}sec [{retry}/{max_retries}]" |
| ) |
| time.sleep(config.STREAMING_READ_RETRY_INTERVAL) |
| else: |
| raise ConnectionError("Server Disconnected") from disconnect_err |
| return out |
|
|
| try: |
| file_obj.read = read_with_retries |
| except AttributeError: |
| orig_file_obj = file_obj |
| file_obj = io.RawIOBase() |
| file_obj.read = read_with_retries |
| file_obj.__getattr__ = lambda _, attr: getattr(orig_file_obj, attr) |
| return file_obj |
|
|
|
|
| def _prepare_path_and_storage_options( |
| urlpath: str, download_config: Optional[DownloadConfig] = None |
| ) -> tuple[str, dict[str, dict[str, Any]]]: |
| prepared_urlpath = [] |
| prepared_storage_options = {} |
| for hop in urlpath.split("::"): |
| hop, storage_options = _prepare_single_hop_path_and_storage_options(hop, download_config=download_config) |
| prepared_urlpath.append(hop) |
| prepared_storage_options.update(storage_options) |
| return "::".join(prepared_urlpath), storage_options |
|
|
|
|
| def _prepare_single_hop_path_and_storage_options( |
| urlpath: str, download_config: Optional[DownloadConfig] = None |
| ) -> tuple[str, dict[str, dict[str, Any]]]: |
| """ |
| Prepare the URL and the kwargs that must be passed to the HttpFileSystem or HfFileSystem |
| |
| In particular it resolves google drive URLs |
| It also adds the authentication headers for the Hugging Face Hub, for both https:// and hf:// paths. |
| |
| Storage options are formatted in the form {protocol: storage_options_for_protocol} |
| """ |
| token = None if download_config is None else download_config.token |
| if urlpath.startswith(config.HF_ENDPOINT) and "/resolve/" in urlpath: |
| urlpath = "hf://" + urlpath[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1) |
| protocol = urlpath.split("://")[0] if "://" in urlpath else "file" |
| if download_config is not None and protocol in download_config.storage_options: |
| storage_options = download_config.storage_options[protocol].copy() |
| elif download_config is not None and protocol not in download_config.storage_options: |
| storage_options = { |
| option_name: option_value |
| for option_name, option_value in download_config.storage_options.items() |
| if option_name not in fsspec.available_protocols() |
| } |
| else: |
| storage_options = {} |
| if protocol in {"http", "https"}: |
| client_kwargs = storage_options.pop("client_kwargs", {}) |
| storage_options["client_kwargs"] = {"trust_env": True, **client_kwargs} |
| if "drive.google.com" in urlpath: |
| response = get_session().head(urlpath, timeout=10) |
| for k, v in response.cookies.items(): |
| if k.startswith("download_warning"): |
| urlpath += "&confirm=" + v |
| cookies = response.cookies |
| storage_options = {"cookies": cookies, **storage_options} |
| |
| if "confirm=" not in urlpath: |
| urlpath += "&confirm=t" |
| if urlpath.startswith("https://raw.githubusercontent.com/"): |
| |
| headers = storage_options.pop("headers", {}) |
| storage_options["headers"] = {"Accept-Encoding": "identity", **headers} |
| elif protocol == "hf": |
| storage_options = { |
| "endpoint": config.HF_ENDPOINT, |
| "token": token, |
| **storage_options, |
| } |
| if storage_options: |
| storage_options = {protocol: storage_options} |
| return urlpath, storage_options |
|
|
|
|
| def xopen(file: str, mode="r", *args, download_config: Optional[DownloadConfig] = None, **kwargs): |
| """Extend `open` function to support remote files using `fsspec`. |
| |
| It also has a retry mechanism in case connection fails. |
| The `args` and `kwargs` are passed to `fsspec.open`, except `token` which is used for queries to private repos on huggingface.co |
| |
| Args: |
| file (`str`): Path name of the file to be opened. |
| mode (`str`, *optional*, default "r"): Mode in which the file is opened. |
| *args: Arguments to be passed to `fsspec.open`. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| **kwargs: Keyword arguments to be passed to `fsspec.open`. |
| |
| Returns: |
| file object |
| """ |
| |
| file_str = _as_str(file) |
| main_hop, *rest_hops = file_str.split("::") |
| if is_local_path(main_hop): |
| |
| kwargs.pop("block_size", None) |
| return open(main_hop, mode, *args, **kwargs) |
| |
| file, storage_options = _prepare_path_and_storage_options(file_str, download_config=download_config) |
| kwargs = {**kwargs, **(storage_options or {})} |
|
|
| max_retries = config.STREAMING_OPEN_MAX_RETRIES |
|
|
| disconnect_err = None |
| for retry in range(1, max_retries + 1): |
| try: |
| file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open() |
| break |
| except CONNECTION_ERRORS_TO_RETRY as err: |
| disconnect_err = err |
| logger.warning( |
| f"Failed to connect to remote data host. Retrying in {config.STREAMING_OPEN_RETRY_INTERVAL}sec [{retry}/{max_retries}]" |
| ) |
| time.sleep(config.STREAMING_OPEN_RETRY_INTERVAL) |
| except ValueError as e: |
| if str(e) == "Cannot seek streaming HTTP file": |
| raise NonStreamableDatasetError( |
| "Streaming is not possible for this dataset because data host server doesn't support HTTP range " |
| "requests. You can still load this dataset in non-streaming mode by passing `streaming=False` (default)" |
| ) from e |
| else: |
| raise |
| except FileNotFoundError: |
| if file.startswith(config.HF_ENDPOINT): |
| raise FileNotFoundError( |
| file + "\nIf the repo is private or gated, make sure to log in with `huggingface-cli login`." |
| ) from None |
| else: |
| raise |
| else: |
| raise ConnectionError("Server Disconnected") from disconnect_err |
| file_obj = _add_retries_to_file_obj_read_method(file_obj) |
| return file_obj |
|
|
|
|
| def xlistdir(path: str, download_config: Optional[DownloadConfig] = None) -> list[str]: |
| """Extend `os.listdir` function to support remote files. |
| |
| Args: |
| path (`str`): URL path. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `list` of `str` |
| """ |
| main_hop, *rest_hops = _as_str(path).split("::") |
| if is_local_path(main_hop): |
| return os.listdir(path) |
| else: |
| |
| path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config) |
| main_hop, *rest_hops = path.split("::") |
| fs, *_ = url_to_fs(path, **storage_options) |
| inner_path = main_hop.split("://")[-1] |
| if inner_path.strip("/") and not fs.isdir(inner_path): |
| raise FileNotFoundError(f"Directory doesn't exist: {path}") |
| paths = fs.listdir(inner_path, detail=False) |
| return [os.path.basename(path.rstrip("/")) for path in paths] |
|
|
|
|
| def xglob(urlpath, *, recursive=False, download_config: Optional[DownloadConfig] = None): |
| """Extend `glob.glob` function to support remote files. |
| |
| Args: |
| urlpath (`str`): URL path with shell-style wildcard patterns. |
| recursive (`bool`, default `False`): Whether to match the "**" pattern recursively to zero or more |
| directories or subdirectories. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `list` of `str` |
| """ |
| main_hop, *rest_hops = _as_str(urlpath).split("::") |
| if is_local_path(main_hop): |
| return glob.glob(main_hop, recursive=recursive) |
| else: |
| |
| urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config) |
| main_hop, *rest_hops = urlpath.split("::") |
| fs, *_ = url_to_fs(urlpath, **storage_options) |
| inner_path = main_hop.split("://")[1] |
| globbed_paths = fs.glob(inner_path) |
| protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1] |
| return ["::".join([f"{protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths] |
|
|
|
|
| def xwalk(urlpath, download_config: Optional[DownloadConfig] = None, **kwargs): |
| """Extend `os.walk` function to support remote files. |
| |
| Args: |
| urlpath (`str`): URL root path. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| **kwargs: Additional keyword arguments forwarded to the underlying filesystem. |
| |
| |
| Yields: |
| `tuple`: 3-tuple (dirpath, dirnames, filenames). |
| """ |
| main_hop, *rest_hops = _as_str(urlpath).split("::") |
| if is_local_path(main_hop): |
| yield from os.walk(main_hop, **kwargs) |
| else: |
| |
| urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config) |
| main_hop, *rest_hops = urlpath.split("::") |
| fs, *_ = url_to_fs(urlpath, **storage_options) |
| inner_path = main_hop.split("://")[-1] |
| if inner_path.strip("/") and not fs.isdir(inner_path): |
| return [] |
| protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1] |
| for dirpath, dirnames, filenames in fs.walk(inner_path, **kwargs): |
| yield "::".join([f"{protocol}://{dirpath}"] + rest_hops), dirnames, filenames |
|
|
|
|
| class xPath(type(Path())): |
| """Extension of `pathlib.Path` to support both local paths and remote URLs.""" |
|
|
| def __str__(self): |
| path_str = super().__str__() |
| main_hop, *rest_hops = path_str.split("::") |
| if is_local_path(main_hop): |
| return main_hop |
| path_as_posix = path_str.replace("\\", "/") |
| path_as_posix = SINGLE_SLASH_AFTER_PROTOCOL_PATTERN.sub("://", path_as_posix) |
| path_as_posix += "//" if path_as_posix.endswith(":") else "" |
| return path_as_posix |
|
|
| def exists(self, download_config: Optional[DownloadConfig] = None): |
| """Extend `pathlib.Path.exists` method to support both local and remote files. |
| |
| Args: |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `bool` |
| """ |
| return xexists(str(self), download_config=download_config) |
|
|
| def glob(self, pattern, download_config: Optional[DownloadConfig] = None): |
| """Glob function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. |
| |
| Args: |
| pattern (`str`): Pattern that resulting paths must match. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Yields: |
| [`xPath`] |
| """ |
| posix_path = self.as_posix() |
| main_hop, *rest_hops = posix_path.split("::") |
| if is_local_path(main_hop): |
| yield from Path(main_hop).glob(pattern) |
| else: |
| |
| if rest_hops: |
| urlpath = rest_hops[0] |
| urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config) |
| storage_options = {urlpath.split("://")[0]: storage_options} |
| posix_path = "::".join([main_hop, urlpath, *rest_hops[1:]]) |
| else: |
| storage_options = None |
| fs, *_ = url_to_fs(xjoin(posix_path, pattern), **(storage_options or {})) |
| globbed_paths = fs.glob(xjoin(main_hop, pattern)) |
| for globbed_path in globbed_paths: |
| yield type(self)("::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops)) |
|
|
| def rglob(self, pattern, **kwargs): |
| """Rglob function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. |
| |
| Args: |
| pattern (`str`): Pattern that resulting paths must match. |
| |
| Yields: |
| [`xPath`] |
| """ |
| return self.glob("**/" + pattern, **kwargs) |
|
|
| @property |
| def parent(self) -> "xPath": |
| """Name function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. |
| |
| Returns: |
| [`xPath`] |
| """ |
| return type(self)(xdirname(self.as_posix())) |
|
|
| @property |
| def name(self) -> str: |
| """Name function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. |
| |
| Returns: |
| `str` |
| """ |
| return PurePosixPath(self.as_posix().split("::")[0]).name |
|
|
| @property |
| def stem(self) -> str: |
| """Stem function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. |
| |
| Returns: |
| `str` |
| """ |
| return PurePosixPath(self.as_posix().split("::")[0]).stem |
|
|
| @property |
| def suffix(self) -> str: |
| """Suffix function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. |
| |
| Returns: |
| `str` |
| """ |
| return PurePosixPath(self.as_posix().split("::")[0]).suffix |
|
|
| def open(self, *args, **kwargs): |
| """Extend :func:`xopen` to support argument of type :obj:`~pathlib.Path`. |
| |
| Args: |
| **args: Arguments passed to :func:`fsspec.open`. |
| **kwargs: Keyword arguments passed to :func:`fsspec.open`. |
| |
| Returns: |
| `io.FileIO`: File-like object. |
| """ |
| return xopen(str(self), *args, **kwargs) |
|
|
| def joinpath(self, *p: tuple[str, ...]) -> "xPath": |
| """Extend :func:`xjoin` to support argument of type :obj:`~pathlib.Path`. |
| |
| Args: |
| *p (`tuple` of `str`): Other path components. |
| |
| Returns: |
| [`xPath`] |
| """ |
| return type(self)(xjoin(self.as_posix(), *p)) |
|
|
| def __truediv__(self, p: str) -> "xPath": |
| return self.joinpath(p) |
|
|
| def with_suffix(self, suffix): |
| main_hop, *rest_hops = str(self).split("::") |
| if is_local_path(main_hop): |
| return type(self)(str(super().with_suffix(suffix))) |
| return type(self)("::".join([type(self)(PurePosixPath(main_hop).with_suffix(suffix)).as_posix()] + rest_hops)) |
|
|
|
|
| def _as_str(path: Union[str, Path, xPath]): |
| return str(path) if isinstance(path, xPath) else str(xPath(str(path))) |
|
|
|
|
| def xgzip_open(filepath_or_buffer, *args, download_config: Optional[DownloadConfig] = None, **kwargs): |
| import gzip |
|
|
| if hasattr(filepath_or_buffer, "read"): |
| return gzip.open(filepath_or_buffer, *args, **kwargs) |
| else: |
| filepath_or_buffer = str(filepath_or_buffer) |
| return gzip.open(xopen(filepath_or_buffer, "rb", download_config=download_config), *args, **kwargs) |
|
|
|
|
| def xnumpy_load(filepath_or_buffer, *args, download_config: Optional[DownloadConfig] = None, **kwargs): |
| import numpy as np |
|
|
| if hasattr(filepath_or_buffer, "read"): |
| return np.load(filepath_or_buffer, *args, **kwargs) |
| else: |
| filepath_or_buffer = str(filepath_or_buffer) |
| return np.load(xopen(filepath_or_buffer, "rb", download_config=download_config), *args, **kwargs) |
|
|
|
|
| def xpandas_read_csv(filepath_or_buffer, download_config: Optional[DownloadConfig] = None, **kwargs): |
| import pandas as pd |
|
|
| if hasattr(filepath_or_buffer, "read"): |
| return pd.read_csv(filepath_or_buffer, **kwargs) |
| else: |
| filepath_or_buffer = str(filepath_or_buffer) |
| if kwargs.get("compression", "infer") == "infer": |
| kwargs["compression"] = _get_extraction_protocol(filepath_or_buffer, download_config=download_config) |
| return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs) |
|
|
|
|
| def xpandas_read_excel(filepath_or_buffer, download_config: Optional[DownloadConfig] = None, **kwargs): |
| import pandas as pd |
|
|
| if hasattr(filepath_or_buffer, "read"): |
| try: |
| return pd.read_excel(filepath_or_buffer, **kwargs) |
| except ValueError: |
| return pd.read_excel(BytesIO(filepath_or_buffer.read()), **kwargs) |
| else: |
| filepath_or_buffer = str(filepath_or_buffer) |
| try: |
| return pd.read_excel(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs) |
| except ValueError: |
| return pd.read_excel( |
| BytesIO(xopen(filepath_or_buffer, "rb", download_config=download_config).read()), **kwargs |
| ) |
|
|
|
|
| def xpyarrow_parquet_read_table(filepath_or_buffer, download_config: Optional[DownloadConfig] = None, **kwargs): |
| import pyarrow.parquet as pq |
|
|
| if hasattr(filepath_or_buffer, "read"): |
| return pq.read_table(filepath_or_buffer, **kwargs) |
| else: |
| filepath_or_buffer = str(filepath_or_buffer) |
| return pq.read_table(xopen(filepath_or_buffer, mode="rb", download_config=download_config), **kwargs) |
|
|
|
|
| def xsio_loadmat(filepath_or_buffer, download_config: Optional[DownloadConfig] = None, **kwargs): |
| import scipy.io as sio |
|
|
| if hasattr(filepath_or_buffer, "read"): |
| return sio.loadmat(filepath_or_buffer, **kwargs) |
| else: |
| return sio.loadmat(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs) |
|
|
|
|
| def xet_parse(source, parser=None, download_config: Optional[DownloadConfig] = None): |
| """Extend `xml.etree.ElementTree.parse` function to support remote files. |
| |
| Args: |
| source: File path or file object. |
| parser (`XMLParser`, *optional*, default `XMLParser`): Parser instance. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| |
| Returns: |
| `xml.etree.ElementTree.Element`: Root element of the given source document. |
| """ |
| if hasattr(source, "read"): |
| return ET.parse(source, parser=parser) |
| else: |
| with xopen(source, "rb", download_config=download_config) as f: |
| return ET.parse(f, parser=parser) |
|
|
|
|
| def xxml_dom_minidom_parse(filename_or_file, download_config: Optional[DownloadConfig] = None, **kwargs): |
| """Extend `xml.dom.minidom.parse` function to support remote files. |
| |
| Args: |
| filename_or_file (`str` or file): File path or file object. |
| download_config : mainly use token or storage_options to support different platforms and auth types. |
| **kwargs (optional): Additional keyword arguments passed to `xml.dom.minidom.parse`. |
| |
| Returns: |
| :obj:`xml.dom.minidom.Document`: Parsed document. |
| """ |
| if hasattr(filename_or_file, "read"): |
| return xml.dom.minidom.parse(filename_or_file, **kwargs) |
| else: |
| with xopen(filename_or_file, "rb", download_config=download_config) as f: |
| return xml.dom.minidom.parse(f, **kwargs) |
|
|
|
|
| class ArchiveIterable(TrackedIterableFromGenerator): |
| """An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`""" |
|
|
| @staticmethod |
| def _iter_tar(f): |
| stream = tarfile.open(fileobj=f, mode="r|*") |
| for tarinfo in stream: |
| file_path = tarinfo.name |
| if not tarinfo.isreg(): |
| continue |
| if file_path is None: |
| continue |
| if os.path.basename(file_path).startswith((".", "__")): |
| |
| continue |
| file_obj = stream.extractfile(tarinfo) |
| yield file_path, file_obj |
| stream.members = [] |
| del stream |
|
|
| @staticmethod |
| def _iter_zip(f): |
| zipf = zipfile.ZipFile(f) |
| for member in zipf.infolist(): |
| file_path = member.filename |
| if member.is_dir(): |
| continue |
| if file_path is None: |
| continue |
| if os.path.basename(file_path).startswith((".", "__")): |
| |
| continue |
| file_obj = zipf.open(member) |
| yield file_path, file_obj |
|
|
| @classmethod |
| def _iter_from_fileobj(cls, f) -> Generator[tuple, None, None]: |
| compression = _get_extraction_protocol_with_magic_number(f) |
| if compression == "zip": |
| yield from cls._iter_zip(f) |
| else: |
| yield from cls._iter_tar(f) |
|
|
| @classmethod |
| def _iter_from_urlpath( |
| cls, urlpath: str, download_config: Optional[DownloadConfig] = None |
| ) -> Generator[tuple, None, None]: |
| compression = _get_extraction_protocol(urlpath, download_config=download_config) |
| |
| |
| with xopen(urlpath, "rb", download_config=download_config, block_size=0) as f: |
| if compression == "zip": |
| yield from cls._iter_zip(f) |
| else: |
| yield from cls._iter_tar(f) |
|
|
| @classmethod |
| def from_buf(cls, fileobj) -> "ArchiveIterable": |
| return cls(cls._iter_from_fileobj, fileobj) |
|
|
| @classmethod |
| def from_urlpath(cls, urlpath_or_buf, download_config: Optional[DownloadConfig] = None) -> "ArchiveIterable": |
| return cls(cls._iter_from_urlpath, urlpath_or_buf, download_config) |
|
|
|
|
| class FilesIterable(TrackedIterableFromGenerator): |
| """An iterable of paths from a list of directories or files""" |
|
|
| @classmethod |
| def _iter_from_urlpaths( |
| cls, urlpaths: Union[str, list[str]], download_config: Optional[DownloadConfig] = None |
| ) -> Generator[str, None, None]: |
| if not isinstance(urlpaths, list): |
| urlpaths = [urlpaths] |
| for urlpath in urlpaths: |
| if xisfile(urlpath, download_config=download_config): |
| yield urlpath |
| elif xisdir(urlpath, download_config=download_config): |
| for dirpath, dirnames, filenames in xwalk(urlpath, download_config=download_config): |
| |
| dirnames[:] = sorted([dirname for dirname in dirnames if not dirname.startswith((".", "__"))]) |
| if xbasename(dirpath).startswith((".", "__")): |
| |
| continue |
| for filename in sorted(filenames): |
| if filename.startswith((".", "__")): |
| |
| continue |
| yield xjoin(dirpath, filename) |
| else: |
| raise FileNotFoundError(urlpath) |
|
|
| @classmethod |
| def from_urlpaths(cls, urlpaths, download_config: Optional[DownloadConfig] = None) -> "FilesIterable": |
| return cls(cls._iter_from_urlpaths, urlpaths, download_config) |
|
|