| import glob |
| import json |
| import os |
| import shutil |
| import time |
| from pathlib import Path |
| from typing import Optional, Union |
|
|
| import pyarrow as pa |
|
|
| import datasets |
| import datasets.config |
| import datasets.data_files |
| from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split |
|
|
|
|
| logger = datasets.utils.logging.get_logger(__name__) |
|
|
|
|
| def _get_modification_time(cached_directory_path): |
| return (Path(cached_directory_path)).stat().st_mtime |
|
|
|
|
| def _find_hash_in_cache( |
| dataset_name: str, |
| config_name: Optional[str], |
| cache_dir: Optional[str], |
| config_kwargs: dict, |
| custom_features: Optional[datasets.Features], |
| ) -> tuple[str, str, str]: |
| if config_name or config_kwargs or custom_features: |
| config_id = datasets.BuilderConfig(config_name or "default").create_config_id( |
| config_kwargs=config_kwargs, custom_features=custom_features |
| ) |
| else: |
| config_id = None |
| cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE)) |
| namespace_and_dataset_name = dataset_name.split("/") |
| namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1]) |
| cached_relative_path = "___".join(namespace_and_dataset_name) |
| cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path) |
| cached_directory_paths = [ |
| cached_directory_path |
| for cached_directory_path in glob.glob( |
| os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*") |
| ) |
| if os.path.isdir(cached_directory_path) |
| and ( |
| config_kwargs |
| or custom_features |
| or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"] |
| == Path(cached_directory_path).parts[-3] |
| ) |
| ] |
| if not cached_directory_paths: |
| cached_directory_paths = [ |
| cached_directory_path |
| for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*")) |
| if os.path.isdir(cached_directory_path) |
| ] |
| available_configs = sorted( |
| {Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths} |
| ) |
| raise ValueError( |
| f"Couldn't find cache for {dataset_name}" |
| + (f" for config '{config_id}'" if config_id else "") |
| + (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "") |
| ) |
| |
| cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1]) |
| version, hash = cached_directory_path.parts[-2:] |
| other_configs = [ |
| Path(_cached_directory_path).parts[-3] |
| for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash)) |
| if os.path.isdir(_cached_directory_path) |
| and ( |
| config_kwargs |
| or custom_features |
| or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"] |
| == Path(_cached_directory_path).parts[-3] |
| ) |
| ] |
| if not config_id and len(other_configs) > 1: |
| raise ValueError( |
| f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}" |
| f"\nPlease specify which configuration to reload from the cache, e.g." |
| f"\n\tload_dataset('{dataset_name}', '{other_configs[0]}')" |
| ) |
| config_name = cached_directory_path.parts[-3] |
| warning_msg = ( |
| f"Found the latest cached dataset configuration '{config_name}' at {cached_directory_path} " |
| f"(last modified on {time.ctime(_get_modification_time(cached_directory_path))})." |
| ) |
| logger.warning(warning_msg) |
| return config_name, version, hash |
|
|
|
|
| class Cache(datasets.ArrowBasedBuilder): |
| def __init__( |
| self, |
| cache_dir: Optional[str] = None, |
| dataset_name: Optional[str] = None, |
| config_name: Optional[str] = None, |
| version: Optional[str] = "0.0.0", |
| hash: Optional[str] = None, |
| base_path: Optional[str] = None, |
| info: Optional[datasets.DatasetInfo] = None, |
| features: Optional[datasets.Features] = None, |
| token: Optional[Union[bool, str]] = None, |
| repo_id: Optional[str] = None, |
| data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None, |
| data_dir: Optional[str] = None, |
| storage_options: Optional[dict] = None, |
| writer_batch_size: Optional[int] = None, |
| **config_kwargs, |
| ): |
| if repo_id is None and dataset_name is None: |
| raise ValueError("repo_id or dataset_name is required for the Cache dataset builder") |
| if data_files is not None: |
| config_kwargs["data_files"] = data_files |
| if data_dir is not None: |
| config_kwargs["data_dir"] = data_dir |
| if hash == "auto" and version == "auto": |
| config_name, version, hash = _find_hash_in_cache( |
| dataset_name=repo_id or dataset_name, |
| config_name=config_name, |
| cache_dir=cache_dir, |
| config_kwargs=config_kwargs, |
| custom_features=features, |
| ) |
| elif hash == "auto" or version == "auto": |
| raise NotImplementedError("Pass both hash='auto' and version='auto' instead") |
| super().__init__( |
| cache_dir=cache_dir, |
| dataset_name=dataset_name, |
| config_name=config_name, |
| version=version, |
| hash=hash, |
| base_path=base_path, |
| info=info, |
| token=token, |
| repo_id=repo_id, |
| storage_options=storage_options, |
| writer_batch_size=writer_batch_size, |
| ) |
|
|
| def _info(self) -> datasets.DatasetInfo: |
| return datasets.DatasetInfo() |
|
|
| def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs): |
| if not os.path.exists(self.cache_dir): |
| raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}") |
| if output_dir is not None and output_dir != self.cache_dir: |
| shutil.copytree(self.cache_dir, output_dir) |
|
|
| def _split_generators(self, dl_manager): |
| |
| if isinstance(self.info.splits, datasets.SplitDict): |
| split_infos: list[datasets.SplitInfo] = list(self.info.splits.values()) |
| else: |
| raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}") |
| return [ |
| datasets.SplitGenerator( |
| name=split_info.name, |
| gen_kwargs={ |
| "files": filenames_for_dataset_split( |
| self.cache_dir, |
| dataset_name=self.dataset_name, |
| split=split_info.name, |
| filetype_suffix="arrow", |
| shard_lengths=split_info.shard_lengths, |
| ) |
| }, |
| ) |
| for split_info in split_infos |
| ] |
|
|
| def _generate_tables(self, files): |
| |
| for file_idx, file in enumerate(files): |
| with open(file, "rb") as f: |
| try: |
| for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)): |
| pa_table = pa.Table.from_batches([record_batch]) |
| |
| |
| |
| yield f"{file_idx}_{batch_idx}", pa_table |
| except ValueError as e: |
| logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") |
| raise |
|
|