| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| """Download manager interface.""" |
|
|
| import enum |
| import io |
| import multiprocessing |
| import os |
| from datetime import datetime |
| from functools import partial |
| from typing import Optional, Union |
|
|
| import fsspec |
| from fsspec.core import url_to_fs |
| from tqdm.contrib.concurrent import thread_map |
|
|
| from .. import config |
| from ..utils import tqdm as hf_tqdm |
| from ..utils.file_utils import ( |
| ArchiveIterable, |
| FilesIterable, |
| cached_path, |
| is_relative_path, |
| stack_multiprocessing_download_progress_bars, |
| url_or_path_join, |
| ) |
| from ..utils.info_utils import get_size_checksum_dict |
| from ..utils.logging import get_logger, tqdm |
| from ..utils.py_utils import NestedDataStructure, map_nested |
| from ..utils.track import tracked_str |
| from .download_config import DownloadConfig |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class DownloadMode(enum.Enum): |
| """`Enum` for how to treat pre-existing downloads and data. |
| |
| The default mode is `REUSE_DATASET_IF_EXISTS`, which will reuse both |
| raw downloads and the prepared dataset if they exist. |
| |
| The generations modes: |
| |
| | | Downloads | Dataset | |
| |-------------------------------------|-----------|---------| |
| | `REUSE_DATASET_IF_EXISTS` (default) | Reuse | Reuse | |
| | `REUSE_CACHE_IF_EXISTS` | Reuse | Fresh | |
| | `FORCE_REDOWNLOAD` | Fresh | Fresh | |
| |
| """ |
|
|
| REUSE_DATASET_IF_EXISTS = "reuse_dataset_if_exists" |
| REUSE_CACHE_IF_EXISTS = "reuse_cache_if_exists" |
| FORCE_REDOWNLOAD = "force_redownload" |
|
|
|
|
| class DownloadManager: |
| is_streaming = False |
|
|
| def __init__( |
| self, |
| dataset_name: Optional[str] = None, |
| data_dir: Optional[str] = None, |
| download_config: Optional[DownloadConfig] = None, |
| base_path: Optional[str] = None, |
| record_checksums=True, |
| ): |
| """Download manager constructor. |
| |
| Args: |
| data_dir: |
| can be used to specify a manual directory to get the files from. |
| dataset_name (`str`): |
| name of dataset this instance will be used for. If |
| provided, downloads will contain which datasets they were used for. |
| download_config (`DownloadConfig`): |
| to specify the cache directory and other |
| download options |
| base_path (`str`): |
| base path that is used when relative paths are used to |
| download files. This can be a remote url. |
| record_checksums (`bool`, defaults to `True`): |
| Whether to record the checksums of the downloaded files. If None, the value is inferred from the builder. |
| """ |
| self._dataset_name = dataset_name |
| self._data_dir = data_dir |
| self._base_path = base_path or os.path.abspath(".") |
| |
| self._recorded_sizes_checksums: dict[str, dict[str, Optional[Union[int, str]]]] = {} |
| self.record_checksums = record_checksums |
| self.download_config = download_config or DownloadConfig() |
| self.downloaded_paths = {} |
| self.extracted_paths = {} |
|
|
| @property |
| def manual_dir(self): |
| return self._data_dir |
|
|
| @property |
| def downloaded_size(self): |
| """Returns the total size of downloaded files.""" |
| return sum(checksums_dict["num_bytes"] for checksums_dict in self._recorded_sizes_checksums.values()) |
|
|
| def _record_sizes_checksums(self, url_or_urls: NestedDataStructure, downloaded_path_or_paths: NestedDataStructure): |
| """Record size/checksum of downloaded files.""" |
| delay = 5 |
| for url, path in hf_tqdm( |
| list(zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten())), |
| delay=delay, |
| desc="Computing checksums", |
| ): |
| |
| self._recorded_sizes_checksums[str(url)] = get_size_checksum_dict( |
| path, record_checksum=self.record_checksums |
| ) |
|
|
| def download(self, url_or_urls): |
| """Download given URL(s). |
| |
| By default, only one process is used for download. Pass customized `download_config.num_proc` to change this behavior. |
| |
| Args: |
| url_or_urls (`str` or `list` or `dict`): |
| URL or `list` or `dict` of URLs to download. Each URL is a `str`. |
| |
| Returns: |
| `str` or `list` or `dict`: |
| The downloaded paths matching the given input `url_or_urls`. |
| |
| Example: |
| |
| ```py |
| >>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz') |
| ``` |
| """ |
| download_config = self.download_config.copy() |
| download_config.extract_compressed_file = False |
| if download_config.download_desc is None: |
| download_config.download_desc = "Downloading data" |
|
|
| download_func = partial(self._download_batched, download_config=download_config) |
|
|
| start_time = datetime.now() |
| with stack_multiprocessing_download_progress_bars(): |
| downloaded_path_or_paths = map_nested( |
| download_func, |
| url_or_urls, |
| map_tuple=True, |
| num_proc=download_config.num_proc, |
| desc="Downloading data files", |
| batched=True, |
| batch_size=-1, |
| ) |
| duration = datetime.now() - start_time |
| logger.info(f"Downloading took {duration.total_seconds() // 60} min") |
| url_or_urls = NestedDataStructure(url_or_urls) |
| downloaded_path_or_paths = NestedDataStructure(downloaded_path_or_paths) |
| self.downloaded_paths.update(dict(zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten()))) |
|
|
| start_time = datetime.now() |
| self._record_sizes_checksums(url_or_urls, downloaded_path_or_paths) |
| duration = datetime.now() - start_time |
| logger.info(f"Checksum Computation took {duration.total_seconds() // 60} min") |
|
|
| return downloaded_path_or_paths.data |
|
|
| def _download_batched( |
| self, |
| url_or_filenames: list[str], |
| download_config: DownloadConfig, |
| ) -> list[str]: |
| if len(url_or_filenames) >= 16: |
| download_config = download_config.copy() |
| download_config.disable_tqdm = True |
| download_func = partial(self._download_single, download_config=download_config) |
|
|
| fs: fsspec.AbstractFileSystem |
| path = str(url_or_filenames[0]) |
| if is_relative_path(path): |
| |
| path = url_or_path_join(self._base_path, path) |
| fs, path = url_to_fs(path, **download_config.storage_options) |
| size = 0 |
| try: |
| size = fs.info(path).get("size", 0) |
| except Exception: |
| pass |
| max_workers = ( |
| config.HF_DATASETS_MULTITHREADING_MAX_WORKERS if size < (20 << 20) else 1 |
| ) |
|
|
| return thread_map( |
| download_func, |
| url_or_filenames, |
| desc=download_config.download_desc or "Downloading", |
| unit="files", |
| position=multiprocessing.current_process()._identity[-1] |
| if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1" |
| and multiprocessing.current_process()._identity |
| else None, |
| max_workers=max_workers, |
| tqdm_class=tqdm, |
| ) |
| else: |
| return [ |
| self._download_single(url_or_filename, download_config=download_config) |
| for url_or_filename in url_or_filenames |
| ] |
|
|
| def _download_single(self, url_or_filename: str, download_config: DownloadConfig) -> str: |
| url_or_filename = str(url_or_filename) |
| if is_relative_path(url_or_filename): |
| |
| url_or_filename = url_or_path_join(self._base_path, url_or_filename) |
| out = cached_path(url_or_filename, download_config=download_config) |
| out = tracked_str(out) |
| out.set_origin(url_or_filename) |
| return out |
|
|
| def iter_archive(self, path_or_buf: Union[str, io.BufferedReader]): |
| """Iterate over files within an archive. |
| |
| Args: |
| path_or_buf (`str` or `io.BufferedReader`): |
| Archive path or archive binary file object. |
| |
| Yields: |
| `tuple[str, io.BufferedReader]`: |
| 2-tuple (path_within_archive, file_object). |
| File object is opened in binary mode. |
| |
| Example: |
| |
| ```py |
| >>> archive = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz') |
| >>> files = dl_manager.iter_archive(archive) |
| ``` |
| """ |
|
|
| if hasattr(path_or_buf, "read"): |
| return ArchiveIterable.from_buf(path_or_buf) |
| else: |
| return ArchiveIterable.from_urlpath(path_or_buf) |
|
|
| def iter_files(self, paths: Union[str, list[str]]): |
| """Iterate over file paths. |
| |
| Args: |
| paths (`str` or `list` of `str`): |
| Root paths. |
| |
| Yields: |
| `str`: File path. |
| |
| Example: |
| |
| ```py |
| >>> files = dl_manager.download_and_extract('https://huggingface.co/datasets/beans/resolve/main/data/train.zip') |
| >>> files = dl_manager.iter_files(files) |
| ``` |
| """ |
| return FilesIterable.from_urlpaths(paths) |
|
|
| def extract(self, path_or_paths): |
| """Extract given path(s). |
| |
| Args: |
| path_or_paths (path or `list` or `dict`): |
| Path of file to extract. Each path is a `str`. |
| |
| Returns: |
| extracted_path(s): `str`, The extracted paths matching the given input |
| path_or_paths. |
| |
| Example: |
| |
| ```py |
| >>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz') |
| >>> extracted_files = dl_manager.extract(downloaded_files) |
| ``` |
| """ |
| download_config = self.download_config.copy() |
| download_config.extract_compressed_file = True |
| extract_func = partial(self._download_single, download_config=download_config) |
| extracted_paths = map_nested( |
| extract_func, |
| path_or_paths, |
| num_proc=download_config.num_proc, |
| desc="Extracting data files", |
| ) |
| path_or_paths = NestedDataStructure(path_or_paths) |
| extracted_paths = NestedDataStructure(extracted_paths) |
| self.extracted_paths.update(dict(zip(path_or_paths.flatten(), extracted_paths.flatten()))) |
| return extracted_paths.data |
|
|
| def download_and_extract(self, url_or_urls): |
| """Download and extract given `url_or_urls`. |
| |
| Is roughly equivalent to: |
| |
| ``` |
| extracted_paths = dl_manager.extract(dl_manager.download(url_or_urls)) |
| ``` |
| |
| Args: |
| url_or_urls (`str` or `list` or `dict`): |
| URL or `list` or `dict` of URLs to download and extract. Each URL is a `str`. |
| |
| Returns: |
| extracted_path(s): `str`, extracted paths of given URL(s). |
| """ |
| return self.extract(self.download(url_or_urls)) |
|
|
| def get_recorded_sizes_checksums(self): |
| return self._recorded_sizes_checksums.copy() |
|
|
| def delete_extracted_files(self): |
| paths_to_delete = set(self.extracted_paths.values()) - set(self.downloaded_paths.values()) |
| for key, path in list(self.extracted_paths.items()): |
| if path in paths_to_delete and os.path.isfile(path): |
| os.remove(path) |
| del self.extracted_paths[key] |
|
|
| def manage_extracted_files(self): |
| if self.download_config.delete_extracted: |
| self.delete_extracted_files() |
|
|