fuuuzzy commited on
Commit
fdb0460
·
verified ·
1 Parent(s): 2d38e41

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +22 -0
  2. .venv/lib/python3.11/site-packages/datasets/__init__.py +47 -0
  3. .venv/lib/python3.11/site-packages/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/datasets/__pycache__/arrow_dataset.cpython-311.pyc +3 -0
  5. .venv/lib/python3.11/site-packages/datasets/__pycache__/arrow_reader.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/datasets/__pycache__/arrow_writer.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/datasets/__pycache__/builder.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/datasets/__pycache__/combine.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/datasets/__pycache__/config.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/datasets/__pycache__/data_files.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/datasets/__pycache__/dataset_dict.cpython-311.pyc +3 -0
  12. .venv/lib/python3.11/site-packages/datasets/__pycache__/exceptions.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/datasets/__pycache__/fingerprint.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/datasets/__pycache__/info.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/datasets/__pycache__/inspect.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/datasets/__pycache__/iterable_dataset.cpython-311.pyc +3 -0
  17. .venv/lib/python3.11/site-packages/datasets/__pycache__/keyhash.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/datasets/__pycache__/load.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/datasets/__pycache__/naming.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/datasets/__pycache__/search.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/datasets/__pycache__/splits.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/datasets/__pycache__/streaming.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/datasets/__pycache__/table.cpython-311.pyc +3 -0
  24. .venv/lib/python3.11/site-packages/datasets/arrow_dataset.py +0 -0
  25. .venv/lib/python3.11/site-packages/datasets/arrow_reader.py +620 -0
  26. .venv/lib/python3.11/site-packages/datasets/arrow_writer.py +766 -0
  27. .venv/lib/python3.11/site-packages/datasets/builder.py +1866 -0
  28. .venv/lib/python3.11/site-packages/datasets/combine.py +223 -0
  29. .venv/lib/python3.11/site-packages/datasets/commands/__init__.py +13 -0
  30. .venv/lib/python3.11/site-packages/datasets/commands/datasets_cli.py +39 -0
  31. .venv/lib/python3.11/site-packages/datasets/commands/delete_from_hub.py +42 -0
  32. .venv/lib/python3.11/site-packages/datasets/commands/env.py +41 -0
  33. .venv/lib/python3.11/site-packages/datasets/commands/test.py +180 -0
  34. .venv/lib/python3.11/site-packages/datasets/config.py +271 -0
  35. .venv/lib/python3.11/site-packages/datasets/data_files.py +811 -0
  36. .venv/lib/python3.11/site-packages/datasets/dataset_dict.py +0 -0
  37. .venv/lib/python3.11/site-packages/datasets/distributed.py +39 -0
  38. .venv/lib/python3.11/site-packages/datasets/download/__init__.py +10 -0
  39. .venv/lib/python3.11/site-packages/datasets/download/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/datasets/download/__pycache__/download_config.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/datasets/download/__pycache__/download_manager.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/datasets/download/__pycache__/streaming_download_manager.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/datasets/download/download_config.py +81 -0
  44. .venv/lib/python3.11/site-packages/datasets/download/download_manager.py +340 -0
  45. .venv/lib/python3.11/site-packages/datasets/download/streaming_download_manager.py +219 -0
  46. .venv/lib/python3.11/site-packages/datasets/exceptions.py +119 -0
  47. .venv/lib/python3.11/site-packages/datasets/features/__init__.py +26 -0
  48. .venv/lib/python3.11/site-packages/datasets/features/__pycache__/__init__.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/datasets/features/__pycache__/audio.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/datasets/features/__pycache__/features.cpython-311.pyc +3 -0
.gitattributes CHANGED
@@ -64,3 +64,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
64
  .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12 filter=lfs diff=lfs merge=lfs -text
65
  .venv/lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_api.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
66
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_graph.so.9 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12 filter=lfs diff=lfs merge=lfs -text
65
  .venv/lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_api.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
66
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_graph.so.9 filter=lfs diff=lfs merge=lfs -text
67
+ .venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libnvperf_target.so filter=lfs diff=lfs merge=lfs -text
68
+ .venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcupti.so.12 filter=lfs diff=lfs merge=lfs -text
69
+ .venv/lib/python3.11/site-packages/datasets/features/__pycache__/features.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
70
+ .venv/lib/python3.11/site-packages/datasets/__pycache__/arrow_dataset.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
71
+ .venv/lib/python3.11/site-packages/datasets/__pycache__/iterable_dataset.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
72
+ .venv/lib/python3.11/site-packages/datasets/__pycache__/dataset_dict.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
73
+ .venv/lib/python3.11/site-packages/datasets/__pycache__/table.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
74
+ .venv/lib/python3.11/site-packages/yaml/_yaml.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
75
+ .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libcublas.so.12 filter=lfs diff=lfs merge=lfs -text
76
+ .venv/lib/python3.11/site-packages/nvidia/curand/lib/libcurand.so.10 filter=lfs diff=lfs merge=lfs -text
77
+ .venv/lib/python3.11/site-packages/nvidia/nvshmem/lib/libnvshmem_device.bc filter=lfs diff=lfs merge=lfs -text
78
+ .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
79
+ .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 filter=lfs diff=lfs merge=lfs -text
80
+ .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolver.so.11 filter=lfs diff=lfs merge=lfs -text
81
+ .venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
82
+ .venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
83
+ .venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libnvperf_host.so filter=lfs diff=lfs merge=lfs -text
84
+ .venv/lib/python3.11/site-packages/nvidia/cufile/lib/libcufile.so.0 filter=lfs diff=lfs merge=lfs -text
85
+ .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
86
+ .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
87
+ .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_adv.so.9 filter=lfs diff=lfs merge=lfs -text
88
+ .venv/lib/python3.11/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/datasets/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __version__ = "4.4.1"
16
+
17
+ from .arrow_dataset import Column, Dataset
18
+ from .arrow_reader import ReadInstruction
19
+ from .builder import ArrowBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
20
+ from .combine import concatenate_datasets, interleave_datasets
21
+ from .dataset_dict import DatasetDict, IterableDatasetDict
22
+ from .download import *
23
+ from .features import *
24
+ from .fingerprint import disable_caching, enable_caching, is_caching_enabled
25
+ from .info import DatasetInfo
26
+ from .inspect import (
27
+ get_dataset_config_info,
28
+ get_dataset_config_names,
29
+ get_dataset_default_config_name,
30
+ get_dataset_infos,
31
+ get_dataset_split_names,
32
+ )
33
+ from .iterable_dataset import IterableColumn, IterableDataset
34
+ from .load import load_dataset, load_dataset_builder, load_from_disk
35
+ from .splits import (
36
+ NamedSplit,
37
+ NamedSplitAll,
38
+ Split,
39
+ SplitBase,
40
+ SplitDict,
41
+ SplitGenerator,
42
+ SplitInfo,
43
+ SubSplitInfo,
44
+ percent,
45
+ )
46
+ from .utils import *
47
+ from .utils import logging
.venv/lib/python3.11/site-packages/datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.87 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/arrow_dataset.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bce41b6d6ccd430a8b99967f0b45c68048991596350316f088e4d4fe78d909f3
3
+ size 350055
.venv/lib/python3.11/site-packages/datasets/__pycache__/arrow_reader.cpython-311.pyc ADDED
Binary file (31.4 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/arrow_writer.cpython-311.pyc ADDED
Binary file (44.5 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/builder.cpython-311.pyc ADDED
Binary file (99.9 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/combine.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/config.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/data_files.cpython-311.pyc ADDED
Binary file (44.3 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/dataset_dict.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e341f4b5d4dcbcf7f0c3fafd96cb990b9e3d15edfe7b88f633a0bf65e79c3c34
3
+ size 149986
.venv/lib/python3.11/site-packages/datasets/__pycache__/exceptions.cpython-311.pyc ADDED
Binary file (7.9 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/fingerprint.cpython-311.pyc ADDED
Binary file (24.4 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/info.cpython-311.pyc ADDED
Binary file (29.2 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/inspect.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/iterable_dataset.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d8aa868348076f7093248ba04fc0c398cf44d39457261531e63e0ae717f7bb7
3
+ size 262991
.venv/lib/python3.11/site-packages/datasets/__pycache__/keyhash.cpython-311.pyc ADDED
Binary file (5.36 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/load.cpython-311.pyc ADDED
Binary file (75.3 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/naming.cpython-311.pyc ADDED
Binary file (4.76 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/search.cpython-311.pyc ADDED
Binary file (46.9 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/splits.cpython-311.pyc ADDED
Binary file (32.8 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/streaming.cpython-311.pyc ADDED
Binary file (8.46 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/__pycache__/table.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c3c3494184b2e8541c43f547b4b6d02e41755ed903dadb22854c60d68556fba
3
+ size 121451
.venv/lib/python3.11/site-packages/datasets/arrow_dataset.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/datasets/arrow_reader.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Lint as: python3
16
+ """Arrow ArrowReader."""
17
+
18
+ import copy
19
+ import math
20
+ import os
21
+ import re
22
+ from dataclasses import dataclass
23
+ from functools import partial
24
+ from typing import TYPE_CHECKING, Optional, Union
25
+
26
+ import pyarrow as pa
27
+ import pyarrow.parquet as pq
28
+ from tqdm.contrib.concurrent import thread_map
29
+
30
+ from .download.download_config import DownloadConfig # noqa: F401
31
+ from .naming import _split_re, filenames_for_dataset_split
32
+ from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables
33
+ from .utils import logging
34
+ from .utils import tqdm as hf_tqdm
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from .info import DatasetInfo # noqa: F401
39
+ from .splits import Split, SplitInfo # noqa: F401
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ HF_GCP_BASE_URL = "https://storage.googleapis.com/huggingface-nlp/cache/datasets"
45
+
46
+ _SUB_SPEC_RE = re.compile(
47
+ rf"""
48
+ ^
49
+ (?P<split>{_split_re[1:-1]})
50
+ (\[
51
+ ((?P<from>-?[\d_]+)
52
+ (?P<from_pct>%)?)?
53
+ :
54
+ ((?P<to>-?[\d_]+)
55
+ (?P<to_pct>%)?)?
56
+ \])?(\((?P<rounding>[^\)]*)\))?
57
+ $
58
+ """, # remove ^ and $
59
+ re.X,
60
+ )
61
+
62
+ _ADDITION_SEP_RE = re.compile(r"\s*\+\s*")
63
+
64
+
65
+ class DatasetNotOnHfGcsError(ConnectionError):
66
+ """When you can't get the dataset from the Hf google cloud storage"""
67
+
68
+ pass
69
+
70
+
71
+ class MissingFilesOnHfGcsError(ConnectionError):
72
+ """When some files are missing on the Hf oogle cloud storage"""
73
+
74
+ pass
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class FileInstructions:
79
+ """The file instructions associated with a split ReadInstruction.
80
+
81
+ Attributes:
82
+ num_examples: `int`, The total number of examples
83
+ file_instructions: List[dict(filename, skip, take)], the files information.
84
+ The filenames contains the relative path, not absolute.
85
+ skip/take indicates which example read in the file: `ds.slice(skip, take)`
86
+ """
87
+
88
+ num_examples: int
89
+ file_instructions: list[dict]
90
+
91
+
92
+ def make_file_instructions(
93
+ name: str,
94
+ split_infos: list["SplitInfo"],
95
+ instruction: Union[str, "ReadInstruction"],
96
+ filetype_suffix: Optional[str] = None,
97
+ prefix_path: Optional[str] = None,
98
+ ) -> FileInstructions:
99
+ """Returns instructions of the split dict.
100
+
101
+ Args:
102
+ name (`str`): Name of the dataset.
103
+ split_infos (`list` of `[SplitInfo]`): Dataset splits information.
104
+ instruction ([`ReadInstruction`] or `str`): Reading instruction for a dataset.
105
+ filetype_suffix (`str`, *optional*): Suffix of dataset files, e.g. 'arrow' or 'parquet'.
106
+ prefix_path (`str`, *optional*): Prefix of dataset files, e.g. directory name.
107
+
108
+ Returns:
109
+ [`FileInstructions`]
110
+ """
111
+ if not isinstance(name, str):
112
+ raise TypeError(f"Expected str 'name', but got: {type(name).__name__}")
113
+ elif not name:
114
+ raise ValueError("Expected non-empty str 'name'")
115
+ name2len = {info.name: info.num_examples for info in split_infos}
116
+ name2shard_lengths = {info.name: info.shard_lengths for info in split_infos}
117
+ name2filenames = {
118
+ info.name: filenames_for_dataset_split(
119
+ path=prefix_path,
120
+ dataset_name=name,
121
+ split=info.name,
122
+ filetype_suffix=filetype_suffix,
123
+ shard_lengths=name2shard_lengths[info.name],
124
+ )
125
+ for info in split_infos
126
+ }
127
+ if not isinstance(instruction, ReadInstruction):
128
+ instruction = ReadInstruction.from_spec(instruction)
129
+ # Create the absolute instruction (per split)
130
+ absolute_instructions = instruction.to_absolute(name2len)
131
+
132
+ # For each split, return the files instruction (skip/take)
133
+ file_instructions = []
134
+ num_examples = 0
135
+ for abs_instr in absolute_instructions:
136
+ split_length = name2len[abs_instr.splitname]
137
+ filenames = name2filenames[abs_instr.splitname]
138
+ shard_lengths = name2shard_lengths[abs_instr.splitname]
139
+ from_ = 0 if abs_instr.from_ is None else abs_instr.from_
140
+ to = split_length if abs_instr.to is None else abs_instr.to
141
+ if shard_lengths is None: # not sharded
142
+ for filename in filenames:
143
+ take = to - from_
144
+ if take == 0:
145
+ continue
146
+ num_examples += take
147
+ file_instructions.append({"filename": filename, "skip": from_, "take": take})
148
+ else: # sharded
149
+ index_start = 0 # Beginning (included) of moving window.
150
+ index_end = 0 # End (excluded) of moving window.
151
+ for filename, shard_length in zip(filenames, shard_lengths):
152
+ index_end += shard_length
153
+ if from_ < index_end and to > index_start: # There is something to take.
154
+ skip = from_ - index_start if from_ > index_start else 0
155
+ take = to - index_start - skip if to < index_end else -1
156
+ if take == 0:
157
+ continue
158
+ file_instructions.append({"filename": filename, "skip": skip, "take": take})
159
+ num_examples += shard_length - skip if take == -1 else take
160
+ index_start += shard_length
161
+ return FileInstructions(
162
+ num_examples=num_examples,
163
+ file_instructions=file_instructions,
164
+ )
165
+
166
+
167
+ class BaseReader:
168
+ """
169
+ Build a Dataset object out of Instruction instance(s).
170
+ """
171
+
172
+ def __init__(self, path: str, info: Optional["DatasetInfo"]):
173
+ """Initializes ArrowReader.
174
+
175
+ Args:
176
+ path (str): path where tfrecords are stored.
177
+ info (DatasetInfo): info about the dataset.
178
+ """
179
+ self._path: str = path
180
+ self._info: Optional["DatasetInfo"] = info
181
+ self._filetype_suffix: Optional[str] = None
182
+
183
+ def _get_table_from_filename(self, filename_skip_take, in_memory=False) -> Table:
184
+ """Returns a Dataset instance from given (filename, skip, take)."""
185
+ raise NotImplementedError
186
+
187
+ def _read_files(self, files, in_memory=False) -> Table:
188
+ """Returns Dataset for given file instructions.
189
+
190
+ Args:
191
+ files: List[dict(filename, skip, take)], the files information.
192
+ The filenames contain the absolute path, not relative.
193
+ skip/take indicates which example read in the file: `ds.slice(skip, take)`
194
+ in_memory (bool, default False): Whether to copy the data in-memory.
195
+ """
196
+ if len(files) == 0 or not all(isinstance(f, dict) for f in files):
197
+ raise ValueError("please provide valid file informations")
198
+ files = copy.deepcopy(files)
199
+ for f in files:
200
+ f["filename"] = os.path.join(self._path, f["filename"])
201
+
202
+ pa_tables = thread_map(
203
+ partial(self._get_table_from_filename, in_memory=in_memory),
204
+ files,
205
+ tqdm_class=hf_tqdm,
206
+ desc="Loading dataset shards",
207
+ # set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
208
+ disable=len(files) <= 16 or None,
209
+ )
210
+ pa_tables = [t for t in pa_tables if len(t) > 0]
211
+ if not pa_tables and (self._info is None or self._info.features is None):
212
+ raise ValueError(
213
+ "Tried to read an empty table. Please specify at least info.features to create an empty table with the right type."
214
+ )
215
+ pa_tables = pa_tables or [InMemoryTable.from_batches([], schema=pa.schema(self._info.features.type))]
216
+ pa_table = concat_tables(pa_tables) if len(pa_tables) != 1 else pa_tables[0]
217
+ return pa_table
218
+
219
+ def get_file_instructions(self, name, instruction, split_infos):
220
+ """Return list of dict {'filename': str, 'skip': int, 'take': int}"""
221
+ file_instructions = make_file_instructions(
222
+ name, split_infos, instruction, filetype_suffix=self._filetype_suffix, prefix_path=self._path
223
+ )
224
+ files = file_instructions.file_instructions
225
+ return files
226
+
227
+ def read(
228
+ self,
229
+ name,
230
+ instructions,
231
+ split_infos,
232
+ in_memory=False,
233
+ ):
234
+ """Returns Dataset instance(s).
235
+
236
+ Args:
237
+ name (str): name of the dataset.
238
+ instructions (ReadInstruction): instructions to read.
239
+ Instruction can be string and will then be passed to the Instruction
240
+ constructor as it.
241
+ split_infos (list of SplitInfo proto): the available splits for dataset.
242
+ in_memory (bool, default False): Whether to copy the data in-memory.
243
+
244
+ Returns:
245
+ kwargs to build a single Dataset instance.
246
+ """
247
+
248
+ files = self.get_file_instructions(name, instructions, split_infos)
249
+ if not files:
250
+ msg = f'Instruction "{instructions}" corresponds to no data!'
251
+ raise ValueError(msg)
252
+ return self.read_files(files=files, original_instructions=instructions, in_memory=in_memory)
253
+
254
+ def read_files(
255
+ self,
256
+ files: list[dict],
257
+ original_instructions: Union[None, "ReadInstruction", "Split"] = None,
258
+ in_memory=False,
259
+ ):
260
+ """Returns single Dataset instance for the set of file instructions.
261
+
262
+ Args:
263
+ files: List[dict(filename, skip, take)], the files information.
264
+ The filenames contains the relative path, not absolute.
265
+ skip/take indicates which example read in the file: `ds.skip().take()`
266
+ original_instructions: store the original instructions used to build the dataset split in the dataset.
267
+ in_memory (bool, default False): Whether to copy the data in-memory.
268
+
269
+ Returns:
270
+ kwargs to build a Dataset instance.
271
+ """
272
+ # Prepend path to filename
273
+ pa_table = self._read_files(files, in_memory=in_memory)
274
+ # If original_instructions is not None, convert it to a human-readable NamedSplit
275
+ if original_instructions is not None:
276
+ from .splits import Split # noqa
277
+
278
+ split = Split(str(original_instructions))
279
+ else:
280
+ split = None
281
+ dataset_kwargs = {"arrow_table": pa_table, "info": self._info, "split": split}
282
+ return dataset_kwargs
283
+
284
+
285
+ class ArrowReader(BaseReader):
286
+ """
287
+ Build a Dataset object out of Instruction instance(s).
288
+ This Reader uses either memory mapping or file descriptors (in-memory) on arrow files.
289
+ """
290
+
291
+ def __init__(self, path: str, info: Optional["DatasetInfo"]):
292
+ """Initializes ArrowReader.
293
+
294
+ Args:
295
+ path (str): path where Arrow files are stored.
296
+ info (DatasetInfo): info about the dataset.
297
+ """
298
+ super().__init__(path, info)
299
+ self._filetype_suffix = "arrow"
300
+
301
+ def _get_table_from_filename(self, filename_skip_take, in_memory=False) -> Table:
302
+ """Returns a Dataset instance from given (filename, skip, take)."""
303
+ filename, skip, take = (
304
+ filename_skip_take["filename"],
305
+ filename_skip_take["skip"] if "skip" in filename_skip_take else None,
306
+ filename_skip_take["take"] if "take" in filename_skip_take else None,
307
+ )
308
+ table = ArrowReader.read_table(filename, in_memory=in_memory)
309
+ if take == -1:
310
+ take = len(table) - skip
311
+ # here we don't want to slice an empty table, or it may segfault
312
+ if skip is not None and take is not None and not (skip == 0 and take == len(table)):
313
+ table = table.slice(skip, take)
314
+ return table
315
+
316
+ @staticmethod
317
+ def read_table(filename, in_memory=False) -> Table:
318
+ """
319
+ Read table from file.
320
+
321
+ Args:
322
+ filename (str): File name of the table.
323
+ in_memory (bool, default=False): Whether to copy the data in-memory.
324
+
325
+ Returns:
326
+ pyarrow.Table
327
+ """
328
+ table_cls = InMemoryTable if in_memory else MemoryMappedTable
329
+ return table_cls.from_file(filename)
330
+
331
+
332
+ class ParquetReader(BaseReader):
333
+ """
334
+ Build a Dataset object out of Instruction instance(s).
335
+ This Reader uses memory mapping on parquet files.
336
+ """
337
+
338
+ def __init__(self, path: str, info: Optional["DatasetInfo"]):
339
+ """Initializes ParquetReader.
340
+
341
+ Args:
342
+ path (str): path where tfrecords are stored.
343
+ info (DatasetInfo): info about the dataset.
344
+ """
345
+ super().__init__(path, info)
346
+ self._filetype_suffix = "parquet"
347
+
348
+ def _get_table_from_filename(self, filename_skip_take, **kwargs):
349
+ """Returns a Dataset instance from given (filename, skip, take)."""
350
+ filename, skip, take = (
351
+ filename_skip_take["filename"],
352
+ filename_skip_take["skip"] if "skip" in filename_skip_take else None,
353
+ filename_skip_take["take"] if "take" in filename_skip_take else None,
354
+ )
355
+ # Parquet read_table always loads data in memory, independently of memory_map
356
+ pa_table = pq.read_table(filename, memory_map=True)
357
+ # here we don't want to slice an empty table, or it may segfault
358
+ if skip is not None and take is not None and not (skip == 0 and take == len(pa_table)):
359
+ pa_table = pa_table.slice(skip, take)
360
+ return pa_table
361
+
362
+
363
+ @dataclass(frozen=True)
364
+ class _AbsoluteInstruction:
365
+ """A machine friendly slice: defined absolute positive boundaries."""
366
+
367
+ splitname: str
368
+ from_: int # uint (starting index).
369
+ to: int # uint (ending index).
370
+
371
+
372
+ @dataclass(frozen=True)
373
+ class _RelativeInstruction:
374
+ """Represents a single parsed slicing instruction, can use % and negatives."""
375
+
376
+ splitname: str
377
+ from_: Optional[int] = None # int (starting index) or None if no lower boundary.
378
+ to: Optional[int] = None # int (ending index) or None if no upper boundary.
379
+ unit: Optional[str] = None
380
+ rounding: Optional[str] = None
381
+
382
+ def __post_init__(self):
383
+ if self.unit is not None and self.unit not in ["%", "abs"]:
384
+ raise ValueError("unit must be either % or abs")
385
+ if self.rounding is not None and self.rounding not in ["closest", "pct1_dropremainder"]:
386
+ raise ValueError("rounding must be either closest or pct1_dropremainder")
387
+ if self.unit != "%" and self.rounding is not None:
388
+ raise ValueError("It is forbidden to specify rounding if not using percent slicing.")
389
+ if self.unit == "%" and self.from_ is not None and abs(self.from_) > 100:
390
+ raise ValueError("Percent slice boundaries must be > -100 and < 100.")
391
+ if self.unit == "%" and self.to is not None and abs(self.to) > 100:
392
+ raise ValueError("Percent slice boundaries must be > -100 and < 100.")
393
+ # Update via __dict__ due to instance being "frozen"
394
+ self.__dict__["rounding"] = "closest" if self.rounding is None and self.unit == "%" else self.rounding
395
+
396
+
397
+ def _str_to_read_instruction(spec):
398
+ """Returns ReadInstruction for given string."""
399
+ res = _SUB_SPEC_RE.match(spec)
400
+ if not res:
401
+ raise ValueError(f"Unrecognized instruction format: {spec}")
402
+ unit = "%" if res.group("from_pct") or res.group("to_pct") else "abs"
403
+ return ReadInstruction(
404
+ split_name=res.group("split"),
405
+ rounding=res.group("rounding"),
406
+ from_=int(res.group("from")) if res.group("from") else None,
407
+ to=int(res.group("to")) if res.group("to") else None,
408
+ unit=unit,
409
+ )
410
+
411
+
412
+ def _pct_to_abs_pct1(boundary, num_examples):
413
+ # Using math.trunc here, since -99.5% should give -99%, not -100%.
414
+ if num_examples < 100:
415
+ msg = (
416
+ 'Using "pct1_dropremainder" rounding on a split with less than 100 '
417
+ "elements is forbidden: it always results in an empty dataset."
418
+ )
419
+ raise ValueError(msg)
420
+ return boundary * math.trunc(num_examples / 100.0)
421
+
422
+
423
+ def _pct_to_abs_closest(boundary, num_examples):
424
+ return int(round(boundary * num_examples / 100.0))
425
+
426
+
427
+ def _rel_to_abs_instr(rel_instr, name2len):
428
+ """Returns _AbsoluteInstruction instance for given RelativeInstruction.
429
+
430
+ Args:
431
+ rel_instr: RelativeInstruction instance.
432
+ name2len: dict {split_name: num_examples}.
433
+ """
434
+ pct_to_abs = _pct_to_abs_closest if rel_instr.rounding == "closest" else _pct_to_abs_pct1
435
+ split = rel_instr.splitname
436
+ if split not in name2len:
437
+ raise ValueError(f'Unknown split "{split}". Should be one of {list(name2len)}.')
438
+ num_examples = name2len[split]
439
+ from_ = rel_instr.from_
440
+ to = rel_instr.to
441
+ if rel_instr.unit == "%":
442
+ from_ = 0 if from_ is None else pct_to_abs(from_, num_examples)
443
+ to = num_examples if to is None else pct_to_abs(to, num_examples)
444
+ else:
445
+ from_ = 0 if from_ is None else from_
446
+ to = num_examples if to is None else to
447
+ if from_ < 0:
448
+ from_ = max(num_examples + from_, 0)
449
+ if to < 0:
450
+ to = max(num_examples + to, 0)
451
+ from_ = min(from_, num_examples)
452
+ to = min(to, num_examples)
453
+ return _AbsoluteInstruction(split, from_, to)
454
+
455
+
456
+ class ReadInstruction:
457
+ """Reading instruction for a dataset.
458
+
459
+ Examples::
460
+
461
+ # The following lines are equivalent:
462
+ ds = datasets.load_dataset('mnist', split='test[:33%]')
463
+ ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec('test[:33%]'))
464
+ ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction('test', to=33, unit='%'))
465
+ ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction(
466
+ 'test', from_=0, to=33, unit='%'))
467
+
468
+ # The following lines are equivalent:
469
+ ds = datasets.load_dataset('mnist', split='test[:33%]+train[1:-1]')
470
+ ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec(
471
+ 'test[:33%]+train[1:-1]'))
472
+ ds = datasets.load_dataset('mnist', split=(
473
+ datasets.ReadInstruction('test', to=33, unit='%') +
474
+ datasets.ReadInstruction('train', from_=1, to=-1, unit='abs')))
475
+
476
+ # The following lines are equivalent:
477
+ ds = datasets.load_dataset('mnist', split='test[:33%](pct1_dropremainder)')
478
+ ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec(
479
+ 'test[:33%](pct1_dropremainder)'))
480
+ ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction(
481
+ 'test', from_=0, to=33, unit='%', rounding="pct1_dropremainder"))
482
+
483
+ # 10-fold validation:
484
+ tests = datasets.load_dataset(
485
+ 'mnist',
486
+ [datasets.ReadInstruction('train', from_=k, to=k+10, unit='%')
487
+ for k in range(0, 100, 10)])
488
+ trains = datasets.load_dataset(
489
+ 'mnist',
490
+ [datasets.ReadInstruction('train', to=k, unit='%') + datasets.ReadInstruction('train', from_=k+10, unit='%')
491
+ for k in range(0, 100, 10)])
492
+
493
+ """
494
+
495
+ def _init(self, relative_instructions):
496
+ # Private initializer.
497
+ self._relative_instructions = relative_instructions
498
+
499
+ @classmethod
500
+ def _read_instruction_from_relative_instructions(cls, relative_instructions):
501
+ """Returns ReadInstruction obj initialized with relative_instructions."""
502
+ # Use __new__ to bypass __init__ used by public API and not conveniant here.
503
+ result = cls.__new__(cls)
504
+ result._init(relative_instructions) # pylint: disable=protected-access
505
+ return result
506
+
507
+ def __init__(self, split_name, rounding=None, from_=None, to=None, unit=None):
508
+ """Initialize ReadInstruction.
509
+
510
+ Args:
511
+ split_name (str): name of the split to read. Eg: 'train'.
512
+ rounding (str, optional): The rounding behaviour to use when percent slicing is
513
+ used. Ignored when slicing with absolute indices.
514
+ Possible values:
515
+ - 'closest' (default): The specified percentages are rounded to the
516
+ closest value. Use this if you want specified percents to be as
517
+ much exact as possible.
518
+ - 'pct1_dropremainder': the specified percentages are treated as
519
+ multiple of 1%. Use this option if you want consistency. Eg:
520
+ len(5%) == 5 * len(1%).
521
+ Using this option, one might not be able to use the full set of
522
+ examples, if the number of those is not a multiple of 100.
523
+ from_ (int):
524
+ to (int): alternative way of specifying slicing boundaries. If any of
525
+ {from_, to, unit} argument is used, slicing cannot be specified as
526
+ string.
527
+ unit (str): optional, one of:
528
+ '%': to set the slicing unit as percents of the split size.
529
+ 'abs': to set the slicing unit as absolute numbers.
530
+ """
531
+ # This constructor is not always called. See factory method
532
+ # `_read_instruction_from_relative_instructions`. Common init instructions
533
+ # MUST be placed in the _init method.
534
+ self._init([_RelativeInstruction(split_name, from_, to, unit, rounding)])
535
+
536
+ @classmethod
537
+ def from_spec(cls, spec):
538
+ """Creates a `ReadInstruction` instance out of a string spec.
539
+
540
+ Args:
541
+ spec (`str`):
542
+ Split(s) + optional slice(s) to read + optional rounding
543
+ if percents are used as the slicing unit. A slice can be specified,
544
+ using absolute numbers (`int`) or percentages (`int`).
545
+
546
+ Examples:
547
+
548
+ ```
549
+ test: test split.
550
+ test + validation: test split + validation split.
551
+ test[10:]: test split, minus its first 10 records.
552
+ test[:10%]: first 10% records of test split.
553
+ test[:20%](pct1_dropremainder): first 10% records, rounded with the pct1_dropremainder rounding.
554
+ test[:-5%]+train[40%:60%]: first 95% of test + middle 20% of train.
555
+ ```
556
+
557
+ Returns:
558
+ ReadInstruction instance.
559
+ """
560
+ spec = str(spec) # Need to convert to str in case of NamedSplit instance.
561
+ subs = _ADDITION_SEP_RE.split(spec)
562
+ if not subs:
563
+ raise ValueError(f"No instructions could be built out of {spec}")
564
+ instruction = _str_to_read_instruction(subs[0])
565
+ return sum((_str_to_read_instruction(sub) for sub in subs[1:]), instruction)
566
+
567
+ def to_spec(self):
568
+ rel_instr_specs = []
569
+ for rel_instr in self._relative_instructions:
570
+ rel_instr_spec = rel_instr.splitname
571
+ if rel_instr.from_ is not None or rel_instr.to is not None:
572
+ from_ = rel_instr.from_
573
+ to = rel_instr.to
574
+ unit = rel_instr.unit
575
+ rounding = rel_instr.rounding
576
+ unit = unit if unit == "%" else ""
577
+ from_ = str(from_) + unit if from_ is not None else ""
578
+ to = str(to) + unit if to is not None else ""
579
+ slice_str = f"[{from_}:{to}]"
580
+ rounding_str = (
581
+ f"({rounding})" if unit == "%" and rounding is not None and rounding != "closest" else ""
582
+ )
583
+ rel_instr_spec += slice_str + rounding_str
584
+ rel_instr_specs.append(rel_instr_spec)
585
+ return "+".join(rel_instr_specs)
586
+
587
+ def __add__(self, other):
588
+ """Returns a new ReadInstruction obj, result of appending other to self."""
589
+ if not isinstance(other, ReadInstruction):
590
+ msg = "ReadInstruction can only be added to another ReadInstruction obj."
591
+ raise TypeError(msg)
592
+ self_ris = self._relative_instructions
593
+ other_ris = other._relative_instructions # pylint: disable=protected-access
594
+ if (
595
+ self_ris[0].unit != "abs"
596
+ and other_ris[0].unit != "abs"
597
+ and self._relative_instructions[0].rounding != other_ris[0].rounding
598
+ ):
599
+ raise ValueError("It is forbidden to sum ReadInstruction instances with different rounding values.")
600
+ return self._read_instruction_from_relative_instructions(self_ris + other_ris)
601
+
602
+ def __str__(self):
603
+ return self.to_spec()
604
+
605
+ def __repr__(self):
606
+ return f"ReadInstruction({self._relative_instructions})"
607
+
608
+ def to_absolute(self, name2len):
609
+ """Translate instruction into a list of absolute instructions.
610
+
611
+ Those absolute instructions are then to be added together.
612
+
613
+ Args:
614
+ name2len (`dict`):
615
+ Associating split names to number of examples.
616
+
617
+ Returns:
618
+ list of _AbsoluteInstruction instances (corresponds to the + in spec).
619
+ """
620
+ return [_rel_to_abs_instr(rel_instr, name2len) for rel_instr in self._relative_instructions]
.venv/lib/python3.11/site-packages/datasets/arrow_writer.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ # Lint as: python3
14
+ """To write records into Parquet files."""
15
+
16
+ import json
17
+ import sys
18
+ from collections.abc import Iterable
19
+ from typing import Any, Optional, Union
20
+
21
+ import fsspec
22
+ import numpy as np
23
+ import pyarrow as pa
24
+ import pyarrow.parquet as pq
25
+ from fsspec.core import url_to_fs
26
+
27
+ from . import config
28
+ from .features import Audio, Features, Image, Pdf, Value, Video
29
+ from .features.features import (
30
+ FeatureType,
31
+ List,
32
+ _ArrayXDExtensionType,
33
+ _visit,
34
+ cast_to_python_objects,
35
+ generate_from_arrow_type,
36
+ get_nested_type,
37
+ list_of_np_array_to_pyarrow_listarray,
38
+ numpy_to_pyarrow_listarray,
39
+ to_pyarrow_listarray,
40
+ )
41
+ from .filesystems import is_remote_filesystem
42
+ from .info import DatasetInfo
43
+ from .keyhash import DuplicatedKeysError, KeyHasher
44
+ from .table import array_cast, cast_array_to_feature, embed_table_storage, table_cast
45
+ from .utils import logging
46
+ from .utils.py_utils import asdict, convert_file_size_to_int, first_non_null_non_empty_value
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ type_ = type # keep python's type function
52
+
53
+
54
+ def get_arrow_writer_batch_size_from_features(features: Optional[Features]) -> Optional[int]:
55
+ """
56
+ Get the writer_batch_size that defines the maximum record batch size in the arrow files based on configuration values.
57
+ The default value is 100 for image/audio datasets and 10 for videos.
58
+ This allows to avoid overflows in arrow buffers.
59
+
60
+ Args:
61
+ features (`datasets.Features` or `None`):
62
+ Dataset Features from `datasets`.
63
+ Returns:
64
+ writer_batch_size (`Optional[int]`):
65
+ Writer batch size to pass to a dataset builder.
66
+ If `None`, then it will use the `datasets` default, i.e. `datasets.config.DEFAULT_MAX_BATCH_SIZE`.
67
+ """
68
+ if not features:
69
+ return None
70
+
71
+ batch_size = np.inf
72
+
73
+ def set_batch_size(feature: FeatureType) -> None:
74
+ nonlocal batch_size
75
+ if isinstance(feature, Image) and config.ARROW_RECORD_BATCH_SIZE_FOR_IMAGE_DATASETS is not None:
76
+ batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_IMAGE_DATASETS)
77
+ elif isinstance(feature, Audio) and config.ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS is not None:
78
+ batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS)
79
+ elif isinstance(feature, Video) and config.ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS is not None:
80
+ batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS)
81
+ elif (
82
+ isinstance(feature, Value)
83
+ and feature.dtype == "binary"
84
+ and config.ARROW_RECORD_BATCH_SIZE_FOR_BINARY_DATASETS is not None
85
+ ):
86
+ batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_BINARY_DATASETS)
87
+
88
+ _visit(features, set_batch_size)
89
+
90
+ return None if batch_size is np.inf else batch_size
91
+
92
+
93
+ def get_writer_batch_size_from_features(features: Optional[Features]) -> Optional[int]:
94
+ """
95
+ Get the writer_batch_size that defines the maximum row group size in the parquet files based on configuration values.
96
+ By default these are not set, but it can be helpful to hard set those values in some cases.
97
+ This allows to optimize random access to parquet file, since accessing 1 row requires
98
+ to read its entire row group.
99
+
100
+ Args:
101
+ features (`datasets.Features` or `None`):
102
+ Dataset Features from `datasets`.
103
+ Returns:
104
+ writer_batch_size (`Optional[int]`):
105
+ Writer batch size to pass to a parquet writer.
106
+ If `None`, then it will use the `datasets` default, i.e. aiming for row groups of 100MB.
107
+ """
108
+ if not features:
109
+ return None
110
+
111
+ batch_size = np.inf
112
+
113
+ def set_batch_size(feature: FeatureType) -> None:
114
+ nonlocal batch_size
115
+ if isinstance(feature, Image) and config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS is not None:
116
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
117
+ elif isinstance(feature, Audio) and config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS is not None:
118
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
119
+ elif isinstance(feature, Video) and config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS is not None:
120
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
121
+ elif (
122
+ isinstance(feature, Value)
123
+ and feature.dtype == "binary"
124
+ and config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS is not None
125
+ ):
126
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)
127
+
128
+ _visit(features, set_batch_size)
129
+
130
+ return None if batch_size is np.inf else batch_size
131
+
132
+
133
+ def get_writer_batch_size_from_data_size(num_rows: int, num_bytes: int) -> int:
134
+ """
135
+ Get the writer_batch_size that defines the maximum row group size in the parquet files.
136
+ The default in `datasets` is aiming for row groups of maximum 100MB uncompressed.
137
+ This allows to optimize random access to parquet file, since accessing 1 row requires
138
+ to read its entire row group.
139
+
140
+ This can be improved to get optimized size for querying/iterating
141
+ but at least it matches the dataset viewer expectations on HF.
142
+
143
+ Args:
144
+ num_rows (`int`):
145
+ Number of rows in the dataset.
146
+ num_bytes (`int`):
147
+ Number of bytes in the dataset.
148
+ For dataset with external files to embed (image, audio, videos), this can also be an
149
+ estimate from `dataset._estimate_nbytes()`.
150
+ Returns:
151
+ writer_batch_size (`Optional[int]`):
152
+ Writer batch size to pass to a parquet writer.
153
+ """
154
+ return max(10, num_rows * convert_file_size_to_int(config.MAX_ROW_GROUP_SIZE) // num_bytes) if num_bytes > 0 else 1
155
+
156
+
157
+ class SchemaInferenceError(ValueError):
158
+ pass
159
+
160
+
161
+ class TypedSequence:
162
+ """
163
+ This data container generalizes the typing when instantiating pyarrow arrays, tables or batches.
164
+
165
+ More specifically it adds several features:
166
+ - Support extension types like ``datasets.features.Array2DExtensionType``:
167
+ By default pyarrow arrays don't return extension arrays. One has to call
168
+ ``pa.ExtensionArray.from_storage(type, pa.array(data, type.storage_type))``
169
+ in order to get an extension array.
170
+ - Support for ``try_type`` parameter that can be used instead of ``type``:
171
+ When an array is transformed, we like to keep the same type as before if possible.
172
+ For example when calling :func:`datasets.Dataset.map`, we don't want to change the type
173
+ of each column by default.
174
+ - Better error message when a pyarrow array overflows.
175
+
176
+ Example::
177
+
178
+ from datasets.features import Array2D, Array2DExtensionType, Value
179
+ from datasets.arrow_writer import TypedSequence
180
+ import pyarrow as pa
181
+
182
+ arr = pa.array(TypedSequence([1, 2, 3], type=Value("int32")))
183
+ assert arr.type == pa.int32()
184
+
185
+ arr = pa.array(TypedSequence([1, 2, 3], try_type=Value("int32")))
186
+ assert arr.type == pa.int32()
187
+
188
+ arr = pa.array(TypedSequence(["foo", "bar"], try_type=Value("int32")))
189
+ assert arr.type == pa.string()
190
+
191
+ arr = pa.array(TypedSequence([[[1, 2, 3]]], type=Array2D((1, 3), "int64")))
192
+ assert arr.type == Array2DExtensionType((1, 3), "int64")
193
+
194
+ table = pa.Table.from_pydict({
195
+ "image": TypedSequence([[[1, 2, 3]]], type=Array2D((1, 3), "int64"))
196
+ })
197
+ assert table["image"].type == Array2DExtensionType((1, 3), "int64")
198
+
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ data: Iterable,
204
+ type: Optional[FeatureType] = None,
205
+ try_type: Optional[FeatureType] = None,
206
+ optimized_int_type: Optional[FeatureType] = None,
207
+ ):
208
+ # assert type is None or try_type is None,
209
+ if type is not None and try_type is not None:
210
+ raise ValueError("You cannot specify both type and try_type")
211
+ # set attributes
212
+ self.data = data
213
+ self.type = type
214
+ self.try_type = try_type # is ignored if it doesn't match the data
215
+ self.optimized_int_type = optimized_int_type
216
+ # when trying a type (is ignored if data is not compatible)
217
+ self.trying_type = self.try_type is not None
218
+ self.trying_int_optimization = optimized_int_type is not None and type is None and try_type is None
219
+ # used to get back the inferred type after __arrow_array__() is called once
220
+ self._inferred_type = None
221
+
222
+ def get_inferred_type(self) -> FeatureType:
223
+ """Return the inferred feature type.
224
+ This is done by converting the sequence to an Arrow array, and getting the corresponding
225
+ feature type.
226
+
227
+ Since building the Arrow array can be expensive, the value of the inferred type is cached
228
+ as soon as pa.array is called on the typed sequence.
229
+
230
+ Returns:
231
+ FeatureType: inferred feature type of the sequence.
232
+ """
233
+ if self._inferred_type is None:
234
+ self._inferred_type = generate_from_arrow_type(pa.array(self).type)
235
+ return self._inferred_type
236
+
237
+ @staticmethod
238
+ def _infer_custom_type_and_encode(data: Iterable) -> tuple[Iterable, Optional[FeatureType]]:
239
+ """Implement type inference for custom objects like PIL.Image.Image -> Image type.
240
+
241
+ This function is only used for custom python objects that can't be directly passed to build
242
+ an Arrow array. In such cases is infers the feature type to use, and it encodes the data so
243
+ that they can be passed to an Arrow array.
244
+
245
+ Args:
246
+ data (Iterable): array of data to infer the type, e.g. a list of PIL images.
247
+
248
+ Returns:
249
+ Tuple[Iterable, Optional[FeatureType]]: a tuple with:
250
+ - the (possibly encoded) array, if the inferred feature type requires encoding
251
+ - the inferred feature type if the array is made of supported custom objects like
252
+ PIL images, else None.
253
+ """
254
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
255
+ import PIL.Image
256
+
257
+ non_null_idx, non_null_value = first_non_null_non_empty_value(data)
258
+ if isinstance(non_null_value, PIL.Image.Image):
259
+ return [Image().encode_example(value) if value is not None else None for value in data], Image()
260
+ if isinstance(non_null_value, list) and isinstance(non_null_value[0], PIL.Image.Image):
261
+ return [
262
+ [Image().encode_example(x) for x in value] if value is not None else None for value in data
263
+ ], List(Image())
264
+ if config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules:
265
+ import pdfplumber
266
+
267
+ non_null_idx, non_null_value = first_non_null_non_empty_value(data)
268
+ if isinstance(non_null_value, pdfplumber.pdf.PDF):
269
+ return [Pdf().encode_example(value) if value is not None else None for value in data], Pdf()
270
+ if isinstance(non_null_value, list) and isinstance(non_null_value[0], pdfplumber.pdf.PDF):
271
+ return [
272
+ [Pdf().encode_example(x) for x in value] if value is not None else None for value in data
273
+ ], List(Pdf())
274
+ return data, None
275
+
276
+ def __arrow_array__(self, type: Optional[pa.DataType] = None):
277
+ """This function is called when calling pa.array(typed_sequence)"""
278
+
279
+ if type is not None:
280
+ raise ValueError("TypedSequence is supposed to be used with pa.array(typed_sequence, type=None)")
281
+ del type # make sure we don't use it
282
+ data = self.data
283
+ # automatic type inference for custom objects
284
+ if self.type is None and self.try_type is None:
285
+ data, self._inferred_type = self._infer_custom_type_and_encode(data)
286
+ if self._inferred_type is None:
287
+ type = self.try_type if self.trying_type else self.type
288
+ else:
289
+ type = self._inferred_type
290
+ pa_type = get_nested_type(type) if type is not None else None
291
+ optimized_int_pa_type = (
292
+ get_nested_type(self.optimized_int_type) if self.optimized_int_type is not None else None
293
+ )
294
+ trying_cast_to_python_objects = False
295
+ try:
296
+ # custom pyarrow types
297
+ if isinstance(pa_type, _ArrayXDExtensionType):
298
+ storage = to_pyarrow_listarray(data, pa_type)
299
+ return pa.ExtensionArray.from_storage(pa_type, storage)
300
+
301
+ # efficient np array to pyarrow array
302
+ if isinstance(data, np.ndarray):
303
+ out = numpy_to_pyarrow_listarray(data)
304
+ elif isinstance(data, list) and data and isinstance(first_non_null_non_empty_value(data)[1], np.ndarray):
305
+ out = list_of_np_array_to_pyarrow_listarray(data)
306
+ else:
307
+ trying_cast_to_python_objects = True
308
+ out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
309
+ # use smaller integer precisions if possible
310
+ if self.trying_int_optimization:
311
+ if pa.types.is_int64(out.type):
312
+ out = out.cast(optimized_int_pa_type)
313
+ elif pa.types.is_list(out.type):
314
+ if pa.types.is_int64(out.type.value_type):
315
+ out = array_cast(out, pa.list_(optimized_int_pa_type))
316
+ elif pa.types.is_list(out.type.value_type) and pa.types.is_int64(out.type.value_type.value_type):
317
+ out = array_cast(out, pa.list_(pa.list_(optimized_int_pa_type)))
318
+ # otherwise we can finally use the user's type
319
+ elif type is not None:
320
+ # We use cast_array_to_feature to support casting to custom types like Audio and Image
321
+ # Also, when trying type "string", we don't want to convert integers or floats to "string".
322
+ # We only do it if trying_type is False - since this is what the user asks for.
323
+ out = cast_array_to_feature(
324
+ out, type, allow_primitive_to_str=not self.trying_type, allow_decimal_to_str=not self.trying_type
325
+ )
326
+ return out
327
+ except (
328
+ TypeError,
329
+ pa.lib.ArrowInvalid,
330
+ pa.lib.ArrowNotImplementedError,
331
+ ) as e: # handle type errors and overflows
332
+ # Ignore ArrowNotImplementedError caused by trying type, otherwise re-raise
333
+ if not self.trying_type and isinstance(e, pa.lib.ArrowNotImplementedError):
334
+ raise
335
+
336
+ if self.trying_type:
337
+ try: # second chance
338
+ if isinstance(data, np.ndarray):
339
+ return numpy_to_pyarrow_listarray(data)
340
+ elif isinstance(data, list) and data and any(isinstance(value, np.ndarray) for value in data):
341
+ return list_of_np_array_to_pyarrow_listarray(data)
342
+ else:
343
+ trying_cast_to_python_objects = True
344
+ return pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
345
+ except pa.lib.ArrowInvalid as e:
346
+ if "overflow" in str(e):
347
+ raise OverflowError(
348
+ f"There was an overflow with type {type_(data)}. Try to reduce writer_batch_size to have batches smaller than 2GB.\n({e})"
349
+ ) from None
350
+ elif self.trying_int_optimization and "not in range" in str(e):
351
+ optimized_int_pa_type_str = np.dtype(optimized_int_pa_type.to_pandas_dtype()).name
352
+ logger.info(
353
+ f"Failed to cast a sequence to {optimized_int_pa_type_str}. Falling back to int64."
354
+ )
355
+ return out
356
+ elif trying_cast_to_python_objects and "Could not convert" in str(e):
357
+ out = pa.array(
358
+ cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False)
359
+ )
360
+ if type is not None:
361
+ out = cast_array_to_feature(
362
+ out, type, allow_primitive_to_str=True, allow_decimal_to_str=True
363
+ )
364
+ return out
365
+ else:
366
+ raise
367
+ elif "overflow" in str(e):
368
+ raise OverflowError(
369
+ f"There was an overflow with type {type_(data)}. Try to reduce writer_batch_size to have batches smaller than 2GB.\n({e})"
370
+ ) from None
371
+ elif self.trying_int_optimization and "not in range" in str(e):
372
+ optimized_int_pa_type_str = np.dtype(optimized_int_pa_type.to_pandas_dtype()).name
373
+ logger.info(f"Failed to cast a sequence to {optimized_int_pa_type_str}. Falling back to int64.")
374
+ return out
375
+ elif trying_cast_to_python_objects and "Could not convert" in str(e):
376
+ out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False))
377
+ if type is not None:
378
+ out = cast_array_to_feature(out, type, allow_primitive_to_str=True, allow_decimal_to_str=True)
379
+ return out
380
+ else:
381
+ raise
382
+
383
+
384
+ class OptimizedTypedSequence(TypedSequence):
385
+ def __init__(
386
+ self,
387
+ data,
388
+ type: Optional[FeatureType] = None,
389
+ try_type: Optional[FeatureType] = None,
390
+ col: Optional[str] = None,
391
+ optimized_int_type: Optional[FeatureType] = None,
392
+ ):
393
+ optimized_int_type_by_col = {
394
+ "attention_mask": Value("int8"), # binary tensor
395
+ "special_tokens_mask": Value("int8"),
396
+ "input_ids": Value("int32"), # typical vocab size: 0-50k (max ~500k, never > 1M)
397
+ "token_type_ids": Value(
398
+ "int8"
399
+ ), # binary mask; some (XLNetModel) use an additional token represented by a 2
400
+ }
401
+ if type is None and try_type is None:
402
+ optimized_int_type = optimized_int_type_by_col.get(col, None)
403
+ super().__init__(data, type=type, try_type=try_type, optimized_int_type=optimized_int_type)
404
+
405
+
406
+ class ArrowWriter:
407
+ """Shuffles and writes Examples to Arrow files."""
408
+
409
+ def __init__(
410
+ self,
411
+ schema: Optional[pa.Schema] = None,
412
+ features: Optional[Features] = None,
413
+ path: Optional[str] = None,
414
+ stream: Optional[pa.NativeFile] = None,
415
+ fingerprint: Optional[str] = None,
416
+ writer_batch_size: Optional[int] = None,
417
+ hash_salt: Optional[str] = None,
418
+ check_duplicates: Optional[bool] = False,
419
+ disable_nullable: bool = False,
420
+ update_features: bool = False,
421
+ with_metadata: bool = True,
422
+ unit: str = "examples",
423
+ embed_local_files: bool = False,
424
+ storage_options: Optional[dict] = None,
425
+ ):
426
+ if path is None and stream is None:
427
+ raise ValueError("At least one of path and stream must be provided.")
428
+ if features is not None:
429
+ self._features = features
430
+ self._schema = None
431
+ elif schema is not None:
432
+ self._schema: pa.Schema = schema
433
+ self._features = Features.from_arrow_schema(self._schema)
434
+ else:
435
+ self._features = None
436
+ self._schema = None
437
+
438
+ if hash_salt is not None:
439
+ # Create KeyHasher instance using split name as hash salt
440
+ self._hasher = KeyHasher(hash_salt)
441
+ else:
442
+ self._hasher = KeyHasher("")
443
+
444
+ self._check_duplicates = check_duplicates
445
+ self._disable_nullable = disable_nullable
446
+
447
+ if stream is None:
448
+ fs, path = url_to_fs(path, **(storage_options or {}))
449
+ self._fs: fsspec.AbstractFileSystem = fs
450
+ self._path = path if not is_remote_filesystem(self._fs) else self._fs.unstrip_protocol(path)
451
+ self.stream = self._fs.open(path, "wb")
452
+ self._closable_stream = True
453
+ else:
454
+ self._fs = None
455
+ self._path = None
456
+ self.stream = stream
457
+ self._closable_stream = False
458
+
459
+ self.fingerprint = fingerprint
460
+ self.disable_nullable = disable_nullable
461
+ self.writer_batch_size = (
462
+ writer_batch_size
463
+ or get_arrow_writer_batch_size_from_features(self._features)
464
+ or config.DEFAULT_MAX_BATCH_SIZE
465
+ )
466
+ self.update_features = update_features
467
+ self.with_metadata = with_metadata
468
+ self.unit = unit
469
+ self.embed_local_files = embed_local_files
470
+
471
+ self._num_examples = 0
472
+ self._num_bytes = 0
473
+ self.current_examples: list[tuple[dict[str, Any], str]] = []
474
+ self.current_rows: list[pa.Table] = []
475
+ self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
476
+ self.hkey_record = []
477
+
478
+ def __len__(self):
479
+ """Return the number of writed and staged examples"""
480
+ return self._num_examples + len(self.current_examples) + len(self.current_rows)
481
+
482
+ def __enter__(self):
483
+ return self
484
+
485
+ def __exit__(self, exc_type, exc_val, exc_tb):
486
+ self.close()
487
+
488
+ def close(self):
489
+ # Try closing if opened; if closed: pyarrow.lib.ArrowInvalid: Invalid operation on closed file
490
+ if self.pa_writer: # it might be None
491
+ try:
492
+ self.pa_writer.close()
493
+ except Exception: # pyarrow.lib.ArrowInvalid, OSError
494
+ pass
495
+ if self._closable_stream and not self.stream.closed:
496
+ self.stream.close() # This also closes self.pa_writer if it is opened
497
+
498
+ def _build_schema(self, inferred_schema: pa.Schema):
499
+ schema = self.schema
500
+ features = self._features
501
+ inferred_features = Features.from_arrow_schema(inferred_schema)
502
+ if self._features is not None:
503
+ if self.update_features: # keep original features it they match, or update them
504
+ fields = {field.name: field for field in self._features.type}
505
+ for inferred_field in inferred_features.type:
506
+ name = inferred_field.name
507
+ if name in fields:
508
+ if inferred_field == fields[name]:
509
+ inferred_features[name] = self._features[name]
510
+ features = inferred_features
511
+ schema: pa.Schema = inferred_schema
512
+ else:
513
+ features = inferred_features
514
+ schema: pa.Schema = inferred_features.arrow_schema
515
+
516
+ if self.disable_nullable:
517
+ schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in schema)
518
+ if self.with_metadata:
519
+ schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=features), self.fingerprint))
520
+ else:
521
+ schema = schema.with_metadata({})
522
+
523
+ return schema, features
524
+
525
+ def _build_writer(self, inferred_schema: pa.Schema):
526
+ self._schema, self._features = self._build_schema(inferred_schema)
527
+ self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
528
+
529
+ @property
530
+ def schema(self):
531
+ _schema = (
532
+ self._schema
533
+ if self._schema is not None
534
+ else (pa.schema(self._features.type) if self._features is not None else None)
535
+ )
536
+ if self._disable_nullable and _schema is not None:
537
+ _schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in _schema)
538
+ return _schema if _schema is not None else []
539
+
540
+ @staticmethod
541
+ def _build_metadata(info: DatasetInfo, fingerprint: Optional[str] = None) -> dict[str, str]:
542
+ info_keys = ["features"] # we can add support for more DatasetInfo keys in the future
543
+ info_as_dict = asdict(info)
544
+ metadata = {}
545
+ metadata["info"] = {key: info_as_dict[key] for key in info_keys}
546
+ if fingerprint is not None:
547
+ metadata["fingerprint"] = fingerprint
548
+ return {"huggingface": json.dumps(metadata)}
549
+
550
+ def write_examples_on_file(self):
551
+ """Write stored examples from the write-pool of examples. It makes a table out of the examples and write it."""
552
+ if not self.current_examples:
553
+ return
554
+ # preserve the order the columns
555
+ if self.schema:
556
+ schema_cols = set(self.schema.names)
557
+ examples_cols = self.current_examples[0][0].keys() # .keys() preserves the order (unlike set)
558
+ common_cols = [col for col in self.schema.names if col in examples_cols]
559
+ extra_cols = [col for col in examples_cols if col not in schema_cols]
560
+ cols = common_cols + extra_cols
561
+ else:
562
+ cols = list(self.current_examples[0][0])
563
+ batch_examples = {}
564
+ for col in cols:
565
+ # We use row[0][col] since current_examples contains (example, key) tuples.
566
+ # Moreover, examples could be Arrow arrays of 1 element.
567
+ # This can happen in `.map()` when we want to re-write the same Arrow data
568
+ if all(isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) for row in self.current_examples):
569
+ arrays = [row[0][col] for row in self.current_examples]
570
+ arrays = [
571
+ chunk
572
+ for array in arrays
573
+ for chunk in (array.chunks if isinstance(array, pa.ChunkedArray) else [array])
574
+ ]
575
+ batch_examples[col] = pa.concat_arrays(arrays)
576
+ else:
577
+ batch_examples[col] = [
578
+ row[0][col].to_pylist()[0] if isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) else row[0][col]
579
+ for row in self.current_examples
580
+ ]
581
+ self.write_batch(batch_examples=batch_examples)
582
+ self.current_examples = []
583
+
584
+ def write_rows_on_file(self):
585
+ """Write stored rows from the write-pool of rows. It concatenates the single-row tables and it writes the resulting table."""
586
+ if not self.current_rows:
587
+ return
588
+ table = pa.concat_tables(self.current_rows)
589
+ self.write_table(table)
590
+ self.current_rows = []
591
+
592
+ def write(
593
+ self,
594
+ example: dict[str, Any],
595
+ key: Optional[Union[str, int, bytes]] = None,
596
+ writer_batch_size: Optional[int] = None,
597
+ ):
598
+ """Add a given (Example,Key) pair to the write-pool of examples which is written to file.
599
+
600
+ Args:
601
+ example: the Example to add.
602
+ key: Optional, a unique identifier(str, int or bytes) associated with each example
603
+ """
604
+ # Utilize the keys and duplicate checking when `self._check_duplicates` is passed True
605
+ if self._check_duplicates:
606
+ # Create unique hash from key and store as (key, example) pairs
607
+ hash = self._hasher.hash(key)
608
+ self.current_examples.append((example, hash))
609
+ # Maintain record of keys and their respective hashes for checking duplicates
610
+ self.hkey_record.append((hash, key))
611
+ else:
612
+ # Store example as a tuple so as to keep the structure of `self.current_examples` uniform
613
+ self.current_examples.append((example, ""))
614
+
615
+ if writer_batch_size is None:
616
+ writer_batch_size = self.writer_batch_size
617
+ if writer_batch_size is not None and len(self.current_examples) >= writer_batch_size:
618
+ if self._check_duplicates:
619
+ self.check_duplicate_keys()
620
+ # Re-initializing to empty list for next batch
621
+ self.hkey_record = []
622
+
623
+ self.write_examples_on_file()
624
+
625
+ def check_duplicate_keys(self):
626
+ """Raises error if duplicates found in a batch"""
627
+ tmp_record = set()
628
+ for hash, key in self.hkey_record:
629
+ if hash in tmp_record:
630
+ duplicate_key_indices = [
631
+ str(self._num_examples + index)
632
+ for index, (duplicate_hash, _) in enumerate(self.hkey_record)
633
+ if duplicate_hash == hash
634
+ ]
635
+
636
+ raise DuplicatedKeysError(key, duplicate_key_indices)
637
+ else:
638
+ tmp_record.add(hash)
639
+
640
+ def write_row(self, row: pa.Table, writer_batch_size: Optional[int] = None):
641
+ """Add a given single-row Table to the write-pool of rows which is written to file.
642
+
643
+ Args:
644
+ row: the row to add.
645
+ """
646
+ if len(row) != 1:
647
+ raise ValueError(f"Only single-row pyarrow tables are allowed but got table with {len(row)} rows.")
648
+ self.current_rows.append(row)
649
+ if writer_batch_size is None:
650
+ writer_batch_size = self.writer_batch_size
651
+ if writer_batch_size is not None and len(self.current_rows) >= writer_batch_size:
652
+ self.write_rows_on_file()
653
+
654
+ def write_batch(
655
+ self,
656
+ batch_examples: dict[str, list],
657
+ writer_batch_size: Optional[int] = None,
658
+ try_original_type: Optional[bool] = True,
659
+ ):
660
+ """Write a batch of Example to file.
661
+ Ignores the batch if it appears to be empty,
662
+ preventing a potential schema update of unknown types.
663
+
664
+ Args:
665
+ batch_examples: the batch of examples to add.
666
+ try_original_type: use `try_type` when instantiating OptimizedTypedSequence if `True`, otherwise `try_type = None`.
667
+ """
668
+ if batch_examples and len(next(iter(batch_examples.values()))) == 0:
669
+ return
670
+ features = None if self.pa_writer is None and self.update_features else self._features
671
+ try_features = self._features if self.pa_writer is None and self.update_features else None
672
+ arrays = []
673
+ inferred_features = Features()
674
+ # preserve the order the columns
675
+ if self.schema:
676
+ schema_cols = set(self.schema.names)
677
+ batch_cols = batch_examples.keys() # .keys() preserves the order (unlike set)
678
+ common_cols = [col for col in self.schema.names if col in batch_cols]
679
+ extra_cols = [col for col in batch_cols if col not in schema_cols]
680
+ cols = common_cols + extra_cols
681
+ else:
682
+ cols = list(batch_examples)
683
+ for col in cols:
684
+ col_values = batch_examples[col]
685
+ col_type = features[col] if features else None
686
+ if isinstance(col_values, (pa.Array, pa.ChunkedArray)):
687
+ array = cast_array_to_feature(col_values, col_type) if col_type is not None else col_values
688
+ arrays.append(array)
689
+ inferred_features[col] = generate_from_arrow_type(col_values.type)
690
+ else:
691
+ col_try_type = (
692
+ try_features[col]
693
+ if try_features is not None and col in try_features and try_original_type
694
+ else None
695
+ )
696
+ typed_sequence = OptimizedTypedSequence(col_values, type=col_type, try_type=col_try_type, col=col)
697
+ arrays.append(pa.array(typed_sequence))
698
+ inferred_features[col] = typed_sequence.get_inferred_type()
699
+ schema = inferred_features.arrow_schema if self.pa_writer is None else self.schema
700
+ pa_table = pa.Table.from_arrays(arrays, schema=schema)
701
+ self.write_table(pa_table, writer_batch_size)
702
+
703
+ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = None):
704
+ """Write a Table to file.
705
+
706
+ Args:
707
+ example: the Table to add.
708
+ """
709
+ if writer_batch_size is None:
710
+ writer_batch_size = self.writer_batch_size
711
+ if self.pa_writer is None:
712
+ self._build_writer(inferred_schema=pa_table.schema)
713
+ pa_table = pa_table.combine_chunks()
714
+ pa_table = table_cast(pa_table, self._schema)
715
+ if self.embed_local_files:
716
+ pa_table = embed_table_storage(pa_table)
717
+ self._num_bytes += pa_table.nbytes
718
+ self._num_examples += pa_table.num_rows
719
+ self.pa_writer.write_table(pa_table, writer_batch_size)
720
+
721
+ def finalize(self, close_stream=True):
722
+ self.write_rows_on_file()
723
+ # In case current_examples < writer_batch_size, but user uses finalize()
724
+ if self._check_duplicates:
725
+ self.check_duplicate_keys()
726
+ # Re-initializing to empty list for next batch
727
+ self.hkey_record = []
728
+ self.write_examples_on_file()
729
+ # If schema is known, infer features even if no examples were written
730
+ if self.pa_writer is None and self.schema:
731
+ self._build_writer(self.schema)
732
+ if self.pa_writer is not None:
733
+ self.pa_writer.close()
734
+ self.pa_writer = None
735
+ if close_stream:
736
+ self.stream.close()
737
+ else:
738
+ if close_stream:
739
+ self.stream.close()
740
+ raise SchemaInferenceError("Please pass `features` or at least one example when writing data")
741
+ logger.debug(
742
+ f"Done writing {self._num_examples} {self.unit} in {self._num_bytes} bytes {self._path if self._path else ''}."
743
+ )
744
+ return self._num_examples, self._num_bytes
745
+
746
+
747
+ class ParquetWriter(ArrowWriter):
748
+ def __init__(self, *args, use_content_defined_chunking=True, write_page_index=True, **kwargs):
749
+ super().__init__(*args, **kwargs)
750
+ if use_content_defined_chunking is True:
751
+ use_content_defined_chunking = config.DEFAULT_CDC_OPTIONS
752
+ self.use_content_defined_chunking = use_content_defined_chunking
753
+ self.write_page_index = write_page_index
754
+
755
+ def _build_writer(self, inferred_schema: pa.Schema):
756
+ self._schema, self._features = self._build_schema(inferred_schema)
757
+ self.pa_writer = pq.ParquetWriter(
758
+ self.stream,
759
+ self._schema,
760
+ use_content_defined_chunking=self.use_content_defined_chunking,
761
+ write_page_index=self.write_page_index,
762
+ )
763
+ if self.use_content_defined_chunking is not False:
764
+ self.pa_writer.add_key_value_metadata(
765
+ {"content_defined_chunking": json.dumps(self.use_content_defined_chunking)}
766
+ )
.venv/lib/python3.11/site-packages/datasets/builder.py ADDED
@@ -0,0 +1,1866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Lint as: python3
16
+ """DatasetBuilder base class."""
17
+
18
+ import abc
19
+ import contextlib
20
+ import copy
21
+ import inspect
22
+ import os
23
+ import posixpath
24
+ import shutil
25
+ import textwrap
26
+ import time
27
+ import urllib
28
+ from collections.abc import Iterable, Mapping
29
+ from dataclasses import dataclass
30
+ from functools import partial
31
+ from pathlib import Path
32
+ from typing import TYPE_CHECKING, Optional, Union
33
+ from unittest.mock import patch
34
+
35
+ import fsspec
36
+ from fsspec.core import url_to_fs
37
+ from multiprocess import Pool
38
+ from tqdm.contrib.concurrent import thread_map
39
+
40
+ from . import config, utils
41
+ from .arrow_dataset import Dataset
42
+ from .arrow_reader import (
43
+ ArrowReader,
44
+ ReadInstruction,
45
+ )
46
+ from .arrow_writer import ArrowWriter, ParquetWriter, SchemaInferenceError
47
+ from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns
48
+ from .dataset_dict import DatasetDict, IterableDatasetDict
49
+ from .download.download_config import DownloadConfig
50
+ from .download.download_manager import DownloadManager, DownloadMode
51
+ from .download.streaming_download_manager import StreamingDownloadManager, xjoin
52
+ from .exceptions import DatasetGenerationCastError, DatasetGenerationError, FileFormatError, ManualDownloadError
53
+ from .features import Features
54
+ from .filesystems import (
55
+ is_remote_filesystem,
56
+ rename,
57
+ )
58
+ from .fingerprint import Hasher
59
+ from .info import DatasetInfo, PostProcessedInfo
60
+ from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset
61
+ from .keyhash import DuplicatedKeysError
62
+ from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
63
+ from .splits import Split, SplitDict, SplitGenerator, SplitInfo
64
+ from .streaming import extend_dataset_builder_for_streaming
65
+ from .table import CastError
66
+ from .utils import logging
67
+ from .utils import tqdm as hf_tqdm
68
+ from .utils._filelock import FileLock
69
+ from .utils.file_utils import is_remote_url
70
+ from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
71
+ from .utils.py_utils import (
72
+ classproperty,
73
+ convert_file_size_to_int,
74
+ has_sufficient_disk_space,
75
+ iflatmap_unordered,
76
+ map_nested,
77
+ memoize,
78
+ size_str,
79
+ temporary_assignment,
80
+ )
81
+ from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs
82
+ from .utils.track import tracked_list
83
+
84
+
85
+ if TYPE_CHECKING:
86
+ from .load import DatasetModule
87
+
88
+
89
+ logger = logging.get_logger(__name__)
90
+
91
+
92
+ class InvalidConfigName(ValueError):
93
+ pass
94
+
95
+
96
+ @dataclass
97
+ class BuilderConfig:
98
+ """Base class for `DatasetBuilder` data configuration.
99
+
100
+ `DatasetBuilder` subclasses with data configuration options should subclass
101
+ `BuilderConfig` and add their own properties.
102
+
103
+ Attributes:
104
+ name (`str`, defaults to `default`):
105
+ The name of the configuration.
106
+ version (`Version` or `str`, defaults to `0.0.0`):
107
+ The version of the configuration.
108
+ data_dir (`str`, *optional*):
109
+ Path to the directory containing the source data.
110
+ data_files (`str` or `Sequence` or `Mapping`, *optional*):
111
+ Path(s) to source data file(s).
112
+ description (`str`, *optional*):
113
+ A human description of the configuration.
114
+ """
115
+
116
+ name: str = "default"
117
+ version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0")
118
+ data_dir: Optional[str] = None
119
+ data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None
120
+ description: Optional[str] = None
121
+
122
+ def __post_init__(self):
123
+ # The config name is used to name the cache directory.
124
+ for invalid_char in INVALID_WINDOWS_CHARACTERS_IN_PATH:
125
+ if invalid_char in self.name:
126
+ raise InvalidConfigName(
127
+ f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. "
128
+ f"They could create issues when creating a directory for this config on Windows filesystem."
129
+ )
130
+ if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)):
131
+ raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}")
132
+
133
+ def __eq__(self, o):
134
+ # we need to override the default dataclass __eq__ since it doesn't check for
135
+ # other attributes that the ones of the signature.
136
+ if set(self.__dict__.keys()) != set(o.__dict__.keys()):
137
+ return False
138
+ return all((k, getattr(self, k)) == (k, getattr(o, k)) for k in self.__dict__.keys())
139
+
140
+ def create_config_id(
141
+ self,
142
+ config_kwargs: dict,
143
+ custom_features: Optional[Features] = None,
144
+ ) -> str:
145
+ """
146
+ The config id is used to build the cache directory.
147
+ By default it is equal to the config name.
148
+ However the name of a config is not sufficient to have a unique identifier for the dataset being generated
149
+ since it doesn't take into account:
150
+ - the config kwargs that can be used to overwrite attributes
151
+ - the custom features used to write the dataset
152
+ - the data_files for json/text/csv/pandas datasets
153
+
154
+ Therefore the config id is just the config name with an optional suffix based on these.
155
+ """
156
+ # Possibly add a suffix to the name to handle custom features/data_files/config_kwargs
157
+ suffix: Optional[str] = None
158
+ config_kwargs_to_add_to_suffix = config_kwargs.copy()
159
+ # name and version are already used to build the cache directory
160
+ config_kwargs_to_add_to_suffix.pop("name", None)
161
+ config_kwargs_to_add_to_suffix.pop("version", None)
162
+ # data dir handling (when specified it points to the manually downloaded data):
163
+ # it was previously ignored before the introduction of config id because we didn't want
164
+ # to change the config name. Now it's fine to take it into account for the config id.
165
+ # config_kwargs_to_add_to_suffix.pop("data_dir", None)
166
+ if "data_dir" in config_kwargs_to_add_to_suffix:
167
+ if config_kwargs_to_add_to_suffix["data_dir"] is None:
168
+ config_kwargs_to_add_to_suffix.pop("data_dir", None)
169
+ else:
170
+ # canonicalize the data dir to avoid two paths to the same location having different
171
+ # hashes
172
+ data_dir = config_kwargs_to_add_to_suffix["data_dir"]
173
+ data_dir = os.path.normpath(data_dir)
174
+ config_kwargs_to_add_to_suffix["data_dir"] = data_dir
175
+ if config_kwargs_to_add_to_suffix:
176
+ # we don't care about the order of the kwargs
177
+ config_kwargs_to_add_to_suffix = {
178
+ k: config_kwargs_to_add_to_suffix[k] for k in sorted(config_kwargs_to_add_to_suffix)
179
+ }
180
+ if all(isinstance(v, (str, bool, int, float)) for v in config_kwargs_to_add_to_suffix.values()):
181
+ suffix = ",".join(
182
+ str(k) + "=" + urllib.parse.quote_plus(str(v)) for k, v in config_kwargs_to_add_to_suffix.items()
183
+ )
184
+ if len(suffix) > 32: # hash if too long
185
+ suffix = Hasher.hash(config_kwargs_to_add_to_suffix)
186
+ else:
187
+ suffix = Hasher.hash(config_kwargs_to_add_to_suffix)
188
+
189
+ if custom_features is not None:
190
+ m = Hasher()
191
+ if suffix:
192
+ m.update(suffix)
193
+ m.update(custom_features)
194
+ suffix = m.hexdigest()
195
+
196
+ if suffix:
197
+ config_id = self.name + "-" + suffix
198
+ if len(config_id) > config.MAX_DATASET_CONFIG_ID_READABLE_LENGTH:
199
+ config_id = self.name + "-" + Hasher.hash(suffix)
200
+ return config_id
201
+ else:
202
+ return self.name
203
+
204
+ def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None:
205
+ if isinstance(self.data_files, DataFilesPatternsDict):
206
+ base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path
207
+ self.data_files = self.data_files.resolve(base_path, download_config)
208
+
209
+
210
+ class DatasetBuilder:
211
+ """Abstract base class for all datasets.
212
+
213
+ `DatasetBuilder` has 3 key methods:
214
+
215
+ - [`DatasetBuilder.info`]: Documents the dataset, including feature
216
+ names, types, shapes, version, splits, citation, etc.
217
+ - [`DatasetBuilder.download_and_prepare`]: Downloads the source data
218
+ and writes it to disk.
219
+ - [`DatasetBuilder.as_dataset`]: Generates a [`Dataset`].
220
+
221
+ Some `DatasetBuilder`s expose multiple variants of the
222
+ dataset by defining a [`BuilderConfig`] subclass and accepting a
223
+ config object (or name) on construction. Configurable datasets expose a
224
+ pre-defined set of configurations in [`DatasetBuilder.builder_configs`].
225
+
226
+ Args:
227
+ cache_dir (`str`, *optional*):
228
+ Directory to cache data. Defaults to `"~/.cache/huggingface/datasets"`.
229
+ dataset_name (`str`, *optional*):
230
+ Name of the dataset, if different from the builder name. Useful for packaged builders
231
+ like csv, imagefolder, audiofolder, etc. to reflect the difference between datasets
232
+ that use the same packaged builder.
233
+ config_name (`str`, *optional*):
234
+ Name of the dataset configuration.
235
+ It affects the data generated on disk. Different configurations will have their own subdirectories and
236
+ versions.
237
+ If not provided, the default configuration is used (if it exists).
238
+
239
+ <Added version="2.3.0">
240
+
241
+ Parameter `name` was renamed to `config_name`.
242
+
243
+ </Added>
244
+ hash (`str`, *optional*):
245
+ Hash specific to the dataset builder code. Used to update the caching directory when the
246
+ dataset builder code is updated (to avoid reusing old data).
247
+ The typical caching directory (defined in `self._relative_data_dir`) is `name/version/hash/`.
248
+ base_path (`str`, *optional*):
249
+ Base path for relative paths that are used to download files.
250
+ This can be a remote URL.
251
+ features ([`Features`], *optional*):
252
+ Features types to use with this dataset.
253
+ It can be used to change the [`Features`] types of a dataset, for example.
254
+ token (`str` or `bool`, *optional*):
255
+ String or boolean to use as Bearer token for remote files on the
256
+ Datasets Hub. If `True`, will get token from `"~/.huggingface"`.
257
+ repo_id (`str`, *optional*):
258
+ ID of the dataset repository.
259
+ Used to distinguish builders with the same name but not coming from the same namespace, for example "rajpurkar/squad"
260
+ and "lhoestq/squad" repo IDs. In the latter, the builder name would be "lhoestq___squad".
261
+ data_files (`str` or `Sequence` or `Mapping`, *optional*):
262
+ Path(s) to source data file(s).
263
+ For builders like "csv" or "json" that need the user to specify data files. They can be either
264
+ local or remote files. For convenience, you can use a `DataFilesDict`.
265
+ data_dir (`str`, *optional*):
266
+ Path to directory containing source data file(s).
267
+ Use only if `data_files` is not passed, in which case it is equivalent to passing
268
+ `os.path.join(data_dir, "**")` as `data_files`.
269
+ For builders that require manual download, it must be the path to the local directory containing the
270
+ manually downloaded data.
271
+ storage_options (`dict`, *optional*):
272
+ Key/value pairs to be passed on to the dataset file-system backend, if any.
273
+ writer_batch_size (`int`, *optional*):
274
+ Batch size used by the ArrowWriter.
275
+ It defines the number of samples that are kept in memory before writing them
276
+ and also the length of the arrow chunks.
277
+ None means that the ArrowWriter will use its default value.
278
+ **config_kwargs (additional keyword arguments): Keyword arguments to be passed to the corresponding builder
279
+ configuration class, set on the class attribute [`DatasetBuilder.BUILDER_CONFIG_CLASS`]. The builder
280
+ configuration class is [`BuilderConfig`] or a subclass of it.
281
+ """
282
+
283
+ # Default version
284
+ VERSION = None # Default version set in BuilderConfig
285
+
286
+ # Class for the builder config.
287
+ BUILDER_CONFIG_CLASS = BuilderConfig
288
+
289
+ # Named configurations that modify the data generated by download_and_prepare.
290
+ BUILDER_CONFIGS = []
291
+
292
+ # Optional default config name to be used when name is None
293
+ DEFAULT_CONFIG_NAME = None
294
+
295
+ # Default batch size used by the ArrowWriter
296
+ # It defines the number of samples that are kept in memory before writing them
297
+ # and also the length of the arrow chunks
298
+ # None means that the ArrowWriter will use its default value
299
+ DEFAULT_WRITER_BATCH_SIZE = None
300
+
301
+ def __init__(
302
+ self,
303
+ cache_dir: Optional[str] = None,
304
+ dataset_name: Optional[str] = None,
305
+ config_name: Optional[str] = None,
306
+ hash: Optional[str] = None,
307
+ base_path: Optional[str] = None,
308
+ info: Optional[DatasetInfo] = None,
309
+ features: Optional[Features] = None,
310
+ token: Optional[Union[bool, str]] = None,
311
+ repo_id: Optional[str] = None,
312
+ data_files: Optional[Union[str, list, dict, DataFilesDict]] = None,
313
+ data_dir: Optional[str] = None,
314
+ storage_options: Optional[dict] = None,
315
+ writer_batch_size: Optional[int] = None,
316
+ config_id: Optional[str] = None,
317
+ **config_kwargs,
318
+ ):
319
+ # DatasetBuilder name
320
+ self.name: str = camelcase_to_snakecase(self.__module__.split(".")[-1])
321
+ self.hash: Optional[str] = hash
322
+ self.base_path = base_path
323
+ self.token = token
324
+ self.repo_id = repo_id
325
+ self.storage_options = storage_options or {}
326
+ self.dataset_name = camelcase_to_snakecase(dataset_name) if dataset_name else self.name
327
+ self._writer_batch_size = writer_batch_size or self.DEFAULT_WRITER_BATCH_SIZE
328
+
329
+ if data_files is not None and not isinstance(data_files, DataFilesDict):
330
+ data_files = DataFilesDict.from_patterns(
331
+ sanitize_patterns(data_files),
332
+ base_path=base_path,
333
+ download_config=DownloadConfig(token=token, storage_options=self.storage_options),
334
+ )
335
+
336
+ # Prepare config: DatasetConfig contains name, version and description but can be extended by each dataset
337
+ if "features" in inspect.signature(self.BUILDER_CONFIG_CLASS.__init__).parameters and features is not None:
338
+ config_kwargs["features"] = features
339
+ if data_files is not None:
340
+ config_kwargs["data_files"] = data_files
341
+ if data_dir is not None:
342
+ config_kwargs["data_dir"] = data_dir
343
+ self.config_kwargs = config_kwargs
344
+ self.config, self.config_id = self._create_builder_config(
345
+ config_name=config_name,
346
+ custom_features=features,
347
+ config_id=config_id,
348
+ **config_kwargs,
349
+ )
350
+
351
+ # prepare info: DatasetInfo are a standardized dataclass across all datasets
352
+ # Prefill datasetinfo
353
+ if info is None:
354
+ info = self._info()
355
+ info.builder_name = self.name
356
+ info.dataset_name = self.dataset_name
357
+ info.config_name = self.config.name
358
+ info.version = self.config.version
359
+ self.info = info
360
+ # update info with user specified infos
361
+ if features is not None:
362
+ self.info.features = features
363
+
364
+ # Prepare data dirs:
365
+ # cache_dir can be a remote bucket on GCS or S3
366
+ self._cache_dir_root = str(cache_dir or config.HF_DATASETS_CACHE)
367
+ self._cache_dir_root = (
368
+ self._cache_dir_root if is_remote_url(self._cache_dir_root) else os.path.expanduser(self._cache_dir_root)
369
+ )
370
+ self._cache_downloaded_dir = (
371
+ posixpath.join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR)
372
+ if cache_dir
373
+ else str(config.DOWNLOADED_DATASETS_PATH)
374
+ )
375
+ self._cache_downloaded_dir = (
376
+ self._cache_downloaded_dir
377
+ if is_remote_url(self._cache_downloaded_dir)
378
+ else os.path.expanduser(self._cache_downloaded_dir)
379
+ )
380
+
381
+ # In case there exists a legacy cache directory
382
+ self._legacy_relative_data_dir = None
383
+
384
+ self._cache_dir = self._build_cache_dir()
385
+ if not is_remote_url(self._cache_dir_root):
386
+ os.makedirs(self._cache_dir_root, exist_ok=True)
387
+ lock_path = os.path.join(
388
+ self._cache_dir_root, Path(self._cache_dir).as_posix().replace("/", "_") + ".lock"
389
+ )
390
+ with FileLock(lock_path):
391
+ if os.path.exists(self._cache_dir): # check if data exist
392
+ if len(os.listdir(self._cache_dir)) > 0:
393
+ if os.path.exists(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)):
394
+ logger.debug("Overwrite dataset info from restored data version if exists.")
395
+ self.info = DatasetInfo.from_directory(self._cache_dir)
396
+ else: # dir exists but no data, remove the empty dir as data aren't available anymore
397
+ logger.warning(
398
+ f"Old caching folder {self._cache_dir} for dataset {self.dataset_name} exists but no data were found. Removing it. "
399
+ )
400
+ os.rmdir(self._cache_dir)
401
+
402
+ # Store in the cache by default unless the user specifies a custom output_dir to download_and_prepare
403
+ self._output_dir = self._cache_dir
404
+ self._fs: fsspec.AbstractFileSystem = fsspec.filesystem("file")
405
+
406
+ # Set download manager
407
+ self.dl_manager = None
408
+
409
+ # Set to True by "datasets-cli test" to generate file checksums for (deprecated) dataset_infos.json independently of verification_mode value.
410
+ self._record_infos = False
411
+
412
+ # Set in `.download_and_prepare` once the format of the generated dataset is known
413
+ self._file_format = None
414
+
415
+ # Enable streaming (e.g. it patches "open" to work with remote files)
416
+ extend_dataset_builder_for_streaming(self)
417
+
418
+ def __getstate__(self):
419
+ return self.__dict__
420
+
421
+ def __setstate__(self, d):
422
+ self.__dict__ = d
423
+ # Re-enable streaming, since patched functions are not kept when pickling
424
+ extend_dataset_builder_for_streaming(self)
425
+
426
+ # Must be set for datasets that use 'data_dir' functionality - the ones
427
+ # that require users to do additional steps to download the data
428
+ # (this is usually due to some external regulations / rules).
429
+ # This field should contain a string with user instructions, including
430
+ # the list of files that should be present. It will be
431
+ # displayed in the dataset documentation.
432
+ @property
433
+ def manual_download_instructions(self) -> Optional[str]:
434
+ return None
435
+
436
+ def _check_legacy_cache(self) -> Optional[str]:
437
+ """Check for the old cache directory template {cache_dir}/{namespace}___{builder_name} from 2.13"""
438
+ if (
439
+ self.__module__.startswith("datasets.")
440
+ and not is_remote_url(self._cache_dir_root)
441
+ and self.config.name == "default"
442
+ ):
443
+ from .packaged_modules import _PACKAGED_DATASETS_MODULES
444
+
445
+ namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
446
+ config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
447
+ config_id = config_name + self.config_id[len(self.config.name) :]
448
+ hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
449
+ legacy_relative_data_dir = posixpath.join(
450
+ self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
451
+ config_id,
452
+ "0.0.0",
453
+ hash,
454
+ )
455
+ legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
456
+ if os.path.isdir(legacy_cache_dir):
457
+ return legacy_relative_data_dir
458
+
459
+ def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]:
460
+ """Check for the old cache directory template {cache_dir}/{namespace}___{dataset_name}/{config_name}-xxx from 2.14 and 2.15"""
461
+ if (
462
+ self.__module__.startswith("datasets.")
463
+ and not is_remote_url(self._cache_dir_root)
464
+ and not (set(self.config_kwargs) - {"data_files", "data_dir"})
465
+ ):
466
+ from .packaged_modules import _PACKAGED_DATASETS_MODULES_2_15_HASHES
467
+ from .utils._dill import Pickler
468
+
469
+ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
470
+ """
471
+ Used to update hash of packaged modules which is used for creating unique cache directories to reflect
472
+ different config parameters which are passed in metadata from readme.
473
+ """
474
+ params_to_exclude = {"config_name", "version", "description"}
475
+ params_to_add_to_hash = {
476
+ param: value
477
+ for param, value in sorted(config_parameters.items())
478
+ if param not in params_to_exclude
479
+ }
480
+ m = Hasher()
481
+ m.update(hash)
482
+ m.update(params_to_add_to_hash)
483
+ return m.hexdigest()
484
+
485
+ namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
486
+ with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True):
487
+ config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
488
+ hash = _PACKAGED_DATASETS_MODULES_2_15_HASHES.get(self.name, "missing")
489
+ if (
490
+ dataset_module.builder_configs_parameters.metadata_configs
491
+ and self.config.name in dataset_module.builder_configs_parameters.metadata_configs
492
+ ):
493
+ hash = update_hash_with_config_parameters(
494
+ hash, dataset_module.builder_configs_parameters.metadata_configs[self.config.name]
495
+ )
496
+ legacy_relative_data_dir = posixpath.join(
497
+ self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
498
+ config_id,
499
+ "0.0.0",
500
+ hash,
501
+ )
502
+ legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
503
+ if os.path.isdir(legacy_cache_dir):
504
+ return legacy_relative_data_dir
505
+
506
+ def _create_builder_config(
507
+ self, config_name=None, custom_features=None, config_id=None, **config_kwargs
508
+ ) -> tuple[BuilderConfig, str]:
509
+ """Create and validate BuilderConfig object as well as a unique config id for this config.
510
+ Raises ValueError if there are multiple builder configs and config_name and DEFAULT_CONFIG_NAME are None.
511
+ config_kwargs override the defaults kwargs in config
512
+ """
513
+ builder_config = None
514
+
515
+ # try default config
516
+ if config_name is None and self.BUILDER_CONFIGS:
517
+ if self.DEFAULT_CONFIG_NAME is not None:
518
+ builder_config = self.builder_configs.get(self.DEFAULT_CONFIG_NAME)
519
+ logger.info(f"No config specified, defaulting to: {self.dataset_name}/{builder_config.name}")
520
+ else:
521
+ if len(self.BUILDER_CONFIGS) > 1:
522
+ if not config_kwargs:
523
+ example_of_usage = (
524
+ f"load_dataset('{self.repo_id or self.dataset_name}', '{self.BUILDER_CONFIGS[0].name}')"
525
+ )
526
+ raise ValueError(
527
+ "Config name is missing."
528
+ f"\nPlease pick one among the available configs: {list(self.builder_configs.keys())}"
529
+ + f"\nExample of usage:\n\t`{example_of_usage}`"
530
+ )
531
+ else:
532
+ builder_config = self.BUILDER_CONFIGS[0]
533
+ logger.info(
534
+ f"No config specified, defaulting to the single config: {self.dataset_name}/{builder_config.name}"
535
+ )
536
+
537
+ # try to get config by name
538
+ if isinstance(config_name, str):
539
+ builder_config = self.builder_configs.get(config_name)
540
+ if builder_config is None and self.BUILDER_CONFIGS:
541
+ raise ValueError(
542
+ f"BuilderConfig '{config_name}' not found. Available: {list(self.builder_configs.keys())}"
543
+ )
544
+
545
+ # if not using an existing config, then create a new config on the fly
546
+ if not builder_config:
547
+ if config_name is not None:
548
+ config_kwargs["name"] = config_name
549
+ elif self.DEFAULT_CONFIG_NAME and not config_kwargs:
550
+ # Use DEFAULT_CONFIG_NAME only if no config_kwargs are passed
551
+ config_kwargs["name"] = self.DEFAULT_CONFIG_NAME
552
+ if "version" not in config_kwargs and hasattr(self, "VERSION") and self.VERSION:
553
+ config_kwargs["version"] = self.VERSION
554
+ builder_config = self.BUILDER_CONFIG_CLASS(**config_kwargs)
555
+
556
+ # otherwise use the config_kwargs to overwrite the attributes
557
+ else:
558
+ builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config
559
+ for key, value in config_kwargs.items():
560
+ if value is not None:
561
+ if not hasattr(builder_config, key):
562
+ raise ValueError(f"BuilderConfig {builder_config} doesn't have a '{key}' key.")
563
+ setattr(builder_config, key, value)
564
+
565
+ if not builder_config.name:
566
+ raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}")
567
+
568
+ # resolve data files if needed
569
+ builder_config._resolve_data_files(
570
+ base_path=self.base_path,
571
+ download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
572
+ )
573
+
574
+ # compute the config id that is going to be used for caching
575
+ if config_id is None:
576
+ config_id = builder_config.create_config_id(
577
+ config_kwargs,
578
+ custom_features=custom_features,
579
+ )
580
+ is_custom = (config_id not in self.builder_configs) and config_id != "default"
581
+ if is_custom:
582
+ logger.info(f"Using custom data configuration {config_id}")
583
+ else:
584
+ if (
585
+ builder_config.name in self.builder_configs
586
+ and builder_config != self.builder_configs[builder_config.name]
587
+ ):
588
+ raise ValueError(
589
+ "Cannot name a custom BuilderConfig the same as an available "
590
+ f"BuilderConfig. Change the name. Available BuilderConfigs: {list(self.builder_configs.keys())}"
591
+ )
592
+ if not builder_config.version:
593
+ raise ValueError(f"BuilderConfig {builder_config.name} must have a version")
594
+
595
+ return builder_config, config_id
596
+
597
+ @classproperty
598
+ @classmethod
599
+ @memoize()
600
+ def builder_configs(cls) -> dict[str, BuilderConfig]:
601
+ """Dictionary of pre-defined configurations for this builder class."""
602
+ configs = {config.name: config for config in cls.BUILDER_CONFIGS}
603
+ if len(configs) != len(cls.BUILDER_CONFIGS):
604
+ names = [config.name for config in cls.BUILDER_CONFIGS]
605
+ raise ValueError(f"Names in BUILDER_CONFIGS must not be duplicated. Got {names}")
606
+ return configs
607
+
608
+ @property
609
+ def cache_dir(self):
610
+ return self._cache_dir
611
+
612
+ def _use_legacy_cache_dir_if_possible(self, dataset_module: "DatasetModule"):
613
+ # Check for the legacy cache directory template (datasets<3.0.0)
614
+ self._legacy_relative_data_dir = (
615
+ self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None
616
+ )
617
+ self._cache_dir = self._build_cache_dir()
618
+ self._output_dir = self._cache_dir
619
+
620
+ def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
621
+ """Relative path of this dataset in cache_dir:
622
+ Will be:
623
+ self.dataset_name/self.config.version/self.hash/
624
+ or if a repo_id with a namespace has been specified:
625
+ self.namespace___self.dataset_name/self.config.version/self.hash/
626
+ If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
627
+ """
628
+ if self._legacy_relative_data_dir is not None and with_version and with_hash:
629
+ return self._legacy_relative_data_dir
630
+
631
+ namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
632
+ builder_data_dir = self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}"
633
+ builder_data_dir = posixpath.join(builder_data_dir, self.config_id)
634
+ if with_version:
635
+ builder_data_dir = posixpath.join(builder_data_dir, str(self.config.version))
636
+ if with_hash and self.hash and isinstance(self.hash, str):
637
+ builder_data_dir = posixpath.join(builder_data_dir, self.hash)
638
+ return builder_data_dir
639
+
640
+ def _build_cache_dir(self):
641
+ """Return the data directory for the current version."""
642
+ builder_data_dir = posixpath.join(self._cache_dir_root, self._relative_data_dir(with_version=False))
643
+ version_data_dir = posixpath.join(self._cache_dir_root, self._relative_data_dir(with_version=True))
644
+
645
+ def _other_versions_on_disk():
646
+ """Returns previous versions on disk."""
647
+ if not os.path.exists(builder_data_dir):
648
+ return []
649
+
650
+ version_dirnames = []
651
+ for dir_name in os.listdir(builder_data_dir):
652
+ try:
653
+ version_dirnames.append((utils.Version(dir_name), dir_name))
654
+ except ValueError: # Invalid version (ex: incomplete data dir)
655
+ pass
656
+ version_dirnames.sort(reverse=True)
657
+ return version_dirnames
658
+
659
+ # Check and warn if other versions exist
660
+ if not is_remote_url(builder_data_dir):
661
+ version_dirs = _other_versions_on_disk()
662
+ if version_dirs:
663
+ other_version = version_dirs[0][0]
664
+ if other_version != self.config.version:
665
+ warn_msg = (
666
+ f"Found a different version {str(other_version)} of dataset {self.dataset_name} in "
667
+ f"cache_dir {self._cache_dir_root}. Using currently defined version "
668
+ f"{str(self.config.version)}."
669
+ )
670
+ logger.warning(warn_msg)
671
+
672
+ return version_data_dir
673
+
674
+ @abc.abstractmethod
675
+ def _info(self) -> DatasetInfo:
676
+ """Construct the DatasetInfo object. See `DatasetInfo` for details.
677
+
678
+ Warning: This function is only called once and the result is cached for all
679
+ following .info() calls.
680
+
681
+ Returns:
682
+ info: (DatasetInfo) The dataset information
683
+ """
684
+ raise NotImplementedError
685
+
686
+ @classmethod
687
+ def get_imported_module_dir(cls):
688
+ """Return the path of the module of this class or subclass."""
689
+ return os.path.dirname(inspect.getfile(inspect.getmodule(cls)))
690
+
691
+ def _rename(self, src: str, dst: str):
692
+ rename(self._fs, src, dst)
693
+
694
+ def download_and_prepare(
695
+ self,
696
+ output_dir: Optional[str] = None,
697
+ download_config: Optional[DownloadConfig] = None,
698
+ download_mode: Optional[Union[DownloadMode, str]] = None,
699
+ verification_mode: Optional[Union[VerificationMode, str]] = None,
700
+ dl_manager: Optional[DownloadManager] = None,
701
+ base_path: Optional[str] = None,
702
+ file_format: str = "arrow",
703
+ max_shard_size: Optional[Union[int, str]] = None,
704
+ num_proc: Optional[int] = None,
705
+ storage_options: Optional[dict] = None,
706
+ **download_and_prepare_kwargs,
707
+ ):
708
+ """Downloads and prepares dataset for reading.
709
+
710
+ Args:
711
+ output_dir (`str`, *optional*):
712
+ Output directory for the dataset.
713
+ Default to this builder's `cache_dir`, which is inside `~/.cache/huggingface/datasets` by default.
714
+
715
+ <Added version="2.5.0"/>
716
+ download_config (`DownloadConfig`, *optional*):
717
+ Specific download configuration parameters.
718
+ download_mode ([`DownloadMode`] or `str`, *optional*):
719
+ Select the download/generate mode, default to `REUSE_DATASET_IF_EXISTS`.
720
+ verification_mode ([`VerificationMode`] or `str`, defaults to `BASIC_CHECKS`):
721
+ Verification mode determining the checks to run on the downloaded/processed dataset information (checksums/size/splits/...).
722
+
723
+ <Added version="2.9.1"/>
724
+ dl_manager (`DownloadManager`, *optional*):
725
+ Specific `DownloadManger` to use.
726
+ base_path (`str`, *optional*):
727
+ Base path for relative paths that are used to download files. This can be a remote url.
728
+ If not specified, the value of the `base_path` attribute (`self.base_path`) will be used instead.
729
+ file_format (`str`, *optional*):
730
+ Format of the data files in which the dataset will be written.
731
+ Supported formats: "arrow", "parquet". Default to "arrow" format.
732
+ If the format is "parquet", then image and audio data are embedded into the Parquet files instead of pointing to local files.
733
+
734
+ <Added version="2.5.0"/>
735
+ max_shard_size (`Union[str, int]`, *optional*):
736
+ Maximum number of bytes written per shard, default is "500MB".
737
+ The size is based on uncompressed data size, so in practice your shard files may be smaller than
738
+ `max_shard_size` thanks to Parquet compression for example.
739
+
740
+ <Added version="2.5.0"/>
741
+ num_proc (`int`, *optional*, defaults to `None`):
742
+ Number of processes when downloading and generating the dataset locally.
743
+ Multiprocessing is disabled by default.
744
+
745
+ <Added version="2.7.0"/>
746
+ storage_options (`dict`, *optional*):
747
+ Key/value pairs to be passed on to the caching file-system backend, if any.
748
+
749
+ <Added version="2.5.0"/>
750
+ **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments.
751
+
752
+ Example:
753
+
754
+ Download and prepare the dataset as Arrow files that can be loaded as a Dataset using `builder.as_dataset()`:
755
+
756
+ ```py
757
+ >>> from datasets import load_dataset_builder
758
+ >>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes")
759
+ >>> builder.download_and_prepare()
760
+ ```
761
+
762
+ Download and prepare the dataset as sharded Parquet files locally:
763
+
764
+ ```py
765
+ >>> from datasets import load_dataset_builder
766
+ >>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes")
767
+ >>> builder.download_and_prepare("./output_dir", file_format="parquet")
768
+ ```
769
+
770
+ Download and prepare the dataset as sharded Parquet files in a cloud storage:
771
+
772
+ ```py
773
+ >>> from datasets import load_dataset_builder
774
+ >>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key}
775
+ >>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes")
776
+ >>> builder.download_and_prepare("s3://my-bucket/my_rotten_tomatoes", storage_options=storage_options, file_format="parquet")
777
+ ```
778
+ """
779
+ output_dir = output_dir if output_dir is not None else self._cache_dir
780
+ # output_dir can be a remote bucket on GCS or S3
781
+ fs, output_dir = url_to_fs(output_dir, **(storage_options or {}))
782
+ self._fs = fs
783
+ self._output_dir = output_dir if not is_remote_filesystem(self._fs) else self._fs.unstrip_protocol(output_dir)
784
+
785
+ download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS)
786
+ verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS)
787
+ base_path = base_path if base_path is not None else self.base_path
788
+
789
+ if file_format is not None and file_format not in ["arrow", "parquet"]:
790
+ raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'")
791
+ self._file_format = file_format
792
+
793
+ if self._fs._strip_protocol(self._output_dir) == "":
794
+ # We don't support the root directory, because it has no dirname,
795
+ # and we need a dirname to use a <dirname>.incomplete directory
796
+ # when the dataset is being written
797
+ raise RuntimeError(
798
+ f"Unable to download and prepare the dataset at the root {self._output_dir}. "
799
+ f"Please specify a subdirectory, e.g. '{self._output_dir + self.dataset_name}'"
800
+ )
801
+
802
+ if dl_manager is None:
803
+ if download_config is None:
804
+ download_config = DownloadConfig(
805
+ cache_dir=self._cache_downloaded_dir,
806
+ force_download=download_mode == DownloadMode.FORCE_REDOWNLOAD,
807
+ force_extract=download_mode == DownloadMode.FORCE_REDOWNLOAD,
808
+ use_etag=False,
809
+ num_proc=num_proc,
810
+ token=self.token,
811
+ storage_options=self.storage_options,
812
+ ) # We don't use etag for data files to speed up the process
813
+
814
+ dl_manager = DownloadManager(
815
+ dataset_name=self.dataset_name,
816
+ download_config=download_config,
817
+ data_dir=self.config.data_dir,
818
+ base_path=base_path,
819
+ record_checksums=(self._record_infos or verification_mode == VerificationMode.ALL_CHECKS),
820
+ )
821
+
822
+ is_local = not is_remote_filesystem(self._fs)
823
+ self.dl_manager = dl_manager
824
+
825
+ # Prevent parallel local disk operations
826
+ if is_local:
827
+ # Create parent directory of the output_dir to put the lock file in there
828
+ Path(self._output_dir).parent.mkdir(parents=True, exist_ok=True)
829
+ lock_path = self._output_dir + "_builder.lock"
830
+
831
+ # File locking only with local paths; no file locking on GCS or S3
832
+ with FileLock(lock_path) if is_local else contextlib.nullcontext():
833
+ # Check if the data already exists
834
+ data_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME))
835
+ if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
836
+ logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})")
837
+ # We need to update the info in case some splits were added in the meantime
838
+ # for example when calling load_dataset from multiple workers.
839
+ self.info = self._load_info()
840
+ self.download_post_processing_resources(dl_manager)
841
+ return
842
+
843
+ logger.info(f"Generating dataset {self.dataset_name} ({self._output_dir})")
844
+ if is_local: # if cache dir is local, check for available space
845
+ if not has_sufficient_disk_space(
846
+ self.info.size_in_bytes or 0, directory=Path(self._output_dir).parent
847
+ ):
848
+ raise OSError(
849
+ f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})"
850
+ )
851
+
852
+ @contextlib.contextmanager
853
+ def incomplete_dir(dirname):
854
+ """Create temporary dir for dirname and rename on exit."""
855
+ if not is_local:
856
+ self._fs.makedirs(dirname, exist_ok=True)
857
+ yield dirname
858
+ else:
859
+ tmp_dir = dirname + ".incomplete"
860
+ os.makedirs(tmp_dir, exist_ok=True)
861
+ try:
862
+ yield tmp_dir
863
+ if os.path.isdir(dirname):
864
+ shutil.rmtree(dirname)
865
+ # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory
866
+ shutil.move(tmp_dir, dirname)
867
+ finally:
868
+ if os.path.exists(tmp_dir):
869
+ shutil.rmtree(tmp_dir)
870
+
871
+ # Print is intentional: we want this to always go to stdout so user has
872
+ # information needed to cancel download/preparation if needed.
873
+ # This comes right before the progress bar.
874
+ if self.info.size_in_bytes:
875
+ logger.info(
876
+ f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} "
877
+ f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, "
878
+ f"post-processed: {size_str(self.info.post_processing_size)}, "
879
+ f"total: {size_str(self.info.size_in_bytes)}) to {self._output_dir}..."
880
+ )
881
+ else:
882
+ _dest = self._fs._strip_protocol(self._output_dir) if is_local else self._output_dir
883
+ logger.info(f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} to {_dest}...")
884
+
885
+ self._check_manual_download(dl_manager)
886
+
887
+ # Create a tmp dir and rename to self._output_dir on successful exit.
888
+ with incomplete_dir(self._output_dir) as tmp_output_dir:
889
+ # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward
890
+ # it to every sub function.
891
+ with temporary_assignment(self, "_output_dir", tmp_output_dir):
892
+ prepare_split_kwargs = {"file_format": file_format}
893
+ if max_shard_size is not None:
894
+ prepare_split_kwargs["max_shard_size"] = max_shard_size
895
+ if num_proc is not None:
896
+ prepare_split_kwargs["num_proc"] = num_proc
897
+ self._download_and_prepare(
898
+ dl_manager=dl_manager,
899
+ verification_mode=verification_mode,
900
+ **prepare_split_kwargs,
901
+ **download_and_prepare_kwargs,
902
+ )
903
+ # Sync info
904
+ self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
905
+ self.info.download_checksums = dl_manager.get_recorded_sizes_checksums()
906
+ if self.info.download_size is not None:
907
+ self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
908
+ # Save info
909
+ self._save_info()
910
+
911
+ # Download post processing resources
912
+ self.download_post_processing_resources(dl_manager)
913
+
914
+ logger.info(
915
+ f"Dataset {self.dataset_name} downloaded and prepared to {self._output_dir}. "
916
+ f"Subsequent calls will reuse this data."
917
+ )
918
+
919
+ def _check_manual_download(self, dl_manager):
920
+ if self.manual_download_instructions is not None and dl_manager.manual_dir is None:
921
+ raise ManualDownloadError(
922
+ textwrap.dedent(
923
+ f"""\
924
+ The dataset {self.dataset_name} with config {self.config.name} requires manual data.
925
+ Please follow the manual download instructions:
926
+ {self.manual_download_instructions}
927
+ Manual data can be loaded with:
928
+ datasets.load_dataset("{self.repo_id or self.dataset_name}", data_dir="<path/to/manual/data>")"""
929
+ )
930
+ )
931
+
932
+ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_kwargs):
933
+ """Downloads and prepares dataset for reading.
934
+
935
+ This is the internal implementation to overwrite called when user calls
936
+ `download_and_prepare`. It should download all required data and generate
937
+ the pre-processed datasets files.
938
+
939
+ Args:
940
+ dl_manager ([`DownloadManager`]):
941
+ `DownloadManager` used to download and cache data.
942
+ verification_mode ([`VerificationMode`]):
943
+ if `ALL_CHECKS`, perform all the verifications including checksums.
944
+ if `BASIC_CHECKS`, do not perform checksums, only perform split tests.
945
+ if `NO_CHECKS`, do not perform any verification.
946
+ prepare_split_kwargs: Additional options, such as `file_format`, `max_shard_size`
947
+ """
948
+ # Generating data for all splits
949
+ split_dict = SplitDict(dataset_name=self.dataset_name)
950
+ split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs)
951
+ split_generators = self._split_generators(dl_manager, **split_generators_kwargs)
952
+
953
+ # Checksums verification
954
+ if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums:
955
+ verify_checksums(
956
+ self.info.download_checksums, dl_manager.get_recorded_sizes_checksums(), "dataset source files"
957
+ )
958
+
959
+ # Build splits
960
+ for split_generator in split_generators:
961
+ if str(split_generator.split_info.name).lower() == "all":
962
+ raise ValueError(
963
+ "`all` is a special split keyword corresponding to the "
964
+ "union of all splits, so cannot be used as key in "
965
+ "._split_generator()."
966
+ )
967
+
968
+ logger.info(f"Generating {split_generator.split_info.name} split")
969
+ split_dict.add(split_generator.split_info)
970
+
971
+ try:
972
+ # Prepare split will record examples associated to the split
973
+ self._prepare_split(split_generator, **prepare_split_kwargs)
974
+ except OSError as e:
975
+ raise OSError(
976
+ "Cannot find data file. "
977
+ + (self.manual_download_instructions or "")
978
+ + "\nOriginal error:\n"
979
+ + str(e)
980
+ ) from None
981
+ # If check_duplicates is set to True , then except DuplicatedKeysError
982
+ except DuplicatedKeysError as e:
983
+ raise DuplicatedKeysError(
984
+ e.key,
985
+ e.duplicate_key_indices,
986
+ fix_msg=f"To avoid duplicate keys, please fix the dataset splits for {self.name}",
987
+ ) from None
988
+ dl_manager.manage_extracted_files()
989
+
990
+ if verification_mode == VerificationMode.BASIC_CHECKS or verification_mode == VerificationMode.ALL_CHECKS:
991
+ verify_splits(self.info.splits, split_dict)
992
+
993
+ # Update the info object with the splits.
994
+ self.info.splits = split_dict
995
+ self.info.download_size = dl_manager.downloaded_size
996
+
997
+ def download_post_processing_resources(self, dl_manager):
998
+ for split in self.info.splits or []:
999
+ for resource_name, resource_file_name in self._post_processing_resources(split).items():
1000
+ if not not is_remote_filesystem(self._fs):
1001
+ raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}")
1002
+ if os.sep in resource_file_name:
1003
+ raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
1004
+ resource_path = os.path.join(self._output_dir, resource_file_name)
1005
+ if not os.path.exists(resource_path):
1006
+ downloaded_resource_path = self._download_post_processing_resources(
1007
+ split, resource_name, dl_manager
1008
+ )
1009
+ if downloaded_resource_path:
1010
+ logger.info(f"Downloaded post-processing resource {resource_name} as {resource_file_name}")
1011
+ shutil.move(downloaded_resource_path, resource_path)
1012
+
1013
+ def _load_info(self) -> DatasetInfo:
1014
+ return DatasetInfo.from_directory(self._output_dir, storage_options=self._fs.storage_options)
1015
+
1016
+ def _save_info(self):
1017
+ file_lock = (
1018
+ FileLock(self._output_dir + "_info.lock")
1019
+ if not is_remote_filesystem(self._fs)
1020
+ else contextlib.nullcontext()
1021
+ )
1022
+ with file_lock:
1023
+ self.info.write_to_directory(self._output_dir, storage_options=self._fs.storage_options)
1024
+
1025
+ def _make_split_generators_kwargs(self, prepare_split_kwargs):
1026
+ """Get kwargs for `self._split_generators()` from `prepare_split_kwargs`."""
1027
+ del prepare_split_kwargs
1028
+ return {}
1029
+
1030
+ def as_dataset(
1031
+ self,
1032
+ split: Optional[Union[str, Split, list[str], list[Split]]] = None,
1033
+ run_post_process=True,
1034
+ verification_mode: Optional[Union[VerificationMode, str]] = None,
1035
+ in_memory=False,
1036
+ ) -> Union[Dataset, DatasetDict]:
1037
+ """Return a Dataset for the specified split.
1038
+
1039
+ Args:
1040
+ split (`datasets.Split`):
1041
+ Which subset of the data to return.
1042
+ run_post_process (`bool`, defaults to `True`):
1043
+ Whether to run post-processing dataset transforms and/or add
1044
+ indexes.
1045
+ verification_mode ([`VerificationMode`] or `str`, defaults to `BASIC_CHECKS`):
1046
+ Verification mode determining the checks to run on the
1047
+ downloaded/processed dataset information (checksums/size/splits/...).
1048
+
1049
+ <Added version="2.9.1"/>
1050
+ in_memory (`bool`, defaults to `False`):
1051
+ Whether to copy the data in-memory.
1052
+
1053
+ Returns:
1054
+ datasets.Dataset
1055
+
1056
+ Example:
1057
+
1058
+ ```py
1059
+ >>> from datasets import load_dataset_builder
1060
+ >>> builder = load_dataset_builder('cornell-movie-review-data/rotten_tomatoes')
1061
+ >>> builder.download_and_prepare()
1062
+ >>> ds = builder.as_dataset(split='train')
1063
+ >>> ds
1064
+ Dataset({
1065
+ features: ['text', 'label'],
1066
+ num_rows: 8530
1067
+ })
1068
+ ```
1069
+ """
1070
+ if self._file_format is not None and self._file_format != "arrow":
1071
+ raise FileFormatError('Loading a dataset not written in the "arrow" format is not supported.')
1072
+ if is_remote_filesystem(self._fs):
1073
+ raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.")
1074
+ if not os.path.exists(self._output_dir):
1075
+ raise FileNotFoundError(
1076
+ f"Dataset {self.dataset_name}: could not find data in {self._output_dir}. Please make sure to call "
1077
+ "builder.download_and_prepare(), or use "
1078
+ "datasets.load_dataset() before trying to access the Dataset object."
1079
+ )
1080
+
1081
+ logger.debug(f"Constructing Dataset for split {split or ', '.join(self.info.splits)}, from {self._output_dir}")
1082
+
1083
+ # By default, return all splits
1084
+ if split is None:
1085
+ split = {s: s for s in self.info.splits}
1086
+
1087
+ verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS)
1088
+
1089
+ # Create a dataset for each of the given splits
1090
+ datasets = map_nested(
1091
+ partial(
1092
+ self._build_single_dataset,
1093
+ run_post_process=run_post_process,
1094
+ verification_mode=verification_mode,
1095
+ in_memory=in_memory,
1096
+ ),
1097
+ split,
1098
+ map_tuple=True,
1099
+ disable_tqdm=True,
1100
+ )
1101
+ if isinstance(datasets, dict):
1102
+ datasets = DatasetDict(datasets)
1103
+ return datasets
1104
+
1105
+ def _build_single_dataset(
1106
+ self,
1107
+ split: Union[str, ReadInstruction, Split],
1108
+ run_post_process: bool,
1109
+ verification_mode: VerificationMode,
1110
+ in_memory: bool = False,
1111
+ ):
1112
+ """as_dataset for a single split."""
1113
+ if not isinstance(split, ReadInstruction):
1114
+ split = str(split)
1115
+ if split == "all":
1116
+ split = "+".join(self.info.splits.keys())
1117
+ split = Split(split)
1118
+
1119
+ # Build base dataset
1120
+ ds = self._as_dataset(
1121
+ split=split,
1122
+ in_memory=in_memory,
1123
+ )
1124
+ if run_post_process:
1125
+ for resource_file_name in self._post_processing_resources(split).values():
1126
+ if os.sep in resource_file_name:
1127
+ raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
1128
+ resources_paths = {
1129
+ resource_name: os.path.join(self._output_dir, resource_file_name)
1130
+ for resource_name, resource_file_name in self._post_processing_resources(split).items()
1131
+ }
1132
+ post_processed = self._post_process(ds, resources_paths)
1133
+ if post_processed is not None:
1134
+ ds = post_processed
1135
+ recorded_checksums = {}
1136
+ record_checksums = False
1137
+ for resource_name, resource_path in resources_paths.items():
1138
+ size_checksum = get_size_checksum_dict(resource_path)
1139
+ recorded_checksums[resource_name] = size_checksum
1140
+ if verification_mode == VerificationMode.ALL_CHECKS and record_checksums:
1141
+ if self.info.post_processed is None or self.info.post_processed.resources_checksums is None:
1142
+ expected_checksums = None
1143
+ else:
1144
+ expected_checksums = self.info.post_processed.resources_checksums.get(split)
1145
+ verify_checksums(expected_checksums, recorded_checksums, "post processing resources")
1146
+ if self.info.post_processed is None:
1147
+ self.info.post_processed = PostProcessedInfo()
1148
+ if self.info.post_processed.resources_checksums is None:
1149
+ self.info.post_processed.resources_checksums = {}
1150
+ self.info.post_processed.resources_checksums[str(split)] = recorded_checksums
1151
+ self.info.post_processing_size = sum(
1152
+ checksums_dict["num_bytes"]
1153
+ for split_checksums_dicts in self.info.post_processed.resources_checksums.values()
1154
+ for checksums_dict in split_checksums_dicts.values()
1155
+ )
1156
+ if self.info.dataset_size is not None and self.info.download_size is not None:
1157
+ self.info.size_in_bytes = (
1158
+ self.info.dataset_size + self.info.download_size + self.info.post_processing_size
1159
+ )
1160
+ self._save_info()
1161
+ ds._info.post_processed = self.info.post_processed
1162
+ ds._info.post_processing_size = self.info.post_processing_size
1163
+ ds._info.size_in_bytes = self.info.size_in_bytes
1164
+ if self.info.post_processed.features is not None:
1165
+ if self.info.post_processed.features.type != ds.features.type:
1166
+ raise ValueError(
1167
+ f"Post-processed features info don't match the dataset:\nGot\n{self.info.post_processed.features}\nbut expected something like\n{ds.features}"
1168
+ )
1169
+ else:
1170
+ ds.info.features = self.info.post_processed.features
1171
+
1172
+ return ds
1173
+
1174
+ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_memory: bool = False) -> Dataset:
1175
+ """Constructs a `Dataset`.
1176
+
1177
+ This is the internal implementation to overwrite called when user calls
1178
+ `as_dataset`. It should read the pre-processed datasets files and generate
1179
+ the `Dataset` object.
1180
+
1181
+ Args:
1182
+ split (`datasets.Split`):
1183
+ which subset of the data to read.
1184
+ in_memory (`bool`, defaults to `False`):
1185
+ Whether to copy the data in-memory.
1186
+
1187
+ Returns:
1188
+ `Dataset`
1189
+ """
1190
+ cache_dir = self._fs._strip_protocol(self._output_dir)
1191
+ dataset_name = self.dataset_name
1192
+ if self._check_legacy_cache():
1193
+ dataset_name = self.name
1194
+ dataset_kwargs = ArrowReader(cache_dir, self.info).read(
1195
+ name=dataset_name,
1196
+ instructions=split,
1197
+ split_infos=self.info.splits.values(),
1198
+ in_memory=in_memory,
1199
+ )
1200
+ fingerprint = self._get_dataset_fingerprint(split)
1201
+ return Dataset(fingerprint=fingerprint, **dataset_kwargs)
1202
+
1203
+ def _get_dataset_fingerprint(self, split: Union[ReadInstruction, Split]) -> str:
1204
+ """The dataset fingerprint is the hash of the relative directory dataset_name/config_name/version/hash, as well as the split specs."""
1205
+ hasher = Hasher()
1206
+ hasher.update(Path(self._relative_data_dir()).as_posix())
1207
+ hasher.update(str(split)) # for example: train, train+test, train[:10%], test[:33%](pct1_dropremainder)
1208
+ fingerprint = hasher.hexdigest()
1209
+ return fingerprint
1210
+
1211
+ def as_streaming_dataset(
1212
+ self,
1213
+ split: Optional[str] = None,
1214
+ base_path: Optional[str] = None,
1215
+ ) -> Union[dict[str, IterableDataset], IterableDataset]:
1216
+ if is_remote_filesystem(self._fs):
1217
+ raise NotImplementedError(
1218
+ f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet."
1219
+ )
1220
+
1221
+ dl_manager = StreamingDownloadManager(
1222
+ base_path=base_path or self.base_path,
1223
+ download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
1224
+ dataset_name=self.dataset_name,
1225
+ data_dir=self.config.data_dir,
1226
+ )
1227
+ self._check_manual_download(dl_manager)
1228
+ splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager)}
1229
+ # By default, return all splits
1230
+ if split is None:
1231
+ splits_generator = splits_generators
1232
+ elif split in splits_generators:
1233
+ splits_generator = splits_generators[split]
1234
+ else:
1235
+ raise ValueError(f"Bad split: {split}. Available splits: {list(splits_generators)}")
1236
+
1237
+ # Create a dataset for each of the given splits
1238
+ datasets = map_nested(
1239
+ self._as_streaming_dataset_single,
1240
+ splits_generator,
1241
+ map_tuple=True,
1242
+ )
1243
+ if isinstance(datasets, dict):
1244
+ datasets = IterableDatasetDict(datasets)
1245
+ return datasets
1246
+
1247
+ def _as_streaming_dataset_single(
1248
+ self,
1249
+ splits_generator,
1250
+ ) -> IterableDataset:
1251
+ ex_iterable = self._get_examples_iterable_for_split(splits_generator)
1252
+ # add auth to be able to access and decode audio/image files from private repositories.
1253
+ token_per_repo_id = {self.repo_id: self.token} if self.repo_id else {}
1254
+ return IterableDataset(
1255
+ ex_iterable, info=self.info, split=splits_generator.name, token_per_repo_id=token_per_repo_id
1256
+ )
1257
+
1258
+ def _post_process(self, dataset: Dataset, resources_paths: Mapping[str, str]) -> Optional[Dataset]:
1259
+ """Run dataset transforms or add indexes"""
1260
+ return None
1261
+
1262
+ def _post_processing_resources(self, split: str) -> dict[str, str]:
1263
+ """Mapping resource_name -> resource_file_name"""
1264
+ return {}
1265
+
1266
+ def _download_post_processing_resources(
1267
+ self, split: str, resource_name: str, dl_manager: DownloadManager
1268
+ ) -> Optional[str]:
1269
+ """Download the resource using the download manager and return the downloaded path."""
1270
+ return None
1271
+
1272
+ @abc.abstractmethod
1273
+ def _split_generators(self, dl_manager: Union[DownloadManager, StreamingDownloadManager]):
1274
+ """Specify feature dictionary generators and dataset splits.
1275
+
1276
+ This function returns a list of `SplitGenerator`s defining how to generate
1277
+ data and what splits to use.
1278
+
1279
+ Example:
1280
+
1281
+ return [
1282
+ datasets.SplitGenerator(
1283
+ name=datasets.Split.TRAIN,
1284
+ gen_kwargs={'file': 'train_data.zip'},
1285
+ ),
1286
+ datasets.SplitGenerator(
1287
+ name=datasets.Split.TEST,
1288
+ gen_kwargs={'file': 'test_data.zip'},
1289
+ ),
1290
+ ]
1291
+
1292
+ The above code will first call `_generate_examples(file='train_data.zip')`
1293
+ to write the train data, then `_generate_examples(file='test_data.zip')` to
1294
+ write the test data.
1295
+
1296
+ Datasets are typically split into different subsets to be used at various
1297
+ stages of training and evaluation.
1298
+
1299
+ Note that for datasets without a `VALIDATION` split, you can use a
1300
+ fraction of the `TRAIN` data for evaluation as you iterate on your model
1301
+ so as not to overfit to the `TEST` data.
1302
+
1303
+ For downloads and extractions, use the given `download_manager`.
1304
+ Note that the `DownloadManager` caches downloads, so it is fine to have each
1305
+ generator attempt to download the source data.
1306
+
1307
+ A good practice is to download all data in this function, and then
1308
+ distribute the relevant parts to each split with the `gen_kwargs` argument
1309
+
1310
+ Args:
1311
+ dl_manager (`Union[DownloadManager, StreamingDownloadManager]`):
1312
+ Download manager to download the data
1313
+
1314
+ Returns:
1315
+ `list<SplitGenerator>`.
1316
+ """
1317
+ raise NotImplementedError()
1318
+
1319
+ @abc.abstractmethod
1320
+ def _prepare_split(
1321
+ self,
1322
+ split_generator: SplitGenerator,
1323
+ file_format: str = "arrow",
1324
+ max_shard_size: Optional[Union[str, int]] = None,
1325
+ num_proc: Optional[int] = None,
1326
+ **kwargs,
1327
+ ):
1328
+ """Generate the examples and record them on disk.
1329
+
1330
+ Args:
1331
+ split_generator (`SplitGenerator`):
1332
+ Split generator to process
1333
+ file_format (`str`, *optional*):
1334
+ format of the data files in which the dataset will be written.
1335
+ Supported formats: "arrow", "parquet". Default to "arrow" format.
1336
+ max_shard_size (`Union[str, int]`, *optional*):
1337
+ Maximum number of bytes written per shard, default is "500MB".
1338
+ The size is based on uncompressed data size, so in practice your shard files may be smaller than
1339
+ `max_shard_size` thanks to Parquet compression for example.
1340
+ num_proc (`int`, *optional*, defaults to `None`):
1341
+ Number of processes when downloading and generating the dataset locally.
1342
+ Multiprocessing is disabled by default.
1343
+
1344
+ <Added version="2.7.0"/>
1345
+ **kwargs: Additional kwargs forwarded from _download_and_prepare
1346
+ """
1347
+ raise NotImplementedError()
1348
+
1349
+ def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
1350
+ """Generate the examples on the fly.
1351
+
1352
+ Args:
1353
+ split_generator (`SplitGenerator`):
1354
+ Split generator to process
1355
+ """
1356
+ raise NotImplementedError()
1357
+
1358
+
1359
+ class GeneratorBasedBuilder(DatasetBuilder):
1360
+ """Base class for datasets with data generation based on dict generators.
1361
+
1362
+ `GeneratorBasedBuilder` is a convenience class that abstracts away much
1363
+ of the data writing and reading of `DatasetBuilder`. It expects subclasses to
1364
+ implement generators of feature dictionaries across the dataset splits
1365
+ (`_split_generators`). See the method docstrings for details.
1366
+ """
1367
+
1368
+ @abc.abstractmethod
1369
+ def _generate_examples(self, **kwargs):
1370
+ """Default function generating examples for each `SplitGenerator`.
1371
+
1372
+ This function preprocess the examples from the raw data to the preprocessed
1373
+ dataset files.
1374
+ This function is called once for each `SplitGenerator` defined in
1375
+ `_split_generators`. The examples yielded here will be written on
1376
+ disk.
1377
+
1378
+ Args:
1379
+ **kwargs (additional keyword arguments):
1380
+ Arguments forwarded from the SplitGenerator.gen_kwargs
1381
+
1382
+ Yields:
1383
+ key: `str` or `int`, a unique deterministic example identification key.
1384
+ * Unique: An error will be raised if two examples are yield with the
1385
+ same key.
1386
+ * Deterministic: When generating the dataset twice, the same example
1387
+ should have the same key.
1388
+ Good keys can be the image id, or line number if examples are extracted
1389
+ from a text file.
1390
+ The key will be hashed and sorted to shuffle examples deterministically,
1391
+ such as generating the dataset multiple times keep examples in the
1392
+ same order.
1393
+ example: `dict<str feature_name, feature_value>`, a feature dictionary
1394
+ ready to be encoded and written to disk. The example will be
1395
+ encoded with `self.info.features.encode_example({...})`.
1396
+ """
1397
+ raise NotImplementedError()
1398
+
1399
+ def _prepare_split(
1400
+ self,
1401
+ split_generator: SplitGenerator,
1402
+ check_duplicate_keys: bool,
1403
+ file_format="arrow",
1404
+ num_proc: Optional[int] = None,
1405
+ max_shard_size: Optional[Union[int, str]] = None,
1406
+ ):
1407
+ max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
1408
+
1409
+ if self.info.splits is not None:
1410
+ split_info = self.info.splits[split_generator.name]
1411
+ else:
1412
+ split_info = split_generator.split_info
1413
+
1414
+ SUFFIX = "-JJJJJ-SSSSS-of-NNNNN"
1415
+ fname = f"{self.dataset_name}-{split_generator.name}{SUFFIX}.{file_format}"
1416
+ fpath = posixpath.join(self._output_dir, fname)
1417
+
1418
+ if num_proc and num_proc > 1:
1419
+ num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs)
1420
+ if num_input_shards <= 1:
1421
+ logger.warning(
1422
+ f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard."
1423
+ )
1424
+ num_proc = 1
1425
+ elif num_input_shards < num_proc:
1426
+ logger.warning(
1427
+ f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards."
1428
+ )
1429
+ num_proc = num_input_shards
1430
+
1431
+ pbar = hf_tqdm(
1432
+ unit=" examples",
1433
+ total=split_info.num_examples,
1434
+ desc=f"Generating {split_info.name} split",
1435
+ )
1436
+
1437
+ _prepare_split_args = {
1438
+ "fpath": fpath,
1439
+ "file_format": file_format,
1440
+ "max_shard_size": max_shard_size,
1441
+ "split_info": split_info,
1442
+ "check_duplicate_keys": check_duplicate_keys,
1443
+ }
1444
+
1445
+ if num_proc is None or num_proc == 1:
1446
+ result = None
1447
+ gen_kwargs = split_generator.gen_kwargs
1448
+ job_id = 0
1449
+ with pbar:
1450
+ for job_id, done, content in self._prepare_split_single(
1451
+ gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args
1452
+ ):
1453
+ if done:
1454
+ result = content
1455
+ else:
1456
+ pbar.update(content)
1457
+ # wrapping everything into lists for consistency with the multiprocessed code path
1458
+ assert result is not None, "Failed to retrieve results from prepare_split"
1459
+ examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = (
1460
+ [item] for item in result
1461
+ )
1462
+ else:
1463
+ kwargs_per_job = [
1464
+ {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args}
1465
+ for job_id, gen_kwargs in enumerate(
1466
+ _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc)
1467
+ )
1468
+ ]
1469
+ num_jobs = len(kwargs_per_job)
1470
+
1471
+ examples_per_job = [None] * num_jobs
1472
+ bytes_per_job = [None] * num_jobs
1473
+ features_per_job = [None] * num_jobs
1474
+ shards_per_job = [None] * num_jobs
1475
+ shard_lengths_per_job = [None] * num_jobs
1476
+
1477
+ with Pool(num_proc) as pool:
1478
+ with pbar:
1479
+ for job_id, done, content in iflatmap_unordered(
1480
+ pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
1481
+ ):
1482
+ if done:
1483
+ # the content is the result of the job
1484
+ (
1485
+ examples_per_job[job_id],
1486
+ bytes_per_job[job_id],
1487
+ features_per_job[job_id],
1488
+ shards_per_job[job_id],
1489
+ shard_lengths_per_job[job_id],
1490
+ ) = content
1491
+ else:
1492
+ # the content is the number of examples progress update
1493
+ pbar.update(content)
1494
+
1495
+ assert None not in examples_per_job, (
1496
+ f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
1497
+ )
1498
+
1499
+ total_shards = sum(shards_per_job)
1500
+ total_num_examples = sum(examples_per_job)
1501
+ total_num_bytes = sum(bytes_per_job)
1502
+ features = features_per_job[0]
1503
+
1504
+ split_generator.split_info.num_examples = total_num_examples
1505
+ split_generator.split_info.num_bytes = total_num_bytes
1506
+
1507
+ # should rename everything at the end
1508
+ logger.debug(f"Renaming {total_shards} shards.")
1509
+ if total_shards > 1:
1510
+ # use the -SSSSS-of-NNNNN pattern
1511
+
1512
+ def _rename_shard(shard_and_job: tuple[int]):
1513
+ shard_id, job_id = shard_and_job
1514
+ global_shard_id = sum(shards_per_job[:job_id]) + shard_id
1515
+ self._rename(
1516
+ fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1517
+ fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"),
1518
+ )
1519
+
1520
+ shards_and_jobs = [
1521
+ (shard_id, job_id)
1522
+ for job_id, num_shards in enumerate(shards_per_job)
1523
+ for shard_id in range(num_shards)
1524
+ ]
1525
+ thread_map(_rename_shard, shards_and_jobs, disable=True, max_workers=64)
1526
+
1527
+ split_generator.split_info.shard_lengths = [
1528
+ shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths
1529
+ ]
1530
+ else:
1531
+ # don't use any pattern
1532
+ shard_id, job_id = 0, 0
1533
+ self._rename(
1534
+ fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1535
+ fpath.replace(SUFFIX, ""),
1536
+ )
1537
+
1538
+ if self.info.features is None:
1539
+ self.info.features = features
1540
+
1541
+ def _prepare_split_single(
1542
+ self,
1543
+ gen_kwargs: dict,
1544
+ fpath: str,
1545
+ file_format: str,
1546
+ max_shard_size: int,
1547
+ split_info: SplitInfo,
1548
+ check_duplicate_keys: bool,
1549
+ job_id: int,
1550
+ ) -> Iterable[tuple[int, bool, Union[int, tuple]]]:
1551
+ generator = self._generate_examples(**gen_kwargs)
1552
+ writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
1553
+ embed_local_files = file_format == "parquet"
1554
+ shard_lengths = []
1555
+ total_num_examples, total_num_bytes = 0, 0
1556
+
1557
+ shard_id = 0
1558
+ num_examples_progress_update = 0
1559
+ try:
1560
+ writer = writer_class(
1561
+ features=self.info.features,
1562
+ path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1563
+ writer_batch_size=self._writer_batch_size,
1564
+ hash_salt=split_info.name,
1565
+ check_duplicates=check_duplicate_keys,
1566
+ storage_options=self._fs.storage_options,
1567
+ embed_local_files=embed_local_files,
1568
+ )
1569
+ try:
1570
+ _time = time.time()
1571
+ for key, record in generator:
1572
+ if max_shard_size is not None and writer._num_bytes > max_shard_size:
1573
+ num_examples, num_bytes = writer.finalize()
1574
+ writer.close()
1575
+ shard_lengths.append(num_examples)
1576
+ total_num_examples += num_examples
1577
+ total_num_bytes += num_bytes
1578
+ shard_id += 1
1579
+ writer = writer_class(
1580
+ features=writer._features,
1581
+ path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1582
+ writer_batch_size=self._writer_batch_size,
1583
+ hash_salt=split_info.name,
1584
+ check_duplicates=check_duplicate_keys,
1585
+ storage_options=self._fs.storage_options,
1586
+ embed_local_files=embed_local_files,
1587
+ )
1588
+ example = self.info.features.encode_example(record) if self.info.features is not None else record
1589
+ writer.write(example, key)
1590
+ num_examples_progress_update += 1
1591
+ if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
1592
+ _time = time.time()
1593
+ yield job_id, False, num_examples_progress_update
1594
+ num_examples_progress_update = 0
1595
+ finally:
1596
+ yield job_id, False, num_examples_progress_update
1597
+ num_shards = shard_id + 1
1598
+ num_examples, num_bytes = writer.finalize()
1599
+ writer.close()
1600
+ shard_lengths.append(num_examples)
1601
+ total_num_examples += num_examples
1602
+ total_num_bytes += num_bytes
1603
+ except Exception as e:
1604
+ # Ignore the writer's error for no examples written to the file if this error was caused by the error in _generate_examples before the first example was yielded
1605
+ if isinstance(e, SchemaInferenceError) and e.__context__ is not None:
1606
+ e = e.__context__
1607
+ raise DatasetGenerationError("An error occurred while generating the dataset") from e
1608
+
1609
+ yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)
1610
+
1611
+ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_kwargs):
1612
+ super()._download_and_prepare(
1613
+ dl_manager,
1614
+ verification_mode,
1615
+ check_duplicate_keys=verification_mode == VerificationMode.BASIC_CHECKS
1616
+ or verification_mode == VerificationMode.ALL_CHECKS,
1617
+ **prepare_splits_kwargs,
1618
+ )
1619
+
1620
+ def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
1621
+ return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs)
1622
+
1623
+
1624
+ class ArrowBasedBuilder(DatasetBuilder):
1625
+ """Base class for datasets with data generation based on Arrow loading functions (CSV/JSON/Parquet)."""
1626
+
1627
+ @abc.abstractmethod
1628
+ def _generate_tables(self, **kwargs):
1629
+ """Default function generating examples for each `SplitGenerator`.
1630
+
1631
+ This function preprocess the examples from the raw data to the preprocessed
1632
+ dataset files.
1633
+ This function is called once for each `SplitGenerator` defined in
1634
+ `_split_generators`. The examples yielded here will be written on
1635
+ disk.
1636
+
1637
+ Args:
1638
+ **kwargs (additional keyword arguments):
1639
+ Arguments forwarded from the SplitGenerator.gen_kwargs
1640
+
1641
+ Yields:
1642
+ key: `str` or `int`, a unique deterministic example identification key.
1643
+ * Unique: An error will be raised if two examples are yield with the
1644
+ same key.
1645
+ * Deterministic: When generating the dataset twice, the same example
1646
+ should have the same key.
1647
+ Good keys can be the image id, or line number if examples are extracted
1648
+ from a text file.
1649
+ The key will be hashed and sorted to shuffle examples deterministically,
1650
+ such as generating the dataset multiple times keep examples in the
1651
+ same order.
1652
+ example: `pyarrow.Table`, a feature table
1653
+ ready to be encoded and written to disk.
1654
+ """
1655
+ raise NotImplementedError()
1656
+
1657
+ def _prepare_split(
1658
+ self,
1659
+ split_generator: SplitGenerator,
1660
+ file_format: str = "arrow",
1661
+ num_proc: Optional[int] = None,
1662
+ max_shard_size: Optional[Union[str, int]] = None,
1663
+ ):
1664
+ max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
1665
+
1666
+ try:
1667
+ split_info = self.info.splits[split_generator.name]
1668
+ except Exception:
1669
+ split_info = split_generator.split_info
1670
+
1671
+ SUFFIX = "-JJJJJ-SSSSS-of-NNNNN"
1672
+ fname = f"{self.dataset_name}-{split_generator.name}{SUFFIX}.{file_format}"
1673
+ fpath = posixpath.join(self._output_dir, fname)
1674
+
1675
+ if num_proc and num_proc > 1:
1676
+ num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs)
1677
+ if num_input_shards <= 1:
1678
+ logger.warning(
1679
+ f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard."
1680
+ )
1681
+ num_proc = 1
1682
+ elif num_input_shards < num_proc:
1683
+ logger.warning(
1684
+ f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards."
1685
+ )
1686
+ num_proc = num_input_shards
1687
+
1688
+ pbar = hf_tqdm(
1689
+ unit=" examples",
1690
+ total=split_info.num_examples,
1691
+ desc=f"Generating {split_info.name} split",
1692
+ )
1693
+
1694
+ _prepare_split_args = {
1695
+ "fpath": fpath,
1696
+ "file_format": file_format,
1697
+ "max_shard_size": max_shard_size,
1698
+ }
1699
+
1700
+ if num_proc is None or num_proc == 1:
1701
+ result = None
1702
+ gen_kwargs = split_generator.gen_kwargs
1703
+ job_id = 0
1704
+ with pbar:
1705
+ for job_id, done, content in self._prepare_split_single(
1706
+ gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args
1707
+ ):
1708
+ if done:
1709
+ result = content
1710
+ else:
1711
+ pbar.update(content)
1712
+ # wrapping everything into lists for consistency with the multiprocessed code path
1713
+ assert result is not None, "Failed to retrieve results from prepare_split"
1714
+ examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = (
1715
+ [item] for item in result
1716
+ )
1717
+ else:
1718
+ kwargs_per_job = [
1719
+ {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args}
1720
+ for job_id, gen_kwargs in enumerate(
1721
+ _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc)
1722
+ )
1723
+ ]
1724
+ num_jobs = len(kwargs_per_job)
1725
+
1726
+ examples_per_job = [None] * num_jobs
1727
+ bytes_per_job = [None] * num_jobs
1728
+ features_per_job = [None] * num_jobs
1729
+ shards_per_job = [None] * num_jobs
1730
+ shard_lengths_per_job = [None] * num_jobs
1731
+
1732
+ with Pool(num_proc) as pool:
1733
+ with pbar:
1734
+ for job_id, done, content in iflatmap_unordered(
1735
+ pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
1736
+ ):
1737
+ if done:
1738
+ # the content is the result of the job
1739
+ (
1740
+ examples_per_job[job_id],
1741
+ bytes_per_job[job_id],
1742
+ features_per_job[job_id],
1743
+ shards_per_job[job_id],
1744
+ shard_lengths_per_job[job_id],
1745
+ ) = content
1746
+ else:
1747
+ # the content is the number of examples progress update
1748
+ pbar.update(content)
1749
+
1750
+ assert None not in examples_per_job, (
1751
+ f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
1752
+ )
1753
+
1754
+ total_shards = sum(shards_per_job)
1755
+ total_num_examples = sum(examples_per_job)
1756
+ total_num_bytes = sum(bytes_per_job)
1757
+ features = features_per_job[0]
1758
+
1759
+ split_generator.split_info.num_examples = total_num_examples
1760
+ split_generator.split_info.num_bytes = total_num_bytes
1761
+
1762
+ # should rename everything at the end
1763
+ logger.debug(f"Renaming {total_shards} shards.")
1764
+ if total_shards > 1:
1765
+ # use the -SSSSS-of-NNNNN pattern
1766
+
1767
+ def _rename_shard(shard_id_and_job: tuple[int]):
1768
+ shard_id, job_id = shard_id_and_job
1769
+ global_shard_id = sum(shards_per_job[:job_id]) + shard_id
1770
+ self._rename(
1771
+ fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1772
+ fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"),
1773
+ )
1774
+
1775
+ shard_ids_and_jobs = [
1776
+ (shard_id, job_id)
1777
+ for job_id, num_shards in enumerate(shards_per_job)
1778
+ for shard_id in range(num_shards)
1779
+ ]
1780
+ thread_map(_rename_shard, shard_ids_and_jobs, disable=True, max_workers=64)
1781
+
1782
+ split_generator.split_info.shard_lengths = [
1783
+ shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths
1784
+ ]
1785
+ else:
1786
+ # don't use any pattern
1787
+ shard_id, job_id = 0, 0
1788
+ self._rename(
1789
+ fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1790
+ fpath.replace(SUFFIX, ""),
1791
+ )
1792
+
1793
+ if self.info.features is None:
1794
+ self.info.features = features
1795
+
1796
+ def _prepare_split_single(
1797
+ self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int
1798
+ ) -> Iterable[tuple[int, bool, Union[int, tuple]]]:
1799
+ gen_kwargs = {k: tracked_list(v) if isinstance(v, list) else v for k, v in gen_kwargs.items()}
1800
+ generator = self._generate_tables(**gen_kwargs)
1801
+ writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
1802
+ embed_local_files = file_format == "parquet"
1803
+ shard_lengths = []
1804
+ total_num_examples, total_num_bytes = 0, 0
1805
+
1806
+ shard_id = 0
1807
+ num_examples_progress_update = 0
1808
+ try:
1809
+ writer = writer_class(
1810
+ features=self.info.features,
1811
+ path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1812
+ writer_batch_size=self._writer_batch_size,
1813
+ storage_options=self._fs.storage_options,
1814
+ embed_local_files=embed_local_files,
1815
+ )
1816
+ try:
1817
+ _time = time.time()
1818
+ for _, table in generator:
1819
+ if max_shard_size is not None and writer._num_bytes > max_shard_size:
1820
+ num_examples, num_bytes = writer.finalize()
1821
+ writer.close()
1822
+ shard_lengths.append(num_examples)
1823
+ total_num_examples += num_examples
1824
+ total_num_bytes += num_bytes
1825
+ shard_id += 1
1826
+ writer = writer_class(
1827
+ features=writer._features,
1828
+ path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1829
+ writer_batch_size=self._writer_batch_size,
1830
+ storage_options=self._fs.storage_options,
1831
+ embed_local_files=embed_local_files,
1832
+ )
1833
+ try:
1834
+ writer.write_table(table)
1835
+ except CastError as cast_error:
1836
+ raise DatasetGenerationCastError.from_cast_error(
1837
+ cast_error=cast_error,
1838
+ builder_name=self.info.builder_name,
1839
+ gen_kwargs=gen_kwargs,
1840
+ token=self.token,
1841
+ )
1842
+ num_examples_progress_update += len(table)
1843
+ if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
1844
+ _time = time.time()
1845
+ yield job_id, False, num_examples_progress_update
1846
+ num_examples_progress_update = 0
1847
+ finally:
1848
+ yield job_id, False, num_examples_progress_update
1849
+ num_shards = shard_id + 1
1850
+ num_examples, num_bytes = writer.finalize()
1851
+ writer.close()
1852
+ shard_lengths.append(num_examples)
1853
+ total_num_examples += num_examples
1854
+ total_num_bytes += num_bytes
1855
+ except Exception as e:
1856
+ # Ignore the writer's error for no examples written to the file if this error was caused by the error in _generate_examples before the first example was yielded
1857
+ if isinstance(e, SchemaInferenceError) and e.__context__ is not None:
1858
+ e = e.__context__
1859
+ if isinstance(e, DatasetGenerationError):
1860
+ raise
1861
+ raise DatasetGenerationError("An error occurred while generating the dataset") from e
1862
+
1863
+ yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)
1864
+
1865
+ def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
1866
+ return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs)
.venv/lib/python3.11/site-packages/datasets/combine.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, TypeVar
2
+
3
+ from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets
4
+ from .dataset_dict import DatasetDict, IterableDatasetDict
5
+ from .info import DatasetInfo
6
+ from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets
7
+ from .splits import NamedSplit
8
+ from .utils import logging
9
+ from .utils.py_utils import Literal
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)
16
+
17
+
18
+ def interleave_datasets(
19
+ datasets: list[DatasetType],
20
+ probabilities: Optional[list[float]] = None,
21
+ seed: Optional[int] = None,
22
+ info: Optional[DatasetInfo] = None,
23
+ split: Optional[NamedSplit] = None,
24
+ stopping_strategy: Literal[
25
+ "first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
26
+ ] = "first_exhausted",
27
+ ) -> DatasetType:
28
+ """
29
+ Interleave several datasets (sources) into a single dataset.
30
+ The new dataset is constructed by alternating between the sources to get the examples.
31
+
32
+ You can use this function on a list of [`Dataset`] objects, or on a list of [`IterableDataset`] objects.
33
+
34
+ - If `probabilities` is `None` (default) the new dataset is constructed by cycling between each source to get the examples.
35
+ - If `probabilities` is not `None`, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.
36
+
37
+ The resulting dataset ends when one of the source datasets runs out of examples except when `oversampling` is `True`,
38
+ in which case, the resulting dataset ends when all datasets have ran out of examples at least one time.
39
+
40
+ Note for iterable datasets:
41
+
42
+ In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
43
+ Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
44
+
45
+ Args:
46
+ datasets (`List[Dataset]` or `List[IterableDataset]`):
47
+ List of datasets to interleave.
48
+ probabilities (`List[float]`, *optional*, defaults to `None`):
49
+ If specified, the new dataset is constructed by sampling
50
+ examples from one source at a time according to these probabilities.
51
+ seed (`int`, *optional*, defaults to `None`):
52
+ The random seed used to choose a source for each example.
53
+ info ([`DatasetInfo`], *optional*):
54
+ Dataset information, like description, citation, etc.
55
+ <Added version="2.4.0"/>
56
+ split ([`NamedSplit`], *optional*):
57
+ Name of the dataset split.
58
+ <Added version="2.4.0"/>
59
+ stopping_strategy (`str`, defaults to `first_exhausted`):
60
+ Three strategies are proposed right now, `first_exhausted`, `all_exhausted` and `all_exhausted_without_replacement`.
61
+ By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
62
+ If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
63
+ When strategy is `all_exhausted_without_replacement` we make sure that each sample in each dataset is sampled only once.
64
+ Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
65
+ - with no probabilities, the resulting dataset will have `max_length_datasets*nb_dataset` samples.
66
+ - with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
67
+ Returns:
68
+ [`Dataset`] or [`IterableDataset`]: Return type depends on the input `datasets`
69
+ parameter. `Dataset` if the input is a list of `Dataset`, `IterableDataset` if the input is a list of
70
+ `IterableDataset`.
71
+
72
+ Example:
73
+
74
+ For regular datasets (map-style):
75
+
76
+ ```python
77
+ >>> from datasets import Dataset, interleave_datasets
78
+ >>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
79
+ >>> d2 = Dataset.from_dict({"a": [10, 11, 12]})
80
+ >>> d3 = Dataset.from_dict({"a": [20, 21, 22]})
81
+ >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
82
+ >>> dataset["a"]
83
+ [10, 0, 11, 1, 2, 20, 12, 10, 0, 1, 2, 21, 0, 11, 1, 2, 0, 1, 12, 2, 10, 0, 22]
84
+ >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
85
+ >>> dataset["a"]
86
+ [10, 0, 11, 1, 2]
87
+ >>> dataset = interleave_datasets([d1, d2, d3])
88
+ >>> dataset["a"]
89
+ [0, 10, 20, 1, 11, 21, 2, 12, 22]
90
+ >>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
91
+ >>> dataset["a"]
92
+ [0, 10, 20, 1, 11, 21, 2, 12, 22]
93
+ >>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
94
+ >>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
95
+ >>> d3 = Dataset.from_dict({"a": [20, 21, 22, 23, 24]})
96
+ >>> dataset = interleave_datasets([d1, d2, d3])
97
+ >>> dataset["a"]
98
+ [0, 10, 20, 1, 11, 21, 2, 12, 22]
99
+ >>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
100
+ >>> dataset["a"]
101
+ [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 23, 1, 10, 24]
102
+ >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
103
+ >>> dataset["a"]
104
+ [10, 0, 11, 1, 2]
105
+ >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
106
+ >>> dataset["a"]
107
+ [10, 0, 11, 1, 2, 20, 12, 13, ..., 0, 1, 2, 0, 24]
108
+ For datasets in streaming mode (iterable):
109
+
110
+ >>> from datasets import interleave_datasets
111
+ >>> d1 = load_dataset('allenai/c4', 'es', split='train', streaming=True)
112
+ >>> d2 = load_dataset('allenai/c4', 'fr', split='train', streaming=True)
113
+ >>> dataset = interleave_datasets([d1, d2])
114
+ >>> iterator = iter(dataset)
115
+ >>> next(iterator)
116
+ {'text': 'Comprar Zapatillas para niña en chancla con goma por...'}
117
+ >>> next(iterator)
118
+ {'text': 'Le sacre de philippe ier, 23 mai 1059 - Compte Rendu...'
119
+ ```
120
+ """
121
+ from .arrow_dataset import Dataset
122
+ from .iterable_dataset import IterableDataset
123
+
124
+ if not datasets:
125
+ raise ValueError("Unable to interleave an empty list of datasets.")
126
+ for i, dataset in enumerate(datasets):
127
+ if not isinstance(dataset, (Dataset, IterableDataset)):
128
+ if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
129
+ if not dataset:
130
+ raise ValueError(
131
+ f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} "
132
+ "is an empty dataset dictionary."
133
+ )
134
+ raise ValueError(
135
+ f"Dataset at position {i} has at least one split: {list(dataset)}\n"
136
+ f"Please pick one to interleave with the other datasets, for example: dataset['{next(iter(dataset))}']"
137
+ )
138
+ raise ValueError(
139
+ f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} is a {type(dataset).__name__}."
140
+ )
141
+ if i == 0:
142
+ dataset_type, other_type = (
143
+ (Dataset, IterableDataset) if isinstance(dataset, Dataset) else (IterableDataset, Dataset)
144
+ )
145
+ elif not isinstance(dataset, dataset_type):
146
+ raise ValueError(
147
+ f"Unable to interleave a {dataset_type.__name__} (at position 0) with a {other_type.__name__} (at position {i}). Expected a list of Dataset objects or a list of IterableDataset objects."
148
+ )
149
+ if stopping_strategy not in ["first_exhausted", "all_exhausted", "all_exhausted_without_replacement"]:
150
+ raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.")
151
+ if dataset_type is Dataset:
152
+ return _interleave_map_style_datasets(
153
+ datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
154
+ )
155
+ else:
156
+ return _interleave_iterable_datasets(
157
+ datasets,
158
+ probabilities,
159
+ seed,
160
+ info=info,
161
+ split=split,
162
+ stopping_strategy=stopping_strategy,
163
+ )
164
+
165
+
166
+ def concatenate_datasets(
167
+ dsets: list[DatasetType],
168
+ info: Optional[DatasetInfo] = None,
169
+ split: Optional[NamedSplit] = None,
170
+ axis: int = 0,
171
+ ) -> DatasetType:
172
+ """
173
+ Converts a list of [`Dataset`] with the same schema into a single [`Dataset`].
174
+
175
+ Args:
176
+ dsets (`List[datasets.Dataset]`):
177
+ List of Datasets to concatenate.
178
+ info (`DatasetInfo`, *optional*):
179
+ Dataset information, like description, citation, etc.
180
+ split (`NamedSplit`, *optional*):
181
+ Name of the dataset split.
182
+ axis (`{0, 1}`, defaults to `0`):
183
+ Axis to concatenate over, where `0` means over rows (vertically) and `1` means over columns
184
+ (horizontally).
185
+
186
+ <Added version="1.6.0"/>
187
+
188
+ Example:
189
+
190
+ ```py
191
+ >>> ds3 = concatenate_datasets([ds1, ds2])
192
+ ```
193
+ """
194
+
195
+ if not dsets:
196
+ raise ValueError("Unable to concatenate an empty list of datasets.")
197
+ for i, dataset in enumerate(dsets):
198
+ if not isinstance(dataset, (Dataset, IterableDataset)):
199
+ if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
200
+ if not dataset:
201
+ raise ValueError(
202
+ f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} "
203
+ "is an empty dataset dictionary."
204
+ )
205
+ raise ValueError(
206
+ f"Dataset at position {i} has at least one split: {list(dataset)}\n"
207
+ f"Please pick one to interleave with the other datasets, for example: dataset['{next(iter(dataset))}']"
208
+ )
209
+ raise ValueError(
210
+ f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} is a {type(dataset).__name__}."
211
+ )
212
+ if i == 0:
213
+ dataset_type, other_type = (
214
+ (Dataset, IterableDataset) if isinstance(dataset, Dataset) else (IterableDataset, Dataset)
215
+ )
216
+ elif not isinstance(dataset, dataset_type):
217
+ raise ValueError(
218
+ f"Unable to interleave a {dataset_type.__name__} (at position 0) with a {other_type.__name__} (at position {i}). Expected a list of Dataset objects or a list of IterableDataset objects."
219
+ )
220
+ if dataset_type is Dataset:
221
+ return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis)
222
+ else:
223
+ return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis)
.venv/lib/python3.11/site-packages/datasets/commands/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from argparse import ArgumentParser
3
+
4
+
5
+ class BaseDatasetsCLICommand(ABC):
6
+ @staticmethod
7
+ @abstractmethod
8
+ def register_subcommand(parser: ArgumentParser):
9
+ raise NotImplementedError()
10
+
11
+ @abstractmethod
12
+ def run(self):
13
+ raise NotImplementedError()
.venv/lib/python3.11/site-packages/datasets/commands/datasets_cli.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ from argparse import ArgumentParser
3
+
4
+ from datasets.commands.delete_from_hub import DeleteFromHubCommand
5
+ from datasets.commands.env import EnvironmentCommand
6
+ from datasets.commands.test import TestCommand
7
+ from datasets.utils.logging import set_verbosity_info
8
+
9
+
10
+ def parse_unknown_args(unknown_args):
11
+ return {key.lstrip("-"): value for key, value in zip(unknown_args[::2], unknown_args[1::2])}
12
+
13
+
14
+ def main():
15
+ parser = ArgumentParser(
16
+ "HuggingFace Datasets CLI tool", usage="datasets-cli <command> [<args>]", allow_abbrev=False
17
+ )
18
+ commands_parser = parser.add_subparsers(help="datasets-cli command helpers")
19
+ set_verbosity_info()
20
+
21
+ # Register commands
22
+ EnvironmentCommand.register_subcommand(commands_parser)
23
+ TestCommand.register_subcommand(commands_parser)
24
+ DeleteFromHubCommand.register_subcommand(commands_parser)
25
+
26
+ # Parse args
27
+ args, unknown_args = parser.parse_known_args()
28
+ if not hasattr(args, "func"):
29
+ parser.print_help()
30
+ exit(1)
31
+ kwargs = parse_unknown_args(unknown_args)
32
+
33
+ # Run
34
+ service = args.func(args, **kwargs)
35
+ service.run()
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main()
.venv/lib/python3.11/site-packages/datasets/commands/delete_from_hub.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from typing import Optional
3
+
4
+ from datasets.commands import BaseDatasetsCLICommand
5
+ from datasets.hub import delete_from_hub
6
+
7
+
8
+ def _command_factory(args):
9
+ return DeleteFromHubCommand(
10
+ args.dataset_id,
11
+ args.config_name,
12
+ args.token,
13
+ args.revision,
14
+ )
15
+
16
+
17
+ class DeleteFromHubCommand(BaseDatasetsCLICommand):
18
+ @staticmethod
19
+ def register_subcommand(parser):
20
+ parser: ArgumentParser = parser.add_parser("delete_from_hub", help="Delete dataset config from the Hub")
21
+ parser.add_argument(
22
+ "dataset_id", help="source dataset ID, e.g. USERNAME/DATASET_NAME or ORGANIZATION/DATASET_NAME"
23
+ )
24
+ parser.add_argument("config_name", help="config name to delete")
25
+ parser.add_argument("--token", help="access token to the Hugging Face Hub")
26
+ parser.add_argument("--revision", help="source revision")
27
+ parser.set_defaults(func=_command_factory)
28
+
29
+ def __init__(
30
+ self,
31
+ dataset_id: str,
32
+ config_name: str,
33
+ token: Optional[str],
34
+ revision: Optional[str],
35
+ ):
36
+ self._dataset_id = dataset_id
37
+ self._config_name = config_name
38
+ self._token = token
39
+ self._revision = revision
40
+
41
+ def run(self) -> None:
42
+ _ = delete_from_hub(self._dataset_id, self._config_name, revision=self._revision, token=self._token)
.venv/lib/python3.11/site-packages/datasets/commands/env.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from argparse import ArgumentParser
3
+
4
+ import fsspec
5
+ import huggingface_hub
6
+ import pandas
7
+ import pyarrow
8
+
9
+ from datasets import __version__ as version
10
+ from datasets.commands import BaseDatasetsCLICommand
11
+
12
+
13
+ def info_command_factory(_):
14
+ return EnvironmentCommand()
15
+
16
+
17
+ class EnvironmentCommand(BaseDatasetsCLICommand):
18
+ @staticmethod
19
+ def register_subcommand(parser: ArgumentParser):
20
+ download_parser = parser.add_parser("env", help="Print relevant system environment info.")
21
+ download_parser.set_defaults(func=info_command_factory)
22
+
23
+ def run(self):
24
+ info = {
25
+ "`datasets` version": version,
26
+ "Platform": platform.platform(),
27
+ "Python version": platform.python_version(),
28
+ "`huggingface_hub` version": huggingface_hub.__version__,
29
+ "PyArrow version": pyarrow.__version__,
30
+ "Pandas version": pandas.__version__,
31
+ "`fsspec` version": fsspec.__version__,
32
+ }
33
+
34
+ print("\nCopy-and-paste the text below in your GitHub issue.\n")
35
+ print(self.format_dict(info))
36
+
37
+ return info
38
+
39
+ @staticmethod
40
+ def format_dict(d):
41
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
.venv/lib/python3.11/site-packages/datasets/commands/test.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from argparse import ArgumentParser
4
+ from collections.abc import Generator
5
+ from shutil import rmtree
6
+
7
+ import datasets.config
8
+ from datasets.builder import DatasetBuilder
9
+ from datasets.commands import BaseDatasetsCLICommand
10
+ from datasets.download.download_manager import DownloadMode
11
+ from datasets.info import DatasetInfosDict
12
+ from datasets.load import dataset_module_factory, get_dataset_builder_class
13
+ from datasets.utils.info_utils import VerificationMode
14
+ from datasets.utils.logging import ERROR, get_logger
15
+
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def _test_command_factory(args):
21
+ return TestCommand(
22
+ args.dataset,
23
+ args.name,
24
+ args.cache_dir,
25
+ args.data_dir,
26
+ args.all_configs,
27
+ args.save_info or args.save_infos,
28
+ args.ignore_verifications,
29
+ args.force_redownload,
30
+ args.clear_cache,
31
+ args.num_proc,
32
+ )
33
+
34
+
35
+ class TestCommand(BaseDatasetsCLICommand):
36
+ __test__ = False # to tell pytest it's not a test class
37
+
38
+ @staticmethod
39
+ def register_subcommand(parser: ArgumentParser):
40
+ test_parser = parser.add_parser("test", help="Test dataset loading.")
41
+ test_parser.add_argument("--name", type=str, default=None, help="Dataset processing name")
42
+ test_parser.add_argument(
43
+ "--cache_dir",
44
+ type=str,
45
+ default=None,
46
+ help="Cache directory where the datasets are stored.",
47
+ )
48
+ test_parser.add_argument(
49
+ "--data_dir",
50
+ type=str,
51
+ default=None,
52
+ help="Can be used to specify a manual directory to get the files from.",
53
+ )
54
+ test_parser.add_argument("--all_configs", action="store_true", help="Test all dataset configurations")
55
+ test_parser.add_argument(
56
+ "--save_info", action="store_true", help="Save the dataset infos in the dataset card (README.md)"
57
+ )
58
+ test_parser.add_argument(
59
+ "--ignore_verifications",
60
+ action="store_true",
61
+ help="Run the test without checksums and splits checks.",
62
+ )
63
+ test_parser.add_argument("--force_redownload", action="store_true", help="Force dataset redownload")
64
+ test_parser.add_argument(
65
+ "--clear_cache",
66
+ action="store_true",
67
+ help="Remove downloaded files and cached datasets after each config test",
68
+ )
69
+ test_parser.add_argument("--num_proc", type=int, default=None, help="Number of processes")
70
+ # aliases
71
+ test_parser.add_argument("--save_infos", action="store_true", help="alias to save_info")
72
+ test_parser.add_argument("dataset", type=str, help="Name of the dataset to download")
73
+ test_parser.set_defaults(func=_test_command_factory)
74
+
75
+ def __init__(
76
+ self,
77
+ dataset: str,
78
+ name: str,
79
+ cache_dir: str,
80
+ data_dir: str,
81
+ all_configs: bool,
82
+ save_infos: bool,
83
+ ignore_verifications: bool,
84
+ force_redownload: bool,
85
+ clear_cache: bool,
86
+ num_proc: int,
87
+ ):
88
+ self._dataset = dataset
89
+ self._name = name
90
+ self._cache_dir = cache_dir
91
+ self._data_dir = data_dir
92
+ self._all_configs = all_configs
93
+ self._save_infos = save_infos
94
+ self._ignore_verifications = ignore_verifications
95
+ self._force_redownload = force_redownload
96
+ self._clear_cache = clear_cache
97
+ self._num_proc = num_proc
98
+ if clear_cache and not cache_dir:
99
+ print(
100
+ "When --clear_cache is used, specifying a cache directory is mandatory.\n"
101
+ "The 'download' folder of the cache directory and the dataset builder cache will be deleted after each configuration test.\n"
102
+ "Please provide a --cache_dir that will be used to test the dataset."
103
+ )
104
+ exit(1)
105
+ if save_infos:
106
+ self._ignore_verifications = True
107
+
108
+ def run(self):
109
+ logging.getLogger("filelock").setLevel(ERROR)
110
+ if self._name is not None and self._all_configs:
111
+ print("Both parameters `config` and `all_configs` can't be used at once.")
112
+ exit(1)
113
+ path, config_name = self._dataset, self._name
114
+ module = dataset_module_factory(path)
115
+ builder_cls = get_dataset_builder_class(module)
116
+ n_builders = len(builder_cls.BUILDER_CONFIGS) if self._all_configs and builder_cls.BUILDER_CONFIGS else 1
117
+
118
+ def get_builders() -> Generator[DatasetBuilder, None, None]:
119
+ if self._all_configs and builder_cls.BUILDER_CONFIGS:
120
+ for i, config in enumerate(builder_cls.BUILDER_CONFIGS):
121
+ if "config_name" in module.builder_kwargs:
122
+ yield builder_cls(
123
+ cache_dir=self._cache_dir,
124
+ data_dir=self._data_dir,
125
+ **module.builder_kwargs,
126
+ )
127
+ else:
128
+ yield builder_cls(
129
+ config_name=config.name,
130
+ cache_dir=self._cache_dir,
131
+ data_dir=self._data_dir,
132
+ **module.builder_kwargs,
133
+ )
134
+ else:
135
+ if "config_name" in module.builder_kwargs:
136
+ yield builder_cls(cache_dir=self._cache_dir, data_dir=self._data_dir, **module.builder_kwargs)
137
+ else:
138
+ yield builder_cls(
139
+ config_name=config_name,
140
+ cache_dir=self._cache_dir,
141
+ data_dir=self._data_dir,
142
+ **module.builder_kwargs,
143
+ )
144
+
145
+ for j, builder in enumerate(get_builders()):
146
+ print(f"Testing builder '{builder.config.name}' ({j + 1}/{n_builders})")
147
+ builder._record_infos = os.path.exists(
148
+ os.path.join(builder.get_imported_module_dir(), datasets.config.DATASETDICT_INFOS_FILENAME)
149
+ ) # record checksums only if we need to update a (deprecated) dataset_infos.json
150
+ builder.download_and_prepare(
151
+ download_mode=DownloadMode.REUSE_CACHE_IF_EXISTS
152
+ if not self._force_redownload
153
+ else DownloadMode.FORCE_REDOWNLOAD,
154
+ verification_mode=VerificationMode.NO_CHECKS
155
+ if self._ignore_verifications
156
+ else VerificationMode.ALL_CHECKS,
157
+ num_proc=self._num_proc,
158
+ )
159
+ builder.as_dataset()
160
+
161
+ # If save_infos=True, we create the dataset card (README.md)
162
+ # The dataset_infos are saved in the YAML part of the README.md
163
+ # This is to allow the user to upload them on HF afterwards.
164
+ if self._save_infos:
165
+ save_infos_dir = os.path.basename(path) if not os.path.isdir(path) else path
166
+ os.makedirs(save_infos_dir, exist_ok=True)
167
+ DatasetInfosDict(**{builder.config.name: builder.info}).write_to_directory(save_infos_dir)
168
+ print(f"Dataset card saved at {os.path.join(save_infos_dir, datasets.config.REPOCARD_FILENAME)}")
169
+
170
+ # If clear_cache=True, the download folder and the dataset builder cache directory are deleted
171
+ if self._clear_cache:
172
+ if os.path.isdir(builder._cache_dir):
173
+ logger.warning(f"Clearing cache at {builder._cache_dir}")
174
+ rmtree(builder._cache_dir)
175
+ download_dir = os.path.join(self._cache_dir, datasets.config.DOWNLOADED_DATASETS_DIR)
176
+ if os.path.isdir(download_dir):
177
+ logger.warning(f"Clearing cache at {download_dir}")
178
+ rmtree(download_dir)
179
+
180
+ print("Test successful.")
.venv/lib/python3.11/site-packages/datasets/config.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import importlib.metadata
3
+ import logging
4
+ import os
5
+ import platform
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ from huggingface_hub import constants
10
+ from packaging import version
11
+
12
+
13
+ logger = logging.getLogger(__name__.split(".", 1)[0]) # to avoid circular import from .utils.logging
14
+
15
+ # Datasets
16
+ S3_DATASETS_BUCKET_PREFIX = "https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets"
17
+ CLOUDFRONT_DATASETS_DISTRIB_PREFIX = "https://cdn-datasets.huggingface.co/datasets/datasets"
18
+ REPO_DATASETS_URL = "https://raw.githubusercontent.com/huggingface/datasets/{revision}/datasets/{path}/{name}"
19
+
20
+ # Hub
21
+ HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
22
+ HUB_DATASETS_URL = HF_ENDPOINT + "/datasets/{repo_id}/resolve/{revision}/{path}"
23
+ HUB_DATASETS_HFFS_URL = "hf://datasets/{repo_id}@{revision}/{path}"
24
+ HUB_DEFAULT_VERSION = "main"
25
+
26
+ PY_VERSION = version.parse(platform.python_version())
27
+
28
+ # General environment variables accepted values for booleans
29
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
30
+ ENV_VARS_FALSE_VALUES = {"0", "OFF", "NO", "FALSE"}
31
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
32
+ ENV_VARS_FALSE_AND_AUTO_VALUES = ENV_VARS_FALSE_VALUES.union({"AUTO"})
33
+
34
+
35
+ # Imports
36
+ DILL_VERSION = version.parse(importlib.metadata.version("dill"))
37
+ FSSPEC_VERSION = version.parse(importlib.metadata.version("fsspec"))
38
+ PANDAS_VERSION = version.parse(importlib.metadata.version("pandas"))
39
+ PYARROW_VERSION = version.parse(importlib.metadata.version("pyarrow"))
40
+ HF_HUB_VERSION = version.parse(importlib.metadata.version("huggingface_hub"))
41
+
42
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
43
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
44
+ USE_JAX = os.environ.get("USE_JAX", "AUTO").upper()
45
+
46
+ TORCH_VERSION = "N/A"
47
+ TORCH_AVAILABLE = False
48
+
49
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
50
+ TORCH_AVAILABLE = importlib.util.find_spec("torch") is not None
51
+ if TORCH_AVAILABLE:
52
+ try:
53
+ TORCH_VERSION = version.parse(importlib.metadata.version("torch"))
54
+ logger.debug(f"PyTorch version {TORCH_VERSION} available.")
55
+ except importlib.metadata.PackageNotFoundError:
56
+ pass
57
+ else:
58
+ logger.info("Disabling PyTorch because USE_TF is set")
59
+
60
+ POLARS_VERSION = "N/A"
61
+ POLARS_AVAILABLE = importlib.util.find_spec("polars") is not None
62
+
63
+ if POLARS_AVAILABLE:
64
+ try:
65
+ POLARS_VERSION = version.parse(importlib.metadata.version("polars"))
66
+ logger.debug(f"Polars version {POLARS_VERSION} available.")
67
+ except importlib.metadata.PackageNotFoundError:
68
+ pass
69
+
70
+
71
+ DUCKDB_VERSION = "N/A"
72
+ DUCKDB_AVAILABLE = importlib.util.find_spec("duckdb") is not None
73
+
74
+ if DUCKDB_AVAILABLE:
75
+ try:
76
+ DUCKDB_VERSION = version.parse(importlib.metadata.version("duckdb"))
77
+ logger.debug(f"Duckdb version {DUCKDB_VERSION} available.")
78
+ except importlib.metadata.PackageNotFoundError:
79
+ pass
80
+
81
+ TF_VERSION = "N/A"
82
+ TF_AVAILABLE = False
83
+
84
+ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
85
+ TF_AVAILABLE = importlib.util.find_spec("tensorflow") is not None
86
+ if TF_AVAILABLE:
87
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
88
+ for package in [
89
+ "tensorflow",
90
+ "tensorflow-cpu",
91
+ "tensorflow-gpu",
92
+ "tf-nightly",
93
+ "tf-nightly-cpu",
94
+ "tf-nightly-gpu",
95
+ "intel-tensorflow",
96
+ "tensorflow-rocm",
97
+ "tensorflow-macos",
98
+ ]:
99
+ try:
100
+ TF_VERSION = version.parse(importlib.metadata.version(package))
101
+ except importlib.metadata.PackageNotFoundError:
102
+ continue
103
+ else:
104
+ break
105
+ else:
106
+ TF_AVAILABLE = False
107
+ if TF_AVAILABLE:
108
+ if TF_VERSION.major < 2:
109
+ logger.info(f"TensorFlow found but with version {TF_VERSION}. `datasets` requires version 2 minimum.")
110
+ TF_AVAILABLE = False
111
+ else:
112
+ logger.info(f"TensorFlow version {TF_VERSION} available.")
113
+ else:
114
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
115
+
116
+
117
+ JAX_VERSION = "N/A"
118
+ JAX_AVAILABLE = False
119
+
120
+ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
121
+ JAX_AVAILABLE = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("jaxlib") is not None
122
+ if JAX_AVAILABLE:
123
+ try:
124
+ JAX_VERSION = version.parse(importlib.metadata.version("jax"))
125
+ logger.info(f"JAX version {JAX_VERSION} available.")
126
+ except importlib.metadata.PackageNotFoundError:
127
+ pass
128
+ else:
129
+ logger.info("Disabling JAX because USE_JAX is set to False")
130
+
131
+
132
+ # Optional tools for data loading
133
+ SQLALCHEMY_AVAILABLE = importlib.util.find_spec("sqlalchemy") is not None
134
+
135
+ # Optional tools for feature decoding
136
+ PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None
137
+ IS_OPUS_SUPPORTED = True
138
+ IS_MP3_SUPPORTED = True
139
+ TORCHCODEC_AVAILABLE = importlib.util.find_spec("torchcodec") is not None
140
+ TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
141
+ PDFPLUMBER_AVAILABLE = importlib.util.find_spec("pdfplumber") is not None
142
+ NIBABEL_AVAILABLE = importlib.util.find_spec("nibabel") is not None
143
+
144
+ # Optional compression tools
145
+ RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None
146
+ ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None
147
+ LZ4_AVAILABLE = importlib.util.find_spec("lz4") is not None
148
+ PY7ZR_AVAILABLE = importlib.util.find_spec("py7zr") is not None
149
+
150
+ # Cache location
151
+ DEFAULT_XDG_CACHE_HOME = "~/.cache"
152
+ XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME)
153
+ DEFAULT_HF_CACHE_HOME = os.path.join(XDG_CACHE_HOME, "huggingface")
154
+ HF_CACHE_HOME = os.path.expanduser(os.getenv("HF_HOME", DEFAULT_HF_CACHE_HOME))
155
+
156
+ DEFAULT_HF_DATASETS_CACHE = os.path.join(HF_CACHE_HOME, "datasets")
157
+ HF_DATASETS_CACHE = Path(os.getenv("HF_DATASETS_CACHE", DEFAULT_HF_DATASETS_CACHE))
158
+
159
+ DEFAULT_HF_MODULES_CACHE = os.path.join(HF_CACHE_HOME, "modules")
160
+ HF_MODULES_CACHE = Path(os.getenv("HF_MODULES_CACHE", DEFAULT_HF_MODULES_CACHE))
161
+
162
+ DOWNLOADED_DATASETS_DIR = "downloads"
163
+ DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(HF_DATASETS_CACHE, DOWNLOADED_DATASETS_DIR)
164
+ DOWNLOADED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_DOWNLOADED_DATASETS_PATH", DEFAULT_DOWNLOADED_DATASETS_PATH))
165
+
166
+ EXTRACTED_DATASETS_DIR = "extracted"
167
+ DEFAULT_EXTRACTED_DATASETS_PATH = os.path.join(DEFAULT_DOWNLOADED_DATASETS_PATH, EXTRACTED_DATASETS_DIR)
168
+ EXTRACTED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_EXTRACTED_DATASETS_PATH", DEFAULT_EXTRACTED_DATASETS_PATH))
169
+
170
+ # Download count for the website
171
+ HF_UPDATE_DOWNLOAD_COUNTS = (
172
+ os.environ.get("HF_UPDATE_DOWNLOAD_COUNTS", "AUTO").upper() in ENV_VARS_TRUE_AND_AUTO_VALUES
173
+ )
174
+
175
+ # For downloads and to check remote files metadata
176
+ HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16
177
+
178
+ # Dataset viewer API
179
+ USE_PARQUET_EXPORT = True
180
+
181
+ # Batch size constants. For more info, see:
182
+ # https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations)
183
+ DEFAULT_MAX_BATCH_SIZE = 1000
184
+
185
+ DEFAULT_CDC_OPTIONS = {"min_chunk_size": 256 * 1024, "max_chunk_size": 1024 * 1024, "norm_level": 0}
186
+
187
+ # Size of the preloaded record batch in `Dataset.__iter__`
188
+ ARROW_READER_BATCH_SIZE_IN_DATASET_ITER = 10
189
+
190
+ # Max uncompressed shard size in bytes (e.g. to shard parquet datasets in push_to_hub or download_and_prepare)
191
+ MAX_SHARD_SIZE = "500MB"
192
+
193
+ # Max uncompressed row group size in bytes (e.g. for parquet files in push_to_hub or download_and_prepare)
194
+ MAX_ROW_GROUP_SIZE = "100MB"
195
+
196
+ # Parquet configuration
197
+ PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS = None
198
+ PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = None
199
+ PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = None
200
+ PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = None
201
+
202
+ # Arrow configuration
203
+ ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS = 100
204
+ ARROW_RECORD_BATCH_SIZE_FOR_IMAGE_DATASETS = 100
205
+ ARROW_RECORD_BATCH_SIZE_FOR_BINARY_DATASETS = 100
206
+ ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS = 10
207
+
208
+ # Offline mode
209
+ _offline = os.environ.get("HF_DATASETS_OFFLINE")
210
+ HF_HUB_OFFLINE = constants.HF_HUB_OFFLINE if _offline is None else _offline.upper() in ENV_VARS_TRUE_VALUES
211
+ HF_DATASETS_OFFLINE = HF_HUB_OFFLINE # kept for backward-compatibility
212
+
213
+ # Here, `True` will disable progress bars globally without possibility of enabling it
214
+ # programmatically. `False` will enable them without possibility of disabling them.
215
+ # If environment variable is not set (None), then the user is free to enable/disable
216
+ # them programmatically.
217
+ # TL;DR: env variable has priority over code
218
+ __HF_DATASETS_DISABLE_PROGRESS_BARS = os.environ.get("HF_DATASETS_DISABLE_PROGRESS_BARS")
219
+ HF_DATASETS_DISABLE_PROGRESS_BARS: Optional[bool] = (
220
+ __HF_DATASETS_DISABLE_PROGRESS_BARS.upper() in ENV_VARS_TRUE_VALUES
221
+ if __HF_DATASETS_DISABLE_PROGRESS_BARS is not None
222
+ else None
223
+ )
224
+
225
+ # In-memory
226
+ DEFAULT_IN_MEMORY_MAX_SIZE = 0 # Disabled
227
+ IN_MEMORY_MAX_SIZE = float(os.environ.get("HF_DATASETS_IN_MEMORY_MAX_SIZE", DEFAULT_IN_MEMORY_MAX_SIZE))
228
+
229
+ # File names
230
+ DATASET_ARROW_FILENAME = "dataset.arrow"
231
+ DATASET_INDICES_FILENAME = "indices.arrow"
232
+ DATASET_STATE_JSON_FILENAME = "state.json"
233
+ DATASET_INFO_FILENAME = "dataset_info.json"
234
+ DATASETDICT_INFOS_FILENAME = "dataset_infos.json"
235
+ LICENSE_FILENAME = "LICENSE"
236
+ DATASETDICT_JSON_FILENAME = "dataset_dict.json"
237
+ METADATA_CONFIGS_FIELD = "configs"
238
+ REPOCARD_FILENAME = "README.md"
239
+ REPOYAML_FILENAME = ".huggingface.yaml"
240
+
241
+ MODULE_NAME_FOR_DYNAMIC_MODULES = "datasets_modules"
242
+
243
+ MAX_DATASET_CONFIG_ID_READABLE_LENGTH = 255
244
+
245
+ # Temporary cache directory prefix
246
+ TEMP_CACHE_DIR_PREFIX = "hf_datasets-"
247
+
248
+ # Streaming
249
+ STREAMING_READ_MAX_RETRIES = 20
250
+ STREAMING_READ_RETRY_INTERVAL = 5
251
+ STREAMING_READ_SERVER_UNAVAILABLE_RETRY_INTERVAL = 20
252
+ STREAMING_READ_RATE_LIMIT_RETRY_INTERVAL = 60
253
+ STREAMING_OPEN_MAX_RETRIES = 20
254
+ STREAMING_OPEN_RETRY_INTERVAL = 5
255
+
256
+ # Datasets repositories exploration
257
+ DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 200
258
+ GLOBBED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 10
259
+ ARCHIVED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 200
260
+
261
+ # Async map functions
262
+ MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL = 1000
263
+
264
+ # Progress bars
265
+ PBAR_REFRESH_TIME_INTERVAL = 0.05 # 20 progress updates per sec
266
+
267
+ # Maximum number of uploaded files per commit
268
+ UPLOADS_MAX_NUMBER_PER_COMMIT = 50
269
+
270
+ # Backward compatibility
271
+ MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30
.venv/lib/python3.11/site-packages/datasets/data_files.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from functools import partial
4
+ from glob import has_magic
5
+ from pathlib import Path, PurePath
6
+ from typing import Callable, Optional, Union
7
+
8
+ import huggingface_hub
9
+ from fsspec.core import url_to_fs
10
+ from huggingface_hub import HfFileSystem
11
+ from packaging import version
12
+ from tqdm.contrib.concurrent import thread_map
13
+
14
+ from . import config
15
+ from .download import DownloadConfig
16
+ from .naming import _split_re
17
+ from .splits import Split
18
+ from .utils import logging
19
+ from .utils import tqdm as hf_tqdm
20
+ from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin
21
+ from .utils.py_utils import string_to_dict
22
+
23
+
24
+ SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]]
25
+
26
+
27
+ SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN)
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class Url(str):
34
+ pass
35
+
36
+
37
+ class EmptyDatasetError(FileNotFoundError):
38
+ pass
39
+
40
+
41
+ SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"
42
+
43
+ SPLIT_KEYWORDS = {
44
+ Split.TRAIN: ["train", "training"],
45
+ Split.VALIDATION: ["validation", "valid", "dev", "val"],
46
+ Split.TEST: ["test", "testing", "eval", "evaluation"],
47
+ }
48
+ NON_WORDS_CHARS = "-._ 0-9"
49
+ if config.FSSPEC_VERSION < version.parse("2023.9.0"):
50
+ KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"]
51
+ KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
52
+ "{keyword}/**",
53
+ "{keyword}[{sep}]*/**",
54
+ "**[{sep}/]{keyword}/**",
55
+ "**[{sep}/]{keyword}[{sep}]*/**",
56
+ ]
57
+ elif config.FSSPEC_VERSION < version.parse("2023.12.0"):
58
+ KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/*[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"]
59
+ KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
60
+ "{keyword}/**/*",
61
+ "{keyword}[{sep}]*/**/*",
62
+ "**/*[{sep}/]{keyword}/**/*",
63
+ "**/*[{sep}/]{keyword}[{sep}]*/**/*",
64
+ ]
65
+ else:
66
+ KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/{keyword}[{sep}]*", "**/*[{sep}]{keyword}[{sep}]*"]
67
+ KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
68
+ "**/{keyword}/**",
69
+ "**/{keyword}[{sep}]*/**",
70
+ "**/*[{sep}]{keyword}/**",
71
+ "**/*[{sep}]{keyword}[{sep}]*/**",
72
+ ]
73
+
74
+ DEFAULT_SPLITS = [Split.TRAIN, Split.VALIDATION, Split.TEST]
75
+ DEFAULT_PATTERNS_SPLIT_IN_FILENAME = {
76
+ split: [
77
+ pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
78
+ for keyword in SPLIT_KEYWORDS[split]
79
+ for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS
80
+ ]
81
+ for split in DEFAULT_SPLITS
82
+ }
83
+ DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME = {
84
+ split: [
85
+ pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
86
+ for keyword in SPLIT_KEYWORDS[split]
87
+ for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS
88
+ ]
89
+ for split in DEFAULT_SPLITS
90
+ }
91
+
92
+
93
+ DEFAULT_PATTERNS_ALL = {
94
+ Split.TRAIN: ["**"],
95
+ }
96
+
97
+ ALL_SPLIT_PATTERNS = [SPLIT_PATTERN_SHARDED]
98
+ ALL_DEFAULT_PATTERNS = [
99
+ DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME,
100
+ DEFAULT_PATTERNS_SPLIT_IN_FILENAME,
101
+ DEFAULT_PATTERNS_ALL,
102
+ ]
103
+ WILDCARD_CHARACTERS = "*[]"
104
+ FILES_TO_IGNORE = [
105
+ "README.md",
106
+ "config.json",
107
+ "dataset_info.json",
108
+ "dataset_infos.json",
109
+ "dummy_data.zip",
110
+ "dataset_dict.json",
111
+ ]
112
+
113
+
114
+ def contains_wildcards(pattern: str) -> bool:
115
+ return any(wildcard_character in pattern for wildcard_character in WILDCARD_CHARACTERS)
116
+
117
+
118
+ def sanitize_patterns(patterns: Union[dict, list, str]) -> dict[str, Union[list[str], "DataFilesList"]]:
119
+ """
120
+ Take the data_files patterns from the user, and format them into a dictionary.
121
+ Each key is the name of the split, and each value is a list of data files patterns (paths or urls).
122
+ The default split is "train".
123
+
124
+ Returns:
125
+ patterns: dictionary of split_name -> list of patterns
126
+ """
127
+ if isinstance(patterns, dict):
128
+ return {str(key): value if isinstance(value, list) else [value] for key, value in patterns.items()}
129
+ elif isinstance(patterns, str):
130
+ return {SANITIZED_DEFAULT_SPLIT: [patterns]}
131
+ elif isinstance(patterns, list):
132
+ if any(isinstance(pattern, dict) for pattern in patterns):
133
+ for pattern in patterns:
134
+ if not (
135
+ isinstance(pattern, dict)
136
+ and len(pattern) == 2
137
+ and "split" in pattern
138
+ and isinstance(pattern.get("path"), (str, list))
139
+ ):
140
+ raise ValueError(
141
+ f"Expected each split to have a 'path' key which can be a string or a list of strings, but got {pattern}"
142
+ )
143
+ splits = [pattern["split"] for pattern in patterns]
144
+ if len(set(splits)) != len(splits):
145
+ raise ValueError(f"Some splits are duplicated in data_files: {splits}")
146
+ return {
147
+ str(pattern["split"]): pattern["path"] if isinstance(pattern["path"], list) else [pattern["path"]]
148
+ for pattern in patterns
149
+ }
150
+ else:
151
+ return {SANITIZED_DEFAULT_SPLIT: patterns}
152
+ else:
153
+ return sanitize_patterns(list(patterns))
154
+
155
+
156
+ def _is_inside_unrequested_special_dir(matched_rel_path: str, pattern: str) -> bool:
157
+ """
158
+ When a path matches a pattern, we additionally check if it's inside a special directory
159
+ we ignore by default (if it starts with a double underscore).
160
+
161
+ Users can still explicitly request a filepath inside such a directory if "__pycache__" is
162
+ mentioned explicitly in the requested pattern.
163
+
164
+ Some examples:
165
+
166
+ base directory:
167
+
168
+ ./
169
+ └── __pycache__
170
+ └── b.txt
171
+
172
+ >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "**")
173
+ True
174
+ >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "*/b.txt")
175
+ True
176
+ >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__pycache__/*")
177
+ False
178
+ >>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__*/*")
179
+ False
180
+ """
181
+ # We just need to check if every special directories from the path is present explicitly in the pattern.
182
+ # Since we assume that the path matches the pattern, it's equivalent to counting that both
183
+ # the parent path and the parent pattern have the same number of special directories.
184
+ data_dirs_to_ignore_in_path = [part for part in PurePath(matched_rel_path).parent.parts if part.startswith("__")]
185
+ data_dirs_to_ignore_in_pattern = [part for part in PurePath(pattern).parent.parts if part.startswith("__")]
186
+ return len(data_dirs_to_ignore_in_path) != len(data_dirs_to_ignore_in_pattern)
187
+
188
+
189
+ def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_path: str, pattern: str) -> bool:
190
+ """
191
+ When a path matches a pattern, we additionally check if it's a hidden file or if it's inside
192
+ a hidden directory we ignore by default, i.e. if the file name or a parent directory name starts with a dot.
193
+
194
+ Users can still explicitly request a filepath that is hidden or is inside a hidden directory
195
+ if the hidden part is mentioned explicitly in the requested pattern.
196
+
197
+ Some examples:
198
+
199
+ base directory:
200
+
201
+ ./
202
+ └── .hidden_file.txt
203
+
204
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", "**")
205
+ True
206
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", ".*")
207
+ False
208
+
209
+ base directory:
210
+
211
+ ./
212
+ └── .hidden_dir
213
+ └── a.txt
214
+
215
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", "**")
216
+ True
217
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".*/*")
218
+ False
219
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".hidden_dir/*")
220
+ False
221
+
222
+ base directory:
223
+
224
+ ./
225
+ └── .hidden_dir
226
+ └── .hidden_file.txt
227
+
228
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", "**")
229
+ True
230
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/*")
231
+ True
232
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/.*")
233
+ False
234
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/*")
235
+ True
236
+ >>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/.*")
237
+ False
238
+ """
239
+ # We just need to check if every hidden part from the path is present explicitly in the pattern.
240
+ # Since we assume that the path matches the pattern, it's equivalent to counting that both
241
+ # the path and the pattern have the same number of hidden parts.
242
+ hidden_directories_in_path = [
243
+ part for part in PurePath(matched_rel_path).parts if part.startswith(".") and not set(part) == {"."}
244
+ ]
245
+ hidden_directories_in_pattern = [
246
+ part for part in PurePath(pattern).parts if part.startswith(".") and not set(part) == {"."}
247
+ ]
248
+ return len(hidden_directories_in_path) != len(hidden_directories_in_pattern)
249
+
250
+
251
+ def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> dict[str, list[str]]:
252
+ """
253
+ Get the default pattern from a directory or repository by testing all the supported patterns.
254
+ The first patterns to return a non-empty list of data files is returned.
255
+
256
+ In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
257
+ """
258
+ # first check the split patterns like data/{split}-00000-of-00001.parquet
259
+ for split_pattern in ALL_SPLIT_PATTERNS:
260
+ pattern = split_pattern.replace("{split}", "*")
261
+ try:
262
+ data_files = pattern_resolver(pattern)
263
+ except FileNotFoundError:
264
+ continue
265
+ if len(data_files) > 0:
266
+ splits: set[str] = set()
267
+ for p in data_files:
268
+ p_parts = string_to_dict(xbasename(p), xbasename(split_pattern))
269
+ assert p_parts is not None
270
+ splits.add(p_parts["split"])
271
+
272
+ if any(not re.match(_split_re, split) for split in splits):
273
+ raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.")
274
+ sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
275
+ splits - {str(split) for split in DEFAULT_SPLITS}
276
+ )
277
+ return {split: [split_pattern.format(split=split)] for split in sorted_splits}
278
+ # then check the default patterns based on train/valid/test splits
279
+ for patterns_dict in ALL_DEFAULT_PATTERNS:
280
+ non_empty_splits = []
281
+ for split, patterns in patterns_dict.items():
282
+ for pattern in patterns:
283
+ try:
284
+ data_files = pattern_resolver(pattern)
285
+ except FileNotFoundError:
286
+ continue
287
+ if len(data_files) > 0:
288
+ non_empty_splits.append(split)
289
+ break
290
+ if non_empty_splits:
291
+ return {split: patterns_dict[split] for split in non_empty_splits}
292
+ raise FileNotFoundError(f"Couldn't resolve pattern {pattern} with resolver {pattern_resolver}")
293
+
294
+
295
+ def resolve_pattern(
296
+ pattern: str,
297
+ base_path: str,
298
+ allowed_extensions: Optional[list[str]] = None,
299
+ download_config: Optional[DownloadConfig] = None,
300
+ ) -> list[str]:
301
+ """
302
+ Resolve the paths and URLs of the data files from the pattern passed by the user.
303
+
304
+ You can use patterns to resolve multiple local files. Here are a few examples:
305
+ - *.csv to match all the CSV files at the first level
306
+ - **.csv to match all the CSV files at any level
307
+ - data/* to match all the files inside "data"
308
+ - data/** to match all the files inside "data" and its subdirectories
309
+
310
+ The patterns are resolved using the fsspec glob. In fsspec>=2023.12.0 this is equivalent to
311
+ Python's glob.glob, Path.glob, Path.match and fnmatch where ** is unsupported with a prefix/suffix
312
+ other than a forward slash /.
313
+
314
+ More generally:
315
+ - '*' matches any character except a forward-slash (to match just the file or directory name)
316
+ - '**' matches any character including a forward-slash /
317
+
318
+ Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested.
319
+ The same applies to special directories that start with a double underscore like "__pycache__".
320
+ You can still include one if the pattern explicitly mentions it:
321
+ - to include a hidden file: "*/.hidden.txt" or "*/.*"
322
+ - to include a hidden directory: ".hidden/*" or ".*/*"
323
+ - to include a special directory: "__special__/*" or "__*/*"
324
+
325
+ Example::
326
+
327
+ >>> from datasets.data_files import resolve_pattern
328
+ >>> base_path = "."
329
+ >>> resolve_pattern("docs/**/*.py", base_path)
330
+ [/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py']
331
+
332
+ Args:
333
+ pattern (str): Unix pattern or paths or URLs of the data files to resolve.
334
+ The paths can be absolute or relative to base_path.
335
+ Remote filesystems using fsspec are supported, e.g. with the hf:// protocol.
336
+ base_path (str): Base path to use when resolving relative paths.
337
+ allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
338
+ For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
339
+ download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters.
340
+ Returns:
341
+ List[str]: List of paths or URLs to the local or remote files that match the patterns.
342
+ """
343
+ if is_relative_path(pattern):
344
+ pattern = xjoin(base_path, pattern)
345
+ elif is_local_path(pattern):
346
+ base_path = os.path.splitdrive(pattern)[0] + os.sep
347
+ else:
348
+ base_path = ""
349
+ pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config)
350
+ fs, fs_pattern = url_to_fs(pattern, **storage_options)
351
+ files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
352
+ protocol = (
353
+ pattern.split("://")[0]
354
+ if "://" in pattern
355
+ else (fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0])
356
+ )
357
+ protocol_prefix = protocol + "://" if protocol != "file" else ""
358
+ glob_kwargs = {}
359
+ if protocol == "hf":
360
+ # 10 times faster glob with detail=True (ignores costly info like lastCommit)
361
+ glob_kwargs["expand_info"] = False
362
+ matched_paths = [
363
+ filepath if "://" in filepath else protocol_prefix + filepath
364
+ for filepath, info in fs.glob(pattern, detail=True, **glob_kwargs).items()
365
+ if (info["type"] == "file" or (info.get("islink") and os.path.isfile(os.path.realpath(filepath))))
366
+ and (xbasename(filepath) not in files_to_ignore)
367
+ and not _is_inside_unrequested_special_dir(filepath, fs_pattern)
368
+ and not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(filepath, fs_pattern)
369
+ ] # ignore .ipynb and __pycache__, but keep /../
370
+ if allowed_extensions is not None:
371
+ out = [
372
+ filepath
373
+ for filepath in matched_paths
374
+ if any("." + suffix in allowed_extensions for suffix in xbasename(filepath).split(".")[1:])
375
+ ]
376
+ if len(out) < len(matched_paths):
377
+ invalid_matched_files = list(set(matched_paths) - set(out))
378
+ logger.info(
379
+ f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: {invalid_matched_files}"
380
+ )
381
+ else:
382
+ out = matched_paths
383
+ if not out:
384
+ error_msg = f"Unable to find '{pattern}'"
385
+ if allowed_extensions is not None:
386
+ error_msg += f" with any supported extension {list(allowed_extensions)}"
387
+ raise FileNotFoundError(error_msg)
388
+ return out
389
+
390
+
391
+ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig] = None) -> dict[str, list[str]]:
392
+ """
393
+ Get the default pattern from a directory testing all the supported patterns.
394
+ The first patterns to return a non-empty list of data files is returned.
395
+
396
+ Some examples of supported patterns:
397
+
398
+ Input:
399
+
400
+ my_dataset_repository/
401
+ ├── README.md
402
+ └── dataset.csv
403
+
404
+ Output:
405
+
406
+ {'train': ['**']}
407
+
408
+ Input:
409
+
410
+ my_dataset_repository/
411
+ ├── README.md
412
+ ├── train.csv
413
+ └── test.csv
414
+
415
+ my_dataset_repository/
416
+ ├── README.md
417
+ └── data/
418
+ ├── train.csv
419
+ └── test.csv
420
+
421
+ my_dataset_repository/
422
+ ├── README.md
423
+ ├── train_0.csv
424
+ ├── train_1.csv
425
+ ├── train_2.csv
426
+ ├── train_3.csv
427
+ ├── test_0.csv
428
+ └── test_1.csv
429
+
430
+ Output:
431
+
432
+ {'train': ['**/train[-._ 0-9]*', '**/*[-._ 0-9]train[-._ 0-9]*', '**/training[-._ 0-9]*', '**/*[-._ 0-9]training[-._ 0-9]*'],
433
+ 'test': ['**/test[-._ 0-9]*', '**/*[-._ 0-9]test[-._ 0-9]*', '**/testing[-._ 0-9]*', '**/*[-._ 0-9]testing[-._ 0-9]*', ...]}
434
+
435
+ Input:
436
+
437
+ my_dataset_repository/
438
+ ├── README.md
439
+ └── data/
440
+ ├── train/
441
+ │ ├── shard_0.csv
442
+ │ ├── shard_1.csv
443
+ │ ├── shard_2.csv
444
+ │ └── shard_3.csv
445
+ └── test/
446
+ ├── shard_0.csv
447
+ └── shard_1.csv
448
+
449
+ Output:
450
+
451
+ {'train': ['**/train/**', '**/train[-._ 0-9]*/**', '**/*[-._ 0-9]train/**', '**/*[-._ 0-9]train[-._ 0-9]*/**', ...],
452
+ 'test': ['**/test/**', '**/test[-._ 0-9]*/**', '**/*[-._ 0-9]test/**', '**/*[-._ 0-9]test[-._ 0-9]*/**', ...]}
453
+
454
+ Input:
455
+
456
+ my_dataset_repository/
457
+ ├── README.md
458
+ └── data/
459
+ ├── train-00000-of-00003.csv
460
+ ├── train-00001-of-00003.csv
461
+ ├── train-00002-of-00003.csv
462
+ ├── test-00000-of-00001.csv
463
+ ├── random-00000-of-00003.csv
464
+ ├── random-00001-of-00003.csv
465
+ └── random-00002-of-00003.csv
466
+
467
+ Output:
468
+
469
+ {'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
470
+ 'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
471
+ 'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']}
472
+
473
+ In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
474
+ """
475
+ resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
476
+ try:
477
+ return _get_data_files_patterns(resolver)
478
+ except FileNotFoundError:
479
+ raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None
480
+
481
+
482
+ def _get_single_origin_metadata(
483
+ data_file: str,
484
+ download_config: Optional[DownloadConfig] = None,
485
+ ) -> SingleOriginMetadata:
486
+ data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
487
+ fs, *_ = url_to_fs(data_file, **storage_options)
488
+ if isinstance(fs, HfFileSystem):
489
+ resolved_path = fs.resolve_path(data_file)
490
+ return resolved_path.repo_id, resolved_path.revision
491
+ elif data_file.startswith(config.HF_ENDPOINT):
492
+ hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
493
+ data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
494
+ resolved_path = hffs.resolve_path(data_file)
495
+ return resolved_path.repo_id, resolved_path.revision
496
+ info = fs.info(data_file)
497
+ # s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
498
+ for key in ["ETag", "etag", "mtime"]:
499
+ if key in info:
500
+ return (str(info[key]),)
501
+ return ()
502
+
503
+
504
+ def _get_origin_metadata(
505
+ data_files: list[str],
506
+ download_config: Optional[DownloadConfig] = None,
507
+ max_workers: Optional[int] = None,
508
+ ) -> list[SingleOriginMetadata]:
509
+ max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
510
+ if all("hf://" in data_file for data_file in data_files):
511
+ # No need for multithreading here since the origin metadata of HF files
512
+ # is (repo_id, revision) and is cached after first .info() call.
513
+ return [
514
+ _get_single_origin_metadata(data_file, download_config=download_config)
515
+ for data_file in hf_tqdm(
516
+ data_files,
517
+ desc="Resolving data files",
518
+ # set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
519
+ disable=len(data_files) <= 16 or None,
520
+ )
521
+ ]
522
+ return thread_map(
523
+ partial(_get_single_origin_metadata, download_config=download_config),
524
+ data_files,
525
+ max_workers=max_workers,
526
+ tqdm_class=hf_tqdm,
527
+ desc="Resolving data files",
528
+ # set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
529
+ disable=len(data_files) <= 16 or None,
530
+ )
531
+
532
+
533
+ class DataFilesList(list[str]):
534
+ """
535
+ List of data files (absolute local paths or URLs).
536
+ It has two construction methods given the user's data files patterns:
537
+ - ``from_hf_repo``: resolve patterns inside a dataset repository
538
+ - ``from_local_or_remote``: resolve patterns from a local path
539
+
540
+ Moreover, DataFilesList has an additional attribute ``origin_metadata``.
541
+ It can store:
542
+ - the last modified time of local files
543
+ - ETag of remote files
544
+ - commit sha of a dataset repository
545
+
546
+ Thanks to this additional attribute, it is possible to hash the list
547
+ and get a different hash if and only if at least one file changed.
548
+ This is useful for caching Dataset objects that are obtained from a list of data files.
549
+ """
550
+
551
+ def __init__(self, data_files: list[str], origin_metadata: list[SingleOriginMetadata]) -> None:
552
+ super().__init__(data_files)
553
+ self.origin_metadata = origin_metadata
554
+
555
+ def __add__(self, other: "DataFilesList") -> "DataFilesList":
556
+ return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata)
557
+
558
+ @classmethod
559
+ def from_hf_repo(
560
+ cls,
561
+ patterns: list[str],
562
+ dataset_info: huggingface_hub.hf_api.DatasetInfo,
563
+ base_path: Optional[str] = None,
564
+ allowed_extensions: Optional[list[str]] = None,
565
+ download_config: Optional[DownloadConfig] = None,
566
+ ) -> "DataFilesList":
567
+ base_path = f"hf://datasets/{dataset_info.id}@{dataset_info.sha}/{base_path or ''}".rstrip("/")
568
+ return cls.from_patterns(
569
+ patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config
570
+ )
571
+
572
+ @classmethod
573
+ def from_local_or_remote(
574
+ cls,
575
+ patterns: list[str],
576
+ base_path: Optional[str] = None,
577
+ allowed_extensions: Optional[list[str]] = None,
578
+ download_config: Optional[DownloadConfig] = None,
579
+ ) -> "DataFilesList":
580
+ base_path = base_path if base_path is not None else Path().resolve().as_posix()
581
+ return cls.from_patterns(
582
+ patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config
583
+ )
584
+
585
+ @classmethod
586
+ def from_patterns(
587
+ cls,
588
+ patterns: list[str],
589
+ base_path: Optional[str] = None,
590
+ allowed_extensions: Optional[list[str]] = None,
591
+ download_config: Optional[DownloadConfig] = None,
592
+ ) -> "DataFilesList":
593
+ base_path = base_path if base_path is not None else Path().resolve().as_posix()
594
+ data_files = []
595
+ for pattern in patterns:
596
+ try:
597
+ data_files.extend(
598
+ resolve_pattern(
599
+ pattern,
600
+ base_path=base_path,
601
+ allowed_extensions=allowed_extensions,
602
+ download_config=download_config,
603
+ )
604
+ )
605
+ except FileNotFoundError:
606
+ if not has_magic(pattern):
607
+ raise
608
+ origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
609
+ return cls(data_files, origin_metadata)
610
+
611
+ def filter(
612
+ self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None
613
+ ) -> "DataFilesList":
614
+ patterns = []
615
+ if extensions:
616
+ ext_pattern = "|".join(re.escape(ext) for ext in extensions)
617
+ patterns.append(re.compile(f".*({ext_pattern})(\\..+)?$"))
618
+ if file_names:
619
+ fn_pattern = "|".join(re.escape(fn) for fn in file_names)
620
+ patterns.append(re.compile(rf".*[\/]?({fn_pattern})$"))
621
+ if patterns:
622
+ return DataFilesList(
623
+ [data_file for data_file in self if any(pattern.match(data_file) for pattern in patterns)],
624
+ origin_metadata=self.origin_metadata,
625
+ )
626
+ else:
627
+ return DataFilesList(list(self), origin_metadata=self.origin_metadata)
628
+
629
+
630
+ class DataFilesDict(dict[str, DataFilesList]):
631
+ """
632
+ Dict of split_name -> list of data files (absolute local paths or URLs).
633
+ It has two construction methods given the user's data files patterns :
634
+ - ``from_hf_repo``: resolve patterns inside a dataset repository
635
+ - ``from_local_or_remote``: resolve patterns from a local path
636
+
637
+ Moreover, each list is a DataFilesList. It is possible to hash the dictionary
638
+ and get a different hash if and only if at least one file changed.
639
+ For more info, see [`DataFilesList`].
640
+
641
+ This is useful for caching Dataset objects that are obtained from a list of data files.
642
+
643
+ Changing the order of the keys of this dictionary also doesn't change its hash.
644
+ """
645
+
646
+ @classmethod
647
+ def from_local_or_remote(
648
+ cls,
649
+ patterns: dict[str, Union[list[str], DataFilesList]],
650
+ base_path: Optional[str] = None,
651
+ allowed_extensions: Optional[list[str]] = None,
652
+ download_config: Optional[DownloadConfig] = None,
653
+ ) -> "DataFilesDict":
654
+ out = cls()
655
+ for key, patterns_for_key in patterns.items():
656
+ out[key] = (
657
+ patterns_for_key
658
+ if isinstance(patterns_for_key, DataFilesList)
659
+ else DataFilesList.from_local_or_remote(
660
+ patterns_for_key,
661
+ base_path=base_path,
662
+ allowed_extensions=allowed_extensions,
663
+ download_config=download_config,
664
+ )
665
+ )
666
+ return out
667
+
668
+ @classmethod
669
+ def from_hf_repo(
670
+ cls,
671
+ patterns: dict[str, Union[list[str], DataFilesList]],
672
+ dataset_info: huggingface_hub.hf_api.DatasetInfo,
673
+ base_path: Optional[str] = None,
674
+ allowed_extensions: Optional[list[str]] = None,
675
+ download_config: Optional[DownloadConfig] = None,
676
+ ) -> "DataFilesDict":
677
+ out = cls()
678
+ for key, patterns_for_key in patterns.items():
679
+ out[key] = (
680
+ patterns_for_key
681
+ if isinstance(patterns_for_key, DataFilesList)
682
+ else DataFilesList.from_hf_repo(
683
+ patterns_for_key,
684
+ dataset_info=dataset_info,
685
+ base_path=base_path,
686
+ allowed_extensions=allowed_extensions,
687
+ download_config=download_config,
688
+ )
689
+ )
690
+ return out
691
+
692
+ @classmethod
693
+ def from_patterns(
694
+ cls,
695
+ patterns: dict[str, Union[list[str], DataFilesList]],
696
+ base_path: Optional[str] = None,
697
+ allowed_extensions: Optional[list[str]] = None,
698
+ download_config: Optional[DownloadConfig] = None,
699
+ ) -> "DataFilesDict":
700
+ out = cls()
701
+ for key, patterns_for_key in patterns.items():
702
+ out[key] = (
703
+ patterns_for_key
704
+ if isinstance(patterns_for_key, DataFilesList)
705
+ else DataFilesList.from_patterns(
706
+ patterns_for_key,
707
+ base_path=base_path,
708
+ allowed_extensions=allowed_extensions,
709
+ download_config=download_config,
710
+ )
711
+ )
712
+ return out
713
+
714
+ def filter(
715
+ self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None
716
+ ) -> "DataFilesDict":
717
+ out = type(self)()
718
+ for key, data_files_list in self.items():
719
+ out[key] = data_files_list.filter(extensions=extensions, file_names=file_names)
720
+ return out
721
+
722
+
723
+ class DataFilesPatternsList(list[str]):
724
+ """
725
+ List of data files patterns (absolute local paths or URLs).
726
+ For each pattern there should also be a list of allowed extensions
727
+ to keep, or a None ot keep all the files for the pattern.
728
+ """
729
+
730
+ def __init__(
731
+ self,
732
+ patterns: list[str],
733
+ allowed_extensions: list[Optional[list[str]]],
734
+ ):
735
+ super().__init__(patterns)
736
+ self.allowed_extensions = allowed_extensions
737
+
738
+ def __add__(self, other):
739
+ return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions)
740
+
741
+ @classmethod
742
+ def from_patterns(
743
+ cls, patterns: list[str], allowed_extensions: Optional[list[str]] = None
744
+ ) -> "DataFilesPatternsList":
745
+ return cls(patterns, [allowed_extensions] * len(patterns))
746
+
747
+ def resolve(
748
+ self,
749
+ base_path: str,
750
+ download_config: Optional[DownloadConfig] = None,
751
+ ) -> "DataFilesList":
752
+ base_path = base_path if base_path is not None else Path().resolve().as_posix()
753
+ data_files = []
754
+ for pattern, allowed_extensions in zip(self, self.allowed_extensions):
755
+ try:
756
+ data_files.extend(
757
+ resolve_pattern(
758
+ pattern,
759
+ base_path=base_path,
760
+ allowed_extensions=allowed_extensions,
761
+ download_config=download_config,
762
+ )
763
+ )
764
+ except FileNotFoundError:
765
+ if not has_magic(pattern):
766
+ raise
767
+ origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
768
+ return DataFilesList(data_files, origin_metadata)
769
+
770
+ def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsList":
771
+ return DataFilesPatternsList(
772
+ self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
773
+ )
774
+
775
+
776
+ class DataFilesPatternsDict(dict[str, DataFilesPatternsList]):
777
+ """
778
+ Dict of split_name -> list of data files patterns (absolute local paths or URLs).
779
+ """
780
+
781
+ @classmethod
782
+ def from_patterns(
783
+ cls, patterns: dict[str, list[str]], allowed_extensions: Optional[list[str]] = None
784
+ ) -> "DataFilesPatternsDict":
785
+ out = cls()
786
+ for key, patterns_for_key in patterns.items():
787
+ out[key] = (
788
+ patterns_for_key
789
+ if isinstance(patterns_for_key, DataFilesPatternsList)
790
+ else DataFilesPatternsList.from_patterns(
791
+ patterns_for_key,
792
+ allowed_extensions=allowed_extensions,
793
+ )
794
+ )
795
+ return out
796
+
797
+ def resolve(
798
+ self,
799
+ base_path: str,
800
+ download_config: Optional[DownloadConfig] = None,
801
+ ) -> "DataFilesDict":
802
+ out = DataFilesDict()
803
+ for key, data_files_patterns_list in self.items():
804
+ out[key] = data_files_patterns_list.resolve(base_path, download_config)
805
+ return out
806
+
807
+ def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsDict":
808
+ out = type(self)()
809
+ for key, data_files_patterns_list in self.items():
810
+ out[key] = data_files_patterns_list.filter_extensions(extensions)
811
+ return out
.venv/lib/python3.11/site-packages/datasets/dataset_dict.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/datasets/distributed.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypeVar
2
+
3
+ from .arrow_dataset import Dataset, _split_by_node_map_style_dataset
4
+ from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset
5
+
6
+
7
+ DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)
8
+
9
+
10
+ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType:
11
+ """
12
+ Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
13
+
14
+ For map-style datasets:
15
+
16
+ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
17
+ To maximize data loading throughput, chunks are made of contiguous data on disk if possible.
18
+
19
+ For iterable datasets:
20
+
21
+ If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
22
+ then the shards are evenly assigned across the nodes, which is the most optimized.
23
+ Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
24
+
25
+ Args:
26
+ dataset ([`Dataset`] or [`IterableDataset`]):
27
+ The dataset to split by node.
28
+ rank (`int`):
29
+ Rank of the current node.
30
+ world_size (`int`):
31
+ Total number of nodes.
32
+
33
+ Returns:
34
+ [`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`.
35
+ """
36
+ if isinstance(dataset, Dataset):
37
+ return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size)
38
+ else:
39
+ return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size)
.venv/lib/python3.11/site-packages/datasets/download/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "DownloadConfig",
3
+ "DownloadManager",
4
+ "DownloadMode",
5
+ "StreamingDownloadManager",
6
+ ]
7
+
8
+ from .download_config import DownloadConfig
9
+ from .download_manager import DownloadManager, DownloadMode
10
+ from .streaming_download_manager import StreamingDownloadManager
.venv/lib/python3.11/site-packages/datasets/download/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (515 Bytes). View file
 
.venv/lib/python3.11/site-packages/datasets/download/__pycache__/download_config.cpython-311.pyc ADDED
Binary file (5.63 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/download/__pycache__/download_manager.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/download/__pycache__/streaming_download_manager.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/download/download_config.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import Any, Optional, Union
5
+
6
+ from .. import config
7
+
8
+
9
+ @dataclass
10
+ class DownloadConfig:
11
+ """Configuration for our cached path manager.
12
+
13
+ Attributes:
14
+ cache_dir (`str` or `Path`, *optional*):
15
+ Specify a cache directory to save the file to (overwrite the
16
+ default cache dir).
17
+ force_download (`bool`, defaults to `False`):
18
+ If `True`, re-download the file even if it's already cached in
19
+ the cache dir.
20
+ resume_download (`bool`, defaults to `False`):
21
+ If `True`, resume the download if an incompletely received file is
22
+ found.
23
+ proxies (`dict`, *optional*):
24
+ user_agent (`str`, *optional*):
25
+ Optional string or dict that will be appended to the user-agent on remote
26
+ requests.
27
+ extract_compressed_file (`bool`, defaults to `False`):
28
+ If `True` and the path point to a zip or tar file,
29
+ extract the compressed file in a folder along the archive.
30
+ force_extract (`bool`, defaults to `False`):
31
+ If `True` when `extract_compressed_file` is `True` and the archive
32
+ was already extracted, re-extract the archive and override the folder where it was extracted.
33
+ delete_extracted (`bool`, defaults to `False`):
34
+ Whether to delete (or keep) the extracted files.
35
+ extract_on_the_fly (`bool`, defaults to `False`):
36
+ If `True`, extract compressed files while they are being read.
37
+ use_etag (`bool`, defaults to `True`):
38
+ Whether to use the ETag HTTP response header to validate the cached files.
39
+ num_proc (`int`, *optional*):
40
+ The number of processes to launch to download the files in parallel.
41
+ max_retries (`int`, default to `1`):
42
+ The number of times to retry an HTTP request if it fails.
43
+ token (`str` or `bool`, *optional*):
44
+ Optional string or boolean to use as Bearer token
45
+ for remote files on the Datasets Hub. If `True`, or not specified, will get token from `~/.huggingface`.
46
+ storage_options (`dict`, *optional*):
47
+ Key/value pairs to be passed on to the dataset file-system backend, if any.
48
+ download_desc (`str`, *optional*):
49
+ A description to be displayed alongside with the progress bar while downloading the files.
50
+ disable_tqdm (`bool`, defaults to `False`):
51
+ Whether to disable the individual files download progress bar
52
+ """
53
+
54
+ cache_dir: Optional[Union[str, Path]] = None
55
+ force_download: bool = False
56
+ resume_download: bool = False
57
+ local_files_only: bool = False
58
+ proxies: Optional[dict] = None
59
+ user_agent: Optional[str] = None
60
+ extract_compressed_file: bool = False
61
+ force_extract: bool = False
62
+ delete_extracted: bool = False
63
+ extract_on_the_fly: bool = False
64
+ use_etag: bool = True
65
+ num_proc: Optional[int] = None
66
+ max_retries: int = 1
67
+ token: Optional[Union[str, bool]] = None
68
+ storage_options: dict[str, Any] = field(default_factory=dict)
69
+ download_desc: Optional[str] = None
70
+ disable_tqdm: bool = False
71
+
72
+ def copy(self) -> "DownloadConfig":
73
+ return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
74
+
75
+ def __setattr__(self, name, value):
76
+ if name == "token" and getattr(self, "storage_options", None) is not None:
77
+ if "hf" not in self.storage_options:
78
+ self.storage_options["hf"] = {"endpoint": config.HF_ENDPOINT, "token": value}
79
+ elif getattr(self.storage_options["hf"], "token", None) is None:
80
+ self.storage_options["hf"]["token"] = value
81
+ super().__setattr__(name, value)
.venv/lib/python3.11/site-packages/datasets/download/download_manager.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The TensorFlow Datasets Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Lint as: python3
16
+ """Download manager interface."""
17
+
18
+ import enum
19
+ import io
20
+ import multiprocessing
21
+ import os
22
+ from datetime import datetime
23
+ from functools import partial
24
+ from typing import Optional, Union
25
+
26
+ import fsspec
27
+ from fsspec.core import url_to_fs
28
+ from tqdm.contrib.concurrent import thread_map
29
+
30
+ from .. import config
31
+ from ..utils import tqdm as hf_tqdm
32
+ from ..utils.file_utils import (
33
+ ArchiveIterable,
34
+ FilesIterable,
35
+ cached_path,
36
+ is_relative_path,
37
+ stack_multiprocessing_download_progress_bars,
38
+ url_or_path_join,
39
+ )
40
+ from ..utils.info_utils import get_size_checksum_dict
41
+ from ..utils.logging import get_logger, tqdm
42
+ from ..utils.py_utils import NestedDataStructure, map_nested
43
+ from ..utils.track import tracked_str
44
+ from .download_config import DownloadConfig
45
+
46
+
47
+ logger = get_logger(__name__)
48
+
49
+
50
+ class DownloadMode(enum.Enum):
51
+ """`Enum` for how to treat pre-existing downloads and data.
52
+
53
+ The default mode is `REUSE_DATASET_IF_EXISTS`, which will reuse both
54
+ raw downloads and the prepared dataset if they exist.
55
+
56
+ The generations modes:
57
+
58
+ | | Downloads | Dataset |
59
+ |-------------------------------------|-----------|---------|
60
+ | `REUSE_DATASET_IF_EXISTS` (default) | Reuse | Reuse |
61
+ | `REUSE_CACHE_IF_EXISTS` | Reuse | Fresh |
62
+ | `FORCE_REDOWNLOAD` | Fresh | Fresh |
63
+
64
+ """
65
+
66
+ REUSE_DATASET_IF_EXISTS = "reuse_dataset_if_exists"
67
+ REUSE_CACHE_IF_EXISTS = "reuse_cache_if_exists"
68
+ FORCE_REDOWNLOAD = "force_redownload"
69
+
70
+
71
+ class DownloadManager:
72
+ is_streaming = False
73
+
74
+ def __init__(
75
+ self,
76
+ dataset_name: Optional[str] = None,
77
+ data_dir: Optional[str] = None,
78
+ download_config: Optional[DownloadConfig] = None,
79
+ base_path: Optional[str] = None,
80
+ record_checksums=True,
81
+ ):
82
+ """Download manager constructor.
83
+
84
+ Args:
85
+ data_dir:
86
+ can be used to specify a manual directory to get the files from.
87
+ dataset_name (`str`):
88
+ name of dataset this instance will be used for. If
89
+ provided, downloads will contain which datasets they were used for.
90
+ download_config (`DownloadConfig`):
91
+ to specify the cache directory and other
92
+ download options
93
+ base_path (`str`):
94
+ base path that is used when relative paths are used to
95
+ download files. This can be a remote url.
96
+ record_checksums (`bool`, defaults to `True`):
97
+ Whether to record the checksums of the downloaded files. If None, the value is inferred from the builder.
98
+ """
99
+ self._dataset_name = dataset_name
100
+ self._data_dir = data_dir
101
+ self._base_path = base_path or os.path.abspath(".")
102
+ # To record what is being used: {url: {num_bytes: int, checksum: str}}
103
+ self._recorded_sizes_checksums: dict[str, dict[str, Optional[Union[int, str]]]] = {}
104
+ self.record_checksums = record_checksums
105
+ self.download_config = download_config or DownloadConfig()
106
+ self.downloaded_paths = {}
107
+ self.extracted_paths = {}
108
+
109
+ @property
110
+ def manual_dir(self):
111
+ return self._data_dir
112
+
113
+ @property
114
+ def downloaded_size(self):
115
+ """Returns the total size of downloaded files."""
116
+ return sum(checksums_dict["num_bytes"] for checksums_dict in self._recorded_sizes_checksums.values())
117
+
118
+ def _record_sizes_checksums(self, url_or_urls: NestedDataStructure, downloaded_path_or_paths: NestedDataStructure):
119
+ """Record size/checksum of downloaded files."""
120
+ delay = 5
121
+ for url, path in hf_tqdm(
122
+ list(zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten())),
123
+ delay=delay,
124
+ desc="Computing checksums",
125
+ ):
126
+ # call str to support PathLike objects
127
+ self._recorded_sizes_checksums[str(url)] = get_size_checksum_dict(
128
+ path, record_checksum=self.record_checksums
129
+ )
130
+
131
+ def download(self, url_or_urls):
132
+ """Download given URL(s).
133
+
134
+ By default, only one process is used for download. Pass customized `download_config.num_proc` to change this behavior.
135
+
136
+ Args:
137
+ url_or_urls (`str` or `list` or `dict`):
138
+ URL or `list` or `dict` of URLs to download. Each URL is a `str`.
139
+
140
+ Returns:
141
+ `str` or `list` or `dict`:
142
+ The downloaded paths matching the given input `url_or_urls`.
143
+
144
+ Example:
145
+
146
+ ```py
147
+ >>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
148
+ ```
149
+ """
150
+ download_config = self.download_config.copy()
151
+ download_config.extract_compressed_file = False
152
+ if download_config.download_desc is None:
153
+ download_config.download_desc = "Downloading data"
154
+
155
+ download_func = partial(self._download_batched, download_config=download_config)
156
+
157
+ start_time = datetime.now()
158
+ with stack_multiprocessing_download_progress_bars():
159
+ downloaded_path_or_paths = map_nested(
160
+ download_func,
161
+ url_or_urls,
162
+ map_tuple=True,
163
+ num_proc=download_config.num_proc,
164
+ desc="Downloading data files",
165
+ batched=True,
166
+ batch_size=-1,
167
+ )
168
+ duration = datetime.now() - start_time
169
+ logger.info(f"Downloading took {duration.total_seconds() // 60} min")
170
+ url_or_urls = NestedDataStructure(url_or_urls)
171
+ downloaded_path_or_paths = NestedDataStructure(downloaded_path_or_paths)
172
+ self.downloaded_paths.update(dict(zip(url_or_urls.flatten(), downloaded_path_or_paths.flatten())))
173
+
174
+ start_time = datetime.now()
175
+ self._record_sizes_checksums(url_or_urls, downloaded_path_or_paths)
176
+ duration = datetime.now() - start_time
177
+ logger.info(f"Checksum Computation took {duration.total_seconds() // 60} min")
178
+
179
+ return downloaded_path_or_paths.data
180
+
181
+ def _download_batched(
182
+ self,
183
+ url_or_filenames: list[str],
184
+ download_config: DownloadConfig,
185
+ ) -> list[str]:
186
+ if len(url_or_filenames) >= 16:
187
+ download_config = download_config.copy()
188
+ download_config.disable_tqdm = True
189
+ download_func = partial(self._download_single, download_config=download_config)
190
+
191
+ fs: fsspec.AbstractFileSystem
192
+ path = str(url_or_filenames[0])
193
+ if is_relative_path(path):
194
+ # append the relative path to the base_path
195
+ path = url_or_path_join(self._base_path, path)
196
+ fs, path = url_to_fs(path, **download_config.storage_options)
197
+ size = 0
198
+ try:
199
+ size = fs.info(path).get("size", 0)
200
+ except Exception:
201
+ pass
202
+ max_workers = (
203
+ config.HF_DATASETS_MULTITHREADING_MAX_WORKERS if size < (20 << 20) else 1
204
+ ) # enable multithreading if files are small
205
+
206
+ return thread_map(
207
+ download_func,
208
+ url_or_filenames,
209
+ desc=download_config.download_desc or "Downloading",
210
+ unit="files",
211
+ position=multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses
212
+ if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
213
+ and multiprocessing.current_process()._identity
214
+ else None,
215
+ max_workers=max_workers,
216
+ tqdm_class=tqdm,
217
+ )
218
+ else:
219
+ return [
220
+ self._download_single(url_or_filename, download_config=download_config)
221
+ for url_or_filename in url_or_filenames
222
+ ]
223
+
224
+ def _download_single(self, url_or_filename: str, download_config: DownloadConfig) -> str:
225
+ url_or_filename = str(url_or_filename)
226
+ if is_relative_path(url_or_filename):
227
+ # append the relative path to the base_path
228
+ url_or_filename = url_or_path_join(self._base_path, url_or_filename)
229
+ out = cached_path(url_or_filename, download_config=download_config)
230
+ out = tracked_str(out)
231
+ out.set_origin(url_or_filename)
232
+ return out
233
+
234
+ def iter_archive(self, path_or_buf: Union[str, io.BufferedReader]):
235
+ """Iterate over files within an archive.
236
+
237
+ Args:
238
+ path_or_buf (`str` or `io.BufferedReader`):
239
+ Archive path or archive binary file object.
240
+
241
+ Yields:
242
+ `tuple[str, io.BufferedReader]`:
243
+ 2-tuple (path_within_archive, file_object).
244
+ File object is opened in binary mode.
245
+
246
+ Example:
247
+
248
+ ```py
249
+ >>> archive = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
250
+ >>> files = dl_manager.iter_archive(archive)
251
+ ```
252
+ """
253
+
254
+ if hasattr(path_or_buf, "read"):
255
+ return ArchiveIterable.from_buf(path_or_buf)
256
+ else:
257
+ return ArchiveIterable.from_urlpath(path_or_buf)
258
+
259
+ def iter_files(self, paths: Union[str, list[str]]):
260
+ """Iterate over file paths.
261
+
262
+ Args:
263
+ paths (`str` or `list` of `str`):
264
+ Root paths.
265
+
266
+ Yields:
267
+ `str`: File path.
268
+
269
+ Example:
270
+
271
+ ```py
272
+ >>> files = dl_manager.download_and_extract('https://huggingface.co/datasets/beans/resolve/main/data/train.zip')
273
+ >>> files = dl_manager.iter_files(files)
274
+ ```
275
+ """
276
+ return FilesIterable.from_urlpaths(paths)
277
+
278
+ def extract(self, path_or_paths):
279
+ """Extract given path(s).
280
+
281
+ Args:
282
+ path_or_paths (path or `list` or `dict`):
283
+ Path of file to extract. Each path is a `str`.
284
+
285
+ Returns:
286
+ extracted_path(s): `str`, The extracted paths matching the given input
287
+ path_or_paths.
288
+
289
+ Example:
290
+
291
+ ```py
292
+ >>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
293
+ >>> extracted_files = dl_manager.extract(downloaded_files)
294
+ ```
295
+ """
296
+ download_config = self.download_config.copy()
297
+ download_config.extract_compressed_file = True
298
+ extract_func = partial(self._download_single, download_config=download_config)
299
+ extracted_paths = map_nested(
300
+ extract_func,
301
+ path_or_paths,
302
+ num_proc=download_config.num_proc,
303
+ desc="Extracting data files",
304
+ )
305
+ path_or_paths = NestedDataStructure(path_or_paths)
306
+ extracted_paths = NestedDataStructure(extracted_paths)
307
+ self.extracted_paths.update(dict(zip(path_or_paths.flatten(), extracted_paths.flatten())))
308
+ return extracted_paths.data
309
+
310
+ def download_and_extract(self, url_or_urls):
311
+ """Download and extract given `url_or_urls`.
312
+
313
+ Is roughly equivalent to:
314
+
315
+ ```
316
+ extracted_paths = dl_manager.extract(dl_manager.download(url_or_urls))
317
+ ```
318
+
319
+ Args:
320
+ url_or_urls (`str` or `list` or `dict`):
321
+ URL or `list` or `dict` of URLs to download and extract. Each URL is a `str`.
322
+
323
+ Returns:
324
+ extracted_path(s): `str`, extracted paths of given URL(s).
325
+ """
326
+ return self.extract(self.download(url_or_urls))
327
+
328
+ def get_recorded_sizes_checksums(self):
329
+ return self._recorded_sizes_checksums.copy()
330
+
331
+ def delete_extracted_files(self):
332
+ paths_to_delete = set(self.extracted_paths.values()) - set(self.downloaded_paths.values())
333
+ for key, path in list(self.extracted_paths.items()):
334
+ if path in paths_to_delete and os.path.isfile(path):
335
+ os.remove(path)
336
+ del self.extracted_paths[key]
337
+
338
+ def manage_extracted_files(self):
339
+ if self.download_config.delete_extracted:
340
+ self.delete_extracted_files()
.venv/lib/python3.11/site-packages/datasets/download/streaming_download_manager.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from collections.abc import Iterable
4
+ from typing import Optional, Union
5
+
6
+ from ..utils.file_utils import ( # noqa: F401 # backward compatibility
7
+ SINGLE_FILE_COMPRESSION_PROTOCOLS,
8
+ ArchiveIterable,
9
+ FilesIterable,
10
+ _get_extraction_protocol,
11
+ _get_path_extension,
12
+ _prepare_path_and_storage_options,
13
+ is_relative_path,
14
+ url_or_path_join,
15
+ xbasename,
16
+ xdirname,
17
+ xet_parse,
18
+ xexists,
19
+ xgetsize,
20
+ xglob,
21
+ xgzip_open,
22
+ xisdir,
23
+ xisfile,
24
+ xjoin,
25
+ xlistdir,
26
+ xnumpy_load,
27
+ xopen,
28
+ xpandas_read_csv,
29
+ xpandas_read_excel,
30
+ xPath,
31
+ xpyarrow_parquet_read_table,
32
+ xrelpath,
33
+ xsio_loadmat,
34
+ xsplit,
35
+ xsplitext,
36
+ xwalk,
37
+ xxml_dom_minidom_parse,
38
+ )
39
+ from ..utils.logging import get_logger
40
+ from ..utils.py_utils import map_nested
41
+ from .download_config import DownloadConfig
42
+
43
+
44
+ logger = get_logger(__name__)
45
+
46
+
47
+ class StreamingDownloadManager:
48
+ """
49
+ Download manager that uses the "::" separator to navigate through (possibly remote) compressed archives.
50
+ Contrary to the regular `DownloadManager`, the `download` and `extract` methods don't actually download nor extract
51
+ data, but they rather return the path or url that could be opened using the `xopen` function which extends the
52
+ built-in `open` function to stream data from remote files.
53
+ """
54
+
55
+ is_streaming = True
56
+
57
+ def __init__(
58
+ self,
59
+ dataset_name: Optional[str] = None,
60
+ data_dir: Optional[str] = None,
61
+ download_config: Optional[DownloadConfig] = None,
62
+ base_path: Optional[str] = None,
63
+ ):
64
+ self._dataset_name = dataset_name
65
+ self._data_dir = data_dir
66
+ self._base_path = base_path or os.path.abspath(".")
67
+ self.download_config = download_config or DownloadConfig()
68
+ self.downloaded_size = None
69
+ self.record_checksums = False
70
+
71
+ @property
72
+ def manual_dir(self):
73
+ return self._data_dir
74
+
75
+ def download(self, url_or_urls):
76
+ """Normalize URL(s) of files to stream data from.
77
+ This is the lazy version of `DownloadManager.download` for streaming.
78
+
79
+ Args:
80
+ url_or_urls (`str` or `list` or `dict`):
81
+ URL(s) of files to stream data from. Each url is a `str`.
82
+
83
+ Returns:
84
+ url(s): (`str` or `list` or `dict`), URL(s) to stream data from matching the given input url_or_urls.
85
+
86
+ Example:
87
+
88
+ ```py
89
+ >>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
90
+ ```
91
+ """
92
+ url_or_urls = map_nested(self._download_single, url_or_urls, map_tuple=True)
93
+ return url_or_urls
94
+
95
+ def _download_single(self, urlpath: str) -> str:
96
+ urlpath = str(urlpath)
97
+ if is_relative_path(urlpath):
98
+ # append the relative path to the base_path
99
+ urlpath = url_or_path_join(self._base_path, urlpath)
100
+ return urlpath
101
+
102
+ def extract(self, url_or_urls):
103
+ """Add extraction protocol for given url(s) for streaming.
104
+
105
+ This is the lazy version of `DownloadManager.extract` for streaming.
106
+
107
+ Args:
108
+ url_or_urls (`str` or `list` or `dict`):
109
+ URL(s) of files to stream data from. Each url is a `str`.
110
+
111
+ Returns:
112
+ url(s): (`str` or `list` or `dict`), URL(s) to stream data from matching the given input `url_or_urls`.
113
+
114
+ Example:
115
+
116
+ ```py
117
+ >>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
118
+ >>> extracted_files = dl_manager.extract(downloaded_files)
119
+ ```
120
+ """
121
+ urlpaths = map_nested(self._extract, url_or_urls, map_tuple=True)
122
+ return urlpaths
123
+
124
+ def _extract(self, urlpath: str) -> str:
125
+ urlpath = str(urlpath)
126
+ protocol = _get_extraction_protocol(urlpath, download_config=self.download_config)
127
+ # get inner file: zip://train-00000.json.gz::https://foo.bar/data.zip -> zip://train-00000.json.gz
128
+ path = urlpath.split("::")[0]
129
+ extension = _get_path_extension(path)
130
+ if extension in ["tgz", "tar"] or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz")):
131
+ raise NotImplementedError(
132
+ f"Extraction protocol for TAR archives like '{urlpath}' is not implemented in streaming mode. "
133
+ f"Please use `dl_manager.iter_archive` instead.\n\n"
134
+ f"Example usage:\n\n"
135
+ f"\turl = dl_manager.download(url)\n"
136
+ f"\ttar_archive_iterator = dl_manager.iter_archive(url)\n\n"
137
+ f"\tfor filename, file in tar_archive_iterator:\n"
138
+ f"\t\t..."
139
+ )
140
+ if protocol is None:
141
+ # no extraction
142
+ return urlpath
143
+ elif protocol in SINGLE_FILE_COMPRESSION_PROTOCOLS:
144
+ # there is one single file which is the uncompressed file
145
+ inner_file = os.path.basename(urlpath.split("::")[0])
146
+ inner_file = inner_file[: inner_file.rindex(".")] if "." in inner_file else inner_file
147
+ return f"{protocol}://{inner_file}::{urlpath}"
148
+ else:
149
+ return f"{protocol}://::{urlpath}"
150
+
151
+ def download_and_extract(self, url_or_urls):
152
+ """Prepare given `url_or_urls` for streaming (add extraction protocol).
153
+
154
+ This is the lazy version of `DownloadManager.download_and_extract` for streaming.
155
+
156
+ Is equivalent to:
157
+
158
+ ```
159
+ urls = dl_manager.extract(dl_manager.download(url_or_urls))
160
+ ```
161
+
162
+ Args:
163
+ url_or_urls (`str` or `list` or `dict`):
164
+ URL(s) to stream from data from. Each url is a `str`.
165
+
166
+ Returns:
167
+ url(s): (`str` or `list` or `dict`), URL(s) to stream data from matching the given input `url_or_urls`.
168
+ """
169
+ return self.extract(self.download(url_or_urls))
170
+
171
+ def iter_archive(self, urlpath_or_buf: Union[str, io.BufferedReader]) -> Iterable[tuple]:
172
+ """Iterate over files within an archive.
173
+
174
+ Args:
175
+ urlpath_or_buf (`str` or `io.BufferedReader`):
176
+ Archive path or archive binary file object.
177
+
178
+ Yields:
179
+ `tuple[str, io.BufferedReader]`:
180
+ 2-tuple (path_within_archive, file_object).
181
+ File object is opened in binary mode.
182
+
183
+ Example:
184
+
185
+ ```py
186
+ >>> archive = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
187
+ >>> files = dl_manager.iter_archive(archive)
188
+ ```
189
+ """
190
+
191
+ if hasattr(urlpath_or_buf, "read"):
192
+ return ArchiveIterable.from_buf(urlpath_or_buf)
193
+ else:
194
+ return ArchiveIterable.from_urlpath(urlpath_or_buf, download_config=self.download_config)
195
+
196
+ def iter_files(self, urlpaths: Union[str, list[str]]) -> Iterable[str]:
197
+ """Iterate over files.
198
+
199
+ Args:
200
+ urlpaths (`str` or `list` of `str`):
201
+ Root paths.
202
+
203
+ Yields:
204
+ str: File URL path.
205
+
206
+ Example:
207
+
208
+ ```py
209
+ >>> files = dl_manager.download_and_extract('https://huggingface.co/datasets/beans/resolve/main/data/train.zip')
210
+ >>> files = dl_manager.iter_files(files)
211
+ ```
212
+ """
213
+ return FilesIterable.from_urlpaths(urlpaths, download_config=self.download_config)
214
+
215
+ def manage_extracted_files(self):
216
+ pass
217
+
218
+ def get_recorded_sizes_checksums(self):
219
+ pass
.venv/lib/python3.11/site-packages/datasets/exceptions.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright 2023 The HuggingFace Authors.
3
+ from typing import Any, Optional, Union
4
+
5
+ from huggingface_hub import HfFileSystem
6
+
7
+ from . import config
8
+ from .table import CastError
9
+ from .utils.track import TrackedIterableFromGenerator, tracked_list, tracked_str
10
+
11
+
12
+ class DatasetsError(Exception):
13
+ """Base class for exceptions in this library."""
14
+
15
+
16
+ class DefunctDatasetError(DatasetsError):
17
+ """The dataset has been defunct."""
18
+
19
+
20
+ class FileNotFoundDatasetsError(DatasetsError, FileNotFoundError):
21
+ """FileNotFoundError raised by this library."""
22
+
23
+
24
+ class DataFilesNotFoundError(FileNotFoundDatasetsError):
25
+ """No (supported) data files found."""
26
+
27
+
28
+ class DatasetNotFoundError(FileNotFoundDatasetsError):
29
+ """Dataset not found.
30
+
31
+ Raised when trying to access:
32
+ - a missing dataset, or
33
+ - a private/gated dataset and the user is not authenticated.
34
+ """
35
+
36
+
37
+ class DatasetBuildError(DatasetsError):
38
+ pass
39
+
40
+
41
+ class ManualDownloadError(DatasetBuildError):
42
+ pass
43
+
44
+
45
+ class FileFormatError(DatasetBuildError):
46
+ pass
47
+
48
+
49
+ class DatasetGenerationError(DatasetBuildError):
50
+ pass
51
+
52
+
53
+ class DatasetGenerationCastError(DatasetGenerationError):
54
+ @classmethod
55
+ def from_cast_error(
56
+ cls,
57
+ cast_error: CastError,
58
+ builder_name: str,
59
+ gen_kwargs: dict[str, Any],
60
+ token: Optional[Union[bool, str]],
61
+ ) -> "DatasetGenerationCastError":
62
+ explanation_message = (
63
+ f"\n\nAll the data files must have the same columns, but at some point {cast_error.details()}"
64
+ )
65
+ formatted_tracked_gen_kwargs: list[str] = []
66
+ for gen_kwarg in gen_kwargs.values():
67
+ if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterableFromGenerator)):
68
+ continue
69
+ while (
70
+ isinstance(gen_kwarg, (tracked_list, TrackedIterableFromGenerator)) and gen_kwarg.last_item is not None
71
+ ):
72
+ gen_kwarg = gen_kwarg.last_item
73
+ if isinstance(gen_kwarg, tracked_str):
74
+ gen_kwarg = gen_kwarg.get_origin()
75
+ if isinstance(gen_kwarg, str) and gen_kwarg.startswith("hf://"):
76
+ resolved_path = HfFileSystem(endpoint=config.HF_ENDPOINT, token=token).resolve_path(gen_kwarg)
77
+ gen_kwarg = "hf://" + resolved_path.unresolve()
78
+ if "@" + resolved_path.revision in gen_kwarg:
79
+ gen_kwarg = (
80
+ gen_kwarg.replace("@" + resolved_path.revision, "", 1)
81
+ + f" (at revision {resolved_path.revision})"
82
+ )
83
+ formatted_tracked_gen_kwargs.append(str(gen_kwarg))
84
+ if formatted_tracked_gen_kwargs:
85
+ explanation_message += f"\n\nThis happened while the {builder_name} dataset builder was generating data using\n\n{', '.join(formatted_tracked_gen_kwargs)}"
86
+ help_message = "\n\nPlease either edit the data files to have matching columns, or separate them into different configurations (see docs at https://hf.co/docs/hub/datasets-manual-configuration#multiple-configurations)"
87
+ return cls("An error occurred while generating the dataset" + explanation_message + help_message)
88
+
89
+
90
+ class ChecksumVerificationError(DatasetsError):
91
+ """Error raised during checksums verifications of downloaded files."""
92
+
93
+
94
+ class UnexpectedDownloadedFileError(ChecksumVerificationError):
95
+ """Some downloaded files were not expected."""
96
+
97
+
98
+ class ExpectedMoreDownloadedFilesError(ChecksumVerificationError):
99
+ """Some files were supposed to be downloaded but were not."""
100
+
101
+
102
+ class NonMatchingChecksumError(ChecksumVerificationError):
103
+ """The downloaded file checksum don't match the expected checksum."""
104
+
105
+
106
+ class SplitsVerificationError(DatasetsError):
107
+ """Error raised during splits verifications."""
108
+
109
+
110
+ class UnexpectedSplitsError(SplitsVerificationError):
111
+ """The expected splits of the downloaded file is missing."""
112
+
113
+
114
+ class ExpectedMoreSplitsError(SplitsVerificationError):
115
+ """Some recorded splits are missing."""
116
+
117
+
118
+ class NonMatchingSplitsSizesError(SplitsVerificationError):
119
+ """The splits sizes don't match the expected splits sizes."""
.venv/lib/python3.11/site-packages/datasets/features/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "Audio",
3
+ "Array2D",
4
+ "Array3D",
5
+ "Array4D",
6
+ "Array5D",
7
+ "ClassLabel",
8
+ "Features",
9
+ "LargeList",
10
+ "List",
11
+ "Sequence",
12
+ "Value",
13
+ "Image",
14
+ "Translation",
15
+ "TranslationVariableLanguages",
16
+ "Video",
17
+ "Pdf",
18
+ "Nifti",
19
+ ]
20
+ from .audio import Audio
21
+ from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, List, Sequence, Value
22
+ from .image import Image
23
+ from .nifti import Nifti
24
+ from .pdf import Pdf
25
+ from .translation import Translation, TranslationVariableLanguages
26
+ from .video import Video
.venv/lib/python3.11/site-packages/datasets/features/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (951 Bytes). View file
 
.venv/lib/python3.11/site-packages/datasets/features/__pycache__/audio.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
.venv/lib/python3.11/site-packages/datasets/features/__pycache__/features.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11f718da0d59a989bb790b038f4b9473b5bed9942e92a1d9b2bc9e25b3e954a3
3
+ size 137118