| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| """DatasetBuilder base class.""" |
|
|
| import abc |
| import contextlib |
| import copy |
| import inspect |
| import os |
| import posixpath |
| import shutil |
| import textwrap |
| import time |
| import urllib |
| from collections.abc import Iterable, Mapping |
| from dataclasses import dataclass |
| from functools import partial |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Optional, Union |
| from unittest.mock import patch |
|
|
| import fsspec |
| from fsspec.core import url_to_fs |
| from multiprocess import Pool |
| from tqdm.contrib.concurrent import thread_map |
|
|
| from . import config, utils |
| from .arrow_dataset import Dataset |
| from .arrow_reader import ( |
| ArrowReader, |
| ReadInstruction, |
| ) |
| from .arrow_writer import ArrowWriter, ParquetWriter, SchemaInferenceError |
| from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns |
| from .dataset_dict import DatasetDict, IterableDatasetDict |
| from .download.download_config import DownloadConfig |
| from .download.download_manager import DownloadManager, DownloadMode |
| from .download.streaming_download_manager import StreamingDownloadManager, xjoin |
| from .exceptions import DatasetGenerationCastError, DatasetGenerationError, FileFormatError, ManualDownloadError |
| from .features import Features |
| from .filesystems import ( |
| is_remote_filesystem, |
| rename, |
| ) |
| from .fingerprint import Hasher |
| from .info import DatasetInfo, PostProcessedInfo |
| from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset |
| from .keyhash import DuplicatedKeysError |
| from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase |
| from .splits import Split, SplitDict, SplitGenerator, SplitInfo |
| from .streaming import extend_dataset_builder_for_streaming |
| from .table import CastError |
| from .utils import logging |
| from .utils import tqdm as hf_tqdm |
| from .utils._filelock import FileLock |
| from .utils.file_utils import is_remote_url |
| from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits |
| from .utils.py_utils import ( |
| classproperty, |
| convert_file_size_to_int, |
| has_sufficient_disk_space, |
| iflatmap_unordered, |
| map_nested, |
| memoize, |
| size_str, |
| temporary_assignment, |
| ) |
| from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs |
| from .utils.track import tracked_list |
|
|
|
|
| if TYPE_CHECKING: |
| from .load import DatasetModule |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class InvalidConfigName(ValueError): |
| pass |
|
|
|
|
| @dataclass |
| class BuilderConfig: |
| """Base class for `DatasetBuilder` data configuration. |
| |
| `DatasetBuilder` subclasses with data configuration options should subclass |
| `BuilderConfig` and add their own properties. |
| |
| Attributes: |
| name (`str`, defaults to `default`): |
| The name of the configuration. |
| version (`Version` or `str`, defaults to `0.0.0`): |
| The version of the configuration. |
| data_dir (`str`, *optional*): |
| Path to the directory containing the source data. |
| data_files (`str` or `Sequence` or `Mapping`, *optional*): |
| Path(s) to source data file(s). |
| description (`str`, *optional*): |
| A human description of the configuration. |
| """ |
|
|
| name: str = "default" |
| version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0") |
| data_dir: Optional[str] = None |
| data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None |
| description: Optional[str] = None |
|
|
| def __post_init__(self): |
| |
| for invalid_char in INVALID_WINDOWS_CHARACTERS_IN_PATH: |
| if invalid_char in self.name: |
| raise InvalidConfigName( |
| f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. " |
| f"They could create issues when creating a directory for this config on Windows filesystem." |
| ) |
| if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)): |
| raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}") |
|
|
| def __eq__(self, o): |
| |
| |
| if set(self.__dict__.keys()) != set(o.__dict__.keys()): |
| return False |
| return all((k, getattr(self, k)) == (k, getattr(o, k)) for k in self.__dict__.keys()) |
|
|
| def create_config_id( |
| self, |
| config_kwargs: dict, |
| custom_features: Optional[Features] = None, |
| ) -> str: |
| """ |
| The config id is used to build the cache directory. |
| By default it is equal to the config name. |
| However the name of a config is not sufficient to have a unique identifier for the dataset being generated |
| since it doesn't take into account: |
| - the config kwargs that can be used to overwrite attributes |
| - the custom features used to write the dataset |
| - the data_files for json/text/csv/pandas datasets |
| |
| Therefore the config id is just the config name with an optional suffix based on these. |
| """ |
| |
| suffix: Optional[str] = None |
| config_kwargs_to_add_to_suffix = config_kwargs.copy() |
| |
| config_kwargs_to_add_to_suffix.pop("name", None) |
| config_kwargs_to_add_to_suffix.pop("version", None) |
| |
| |
| |
| |
| if "data_dir" in config_kwargs_to_add_to_suffix: |
| if config_kwargs_to_add_to_suffix["data_dir"] is None: |
| config_kwargs_to_add_to_suffix.pop("data_dir", None) |
| else: |
| |
| |
| data_dir = config_kwargs_to_add_to_suffix["data_dir"] |
| data_dir = os.path.normpath(data_dir) |
| config_kwargs_to_add_to_suffix["data_dir"] = data_dir |
| if config_kwargs_to_add_to_suffix: |
| |
| config_kwargs_to_add_to_suffix = { |
| k: config_kwargs_to_add_to_suffix[k] for k in sorted(config_kwargs_to_add_to_suffix) |
| } |
| if all(isinstance(v, (str, bool, int, float)) for v in config_kwargs_to_add_to_suffix.values()): |
| suffix = ",".join( |
| str(k) + "=" + urllib.parse.quote_plus(str(v)) for k, v in config_kwargs_to_add_to_suffix.items() |
| ) |
| if len(suffix) > 32: |
| suffix = Hasher.hash(config_kwargs_to_add_to_suffix) |
| else: |
| suffix = Hasher.hash(config_kwargs_to_add_to_suffix) |
|
|
| if custom_features is not None: |
| m = Hasher() |
| if suffix: |
| m.update(suffix) |
| m.update(custom_features) |
| suffix = m.hexdigest() |
|
|
| if suffix: |
| config_id = self.name + "-" + suffix |
| if len(config_id) > config.MAX_DATASET_CONFIG_ID_READABLE_LENGTH: |
| config_id = self.name + "-" + Hasher.hash(suffix) |
| return config_id |
| else: |
| return self.name |
|
|
| def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None: |
| if isinstance(self.data_files, DataFilesPatternsDict): |
| base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path |
| self.data_files = self.data_files.resolve(base_path, download_config) |
|
|
|
|
| class DatasetBuilder: |
| """Abstract base class for all datasets. |
| |
| `DatasetBuilder` has 3 key methods: |
| |
| - [`DatasetBuilder.info`]: Documents the dataset, including feature |
| names, types, shapes, version, splits, citation, etc. |
| - [`DatasetBuilder.download_and_prepare`]: Downloads the source data |
| and writes it to disk. |
| - [`DatasetBuilder.as_dataset`]: Generates a [`Dataset`]. |
| |
| Some `DatasetBuilder`s expose multiple variants of the |
| dataset by defining a [`BuilderConfig`] subclass and accepting a |
| config object (or name) on construction. Configurable datasets expose a |
| pre-defined set of configurations in [`DatasetBuilder.builder_configs`]. |
| |
| Args: |
| cache_dir (`str`, *optional*): |
| Directory to cache data. Defaults to `"~/.cache/huggingface/datasets"`. |
| dataset_name (`str`, *optional*): |
| Name of the dataset, if different from the builder name. Useful for packaged builders |
| like csv, imagefolder, audiofolder, etc. to reflect the difference between datasets |
| that use the same packaged builder. |
| config_name (`str`, *optional*): |
| Name of the dataset configuration. |
| It affects the data generated on disk. Different configurations will have their own subdirectories and |
| versions. |
| If not provided, the default configuration is used (if it exists). |
| |
| <Added version="2.3.0"> |
| |
| Parameter `name` was renamed to `config_name`. |
| |
| </Added> |
| hash (`str`, *optional*): |
| Hash specific to the dataset builder code. Used to update the caching directory when the |
| dataset builder code is updated (to avoid reusing old data). |
| The typical caching directory (defined in `self._relative_data_dir`) is `name/version/hash/`. |
| base_path (`str`, *optional*): |
| Base path for relative paths that are used to download files. |
| This can be a remote URL. |
| features ([`Features`], *optional*): |
| Features types to use with this dataset. |
| It can be used to change the [`Features`] types of a dataset, for example. |
| token (`str` or `bool`, *optional*): |
| String or boolean to use as Bearer token for remote files on the |
| Datasets Hub. If `True`, will get token from `"~/.huggingface"`. |
| repo_id (`str`, *optional*): |
| ID of the dataset repository. |
| Used to distinguish builders with the same name but not coming from the same namespace, for example "rajpurkar/squad" |
| and "lhoestq/squad" repo IDs. In the latter, the builder name would be "lhoestq___squad". |
| data_files (`str` or `Sequence` or `Mapping`, *optional*): |
| Path(s) to source data file(s). |
| For builders like "csv" or "json" that need the user to specify data files. They can be either |
| local or remote files. For convenience, you can use a `DataFilesDict`. |
| data_dir (`str`, *optional*): |
| Path to directory containing source data file(s). |
| Use only if `data_files` is not passed, in which case it is equivalent to passing |
| `os.path.join(data_dir, "**")` as `data_files`. |
| For builders that require manual download, it must be the path to the local directory containing the |
| manually downloaded data. |
| storage_options (`dict`, *optional*): |
| Key/value pairs to be passed on to the dataset file-system backend, if any. |
| writer_batch_size (`int`, *optional*): |
| Batch size used by the ArrowWriter. |
| It defines the number of samples that are kept in memory before writing them |
| and also the length of the arrow chunks. |
| None means that the ArrowWriter will use its default value. |
| **config_kwargs (additional keyword arguments): Keyword arguments to be passed to the corresponding builder |
| configuration class, set on the class attribute [`DatasetBuilder.BUILDER_CONFIG_CLASS`]. The builder |
| configuration class is [`BuilderConfig`] or a subclass of it. |
| """ |
|
|
| |
| VERSION = None |
|
|
| |
| BUILDER_CONFIG_CLASS = BuilderConfig |
|
|
| |
| BUILDER_CONFIGS = [] |
|
|
| |
| DEFAULT_CONFIG_NAME = None |
|
|
| |
| |
| |
| |
| DEFAULT_WRITER_BATCH_SIZE = None |
|
|
| def __init__( |
| self, |
| cache_dir: Optional[str] = None, |
| dataset_name: Optional[str] = None, |
| config_name: Optional[str] = None, |
| hash: Optional[str] = None, |
| base_path: Optional[str] = None, |
| info: Optional[DatasetInfo] = None, |
| features: Optional[Features] = None, |
| token: Optional[Union[bool, str]] = None, |
| repo_id: Optional[str] = None, |
| data_files: Optional[Union[str, list, dict, DataFilesDict]] = None, |
| data_dir: Optional[str] = None, |
| storage_options: Optional[dict] = None, |
| writer_batch_size: Optional[int] = None, |
| config_id: Optional[str] = None, |
| **config_kwargs, |
| ): |
| |
| self.name: str = camelcase_to_snakecase(self.__module__.split(".")[-1]) |
| self.hash: Optional[str] = hash |
| self.base_path = base_path |
| self.token = token |
| self.repo_id = repo_id |
| self.storage_options = storage_options or {} |
| self.dataset_name = camelcase_to_snakecase(dataset_name) if dataset_name else self.name |
| self._writer_batch_size = writer_batch_size or self.DEFAULT_WRITER_BATCH_SIZE |
|
|
| if data_files is not None and not isinstance(data_files, DataFilesDict): |
| data_files = DataFilesDict.from_patterns( |
| sanitize_patterns(data_files), |
| base_path=base_path, |
| download_config=DownloadConfig(token=token, storage_options=self.storage_options), |
| ) |
|
|
| |
| if "features" in inspect.signature(self.BUILDER_CONFIG_CLASS.__init__).parameters and features is not None: |
| config_kwargs["features"] = features |
| if data_files is not None: |
| config_kwargs["data_files"] = data_files |
| if data_dir is not None: |
| config_kwargs["data_dir"] = data_dir |
| self.config_kwargs = config_kwargs |
| self.config, self.config_id = self._create_builder_config( |
| config_name=config_name, |
| custom_features=features, |
| config_id=config_id, |
| **config_kwargs, |
| ) |
|
|
| |
| |
| if info is None: |
| info = self._info() |
| info.builder_name = self.name |
| info.dataset_name = self.dataset_name |
| info.config_name = self.config.name |
| info.version = self.config.version |
| self.info = info |
| |
| if features is not None: |
| self.info.features = features |
|
|
| |
| |
| self._cache_dir_root = str(cache_dir or config.HF_DATASETS_CACHE) |
| self._cache_dir_root = ( |
| self._cache_dir_root if is_remote_url(self._cache_dir_root) else os.path.expanduser(self._cache_dir_root) |
| ) |
| self._cache_downloaded_dir = ( |
| posixpath.join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) |
| if cache_dir |
| else str(config.DOWNLOADED_DATASETS_PATH) |
| ) |
| self._cache_downloaded_dir = ( |
| self._cache_downloaded_dir |
| if is_remote_url(self._cache_downloaded_dir) |
| else os.path.expanduser(self._cache_downloaded_dir) |
| ) |
|
|
| |
| self._legacy_relative_data_dir = None |
|
|
| self._cache_dir = self._build_cache_dir() |
| if not is_remote_url(self._cache_dir_root): |
| os.makedirs(self._cache_dir_root, exist_ok=True) |
| lock_path = os.path.join( |
| self._cache_dir_root, Path(self._cache_dir).as_posix().replace("/", "_") + ".lock" |
| ) |
| with FileLock(lock_path): |
| if os.path.exists(self._cache_dir): |
| if len(os.listdir(self._cache_dir)) > 0: |
| if os.path.exists(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)): |
| logger.debug("Overwrite dataset info from restored data version if exists.") |
| self.info = DatasetInfo.from_directory(self._cache_dir) |
| else: |
| logger.warning( |
| f"Old caching folder {self._cache_dir} for dataset {self.dataset_name} exists but no data were found. Removing it. " |
| ) |
| os.rmdir(self._cache_dir) |
|
|
| |
| self._output_dir = self._cache_dir |
| self._fs: fsspec.AbstractFileSystem = fsspec.filesystem("file") |
|
|
| |
| self.dl_manager = None |
|
|
| |
| self._record_infos = False |
|
|
| |
| self._file_format = None |
|
|
| |
| extend_dataset_builder_for_streaming(self) |
|
|
| def __getstate__(self): |
| return self.__dict__ |
|
|
| def __setstate__(self, d): |
| self.__dict__ = d |
| |
| extend_dataset_builder_for_streaming(self) |
|
|
| |
| |
| |
| |
| |
| |
| @property |
| def manual_download_instructions(self) -> Optional[str]: |
| return None |
|
|
| def _check_legacy_cache(self) -> Optional[str]: |
| """Check for the old cache directory template {cache_dir}/{namespace}___{builder_name} from 2.13""" |
| if ( |
| self.__module__.startswith("datasets.") |
| and not is_remote_url(self._cache_dir_root) |
| and self.config.name == "default" |
| ): |
| from .packaged_modules import _PACKAGED_DATASETS_MODULES |
|
|
| namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None |
| config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name |
| config_id = config_name + self.config_id[len(self.config.name) :] |
| hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1] |
| legacy_relative_data_dir = posixpath.join( |
| self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}", |
| config_id, |
| "0.0.0", |
| hash, |
| ) |
| legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir) |
| if os.path.isdir(legacy_cache_dir): |
| return legacy_relative_data_dir |
|
|
| def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]: |
| """Check for the old cache directory template {cache_dir}/{namespace}___{dataset_name}/{config_name}-xxx from 2.14 and 2.15""" |
| if ( |
| self.__module__.startswith("datasets.") |
| and not is_remote_url(self._cache_dir_root) |
| and not (set(self.config_kwargs) - {"data_files", "data_dir"}) |
| ): |
| from .packaged_modules import _PACKAGED_DATASETS_MODULES_2_15_HASHES |
| from .utils._dill import Pickler |
|
|
| def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str: |
| """ |
| Used to update hash of packaged modules which is used for creating unique cache directories to reflect |
| different config parameters which are passed in metadata from readme. |
| """ |
| params_to_exclude = {"config_name", "version", "description"} |
| params_to_add_to_hash = { |
| param: value |
| for param, value in sorted(config_parameters.items()) |
| if param not in params_to_exclude |
| } |
| m = Hasher() |
| m.update(hash) |
| m.update(params_to_add_to_hash) |
| return m.hexdigest() |
|
|
| namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None |
| with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True): |
| config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files}) |
| hash = _PACKAGED_DATASETS_MODULES_2_15_HASHES.get(self.name, "missing") |
| if ( |
| dataset_module.builder_configs_parameters.metadata_configs |
| and self.config.name in dataset_module.builder_configs_parameters.metadata_configs |
| ): |
| hash = update_hash_with_config_parameters( |
| hash, dataset_module.builder_configs_parameters.metadata_configs[self.config.name] |
| ) |
| legacy_relative_data_dir = posixpath.join( |
| self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}", |
| config_id, |
| "0.0.0", |
| hash, |
| ) |
| legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir) |
| if os.path.isdir(legacy_cache_dir): |
| return legacy_relative_data_dir |
|
|
| def _create_builder_config( |
| self, config_name=None, custom_features=None, config_id=None, **config_kwargs |
| ) -> tuple[BuilderConfig, str]: |
| """Create and validate BuilderConfig object as well as a unique config id for this config. |
| Raises ValueError if there are multiple builder configs and config_name and DEFAULT_CONFIG_NAME are None. |
| config_kwargs override the defaults kwargs in config |
| """ |
| builder_config = None |
|
|
| |
| if config_name is None and self.BUILDER_CONFIGS: |
| if self.DEFAULT_CONFIG_NAME is not None: |
| builder_config = self.builder_configs.get(self.DEFAULT_CONFIG_NAME) |
| logger.info(f"No config specified, defaulting to: {self.dataset_name}/{builder_config.name}") |
| else: |
| if len(self.BUILDER_CONFIGS) > 1: |
| if not config_kwargs: |
| example_of_usage = ( |
| f"load_dataset('{self.repo_id or self.dataset_name}', '{self.BUILDER_CONFIGS[0].name}')" |
| ) |
| raise ValueError( |
| "Config name is missing." |
| f"\nPlease pick one among the available configs: {list(self.builder_configs.keys())}" |
| + f"\nExample of usage:\n\t`{example_of_usage}`" |
| ) |
| else: |
| builder_config = self.BUILDER_CONFIGS[0] |
| logger.info( |
| f"No config specified, defaulting to the single config: {self.dataset_name}/{builder_config.name}" |
| ) |
|
|
| |
| if isinstance(config_name, str): |
| builder_config = self.builder_configs.get(config_name) |
| if builder_config is None and self.BUILDER_CONFIGS: |
| raise ValueError( |
| f"BuilderConfig '{config_name}' not found. Available: {list(self.builder_configs.keys())}" |
| ) |
|
|
| |
| if not builder_config: |
| if config_name is not None: |
| config_kwargs["name"] = config_name |
| elif self.DEFAULT_CONFIG_NAME and not config_kwargs: |
| |
| config_kwargs["name"] = self.DEFAULT_CONFIG_NAME |
| if "version" not in config_kwargs and hasattr(self, "VERSION") and self.VERSION: |
| config_kwargs["version"] = self.VERSION |
| builder_config = self.BUILDER_CONFIG_CLASS(**config_kwargs) |
|
|
| |
| else: |
| builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config |
| for key, value in config_kwargs.items(): |
| if value is not None: |
| if not hasattr(builder_config, key): |
| raise ValueError(f"BuilderConfig {builder_config} doesn't have a '{key}' key.") |
| setattr(builder_config, key, value) |
|
|
| if not builder_config.name: |
| raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}") |
|
|
| |
| builder_config._resolve_data_files( |
| base_path=self.base_path, |
| download_config=DownloadConfig(token=self.token, storage_options=self.storage_options), |
| ) |
|
|
| |
| if config_id is None: |
| config_id = builder_config.create_config_id( |
| config_kwargs, |
| custom_features=custom_features, |
| ) |
| is_custom = (config_id not in self.builder_configs) and config_id != "default" |
| if is_custom: |
| logger.info(f"Using custom data configuration {config_id}") |
| else: |
| if ( |
| builder_config.name in self.builder_configs |
| and builder_config != self.builder_configs[builder_config.name] |
| ): |
| raise ValueError( |
| "Cannot name a custom BuilderConfig the same as an available " |
| f"BuilderConfig. Change the name. Available BuilderConfigs: {list(self.builder_configs.keys())}" |
| ) |
| if not builder_config.version: |
| raise ValueError(f"BuilderConfig {builder_config.name} must have a version") |
|
|
| return builder_config, config_id |
|
|
| @classproperty |
| @classmethod |
| @memoize() |
| def builder_configs(cls) -> dict[str, BuilderConfig]: |
| """Dictionary of pre-defined configurations for this builder class.""" |
| configs = {config.name: config for config in cls.BUILDER_CONFIGS} |
| if len(configs) != len(cls.BUILDER_CONFIGS): |
| names = [config.name for config in cls.BUILDER_CONFIGS] |
| raise ValueError(f"Names in BUILDER_CONFIGS must not be duplicated. Got {names}") |
| return configs |
|
|
| @property |
| def cache_dir(self): |
| return self._cache_dir |
|
|
| def _use_legacy_cache_dir_if_possible(self, dataset_module: "DatasetModule"): |
| |
| self._legacy_relative_data_dir = ( |
| self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None |
| ) |
| self._cache_dir = self._build_cache_dir() |
| self._output_dir = self._cache_dir |
|
|
| def _relative_data_dir(self, with_version=True, with_hash=True) -> str: |
| """Relative path of this dataset in cache_dir: |
| Will be: |
| self.dataset_name/self.config.version/self.hash/ |
| or if a repo_id with a namespace has been specified: |
| self.namespace___self.dataset_name/self.config.version/self.hash/ |
| If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped. |
| """ |
| if self._legacy_relative_data_dir is not None and with_version and with_hash: |
| return self._legacy_relative_data_dir |
|
|
| namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None |
| builder_data_dir = self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}" |
| builder_data_dir = posixpath.join(builder_data_dir, self.config_id) |
| if with_version: |
| builder_data_dir = posixpath.join(builder_data_dir, str(self.config.version)) |
| if with_hash and self.hash and isinstance(self.hash, str): |
| builder_data_dir = posixpath.join(builder_data_dir, self.hash) |
| return builder_data_dir |
|
|
| def _build_cache_dir(self): |
| """Return the data directory for the current version.""" |
| builder_data_dir = posixpath.join(self._cache_dir_root, self._relative_data_dir(with_version=False)) |
| version_data_dir = posixpath.join(self._cache_dir_root, self._relative_data_dir(with_version=True)) |
|
|
| def _other_versions_on_disk(): |
| """Returns previous versions on disk.""" |
| if not os.path.exists(builder_data_dir): |
| return [] |
|
|
| version_dirnames = [] |
| for dir_name in os.listdir(builder_data_dir): |
| try: |
| version_dirnames.append((utils.Version(dir_name), dir_name)) |
| except ValueError: |
| pass |
| version_dirnames.sort(reverse=True) |
| return version_dirnames |
|
|
| |
| if not is_remote_url(builder_data_dir): |
| version_dirs = _other_versions_on_disk() |
| if version_dirs: |
| other_version = version_dirs[0][0] |
| if other_version != self.config.version: |
| warn_msg = ( |
| f"Found a different version {str(other_version)} of dataset {self.dataset_name} in " |
| f"cache_dir {self._cache_dir_root}. Using currently defined version " |
| f"{str(self.config.version)}." |
| ) |
| logger.warning(warn_msg) |
|
|
| return version_data_dir |
|
|
| @abc.abstractmethod |
| def _info(self) -> DatasetInfo: |
| """Construct the DatasetInfo object. See `DatasetInfo` for details. |
| |
| Warning: This function is only called once and the result is cached for all |
| following .info() calls. |
| |
| Returns: |
| info: (DatasetInfo) The dataset information |
| """ |
| raise NotImplementedError |
|
|
| @classmethod |
| def get_imported_module_dir(cls): |
| """Return the path of the module of this class or subclass.""" |
| return os.path.dirname(inspect.getfile(inspect.getmodule(cls))) |
|
|
| def _rename(self, src: str, dst: str): |
| rename(self._fs, src, dst) |
|
|
| def download_and_prepare( |
| self, |
| output_dir: Optional[str] = None, |
| download_config: Optional[DownloadConfig] = None, |
| download_mode: Optional[Union[DownloadMode, str]] = None, |
| verification_mode: Optional[Union[VerificationMode, str]] = None, |
| dl_manager: Optional[DownloadManager] = None, |
| base_path: Optional[str] = None, |
| file_format: str = "arrow", |
| max_shard_size: Optional[Union[int, str]] = None, |
| num_proc: Optional[int] = None, |
| storage_options: Optional[dict] = None, |
| **download_and_prepare_kwargs, |
| ): |
| """Downloads and prepares dataset for reading. |
| |
| Args: |
| output_dir (`str`, *optional*): |
| Output directory for the dataset. |
| Default to this builder's `cache_dir`, which is inside `~/.cache/huggingface/datasets` by default. |
| |
| <Added version="2.5.0"/> |
| download_config (`DownloadConfig`, *optional*): |
| Specific download configuration parameters. |
| download_mode ([`DownloadMode`] or `str`, *optional*): |
| Select the download/generate mode, default to `REUSE_DATASET_IF_EXISTS`. |
| verification_mode ([`VerificationMode`] or `str`, defaults to `BASIC_CHECKS`): |
| Verification mode determining the checks to run on the downloaded/processed dataset information (checksums/size/splits/...). |
| |
| <Added version="2.9.1"/> |
| dl_manager (`DownloadManager`, *optional*): |
| Specific `DownloadManger` to use. |
| base_path (`str`, *optional*): |
| Base path for relative paths that are used to download files. This can be a remote url. |
| If not specified, the value of the `base_path` attribute (`self.base_path`) will be used instead. |
| file_format (`str`, *optional*): |
| Format of the data files in which the dataset will be written. |
| Supported formats: "arrow", "parquet". Default to "arrow" format. |
| If the format is "parquet", then image and audio data are embedded into the Parquet files instead of pointing to local files. |
| |
| <Added version="2.5.0"/> |
| max_shard_size (`Union[str, int]`, *optional*): |
| Maximum number of bytes written per shard, default is "500MB". |
| The size is based on uncompressed data size, so in practice your shard files may be smaller than |
| `max_shard_size` thanks to Parquet compression for example. |
| |
| <Added version="2.5.0"/> |
| num_proc (`int`, *optional*, defaults to `None`): |
| Number of processes when downloading and generating the dataset locally. |
| Multiprocessing is disabled by default. |
| |
| <Added version="2.7.0"/> |
| storage_options (`dict`, *optional*): |
| Key/value pairs to be passed on to the caching file-system backend, if any. |
| |
| <Added version="2.5.0"/> |
| **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments. |
| |
| Example: |
| |
| Download and prepare the dataset as Arrow files that can be loaded as a Dataset using `builder.as_dataset()`: |
| |
| ```py |
| >>> from datasets import load_dataset_builder |
| >>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes") |
| >>> builder.download_and_prepare() |
| ``` |
| |
| Download and prepare the dataset as sharded Parquet files locally: |
| |
| ```py |
| >>> from datasets import load_dataset_builder |
| >>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes") |
| >>> builder.download_and_prepare("./output_dir", file_format="parquet") |
| ``` |
| |
| Download and prepare the dataset as sharded Parquet files in a cloud storage: |
| |
| ```py |
| >>> from datasets import load_dataset_builder |
| >>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} |
| >>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes") |
| >>> builder.download_and_prepare("s3://my-bucket/my_rotten_tomatoes", storage_options=storage_options, file_format="parquet") |
| ``` |
| """ |
| output_dir = output_dir if output_dir is not None else self._cache_dir |
| |
| fs, output_dir = url_to_fs(output_dir, **(storage_options or {})) |
| self._fs = fs |
| self._output_dir = output_dir if not is_remote_filesystem(self._fs) else self._fs.unstrip_protocol(output_dir) |
|
|
| download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) |
| verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS) |
| base_path = base_path if base_path is not None else self.base_path |
|
|
| if file_format is not None and file_format not in ["arrow", "parquet"]: |
| raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") |
| self._file_format = file_format |
|
|
| if self._fs._strip_protocol(self._output_dir) == "": |
| |
| |
| |
| raise RuntimeError( |
| f"Unable to download and prepare the dataset at the root {self._output_dir}. " |
| f"Please specify a subdirectory, e.g. '{self._output_dir + self.dataset_name}'" |
| ) |
|
|
| if dl_manager is None: |
| if download_config is None: |
| download_config = DownloadConfig( |
| cache_dir=self._cache_downloaded_dir, |
| force_download=download_mode == DownloadMode.FORCE_REDOWNLOAD, |
| force_extract=download_mode == DownloadMode.FORCE_REDOWNLOAD, |
| use_etag=False, |
| num_proc=num_proc, |
| token=self.token, |
| storage_options=self.storage_options, |
| ) |
|
|
| dl_manager = DownloadManager( |
| dataset_name=self.dataset_name, |
| download_config=download_config, |
| data_dir=self.config.data_dir, |
| base_path=base_path, |
| record_checksums=(self._record_infos or verification_mode == VerificationMode.ALL_CHECKS), |
| ) |
|
|
| is_local = not is_remote_filesystem(self._fs) |
| self.dl_manager = dl_manager |
|
|
| |
| if is_local: |
| |
| Path(self._output_dir).parent.mkdir(parents=True, exist_ok=True) |
| lock_path = self._output_dir + "_builder.lock" |
|
|
| |
| with FileLock(lock_path) if is_local else contextlib.nullcontext(): |
| |
| data_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) |
| if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: |
| logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") |
| |
| |
| self.info = self._load_info() |
| self.download_post_processing_resources(dl_manager) |
| return |
|
|
| logger.info(f"Generating dataset {self.dataset_name} ({self._output_dir})") |
| if is_local: |
| if not has_sufficient_disk_space( |
| self.info.size_in_bytes or 0, directory=Path(self._output_dir).parent |
| ): |
| raise OSError( |
| f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})" |
| ) |
|
|
| @contextlib.contextmanager |
| def incomplete_dir(dirname): |
| """Create temporary dir for dirname and rename on exit.""" |
| if not is_local: |
| self._fs.makedirs(dirname, exist_ok=True) |
| yield dirname |
| else: |
| tmp_dir = dirname + ".incomplete" |
| os.makedirs(tmp_dir, exist_ok=True) |
| try: |
| yield tmp_dir |
| if os.path.isdir(dirname): |
| shutil.rmtree(dirname) |
| |
| shutil.move(tmp_dir, dirname) |
| finally: |
| if os.path.exists(tmp_dir): |
| shutil.rmtree(tmp_dir) |
|
|
| |
| |
| |
| if self.info.size_in_bytes: |
| logger.info( |
| f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} " |
| f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, " |
| f"post-processed: {size_str(self.info.post_processing_size)}, " |
| f"total: {size_str(self.info.size_in_bytes)}) to {self._output_dir}..." |
| ) |
| else: |
| _dest = self._fs._strip_protocol(self._output_dir) if is_local else self._output_dir |
| logger.info(f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} to {_dest}...") |
|
|
| self._check_manual_download(dl_manager) |
|
|
| |
| with incomplete_dir(self._output_dir) as tmp_output_dir: |
| |
| |
| with temporary_assignment(self, "_output_dir", tmp_output_dir): |
| prepare_split_kwargs = {"file_format": file_format} |
| if max_shard_size is not None: |
| prepare_split_kwargs["max_shard_size"] = max_shard_size |
| if num_proc is not None: |
| prepare_split_kwargs["num_proc"] = num_proc |
| self._download_and_prepare( |
| dl_manager=dl_manager, |
| verification_mode=verification_mode, |
| **prepare_split_kwargs, |
| **download_and_prepare_kwargs, |
| ) |
| |
| self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) |
| self.info.download_checksums = dl_manager.get_recorded_sizes_checksums() |
| if self.info.download_size is not None: |
| self.info.size_in_bytes = self.info.dataset_size + self.info.download_size |
| |
| self._save_info() |
|
|
| |
| self.download_post_processing_resources(dl_manager) |
|
|
| logger.info( |
| f"Dataset {self.dataset_name} downloaded and prepared to {self._output_dir}. " |
| f"Subsequent calls will reuse this data." |
| ) |
|
|
| def _check_manual_download(self, dl_manager): |
| if self.manual_download_instructions is not None and dl_manager.manual_dir is None: |
| raise ManualDownloadError( |
| textwrap.dedent( |
| f"""\ |
| The dataset {self.dataset_name} with config {self.config.name} requires manual data. |
| Please follow the manual download instructions: |
| {self.manual_download_instructions} |
| Manual data can be loaded with: |
| datasets.load_dataset("{self.repo_id or self.dataset_name}", data_dir="<path/to/manual/data>")""" |
| ) |
| ) |
|
|
| def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_kwargs): |
| """Downloads and prepares dataset for reading. |
| |
| This is the internal implementation to overwrite called when user calls |
| `download_and_prepare`. It should download all required data and generate |
| the pre-processed datasets files. |
| |
| Args: |
| dl_manager ([`DownloadManager`]): |
| `DownloadManager` used to download and cache data. |
| verification_mode ([`VerificationMode`]): |
| if `ALL_CHECKS`, perform all the verifications including checksums. |
| if `BASIC_CHECKS`, do not perform checksums, only perform split tests. |
| if `NO_CHECKS`, do not perform any verification. |
| prepare_split_kwargs: Additional options, such as `file_format`, `max_shard_size` |
| """ |
| |
| split_dict = SplitDict(dataset_name=self.dataset_name) |
| split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) |
| split_generators = self._split_generators(dl_manager, **split_generators_kwargs) |
|
|
| |
| if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums: |
| verify_checksums( |
| self.info.download_checksums, dl_manager.get_recorded_sizes_checksums(), "dataset source files" |
| ) |
|
|
| |
| for split_generator in split_generators: |
| if str(split_generator.split_info.name).lower() == "all": |
| raise ValueError( |
| "`all` is a special split keyword corresponding to the " |
| "union of all splits, so cannot be used as key in " |
| "._split_generator()." |
| ) |
|
|
| logger.info(f"Generating {split_generator.split_info.name} split") |
| split_dict.add(split_generator.split_info) |
|
|
| try: |
| |
| self._prepare_split(split_generator, **prepare_split_kwargs) |
| except OSError as e: |
| raise OSError( |
| "Cannot find data file. " |
| + (self.manual_download_instructions or "") |
| + "\nOriginal error:\n" |
| + str(e) |
| ) from None |
| |
| except DuplicatedKeysError as e: |
| raise DuplicatedKeysError( |
| e.key, |
| e.duplicate_key_indices, |
| fix_msg=f"To avoid duplicate keys, please fix the dataset splits for {self.name}", |
| ) from None |
| dl_manager.manage_extracted_files() |
|
|
| if verification_mode == VerificationMode.BASIC_CHECKS or verification_mode == VerificationMode.ALL_CHECKS: |
| verify_splits(self.info.splits, split_dict) |
|
|
| |
| self.info.splits = split_dict |
| self.info.download_size = dl_manager.downloaded_size |
|
|
| def download_post_processing_resources(self, dl_manager): |
| for split in self.info.splits or []: |
| for resource_name, resource_file_name in self._post_processing_resources(split).items(): |
| if not not is_remote_filesystem(self._fs): |
| raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}") |
| if os.sep in resource_file_name: |
| raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") |
| resource_path = os.path.join(self._output_dir, resource_file_name) |
| if not os.path.exists(resource_path): |
| downloaded_resource_path = self._download_post_processing_resources( |
| split, resource_name, dl_manager |
| ) |
| if downloaded_resource_path: |
| logger.info(f"Downloaded post-processing resource {resource_name} as {resource_file_name}") |
| shutil.move(downloaded_resource_path, resource_path) |
|
|
| def _load_info(self) -> DatasetInfo: |
| return DatasetInfo.from_directory(self._output_dir, storage_options=self._fs.storage_options) |
|
|
| def _save_info(self): |
| file_lock = ( |
| FileLock(self._output_dir + "_info.lock") |
| if not is_remote_filesystem(self._fs) |
| else contextlib.nullcontext() |
| ) |
| with file_lock: |
| self.info.write_to_directory(self._output_dir, storage_options=self._fs.storage_options) |
|
|
| def _make_split_generators_kwargs(self, prepare_split_kwargs): |
| """Get kwargs for `self._split_generators()` from `prepare_split_kwargs`.""" |
| del prepare_split_kwargs |
| return {} |
|
|
| def as_dataset( |
| self, |
| split: Optional[Union[str, Split, list[str], list[Split]]] = None, |
| run_post_process=True, |
| verification_mode: Optional[Union[VerificationMode, str]] = None, |
| in_memory=False, |
| ) -> Union[Dataset, DatasetDict]: |
| """Return a Dataset for the specified split. |
| |
| Args: |
| split (`datasets.Split`): |
| Which subset of the data to return. |
| run_post_process (`bool`, defaults to `True`): |
| Whether to run post-processing dataset transforms and/or add |
| indexes. |
| verification_mode ([`VerificationMode`] or `str`, defaults to `BASIC_CHECKS`): |
| Verification mode determining the checks to run on the |
| downloaded/processed dataset information (checksums/size/splits/...). |
| |
| <Added version="2.9.1"/> |
| in_memory (`bool`, defaults to `False`): |
| Whether to copy the data in-memory. |
| |
| Returns: |
| datasets.Dataset |
| |
| Example: |
| |
| ```py |
| >>> from datasets import load_dataset_builder |
| >>> builder = load_dataset_builder('cornell-movie-review-data/rotten_tomatoes') |
| >>> builder.download_and_prepare() |
| >>> ds = builder.as_dataset(split='train') |
| >>> ds |
| Dataset({ |
| features: ['text', 'label'], |
| num_rows: 8530 |
| }) |
| ``` |
| """ |
| if self._file_format is not None and self._file_format != "arrow": |
| raise FileFormatError('Loading a dataset not written in the "arrow" format is not supported.') |
| if is_remote_filesystem(self._fs): |
| raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.") |
| if not os.path.exists(self._output_dir): |
| raise FileNotFoundError( |
| f"Dataset {self.dataset_name}: could not find data in {self._output_dir}. Please make sure to call " |
| "builder.download_and_prepare(), or use " |
| "datasets.load_dataset() before trying to access the Dataset object." |
| ) |
|
|
| logger.debug(f"Constructing Dataset for split {split or ', '.join(self.info.splits)}, from {self._output_dir}") |
|
|
| |
| if split is None: |
| split = {s: s for s in self.info.splits} |
|
|
| verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS) |
|
|
| |
| datasets = map_nested( |
| partial( |
| self._build_single_dataset, |
| run_post_process=run_post_process, |
| verification_mode=verification_mode, |
| in_memory=in_memory, |
| ), |
| split, |
| map_tuple=True, |
| disable_tqdm=True, |
| ) |
| if isinstance(datasets, dict): |
| datasets = DatasetDict(datasets) |
| return datasets |
|
|
| def _build_single_dataset( |
| self, |
| split: Union[str, ReadInstruction, Split], |
| run_post_process: bool, |
| verification_mode: VerificationMode, |
| in_memory: bool = False, |
| ): |
| """as_dataset for a single split.""" |
| if not isinstance(split, ReadInstruction): |
| split = str(split) |
| if split == "all": |
| split = "+".join(self.info.splits.keys()) |
| split = Split(split) |
|
|
| |
| ds = self._as_dataset( |
| split=split, |
| in_memory=in_memory, |
| ) |
| if run_post_process: |
| for resource_file_name in self._post_processing_resources(split).values(): |
| if os.sep in resource_file_name: |
| raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") |
| resources_paths = { |
| resource_name: os.path.join(self._output_dir, resource_file_name) |
| for resource_name, resource_file_name in self._post_processing_resources(split).items() |
| } |
| post_processed = self._post_process(ds, resources_paths) |
| if post_processed is not None: |
| ds = post_processed |
| recorded_checksums = {} |
| record_checksums = False |
| for resource_name, resource_path in resources_paths.items(): |
| size_checksum = get_size_checksum_dict(resource_path) |
| recorded_checksums[resource_name] = size_checksum |
| if verification_mode == VerificationMode.ALL_CHECKS and record_checksums: |
| if self.info.post_processed is None or self.info.post_processed.resources_checksums is None: |
| expected_checksums = None |
| else: |
| expected_checksums = self.info.post_processed.resources_checksums.get(split) |
| verify_checksums(expected_checksums, recorded_checksums, "post processing resources") |
| if self.info.post_processed is None: |
| self.info.post_processed = PostProcessedInfo() |
| if self.info.post_processed.resources_checksums is None: |
| self.info.post_processed.resources_checksums = {} |
| self.info.post_processed.resources_checksums[str(split)] = recorded_checksums |
| self.info.post_processing_size = sum( |
| checksums_dict["num_bytes"] |
| for split_checksums_dicts in self.info.post_processed.resources_checksums.values() |
| for checksums_dict in split_checksums_dicts.values() |
| ) |
| if self.info.dataset_size is not None and self.info.download_size is not None: |
| self.info.size_in_bytes = ( |
| self.info.dataset_size + self.info.download_size + self.info.post_processing_size |
| ) |
| self._save_info() |
| ds._info.post_processed = self.info.post_processed |
| ds._info.post_processing_size = self.info.post_processing_size |
| ds._info.size_in_bytes = self.info.size_in_bytes |
| if self.info.post_processed.features is not None: |
| if self.info.post_processed.features.type != ds.features.type: |
| raise ValueError( |
| f"Post-processed features info don't match the dataset:\nGot\n{self.info.post_processed.features}\nbut expected something like\n{ds.features}" |
| ) |
| else: |
| ds.info.features = self.info.post_processed.features |
|
|
| return ds |
|
|
| def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_memory: bool = False) -> Dataset: |
| """Constructs a `Dataset`. |
| |
| This is the internal implementation to overwrite called when user calls |
| `as_dataset`. It should read the pre-processed datasets files and generate |
| the `Dataset` object. |
| |
| Args: |
| split (`datasets.Split`): |
| which subset of the data to read. |
| in_memory (`bool`, defaults to `False`): |
| Whether to copy the data in-memory. |
| |
| Returns: |
| `Dataset` |
| """ |
| cache_dir = self._fs._strip_protocol(self._output_dir) |
| dataset_name = self.dataset_name |
| if self._check_legacy_cache(): |
| dataset_name = self.name |
| dataset_kwargs = ArrowReader(cache_dir, self.info).read( |
| name=dataset_name, |
| instructions=split, |
| split_infos=self.info.splits.values(), |
| in_memory=in_memory, |
| ) |
| fingerprint = self._get_dataset_fingerprint(split) |
| return Dataset(fingerprint=fingerprint, **dataset_kwargs) |
|
|
| def _get_dataset_fingerprint(self, split: Union[ReadInstruction, Split]) -> str: |
| """The dataset fingerprint is the hash of the relative directory dataset_name/config_name/version/hash, as well as the split specs.""" |
| hasher = Hasher() |
| hasher.update(Path(self._relative_data_dir()).as_posix()) |
| hasher.update(str(split)) |
| fingerprint = hasher.hexdigest() |
| return fingerprint |
|
|
| def as_streaming_dataset( |
| self, |
| split: Optional[str] = None, |
| base_path: Optional[str] = None, |
| ) -> Union[dict[str, IterableDataset], IterableDataset]: |
| if is_remote_filesystem(self._fs): |
| raise NotImplementedError( |
| f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet." |
| ) |
|
|
| dl_manager = StreamingDownloadManager( |
| base_path=base_path or self.base_path, |
| download_config=DownloadConfig(token=self.token, storage_options=self.storage_options), |
| dataset_name=self.dataset_name, |
| data_dir=self.config.data_dir, |
| ) |
| self._check_manual_download(dl_manager) |
| splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager)} |
| |
| if split is None: |
| splits_generator = splits_generators |
| elif split in splits_generators: |
| splits_generator = splits_generators[split] |
| else: |
| raise ValueError(f"Bad split: {split}. Available splits: {list(splits_generators)}") |
|
|
| |
| datasets = map_nested( |
| self._as_streaming_dataset_single, |
| splits_generator, |
| map_tuple=True, |
| ) |
| if isinstance(datasets, dict): |
| datasets = IterableDatasetDict(datasets) |
| return datasets |
|
|
| def _as_streaming_dataset_single( |
| self, |
| splits_generator, |
| ) -> IterableDataset: |
| ex_iterable = self._get_examples_iterable_for_split(splits_generator) |
| |
| token_per_repo_id = {self.repo_id: self.token} if self.repo_id else {} |
| return IterableDataset( |
| ex_iterable, info=self.info, split=splits_generator.name, token_per_repo_id=token_per_repo_id |
| ) |
|
|
| def _post_process(self, dataset: Dataset, resources_paths: Mapping[str, str]) -> Optional[Dataset]: |
| """Run dataset transforms or add indexes""" |
| return None |
|
|
| def _post_processing_resources(self, split: str) -> dict[str, str]: |
| """Mapping resource_name -> resource_file_name""" |
| return {} |
|
|
| def _download_post_processing_resources( |
| self, split: str, resource_name: str, dl_manager: DownloadManager |
| ) -> Optional[str]: |
| """Download the resource using the download manager and return the downloaded path.""" |
| return None |
|
|
| @abc.abstractmethod |
| def _split_generators(self, dl_manager: Union[DownloadManager, StreamingDownloadManager]): |
| """Specify feature dictionary generators and dataset splits. |
| |
| This function returns a list of `SplitGenerator`s defining how to generate |
| data and what splits to use. |
| |
| Example: |
| |
| return [ |
| datasets.SplitGenerator( |
| name=datasets.Split.TRAIN, |
| gen_kwargs={'file': 'train_data.zip'}, |
| ), |
| datasets.SplitGenerator( |
| name=datasets.Split.TEST, |
| gen_kwargs={'file': 'test_data.zip'}, |
| ), |
| ] |
| |
| The above code will first call `_generate_examples(file='train_data.zip')` |
| to write the train data, then `_generate_examples(file='test_data.zip')` to |
| write the test data. |
| |
| Datasets are typically split into different subsets to be used at various |
| stages of training and evaluation. |
| |
| Note that for datasets without a `VALIDATION` split, you can use a |
| fraction of the `TRAIN` data for evaluation as you iterate on your model |
| so as not to overfit to the `TEST` data. |
| |
| For downloads and extractions, use the given `download_manager`. |
| Note that the `DownloadManager` caches downloads, so it is fine to have each |
| generator attempt to download the source data. |
| |
| A good practice is to download all data in this function, and then |
| distribute the relevant parts to each split with the `gen_kwargs` argument |
| |
| Args: |
| dl_manager (`Union[DownloadManager, StreamingDownloadManager]`): |
| Download manager to download the data |
| |
| Returns: |
| `list<SplitGenerator>`. |
| """ |
| raise NotImplementedError() |
|
|
| @abc.abstractmethod |
| def _prepare_split( |
| self, |
| split_generator: SplitGenerator, |
| file_format: str = "arrow", |
| max_shard_size: Optional[Union[str, int]] = None, |
| num_proc: Optional[int] = None, |
| **kwargs, |
| ): |
| """Generate the examples and record them on disk. |
| |
| Args: |
| split_generator (`SplitGenerator`): |
| Split generator to process |
| file_format (`str`, *optional*): |
| format of the data files in which the dataset will be written. |
| Supported formats: "arrow", "parquet". Default to "arrow" format. |
| max_shard_size (`Union[str, int]`, *optional*): |
| Maximum number of bytes written per shard, default is "500MB". |
| The size is based on uncompressed data size, so in practice your shard files may be smaller than |
| `max_shard_size` thanks to Parquet compression for example. |
| num_proc (`int`, *optional*, defaults to `None`): |
| Number of processes when downloading and generating the dataset locally. |
| Multiprocessing is disabled by default. |
| |
| <Added version="2.7.0"/> |
| **kwargs: Additional kwargs forwarded from _download_and_prepare |
| """ |
| raise NotImplementedError() |
|
|
| def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: |
| """Generate the examples on the fly. |
| |
| Args: |
| split_generator (`SplitGenerator`): |
| Split generator to process |
| """ |
| raise NotImplementedError() |
|
|
|
|
| class GeneratorBasedBuilder(DatasetBuilder): |
| """Base class for datasets with data generation based on dict generators. |
| |
| `GeneratorBasedBuilder` is a convenience class that abstracts away much |
| of the data writing and reading of `DatasetBuilder`. It expects subclasses to |
| implement generators of feature dictionaries across the dataset splits |
| (`_split_generators`). See the method docstrings for details. |
| """ |
|
|
| @abc.abstractmethod |
| def _generate_examples(self, **kwargs): |
| """Default function generating examples for each `SplitGenerator`. |
| |
| This function preprocess the examples from the raw data to the preprocessed |
| dataset files. |
| This function is called once for each `SplitGenerator` defined in |
| `_split_generators`. The examples yielded here will be written on |
| disk. |
| |
| Args: |
| **kwargs (additional keyword arguments): |
| Arguments forwarded from the SplitGenerator.gen_kwargs |
| |
| Yields: |
| key: `str` or `int`, a unique deterministic example identification key. |
| * Unique: An error will be raised if two examples are yield with the |
| same key. |
| * Deterministic: When generating the dataset twice, the same example |
| should have the same key. |
| Good keys can be the image id, or line number if examples are extracted |
| from a text file. |
| The key will be hashed and sorted to shuffle examples deterministically, |
| such as generating the dataset multiple times keep examples in the |
| same order. |
| example: `dict<str feature_name, feature_value>`, a feature dictionary |
| ready to be encoded and written to disk. The example will be |
| encoded with `self.info.features.encode_example({...})`. |
| """ |
| raise NotImplementedError() |
|
|
| def _prepare_split( |
| self, |
| split_generator: SplitGenerator, |
| check_duplicate_keys: bool, |
| file_format="arrow", |
| num_proc: Optional[int] = None, |
| max_shard_size: Optional[Union[int, str]] = None, |
| ): |
| max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) |
|
|
| if self.info.splits is not None: |
| split_info = self.info.splits[split_generator.name] |
| else: |
| split_info = split_generator.split_info |
|
|
| SUFFIX = "-JJJJJ-SSSSS-of-NNNNN" |
| fname = f"{self.dataset_name}-{split_generator.name}{SUFFIX}.{file_format}" |
| fpath = posixpath.join(self._output_dir, fname) |
|
|
| if num_proc and num_proc > 1: |
| num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs) |
| if num_input_shards <= 1: |
| logger.warning( |
| f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard." |
| ) |
| num_proc = 1 |
| elif num_input_shards < num_proc: |
| logger.warning( |
| f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards." |
| ) |
| num_proc = num_input_shards |
|
|
| pbar = hf_tqdm( |
| unit=" examples", |
| total=split_info.num_examples, |
| desc=f"Generating {split_info.name} split", |
| ) |
|
|
| _prepare_split_args = { |
| "fpath": fpath, |
| "file_format": file_format, |
| "max_shard_size": max_shard_size, |
| "split_info": split_info, |
| "check_duplicate_keys": check_duplicate_keys, |
| } |
|
|
| if num_proc is None or num_proc == 1: |
| result = None |
| gen_kwargs = split_generator.gen_kwargs |
| job_id = 0 |
| with pbar: |
| for job_id, done, content in self._prepare_split_single( |
| gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args |
| ): |
| if done: |
| result = content |
| else: |
| pbar.update(content) |
| |
| assert result is not None, "Failed to retrieve results from prepare_split" |
| examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = ( |
| [item] for item in result |
| ) |
| else: |
| kwargs_per_job = [ |
| {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} |
| for job_id, gen_kwargs in enumerate( |
| _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc) |
| ) |
| ] |
| num_jobs = len(kwargs_per_job) |
|
|
| examples_per_job = [None] * num_jobs |
| bytes_per_job = [None] * num_jobs |
| features_per_job = [None] * num_jobs |
| shards_per_job = [None] * num_jobs |
| shard_lengths_per_job = [None] * num_jobs |
|
|
| with Pool(num_proc) as pool: |
| with pbar: |
| for job_id, done, content in iflatmap_unordered( |
| pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job |
| ): |
| if done: |
| |
| ( |
| examples_per_job[job_id], |
| bytes_per_job[job_id], |
| features_per_job[job_id], |
| shards_per_job[job_id], |
| shard_lengths_per_job[job_id], |
| ) = content |
| else: |
| |
| pbar.update(content) |
|
|
| assert None not in examples_per_job, ( |
| f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results" |
| ) |
|
|
| total_shards = sum(shards_per_job) |
| total_num_examples = sum(examples_per_job) |
| total_num_bytes = sum(bytes_per_job) |
| features = features_per_job[0] |
|
|
| split_generator.split_info.num_examples = total_num_examples |
| split_generator.split_info.num_bytes = total_num_bytes |
|
|
| |
| logger.debug(f"Renaming {total_shards} shards.") |
| if total_shards > 1: |
| |
|
|
| def _rename_shard(shard_and_job: tuple[int]): |
| shard_id, job_id = shard_and_job |
| global_shard_id = sum(shards_per_job[:job_id]) + shard_id |
| self._rename( |
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"), |
| ) |
|
|
| shards_and_jobs = [ |
| (shard_id, job_id) |
| for job_id, num_shards in enumerate(shards_per_job) |
| for shard_id in range(num_shards) |
| ] |
| thread_map(_rename_shard, shards_and_jobs, disable=True, max_workers=64) |
|
|
| split_generator.split_info.shard_lengths = [ |
| shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths |
| ] |
| else: |
| |
| shard_id, job_id = 0, 0 |
| self._rename( |
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| fpath.replace(SUFFIX, ""), |
| ) |
|
|
| if self.info.features is None: |
| self.info.features = features |
|
|
| def _prepare_split_single( |
| self, |
| gen_kwargs: dict, |
| fpath: str, |
| file_format: str, |
| max_shard_size: int, |
| split_info: SplitInfo, |
| check_duplicate_keys: bool, |
| job_id: int, |
| ) -> Iterable[tuple[int, bool, Union[int, tuple]]]: |
| generator = self._generate_examples(**gen_kwargs) |
| writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter |
| embed_local_files = file_format == "parquet" |
| shard_lengths = [] |
| total_num_examples, total_num_bytes = 0, 0 |
|
|
| shard_id = 0 |
| num_examples_progress_update = 0 |
| try: |
| writer = writer_class( |
| features=self.info.features, |
| path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| writer_batch_size=self._writer_batch_size, |
| hash_salt=split_info.name, |
| check_duplicates=check_duplicate_keys, |
| storage_options=self._fs.storage_options, |
| embed_local_files=embed_local_files, |
| ) |
| try: |
| _time = time.time() |
| for key, record in generator: |
| if max_shard_size is not None and writer._num_bytes > max_shard_size: |
| num_examples, num_bytes = writer.finalize() |
| writer.close() |
| shard_lengths.append(num_examples) |
| total_num_examples += num_examples |
| total_num_bytes += num_bytes |
| shard_id += 1 |
| writer = writer_class( |
| features=writer._features, |
| path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| writer_batch_size=self._writer_batch_size, |
| hash_salt=split_info.name, |
| check_duplicates=check_duplicate_keys, |
| storage_options=self._fs.storage_options, |
| embed_local_files=embed_local_files, |
| ) |
| example = self.info.features.encode_example(record) if self.info.features is not None else record |
| writer.write(example, key) |
| num_examples_progress_update += 1 |
| if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL: |
| _time = time.time() |
| yield job_id, False, num_examples_progress_update |
| num_examples_progress_update = 0 |
| finally: |
| yield job_id, False, num_examples_progress_update |
| num_shards = shard_id + 1 |
| num_examples, num_bytes = writer.finalize() |
| writer.close() |
| shard_lengths.append(num_examples) |
| total_num_examples += num_examples |
| total_num_bytes += num_bytes |
| except Exception as e: |
| |
| if isinstance(e, SchemaInferenceError) and e.__context__ is not None: |
| e = e.__context__ |
| raise DatasetGenerationError("An error occurred while generating the dataset") from e |
|
|
| yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths) |
|
|
| def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_kwargs): |
| super()._download_and_prepare( |
| dl_manager, |
| verification_mode, |
| check_duplicate_keys=verification_mode == VerificationMode.BASIC_CHECKS |
| or verification_mode == VerificationMode.ALL_CHECKS, |
| **prepare_splits_kwargs, |
| ) |
|
|
| def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: |
| return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs) |
|
|
|
|
| class ArrowBasedBuilder(DatasetBuilder): |
| """Base class for datasets with data generation based on Arrow loading functions (CSV/JSON/Parquet).""" |
|
|
| @abc.abstractmethod |
| def _generate_tables(self, **kwargs): |
| """Default function generating examples for each `SplitGenerator`. |
| |
| This function preprocess the examples from the raw data to the preprocessed |
| dataset files. |
| This function is called once for each `SplitGenerator` defined in |
| `_split_generators`. The examples yielded here will be written on |
| disk. |
| |
| Args: |
| **kwargs (additional keyword arguments): |
| Arguments forwarded from the SplitGenerator.gen_kwargs |
| |
| Yields: |
| key: `str` or `int`, a unique deterministic example identification key. |
| * Unique: An error will be raised if two examples are yield with the |
| same key. |
| * Deterministic: When generating the dataset twice, the same example |
| should have the same key. |
| Good keys can be the image id, or line number if examples are extracted |
| from a text file. |
| The key will be hashed and sorted to shuffle examples deterministically, |
| such as generating the dataset multiple times keep examples in the |
| same order. |
| example: `pyarrow.Table`, a feature table |
| ready to be encoded and written to disk. |
| """ |
| raise NotImplementedError() |
|
|
| def _prepare_split( |
| self, |
| split_generator: SplitGenerator, |
| file_format: str = "arrow", |
| num_proc: Optional[int] = None, |
| max_shard_size: Optional[Union[str, int]] = None, |
| ): |
| max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) |
|
|
| try: |
| split_info = self.info.splits[split_generator.name] |
| except Exception: |
| split_info = split_generator.split_info |
|
|
| SUFFIX = "-JJJJJ-SSSSS-of-NNNNN" |
| fname = f"{self.dataset_name}-{split_generator.name}{SUFFIX}.{file_format}" |
| fpath = posixpath.join(self._output_dir, fname) |
|
|
| if num_proc and num_proc > 1: |
| num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs) |
| if num_input_shards <= 1: |
| logger.warning( |
| f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard." |
| ) |
| num_proc = 1 |
| elif num_input_shards < num_proc: |
| logger.warning( |
| f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards." |
| ) |
| num_proc = num_input_shards |
|
|
| pbar = hf_tqdm( |
| unit=" examples", |
| total=split_info.num_examples, |
| desc=f"Generating {split_info.name} split", |
| ) |
|
|
| _prepare_split_args = { |
| "fpath": fpath, |
| "file_format": file_format, |
| "max_shard_size": max_shard_size, |
| } |
|
|
| if num_proc is None or num_proc == 1: |
| result = None |
| gen_kwargs = split_generator.gen_kwargs |
| job_id = 0 |
| with pbar: |
| for job_id, done, content in self._prepare_split_single( |
| gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args |
| ): |
| if done: |
| result = content |
| else: |
| pbar.update(content) |
| |
| assert result is not None, "Failed to retrieve results from prepare_split" |
| examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = ( |
| [item] for item in result |
| ) |
| else: |
| kwargs_per_job = [ |
| {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} |
| for job_id, gen_kwargs in enumerate( |
| _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc) |
| ) |
| ] |
| num_jobs = len(kwargs_per_job) |
|
|
| examples_per_job = [None] * num_jobs |
| bytes_per_job = [None] * num_jobs |
| features_per_job = [None] * num_jobs |
| shards_per_job = [None] * num_jobs |
| shard_lengths_per_job = [None] * num_jobs |
|
|
| with Pool(num_proc) as pool: |
| with pbar: |
| for job_id, done, content in iflatmap_unordered( |
| pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job |
| ): |
| if done: |
| |
| ( |
| examples_per_job[job_id], |
| bytes_per_job[job_id], |
| features_per_job[job_id], |
| shards_per_job[job_id], |
| shard_lengths_per_job[job_id], |
| ) = content |
| else: |
| |
| pbar.update(content) |
|
|
| assert None not in examples_per_job, ( |
| f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results" |
| ) |
|
|
| total_shards = sum(shards_per_job) |
| total_num_examples = sum(examples_per_job) |
| total_num_bytes = sum(bytes_per_job) |
| features = features_per_job[0] |
|
|
| split_generator.split_info.num_examples = total_num_examples |
| split_generator.split_info.num_bytes = total_num_bytes |
|
|
| |
| logger.debug(f"Renaming {total_shards} shards.") |
| if total_shards > 1: |
| |
|
|
| def _rename_shard(shard_id_and_job: tuple[int]): |
| shard_id, job_id = shard_id_and_job |
| global_shard_id = sum(shards_per_job[:job_id]) + shard_id |
| self._rename( |
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"), |
| ) |
|
|
| shard_ids_and_jobs = [ |
| (shard_id, job_id) |
| for job_id, num_shards in enumerate(shards_per_job) |
| for shard_id in range(num_shards) |
| ] |
| thread_map(_rename_shard, shard_ids_and_jobs, disable=True, max_workers=64) |
|
|
| split_generator.split_info.shard_lengths = [ |
| shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths |
| ] |
| else: |
| |
| shard_id, job_id = 0, 0 |
| self._rename( |
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| fpath.replace(SUFFIX, ""), |
| ) |
|
|
| if self.info.features is None: |
| self.info.features = features |
|
|
| def _prepare_split_single( |
| self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int |
| ) -> Iterable[tuple[int, bool, Union[int, tuple]]]: |
| gen_kwargs = {k: tracked_list(v) if isinstance(v, list) else v for k, v in gen_kwargs.items()} |
| generator = self._generate_tables(**gen_kwargs) |
| writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter |
| embed_local_files = file_format == "parquet" |
| shard_lengths = [] |
| total_num_examples, total_num_bytes = 0, 0 |
|
|
| shard_id = 0 |
| num_examples_progress_update = 0 |
| try: |
| writer = writer_class( |
| features=self.info.features, |
| path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| writer_batch_size=self._writer_batch_size, |
| storage_options=self._fs.storage_options, |
| embed_local_files=embed_local_files, |
| ) |
| try: |
| _time = time.time() |
| for _, table in generator: |
| if max_shard_size is not None and writer._num_bytes > max_shard_size: |
| num_examples, num_bytes = writer.finalize() |
| writer.close() |
| shard_lengths.append(num_examples) |
| total_num_examples += num_examples |
| total_num_bytes += num_bytes |
| shard_id += 1 |
| writer = writer_class( |
| features=writer._features, |
| path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), |
| writer_batch_size=self._writer_batch_size, |
| storage_options=self._fs.storage_options, |
| embed_local_files=embed_local_files, |
| ) |
| try: |
| writer.write_table(table) |
| except CastError as cast_error: |
| raise DatasetGenerationCastError.from_cast_error( |
| cast_error=cast_error, |
| builder_name=self.info.builder_name, |
| gen_kwargs=gen_kwargs, |
| token=self.token, |
| ) |
| num_examples_progress_update += len(table) |
| if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL: |
| _time = time.time() |
| yield job_id, False, num_examples_progress_update |
| num_examples_progress_update = 0 |
| finally: |
| yield job_id, False, num_examples_progress_update |
| num_shards = shard_id + 1 |
| num_examples, num_bytes = writer.finalize() |
| writer.close() |
| shard_lengths.append(num_examples) |
| total_num_examples += num_examples |
| total_num_bytes += num_bytes |
| except Exception as e: |
| |
| if isinstance(e, SchemaInferenceError) and e.__context__ is not None: |
| e = e.__context__ |
| if isinstance(e, DatasetGenerationError): |
| raise |
| raise DatasetGenerationError("An error occurred while generating the dataset") from e |
|
|
| yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths) |
|
|
| def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: |
| return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs) |
|
|