| import os |
| import re |
| from functools import partial |
| from glob import has_magic |
| from pathlib import Path, PurePath |
| from typing import Callable, Optional, Union |
|
|
| import huggingface_hub |
| from fsspec.core import url_to_fs |
| from huggingface_hub import HfFileSystem |
| from packaging import version |
| from tqdm.contrib.concurrent import thread_map |
|
|
| from . import config |
| from .download import DownloadConfig |
| from .naming import _split_re |
| from .splits import Split |
| from .utils import logging |
| from .utils import tqdm as hf_tqdm |
| from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin |
| from .utils.py_utils import string_to_dict |
|
|
|
|
| SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]] |
|
|
|
|
| SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN) |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class Url(str): |
| pass |
|
|
|
|
| class EmptyDatasetError(FileNotFoundError): |
| pass |
|
|
|
|
| SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*" |
|
|
| SPLIT_KEYWORDS = { |
| Split.TRAIN: ["train", "training"], |
| Split.VALIDATION: ["validation", "valid", "dev", "val"], |
| Split.TEST: ["test", "testing", "eval", "evaluation"], |
| } |
| NON_WORDS_CHARS = "-._ 0-9" |
| if config.FSSPEC_VERSION < version.parse("2023.9.0"): |
| KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"] |
| KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [ |
| "{keyword}/**", |
| "{keyword}[{sep}]*/**", |
| "**[{sep}/]{keyword}/**", |
| "**[{sep}/]{keyword}[{sep}]*/**", |
| ] |
| elif config.FSSPEC_VERSION < version.parse("2023.12.0"): |
| KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/*[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"] |
| KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [ |
| "{keyword}/**/*", |
| "{keyword}[{sep}]*/**/*", |
| "**/*[{sep}/]{keyword}/**/*", |
| "**/*[{sep}/]{keyword}[{sep}]*/**/*", |
| ] |
| else: |
| KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/{keyword}[{sep}]*", "**/*[{sep}]{keyword}[{sep}]*"] |
| KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [ |
| "**/{keyword}/**", |
| "**/{keyword}[{sep}]*/**", |
| "**/*[{sep}]{keyword}/**", |
| "**/*[{sep}]{keyword}[{sep}]*/**", |
| ] |
|
|
| DEFAULT_SPLITS = [Split.TRAIN, Split.VALIDATION, Split.TEST] |
| DEFAULT_PATTERNS_SPLIT_IN_FILENAME = { |
| split: [ |
| pattern.format(keyword=keyword, sep=NON_WORDS_CHARS) |
| for keyword in SPLIT_KEYWORDS[split] |
| for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS |
| ] |
| for split in DEFAULT_SPLITS |
| } |
| DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME = { |
| split: [ |
| pattern.format(keyword=keyword, sep=NON_WORDS_CHARS) |
| for keyword in SPLIT_KEYWORDS[split] |
| for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS |
| ] |
| for split in DEFAULT_SPLITS |
| } |
|
|
|
|
| DEFAULT_PATTERNS_ALL = { |
| Split.TRAIN: ["**"], |
| } |
|
|
| ALL_SPLIT_PATTERNS = [SPLIT_PATTERN_SHARDED] |
| ALL_DEFAULT_PATTERNS = [ |
| DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME, |
| DEFAULT_PATTERNS_SPLIT_IN_FILENAME, |
| DEFAULT_PATTERNS_ALL, |
| ] |
| WILDCARD_CHARACTERS = "*[]" |
| FILES_TO_IGNORE = [ |
| "README.md", |
| "config.json", |
| "dataset_info.json", |
| "dataset_infos.json", |
| "dummy_data.zip", |
| "dataset_dict.json", |
| ] |
|
|
|
|
| def contains_wildcards(pattern: str) -> bool: |
| return any(wildcard_character in pattern for wildcard_character in WILDCARD_CHARACTERS) |
|
|
|
|
| def sanitize_patterns(patterns: Union[dict, list, str]) -> dict[str, Union[list[str], "DataFilesList"]]: |
| """ |
| Take the data_files patterns from the user, and format them into a dictionary. |
| Each key is the name of the split, and each value is a list of data files patterns (paths or urls). |
| The default split is "train". |
| |
| Returns: |
| patterns: dictionary of split_name -> list of patterns |
| """ |
| if isinstance(patterns, dict): |
| return {str(key): value if isinstance(value, list) else [value] for key, value in patterns.items()} |
| elif isinstance(patterns, str): |
| return {SANITIZED_DEFAULT_SPLIT: [patterns]} |
| elif isinstance(patterns, list): |
| if any(isinstance(pattern, dict) for pattern in patterns): |
| for pattern in patterns: |
| if not ( |
| isinstance(pattern, dict) |
| and len(pattern) == 2 |
| and "split" in pattern |
| and isinstance(pattern.get("path"), (str, list)) |
| ): |
| raise ValueError( |
| f"Expected each split to have a 'path' key which can be a string or a list of strings, but got {pattern}" |
| ) |
| splits = [pattern["split"] for pattern in patterns] |
| if len(set(splits)) != len(splits): |
| raise ValueError(f"Some splits are duplicated in data_files: {splits}") |
| return { |
| str(pattern["split"]): pattern["path"] if isinstance(pattern["path"], list) else [pattern["path"]] |
| for pattern in patterns |
| } |
| else: |
| return {SANITIZED_DEFAULT_SPLIT: patterns} |
| else: |
| return sanitize_patterns(list(patterns)) |
|
|
|
|
| def _is_inside_unrequested_special_dir(matched_rel_path: str, pattern: str) -> bool: |
| """ |
| When a path matches a pattern, we additionally check if it's inside a special directory |
| we ignore by default (if it starts with a double underscore). |
| |
| Users can still explicitly request a filepath inside such a directory if "__pycache__" is |
| mentioned explicitly in the requested pattern. |
| |
| Some examples: |
| |
| base directory: |
| |
| ./ |
| └── __pycache__ |
| └── b.txt |
| |
| >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "**") |
| True |
| >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "*/b.txt") |
| True |
| >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__pycache__/*") |
| False |
| >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__*/*") |
| False |
| """ |
| |
| |
| |
| data_dirs_to_ignore_in_path = [part for part in PurePath(matched_rel_path).parent.parts if part.startswith("__")] |
| data_dirs_to_ignore_in_pattern = [part for part in PurePath(pattern).parent.parts if part.startswith("__")] |
| return len(data_dirs_to_ignore_in_path) != len(data_dirs_to_ignore_in_pattern) |
|
|
|
|
| def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_path: str, pattern: str) -> bool: |
| """ |
| When a path matches a pattern, we additionally check if it's a hidden file or if it's inside |
| a hidden directory we ignore by default, i.e. if the file name or a parent directory name starts with a dot. |
| |
| Users can still explicitly request a filepath that is hidden or is inside a hidden directory |
| if the hidden part is mentioned explicitly in the requested pattern. |
| |
| Some examples: |
| |
| base directory: |
| |
| ./ |
| └── .hidden_file.txt |
| |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", "**") |
| True |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", ".*") |
| False |
| |
| base directory: |
| |
| ./ |
| └── .hidden_dir |
| └── a.txt |
| |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", "**") |
| True |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".*/*") |
| False |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".hidden_dir/*") |
| False |
| |
| base directory: |
| |
| ./ |
| └── .hidden_dir |
| └── .hidden_file.txt |
| |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", "**") |
| True |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/*") |
| True |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/.*") |
| False |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/*") |
| True |
| >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/.*") |
| False |
| """ |
| |
| |
| |
| hidden_directories_in_path = [ |
| part for part in PurePath(matched_rel_path).parts if part.startswith(".") and not set(part) == {"."} |
| ] |
| hidden_directories_in_pattern = [ |
| part for part in PurePath(pattern).parts if part.startswith(".") and not set(part) == {"."} |
| ] |
| return len(hidden_directories_in_path) != len(hidden_directories_in_pattern) |
|
|
|
|
| def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> dict[str, list[str]]: |
| """ |
| Get the default pattern from a directory or repository by testing all the supported patterns. |
| The first patterns to return a non-empty list of data files is returned. |
| |
| In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS. |
| """ |
| |
| for split_pattern in ALL_SPLIT_PATTERNS: |
| pattern = split_pattern.replace("{split}", "*") |
| try: |
| data_files = pattern_resolver(pattern) |
| except FileNotFoundError: |
| continue |
| if len(data_files) > 0: |
| splits: set[str] = set() |
| for p in data_files: |
| p_parts = string_to_dict(xbasename(p), xbasename(split_pattern)) |
| assert p_parts is not None |
| splits.add(p_parts["split"]) |
|
|
| if any(not re.match(_split_re, split) for split in splits): |
| raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.") |
| sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted( |
| splits - {str(split) for split in DEFAULT_SPLITS} |
| ) |
| return {split: [split_pattern.format(split=split)] for split in sorted_splits} |
| |
| for patterns_dict in ALL_DEFAULT_PATTERNS: |
| non_empty_splits = [] |
| for split, patterns in patterns_dict.items(): |
| for pattern in patterns: |
| try: |
| data_files = pattern_resolver(pattern) |
| except FileNotFoundError: |
| continue |
| if len(data_files) > 0: |
| non_empty_splits.append(split) |
| break |
| if non_empty_splits: |
| return {split: patterns_dict[split] for split in non_empty_splits} |
| raise FileNotFoundError(f"Couldn't resolve pattern {pattern} with resolver {pattern_resolver}") |
|
|
|
|
| def resolve_pattern( |
| pattern: str, |
| base_path: str, |
| allowed_extensions: Optional[list[str]] = None, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> list[str]: |
| """ |
| Resolve the paths and URLs of the data files from the pattern passed by the user. |
| |
| You can use patterns to resolve multiple local files. Here are a few examples: |
| - *.csv to match all the CSV files at the first level |
| - **.csv to match all the CSV files at any level |
| - data/* to match all the files inside "data" |
| - data/** to match all the files inside "data" and its subdirectories |
| |
| The patterns are resolved using the fsspec glob. In fsspec>=2023.12.0 this is equivalent to |
| Python's glob.glob, Path.glob, Path.match and fnmatch where ** is unsupported with a prefix/suffix |
| other than a forward slash /. |
| |
| More generally: |
| - '*' matches any character except a forward-slash (to match just the file or directory name) |
| - '**' matches any character including a forward-slash / |
| |
| Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested. |
| The same applies to special directories that start with a double underscore like "__pycache__". |
| You can still include one if the pattern explicitly mentions it: |
| - to include a hidden file: "*/.hidden.txt" or "*/.*" |
| - to include a hidden directory: ".hidden/*" or ".*/*" |
| - to include a special directory: "__special__/*" or "__*/*" |
| |
| Example:: |
| |
| >>> from datasets.data_files import resolve_pattern |
| >>> base_path = "." |
| >>> resolve_pattern("docs/**/*.py", base_path) |
| [/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py'] |
| |
| Args: |
| pattern (str): Unix pattern or paths or URLs of the data files to resolve. |
| The paths can be absolute or relative to base_path. |
| Remote filesystems using fsspec are supported, e.g. with the hf:// protocol. |
| base_path (str): Base path to use when resolving relative paths. |
| allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions). |
| For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"] |
| download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters. |
| Returns: |
| List[str]: List of paths or URLs to the local or remote files that match the patterns. |
| """ |
| if is_relative_path(pattern): |
| pattern = xjoin(base_path, pattern) |
| elif is_local_path(pattern): |
| base_path = os.path.splitdrive(pattern)[0] + os.sep |
| else: |
| base_path = "" |
| pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config) |
| fs, fs_pattern = url_to_fs(pattern, **storage_options) |
| files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)} |
| protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0] |
| protocol_prefix = protocol + "://" if protocol != "file" else "" |
| glob_kwargs = {} |
| if protocol == "hf": |
| |
| glob_kwargs["expand_info"] = False |
| matched_paths = [ |
| filepath if filepath.startswith(protocol_prefix) else protocol_prefix + filepath |
| for filepath, info in fs.glob(pattern, detail=True, **glob_kwargs).items() |
| if (info["type"] == "file" or (info.get("islink") and os.path.isfile(os.path.realpath(filepath)))) |
| and (xbasename(filepath) not in files_to_ignore) |
| and not _is_inside_unrequested_special_dir(filepath, fs_pattern) |
| and not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(filepath, fs_pattern) |
| ] |
| if allowed_extensions is not None: |
| out = [ |
| filepath |
| for filepath in matched_paths |
| if any("." + suffix in allowed_extensions for suffix in xbasename(filepath).split(".")[1:]) |
| ] |
| if len(out) < len(matched_paths): |
| invalid_matched_files = list(set(matched_paths) - set(out)) |
| logger.info( |
| f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: {invalid_matched_files}" |
| ) |
| else: |
| out = matched_paths |
| if not out: |
| error_msg = f"Unable to find '{pattern}'" |
| if allowed_extensions is not None: |
| error_msg += f" with any supported extension {list(allowed_extensions)}" |
| raise FileNotFoundError(error_msg) |
| return out |
|
|
|
|
| def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig] = None) -> dict[str, list[str]]: |
| """ |
| Get the default pattern from a directory testing all the supported patterns. |
| The first patterns to return a non-empty list of data files is returned. |
| |
| Some examples of supported patterns: |
| |
| Input: |
| |
| my_dataset_repository/ |
| ├── README.md |
| └── dataset.csv |
| |
| Output: |
| |
| {'train': ['**']} |
| |
| Input: |
| |
| my_dataset_repository/ |
| ├── README.md |
| ├── train.csv |
| └── test.csv |
| |
| my_dataset_repository/ |
| ├── README.md |
| └── data/ |
| ├── train.csv |
| └── test.csv |
| |
| my_dataset_repository/ |
| ├── README.md |
| ├── train_0.csv |
| ├── train_1.csv |
| ├── train_2.csv |
| ├── train_3.csv |
| ├── test_0.csv |
| └── test_1.csv |
| |
| Output: |
| |
| {'train': ['**/train[-._ 0-9]*', '**/*[-._ 0-9]train[-._ 0-9]*', '**/training[-._ 0-9]*', '**/*[-._ 0-9]training[-._ 0-9]*'], |
| 'test': ['**/test[-._ 0-9]*', '**/*[-._ 0-9]test[-._ 0-9]*', '**/testing[-._ 0-9]*', '**/*[-._ 0-9]testing[-._ 0-9]*', ...]} |
| |
| Input: |
| |
| my_dataset_repository/ |
| ├── README.md |
| └── data/ |
| ├── train/ |
| │ ├── shard_0.csv |
| │ ├── shard_1.csv |
| │ ├── shard_2.csv |
| │ └── shard_3.csv |
| └── test/ |
| ├── shard_0.csv |
| └── shard_1.csv |
| |
| Output: |
| |
| {'train': ['**/train/**', '**/train[-._ 0-9]*/**', '**/*[-._ 0-9]train/**', '**/*[-._ 0-9]train[-._ 0-9]*/**', ...], |
| 'test': ['**/test/**', '**/test[-._ 0-9]*/**', '**/*[-._ 0-9]test/**', '**/*[-._ 0-9]test[-._ 0-9]*/**', ...]} |
| |
| Input: |
| |
| my_dataset_repository/ |
| ├── README.md |
| └── data/ |
| ├── train-00000-of-00003.csv |
| ├── train-00001-of-00003.csv |
| ├── train-00002-of-00003.csv |
| ├── test-00000-of-00001.csv |
| ├── random-00000-of-00003.csv |
| ├── random-00001-of-00003.csv |
| └── random-00002-of-00003.csv |
| |
| Output: |
| |
| {'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'], |
| 'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'], |
| 'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']} |
| |
| In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS. |
| """ |
| resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config) |
| try: |
| return _get_data_files_patterns(resolver) |
| except FileNotFoundError: |
| raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None |
|
|
|
|
| def _get_single_origin_metadata( |
| data_file: str, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> SingleOriginMetadata: |
| data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config) |
| fs, *_ = url_to_fs(data_file, **storage_options) |
| if isinstance(fs, HfFileSystem): |
| resolved_path = fs.resolve_path(data_file) |
| return resolved_path.repo_id, resolved_path.revision |
| elif data_file.startswith(config.HF_ENDPOINT): |
| hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token) |
| data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1) |
| resolved_path = hffs.resolve_path(data_file) |
| return resolved_path.repo_id, resolved_path.revision |
| info = fs.info(data_file) |
| |
| for key in ["ETag", "etag", "mtime"]: |
| if key in info: |
| return (str(info[key]),) |
| return () |
|
|
|
|
| def _get_origin_metadata( |
| data_files: list[str], |
| download_config: Optional[DownloadConfig] = None, |
| max_workers: Optional[int] = None, |
| ) -> list[SingleOriginMetadata]: |
| max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS |
| if all("hf://" in data_file for data_file in data_files): |
| |
| |
| return [ |
| _get_single_origin_metadata(data_file, download_config=download_config) |
| for data_file in hf_tqdm( |
| data_files, |
| desc="Resolving data files", |
| |
| disable=len(data_files) <= 16 or None, |
| ) |
| ] |
| return thread_map( |
| partial(_get_single_origin_metadata, download_config=download_config), |
| data_files, |
| max_workers=max_workers, |
| tqdm_class=hf_tqdm, |
| desc="Resolving data files", |
| |
| disable=len(data_files) <= 16 or None, |
| ) |
|
|
|
|
| class DataFilesList(list[str]): |
| """ |
| List of data files (absolute local paths or URLs). |
| It has two construction methods given the user's data files patterns: |
| - ``from_hf_repo``: resolve patterns inside a dataset repository |
| - ``from_local_or_remote``: resolve patterns from a local path |
| |
| Moreover, DataFilesList has an additional attribute ``origin_metadata``. |
| It can store: |
| - the last modified time of local files |
| - ETag of remote files |
| - commit sha of a dataset repository |
| |
| Thanks to this additional attribute, it is possible to hash the list |
| and get a different hash if and only if at least one file changed. |
| This is useful for caching Dataset objects that are obtained from a list of data files. |
| """ |
|
|
| def __init__(self, data_files: list[str], origin_metadata: list[SingleOriginMetadata]) -> None: |
| super().__init__(data_files) |
| self.origin_metadata = origin_metadata |
|
|
| def __add__(self, other: "DataFilesList") -> "DataFilesList": |
| return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata) |
|
|
| @classmethod |
| def from_hf_repo( |
| cls, |
| patterns: list[str], |
| dataset_info: huggingface_hub.hf_api.DatasetInfo, |
| base_path: Optional[str] = None, |
| allowed_extensions: Optional[list[str]] = None, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesList": |
| base_path = f"hf://datasets/{dataset_info.id}@{dataset_info.sha}/{base_path or ''}".rstrip("/") |
| return cls.from_patterns( |
| patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config |
| ) |
|
|
| @classmethod |
| def from_local_or_remote( |
| cls, |
| patterns: list[str], |
| base_path: Optional[str] = None, |
| allowed_extensions: Optional[list[str]] = None, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesList": |
| base_path = base_path if base_path is not None else Path().resolve().as_posix() |
| return cls.from_patterns( |
| patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config |
| ) |
|
|
| @classmethod |
| def from_patterns( |
| cls, |
| patterns: list[str], |
| base_path: Optional[str] = None, |
| allowed_extensions: Optional[list[str]] = None, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesList": |
| base_path = base_path if base_path is not None else Path().resolve().as_posix() |
| data_files = [] |
| for pattern in patterns: |
| try: |
| data_files.extend( |
| resolve_pattern( |
| pattern, |
| base_path=base_path, |
| allowed_extensions=allowed_extensions, |
| download_config=download_config, |
| ) |
| ) |
| except FileNotFoundError: |
| if not has_magic(pattern): |
| raise |
| origin_metadata = _get_origin_metadata(data_files, download_config=download_config) |
| return cls(data_files, origin_metadata) |
|
|
| def filter( |
| self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None |
| ) -> "DataFilesList": |
| patterns = [] |
| if extensions: |
| ext_pattern = "|".join(re.escape(ext) for ext in extensions) |
| patterns.append(re.compile(f".*({ext_pattern})(\\..+)?$")) |
| if file_names: |
| fn_pattern = "|".join(re.escape(fn) for fn in file_names) |
| patterns.append(re.compile(rf".*[\/]?({fn_pattern})$")) |
| if patterns: |
| return DataFilesList( |
| [data_file for data_file in self if any(pattern.match(data_file) for pattern in patterns)], |
| origin_metadata=self.origin_metadata, |
| ) |
| else: |
| return DataFilesList(list(self), origin_metadata=self.origin_metadata) |
|
|
|
|
| class DataFilesDict(dict[str, DataFilesList]): |
| """ |
| Dict of split_name -> list of data files (absolute local paths or URLs). |
| It has two construction methods given the user's data files patterns : |
| - ``from_hf_repo``: resolve patterns inside a dataset repository |
| - ``from_local_or_remote``: resolve patterns from a local path |
| |
| Moreover, each list is a DataFilesList. It is possible to hash the dictionary |
| and get a different hash if and only if at least one file changed. |
| For more info, see [`DataFilesList`]. |
| |
| This is useful for caching Dataset objects that are obtained from a list of data files. |
| |
| Changing the order of the keys of this dictionary also doesn't change its hash. |
| """ |
|
|
| @classmethod |
| def from_local_or_remote( |
| cls, |
| patterns: dict[str, Union[list[str], DataFilesList]], |
| base_path: Optional[str] = None, |
| allowed_extensions: Optional[list[str]] = None, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesDict": |
| out = cls() |
| for key, patterns_for_key in patterns.items(): |
| out[key] = ( |
| patterns_for_key |
| if isinstance(patterns_for_key, DataFilesList) |
| else DataFilesList.from_local_or_remote( |
| patterns_for_key, |
| base_path=base_path, |
| allowed_extensions=allowed_extensions, |
| download_config=download_config, |
| ) |
| ) |
| return out |
|
|
| @classmethod |
| def from_hf_repo( |
| cls, |
| patterns: dict[str, Union[list[str], DataFilesList]], |
| dataset_info: huggingface_hub.hf_api.DatasetInfo, |
| base_path: Optional[str] = None, |
| allowed_extensions: Optional[list[str]] = None, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesDict": |
| out = cls() |
| for key, patterns_for_key in patterns.items(): |
| out[key] = ( |
| patterns_for_key |
| if isinstance(patterns_for_key, DataFilesList) |
| else DataFilesList.from_hf_repo( |
| patterns_for_key, |
| dataset_info=dataset_info, |
| base_path=base_path, |
| allowed_extensions=allowed_extensions, |
| download_config=download_config, |
| ) |
| ) |
| return out |
|
|
| @classmethod |
| def from_patterns( |
| cls, |
| patterns: dict[str, Union[list[str], DataFilesList]], |
| base_path: Optional[str] = None, |
| allowed_extensions: Optional[list[str]] = None, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesDict": |
| out = cls() |
| for key, patterns_for_key in patterns.items(): |
| out[key] = ( |
| patterns_for_key |
| if isinstance(patterns_for_key, DataFilesList) |
| else DataFilesList.from_patterns( |
| patterns_for_key, |
| base_path=base_path, |
| allowed_extensions=allowed_extensions, |
| download_config=download_config, |
| ) |
| ) |
| return out |
|
|
| def filter( |
| self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None |
| ) -> "DataFilesDict": |
| out = type(self)() |
| for key, data_files_list in self.items(): |
| out[key] = data_files_list.filter(extensions=extensions, file_names=file_names) |
| return out |
|
|
|
|
| class DataFilesPatternsList(list[str]): |
| """ |
| List of data files patterns (absolute local paths or URLs). |
| For each pattern there should also be a list of allowed extensions |
| to keep, or a None ot keep all the files for the pattern. |
| """ |
|
|
| def __init__( |
| self, |
| patterns: list[str], |
| allowed_extensions: list[Optional[list[str]]], |
| ): |
| super().__init__(patterns) |
| self.allowed_extensions = allowed_extensions |
|
|
| def __add__(self, other): |
| return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions) |
|
|
| @classmethod |
| def from_patterns( |
| cls, patterns: list[str], allowed_extensions: Optional[list[str]] = None |
| ) -> "DataFilesPatternsList": |
| return cls(patterns, [allowed_extensions] * len(patterns)) |
|
|
| def resolve( |
| self, |
| base_path: str, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesList": |
| base_path = base_path if base_path is not None else Path().resolve().as_posix() |
| data_files = [] |
| for pattern, allowed_extensions in zip(self, self.allowed_extensions): |
| try: |
| data_files.extend( |
| resolve_pattern( |
| pattern, |
| base_path=base_path, |
| allowed_extensions=allowed_extensions, |
| download_config=download_config, |
| ) |
| ) |
| except FileNotFoundError: |
| if not has_magic(pattern): |
| raise |
| origin_metadata = _get_origin_metadata(data_files, download_config=download_config) |
| return DataFilesList(data_files, origin_metadata) |
|
|
| def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsList": |
| return DataFilesPatternsList( |
| self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions] |
| ) |
|
|
|
|
| class DataFilesPatternsDict(dict[str, DataFilesPatternsList]): |
| """ |
| Dict of split_name -> list of data files patterns (absolute local paths or URLs). |
| """ |
|
|
| @classmethod |
| def from_patterns( |
| cls, patterns: dict[str, list[str]], allowed_extensions: Optional[list[str]] = None |
| ) -> "DataFilesPatternsDict": |
| out = cls() |
| for key, patterns_for_key in patterns.items(): |
| out[key] = ( |
| patterns_for_key |
| if isinstance(patterns_for_key, DataFilesPatternsList) |
| else DataFilesPatternsList.from_patterns( |
| patterns_for_key, |
| allowed_extensions=allowed_extensions, |
| ) |
| ) |
| return out |
|
|
| def resolve( |
| self, |
| base_path: str, |
| download_config: Optional[DownloadConfig] = None, |
| ) -> "DataFilesDict": |
| out = DataFilesDict() |
| for key, data_files_patterns_list in self.items(): |
| out[key] = data_files_patterns_list.resolve(base_path, download_config) |
| return out |
|
|
| def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsDict": |
| out = type(self)() |
| for key, data_files_patterns_list in self.items(): |
| out[key] = data_files_patterns_list.filter_extensions(extensions) |
| return out |
|
|