koichi12 commited on
Commit
3eb4a70
·
verified ·
1 Parent(s): 36383c5

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/aggregate.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/arrow_block.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/memory_tracing.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/null_aggregate.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/numpy_support.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/pandas_block.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/progress_bar.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/split.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/audio_datasource.py +57 -0
  11. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/avro_datasource.py +42 -0
  12. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasink.py +129 -0
  13. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasource.py +156 -0
  14. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/binary_datasource.py +24 -0
  15. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/clickhouse_datasource.py +349 -0
  16. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/csv_datasink.py +36 -0
  17. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/databricks_uc_datasource.py +187 -0
  18. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/delta_sharing_datasource.py +126 -0
  19. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/hudi_datasource.py +87 -0
  20. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/huggingface_datasource.py +176 -0
  21. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/iceberg_datasource.py +261 -0
  22. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasink.py +24 -0
  23. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasource.py +175 -0
  24. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasink.py +36 -0
  25. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasource.py +154 -0
  26. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/lance_datasource.py +129 -0
  27. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasink.py +48 -0
  28. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasource.py +140 -0
  29. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasink.py +23 -0
  30. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasource.py +41 -0
  31. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_bulk_datasource.py +51 -0
  32. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasink.py +172 -0
  33. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasource.py +731 -0
  34. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/range_datasource.py +139 -0
  35. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/sql_datasink.py +35 -0
  36. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/text_datasource.py +42 -0
  37. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasink.py +205 -0
  38. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasource.py +434 -0
  39. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/torch_datasource.py +62 -0
  40. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/video_datasource.py +59 -0
  41. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasink.py +53 -0
  42. .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasource.py +385 -0
  43. .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__init__.py +0 -0
  44. .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/__init__.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/iterator_impl.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/stream_split_iterator.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/iterator_impl.py +41 -0
  48. .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/stream_split_iterator.py +290 -0
  49. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/__init__.py +0 -0
  50. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (191 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/aggregate.cpython-311.pyc ADDED
Binary file (22.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/arrow_block.cpython-311.pyc ADDED
Binary file (34.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/memory_tracing.cpython-311.pyc ADDED
Binary file (8.89 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/null_aggregate.cpython-311.pyc ADDED
Binary file (9.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/numpy_support.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/pandas_block.cpython-311.pyc ADDED
Binary file (38.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/progress_bar.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/split.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/audio_datasource.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import TYPE_CHECKING, Iterator, List, Union
3
+
4
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
5
+ from ray.data._internal.util import _check_import
6
+ from ray.data.block import Block
7
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
8
+
9
+ if TYPE_CHECKING:
10
+ import pyarrow
11
+
12
+
13
+ class AudioDatasource(FileBasedDatasource):
14
+ _FILE_EXTENSIONS = [
15
+ "mp3",
16
+ "wav",
17
+ "aac",
18
+ "flac",
19
+ "ogg",
20
+ "m4a",
21
+ "wma",
22
+ "alac",
23
+ "aiff",
24
+ "pcm",
25
+ "amr",
26
+ "opus",
27
+ "ra",
28
+ "rm",
29
+ "au",
30
+ "mid",
31
+ "midi",
32
+ "caf",
33
+ ]
34
+
35
+ def __init__(
36
+ self,
37
+ paths: Union[str, List[str]],
38
+ **file_based_datasource_kwargs,
39
+ ):
40
+ super().__init__(paths, **file_based_datasource_kwargs)
41
+
42
+ _check_import(self, module="soundfile", package="soundfile")
43
+
44
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
45
+ import soundfile
46
+
47
+ # `soundfile` doesn't support reading from a `pyarrow.NativeFile` directly, so
48
+ # we need to read the file into memory first.
49
+ stream = io.BytesIO(f.read())
50
+ amplitude, sample_rate = soundfile.read(stream, always_2d=True, dtype="float32")
51
+
52
+ # (amplitude, channels) -> (channels, amplitude)
53
+ amplitude = amplitude.transpose((1, 0))
54
+
55
+ builder = DelegatingBlockBuilder()
56
+ builder.add({"amplitude": amplitude, "sample_rate": sample_rate})
57
+ yield builder.build()
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/avro_datasource.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Iterator, List, Union
2
+
3
+ from ray.data._internal.output_buffer import BlockOutputBuffer
4
+ from ray.data._internal.util import _check_import
5
+ from ray.data.block import Block
6
+ from ray.data.context import DataContext
7
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
8
+
9
+ if TYPE_CHECKING:
10
+ import pyarrow
11
+
12
+
13
+ class AvroDatasource(FileBasedDatasource):
14
+ """A datasource that reads Avro files."""
15
+
16
+ _FILE_EXTENSIONS = ["avro"]
17
+
18
+ def __init__(
19
+ self,
20
+ paths: Union[str, List[str]],
21
+ **file_based_datasource_kwargs,
22
+ ):
23
+ super().__init__(paths, **file_based_datasource_kwargs)
24
+
25
+ _check_import(self, module="fastavro", package="fastavro")
26
+
27
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
28
+ import fastavro
29
+
30
+ # Read the Avro file. This assumes the Avro file includes its schema.
31
+ reader = fastavro.reader(f)
32
+
33
+ ctx = DataContext.get_current()
34
+ output_buffer = BlockOutputBuffer(ctx.target_max_block_size)
35
+ for record in reader:
36
+ output_buffer.add(record)
37
+ while output_buffer.has_next():
38
+ yield output_buffer.next()
39
+
40
+ output_buffer.finalize()
41
+ while output_buffer.has_next():
42
+ yield output_buffer.next()
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasink.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+ import uuid
6
+ from typing import Iterable, Optional
7
+
8
+ import pyarrow.parquet as pq
9
+
10
+ import ray
11
+ from ray.data._internal.execution.interfaces import TaskContext
12
+ from ray.data._internal.datasource import bigquery_datasource
13
+ from ray.data._internal.remote_fn import cached_remote_fn
14
+ from ray.data._internal.util import _check_import
15
+ from ray.data.block import Block, BlockAccessor
16
+ from ray.data.datasource.datasink import Datasink
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ DEFAULT_MAX_RETRY_CNT = 10
21
+ RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
22
+
23
+
24
+ class BigQueryDatasink(Datasink[None]):
25
+ def __init__(
26
+ self,
27
+ project_id: str,
28
+ dataset: str,
29
+ max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
30
+ overwrite_table: Optional[bool] = True,
31
+ ) -> None:
32
+ _check_import(self, module="google.cloud", package="bigquery")
33
+ _check_import(self, module="google.cloud", package="bigquery_storage")
34
+ _check_import(self, module="google.api_core", package="exceptions")
35
+
36
+ self.project_id = project_id
37
+ self.dataset = dataset
38
+ self.max_retry_cnt = max_retry_cnt
39
+ self.overwrite_table = overwrite_table
40
+
41
+ def on_write_start(self) -> None:
42
+ from google.api_core import exceptions
43
+
44
+ if self.project_id is None or self.dataset is None:
45
+ raise ValueError("project_id and dataset are required args")
46
+
47
+ # Set up datasets to write
48
+ client = bigquery_datasource._create_client(project_id=self.project_id)
49
+ dataset_id = self.dataset.split(".", 1)[0]
50
+ try:
51
+ client.get_dataset(dataset_id)
52
+ except exceptions.NotFound:
53
+ client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
54
+ logger.info("Created dataset " + dataset_id)
55
+
56
+ # Delete table if overwrite_table is True
57
+ if self.overwrite_table:
58
+ logger.info(
59
+ f"Attempting to delete table {self.dataset}"
60
+ + " if it already exists since kwarg overwrite_table = True."
61
+ )
62
+ client.delete_table(f"{self.project_id}.{self.dataset}", not_found_ok=True)
63
+ else:
64
+ logger.info(
65
+ f"The write will append to table {self.dataset}"
66
+ + " if it already exists since kwarg overwrite_table = False."
67
+ )
68
+
69
+ def write(
70
+ self,
71
+ blocks: Iterable[Block],
72
+ ctx: TaskContext,
73
+ ) -> None:
74
+ def _write_single_block(block: Block, project_id: str, dataset: str) -> None:
75
+ from google.api_core import exceptions
76
+ from google.cloud import bigquery
77
+
78
+ block = BlockAccessor.for_block(block).to_arrow()
79
+
80
+ client = bigquery_datasource._create_client(project=project_id)
81
+ job_config = bigquery.LoadJobConfig(autodetect=True)
82
+ job_config.source_format = bigquery.SourceFormat.PARQUET
83
+ job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
84
+
85
+ with tempfile.TemporaryDirectory() as temp_dir:
86
+ fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
87
+ pq.write_table(block, fp, compression="SNAPPY")
88
+
89
+ retry_cnt = 0
90
+ while retry_cnt <= self.max_retry_cnt:
91
+ with open(fp, "rb") as source_file:
92
+ job = client.load_table_from_file(
93
+ source_file, dataset, job_config=job_config
94
+ )
95
+ try:
96
+ logger.info(job.result())
97
+ break
98
+ except exceptions.Forbidden as e:
99
+ retry_cnt += 1
100
+ if retry_cnt > self.max_retry_cnt:
101
+ break
102
+ logger.info(
103
+ "A block write encountered a rate limit exceeded error"
104
+ + f" {retry_cnt} time(s). Sleeping to try again."
105
+ )
106
+ logging.debug(e)
107
+ time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
108
+
109
+ # Raise exception if retry_cnt exceeds max_retry_cnt
110
+ if retry_cnt > self.max_retry_cnt:
111
+ logger.info(
112
+ f"Maximum ({self.max_retry_cnt}) retry count exceeded. Ray"
113
+ + " will attempt to retry the block write via fault tolerance."
114
+ )
115
+ raise RuntimeError(
116
+ f"Write failed due to {retry_cnt}"
117
+ + " repeated API rate limit exceeded responses. Consider"
118
+ + " specifiying the max_retry_cnt kwarg with a higher value."
119
+ )
120
+
121
+ _write_single_block = cached_remote_fn(_write_single_block)
122
+
123
+ # Launch a remote task for each block within this write task
124
+ ray.get(
125
+ [
126
+ _write_single_block.remote(block, self.project_id, self.dataset)
127
+ for block in blocks
128
+ ]
129
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasource.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ from ray.data._internal.util import _check_import
5
+ from ray.data.block import Block, BlockMetadata
6
+ from ray.data.datasource.datasource import Datasource, ReadTask
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def _create_user_agent() -> str:
12
+ import ray
13
+
14
+ return f"ray/{ray.__version__}"
15
+
16
+
17
+ def _create_client_info():
18
+ from google.api_core.client_info import ClientInfo
19
+
20
+ return ClientInfo(
21
+ user_agent=_create_user_agent(),
22
+ )
23
+
24
+
25
+ def _create_client_info_gapic():
26
+ from google.api_core.gapic_v1.client_info import ClientInfo
27
+
28
+ return ClientInfo(
29
+ user_agent=_create_user_agent(),
30
+ )
31
+
32
+
33
+ def _create_client(project_id: str):
34
+ from google.cloud import bigquery
35
+
36
+ return bigquery.Client(
37
+ project=project_id,
38
+ client_info=_create_client_info(),
39
+ )
40
+
41
+
42
+ def _create_read_client():
43
+ from google.cloud import bigquery_storage
44
+
45
+ return bigquery_storage.BigQueryReadClient(
46
+ client_info=_create_client_info_gapic(),
47
+ )
48
+
49
+
50
+ class BigQueryDatasource(Datasource):
51
+ def __init__(
52
+ self,
53
+ project_id: str,
54
+ dataset: Optional[str] = None,
55
+ query: Optional[str] = None,
56
+ ):
57
+ _check_import(self, module="google.cloud", package="bigquery")
58
+ _check_import(self, module="google.cloud", package="bigquery_storage")
59
+ _check_import(self, module="google.api_core", package="exceptions")
60
+
61
+ self._project_id = project_id
62
+ self._dataset = dataset
63
+ self._query = query
64
+
65
+ if query is not None and dataset is not None:
66
+ raise ValueError(
67
+ "Query and dataset kwargs cannot both be provided "
68
+ + "(must be mutually exclusive)."
69
+ )
70
+
71
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
72
+ from google.cloud import bigquery_storage
73
+
74
+ def _read_single_partition(stream) -> Block:
75
+ client = _create_read_client()
76
+ reader = client.read_rows(stream.name)
77
+ return reader.to_arrow()
78
+
79
+ if self._query:
80
+ query_client = _create_client(project_id=self._project_id)
81
+ query_job = query_client.query(self._query)
82
+ query_job.result()
83
+ destination = str(query_job.destination)
84
+ dataset_id = destination.split(".")[-2]
85
+ table_id = destination.split(".")[-1]
86
+ else:
87
+ self._validate_dataset_table_exist(self._project_id, self._dataset)
88
+ dataset_id = self._dataset.split(".")[0]
89
+ table_id = self._dataset.split(".")[1]
90
+
91
+ bqs_client = _create_read_client()
92
+ table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}"
93
+
94
+ if parallelism == -1:
95
+ parallelism = None
96
+ requested_session = bigquery_storage.types.ReadSession(
97
+ table=table,
98
+ data_format=bigquery_storage.types.DataFormat.ARROW,
99
+ )
100
+ read_session = bqs_client.create_read_session(
101
+ parent=f"projects/{self._project_id}",
102
+ read_session=requested_session,
103
+ max_stream_count=parallelism,
104
+ )
105
+
106
+ read_tasks = []
107
+ logger.info("Created streams: " + str(len(read_session.streams)))
108
+ if len(read_session.streams) < parallelism:
109
+ logger.info(
110
+ "The number of streams created by the "
111
+ + "BigQuery Storage Read API is less than the requested "
112
+ + "parallelism due to the size of the dataset."
113
+ )
114
+
115
+ for stream in read_session.streams:
116
+ # Create a metadata block object to store schema, etc.
117
+ metadata = BlockMetadata(
118
+ num_rows=None,
119
+ size_bytes=None,
120
+ schema=None,
121
+ input_files=None,
122
+ exec_stats=None,
123
+ )
124
+
125
+ # Create the read task and pass the no-arg wrapper and metadata in
126
+ read_task = ReadTask(
127
+ lambda stream=stream: [_read_single_partition(stream)],
128
+ metadata,
129
+ )
130
+ read_tasks.append(read_task)
131
+
132
+ return read_tasks
133
+
134
+ def estimate_inmemory_data_size(self) -> Optional[int]:
135
+ return None
136
+
137
+ def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None:
138
+ from google.api_core import exceptions
139
+
140
+ client = _create_client(project_id=project_id)
141
+ dataset_id = dataset.split(".")[0]
142
+ try:
143
+ client.get_dataset(dataset_id)
144
+ except exceptions.NotFound:
145
+ raise ValueError(
146
+ "Dataset {} is not found. Please ensure that it exists.".format(
147
+ dataset_id
148
+ )
149
+ )
150
+
151
+ try:
152
+ client.get_table(dataset)
153
+ except exceptions.NotFound:
154
+ raise ValueError(
155
+ "Table {} is not found. Please ensure that it exists.".format(dataset)
156
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/binary_datasource.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ray.data._internal.arrow_block import ArrowBlockBuilder
4
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
5
+
6
+ if TYPE_CHECKING:
7
+ import pyarrow
8
+
9
+
10
+ class BinaryDatasource(FileBasedDatasource):
11
+ """Binary datasource, for reading and writing binary files."""
12
+
13
+ _COLUMN_NAME = "bytes"
14
+
15
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str):
16
+ data = f.readall()
17
+
18
+ builder = ArrowBlockBuilder()
19
+ item = {self._COLUMN_NAME: data}
20
+ builder.add(item)
21
+ yield builder.build()
22
+
23
+ def _rows_per_file(self):
24
+ return 1
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/clickhouse_datasource.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
4
+
5
+ from ray.data._internal.util import _check_import
6
+ from ray.data.block import Block, BlockAccessor, BlockMetadata
7
+ from ray.data.datasource.datasource import Datasource, ReadTask
8
+ from ray.util.annotations import DeveloperAPI
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def _is_filter_string_safe(filter_str: str) -> bool:
14
+ in_string = False
15
+ escape_next = False
16
+ for c in filter_str:
17
+ if in_string:
18
+ # If we're inside a string, check if we're closing it.
19
+ if c == "'" and not escape_next:
20
+ in_string = False
21
+ escape_next = (c == "\\") and not escape_next
22
+ else:
23
+ # If we're not in a string, entering one if we see a single quote
24
+ if c == "'":
25
+ in_string = True
26
+ escape_next = False
27
+ # Disallow semicolon if we're not in a string
28
+ elif c == ";":
29
+ return False
30
+ else:
31
+ escape_next = False
32
+ # If we end inside a string, it's suspicious, but let's allow
33
+ # it to be further validated by the DB. Just return True here.
34
+ return True
35
+
36
+
37
+ @DeveloperAPI
38
+ class ClickHouseDatasource(Datasource):
39
+ """
40
+ A Ray datasource for reading from ClickHouse.
41
+
42
+ Args:
43
+ table: Fully qualified table or view identifier (e.g.,
44
+ "default.table_name").
45
+ dsn: A string in DSN (Data Source Name) HTTP format (e.g.,
46
+ "clickhouse+http://username:password@host:8124/default").
47
+ For more information, see `ClickHouse Connection String doc
48
+ <https://clickhouse.com/docs/en/integrations/sql-clients/cli#connection_string>`_.
49
+ columns: Optional List of columns to select from the data source.
50
+ If no columns are specified, all columns will be selected by default.
51
+ filter: Optional SQL filter string that will be used in the
52
+ WHERE statement (e.g., "label = 2 AND text IS NOT NULL").
53
+ The filter must be valid for use in a ClickHouse SQL WHERE clause.
54
+ Note: Parallel reads are not currently supported when a filter is set.
55
+ Specifying a filter forces the parallelism to 1 to ensure deterministic
56
+ and consistent results. For more information, see
57
+ `ClickHouse SQL WHERE Clause doc
58
+ <https://clickhouse.com/docs/en/sql-reference/statements/select/where>`_.
59
+ order_by: Optional Tuple containing a list of columns to order by
60
+ and a boolean indicating the order. Note: order_by is required to
61
+ support parallelism.
62
+ client_settings: Optional ClickHouse server settings to be used with the
63
+ session/every request. For more information, see
64
+ `ClickHouse Client Settings doc
65
+ <https://clickhouse.com/docs/en/integrations/python#settings-argument>`_.
66
+ client_kwargs: Optional Additional keyword arguments to pass to the
67
+ ClickHouse client. For more information,
68
+ see `ClickHouse Core Settings doc
69
+ <https://clickhouse.com/docs/en/integrations/python#additional-options>`_.
70
+ """
71
+
72
+ NUM_SAMPLE_ROWS = 100
73
+ MIN_ROWS_PER_READ_TASK = 50
74
+ _BASE_QUERY = "SELECT {select_clause} FROM {table}"
75
+ _EXPLAIN_FILTERS_QUERY = "EXPLAIN SELECT 1 FROM {table} WHERE {filter_clause}"
76
+ _SIZE_ESTIMATE_QUERY = "SELECT SUM(byteSize(*)) AS estimate FROM ({query})"
77
+ _COUNT_ESTIMATE_QUERY = "SELECT COUNT(*) AS estimate FROM ({query})"
78
+ _SAMPLE_BLOCK_QUERY = "{query} LIMIT {limit_row_count}"
79
+ _FIRST_BLOCK_QUERY = """
80
+ {query}
81
+ FETCH FIRST {fetch_row_count} {fetch_row_or_rows} ONLY
82
+ """
83
+ _NEXT_BLOCK_QUERY = """
84
+ {query}
85
+ OFFSET {offset_row_count} {offset_row_or_rows}
86
+ FETCH NEXT {fetch_row_count} {fetch_row_or_rows} ONLY
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ table: str,
92
+ dsn: str,
93
+ columns: Optional[List[str]] = None,
94
+ filter: Optional[str] = None,
95
+ order_by: Optional[Tuple[List[str], bool]] = None,
96
+ client_settings: Optional[Dict[str, Any]] = None,
97
+ client_kwargs: Optional[Dict[str, Any]] = None,
98
+ ):
99
+ self._table = table
100
+ self._dsn = dsn
101
+ self._columns = columns
102
+ self._filter = filter
103
+ self._order_by = order_by
104
+ self._client_settings = client_settings or {}
105
+ self._client_kwargs = client_kwargs or {}
106
+ self._query = self._generate_query()
107
+
108
+ def _init_client(self):
109
+ _check_import(self, module="clickhouse_connect", package="clickhouse-connect")
110
+ import clickhouse_connect
111
+
112
+ return clickhouse_connect.get_client(
113
+ dsn=self._dsn,
114
+ settings=self._client_settings or {},
115
+ **self._client_kwargs or {},
116
+ )
117
+
118
+ def _validate_filter(self):
119
+ if not self._filter:
120
+ return
121
+ # Minimal lexical check (regex or manual approach for semicolons, etc.).
122
+ if not _is_filter_string_safe(self._filter):
123
+ raise ValueError(
124
+ f"Invalid characters outside of "
125
+ f"string literals in filter: {self._filter}"
126
+ )
127
+ # Test "EXPLAIN" query to confirm parse-ability and catch expression errors.
128
+ client = self._init_client()
129
+ try:
130
+ test_query = self._EXPLAIN_FILTERS_QUERY.format(
131
+ table=self._table,
132
+ filter_clause=self._filter,
133
+ )
134
+ client.query(test_query)
135
+ except Exception as e:
136
+ raise ValueError(
137
+ f"Invalid filter expression: {self._filter}. Error: {e}",
138
+ )
139
+ finally:
140
+ client.close()
141
+
142
+ def _generate_query(self) -> str:
143
+ query = self._BASE_QUERY.format(
144
+ select_clause=", ".join(self._columns) if self._columns else "*",
145
+ table=self._table,
146
+ )
147
+ if self._filter:
148
+ self._validate_filter()
149
+ query += f" WHERE {self._filter}"
150
+ if self._order_by:
151
+ columns, desc = self._order_by
152
+ direction = " DESC" if desc else ""
153
+ if len(columns) == 1:
154
+ query += f" ORDER BY {columns[0]}{direction}"
155
+ elif len(columns) > 1:
156
+ columns_clause = ", ".join(columns)
157
+ query += f" ORDER BY ({columns_clause}){direction}"
158
+ return query
159
+
160
+ def _build_block_query(self, limit_row_count: int, offset_row_count: int) -> str:
161
+ if offset_row_count == 0:
162
+ # The first block query is optimized to use FETCH FIRST clause
163
+ # with an OFFSET specified.
164
+ return self._FIRST_BLOCK_QUERY.format(
165
+ query=self._query,
166
+ fetch_row_count=limit_row_count,
167
+ fetch_row_or_rows="ROWS" if limit_row_count > 1 else "ROW",
168
+ )
169
+ # Subsequent block queries use OFFSET and FETCH NEXT clauses to read the
170
+ # next block of data.
171
+ return self._NEXT_BLOCK_QUERY.format(
172
+ query=self._query,
173
+ offset_row_count=offset_row_count,
174
+ offset_row_or_rows="ROWS" if offset_row_count > 1 else "ROW",
175
+ fetch_row_count=limit_row_count,
176
+ fetch_row_or_rows="ROWS" if limit_row_count > 1 else "ROW",
177
+ )
178
+
179
+ def _create_read_fn(
180
+ self,
181
+ query: str,
182
+ ) -> Callable[[], Iterable[Block]]:
183
+ def read_fn() -> Iterable[Block]:
184
+ return [self._execute_block_query(query)]
185
+
186
+ return read_fn
187
+
188
+ def _get_sampled_estimates(self):
189
+ if self._order_by is not None:
190
+ # If the query is ordered, we can use a FETCH clause to get a sample.
191
+ # This reduces the CPU overhead on ClickHouse and speeds up the
192
+ # estimation query.
193
+ query = self._FIRST_BLOCK_QUERY.format(
194
+ query=self._query,
195
+ fetch_row_count=self.NUM_SAMPLE_ROWS,
196
+ fetch_row_or_rows="ROWS" if self.NUM_SAMPLE_ROWS > 1 else "ROW",
197
+ )
198
+ else:
199
+ # If the query is not ordered, we need to use a LIMIT clause to
200
+ # get a sample.
201
+ query = self._SAMPLE_BLOCK_QUERY.format(
202
+ query=self._query,
203
+ limit_row_count=self.NUM_SAMPLE_ROWS,
204
+ )
205
+ sample_block_accessor = BlockAccessor.for_block(
206
+ self._execute_block_query(query)
207
+ )
208
+ estimated_size_bytes_per_row = math.ceil(
209
+ sample_block_accessor.size_bytes() / sample_block_accessor.num_rows()
210
+ )
211
+ sample_block_schema = sample_block_accessor.schema()
212
+ return estimated_size_bytes_per_row, sample_block_schema
213
+
214
+ def _get_estimate_count(self) -> Optional[int]:
215
+ return self._execute_estimate_query(self._COUNT_ESTIMATE_QUERY)
216
+
217
+ def _get_estimate_size(self) -> Optional[int]:
218
+ return self._execute_estimate_query(self._SIZE_ESTIMATE_QUERY)
219
+
220
+ def _execute_estimate_query(self, estimate_query: str) -> Optional[int]:
221
+ client = self._init_client()
222
+ try:
223
+ # Estimate queries wrap around the primary query, self._query.
224
+ # This allows us to use self._query as a sub-query to efficiently
225
+ # and accurately estimate the size or count of the result set.
226
+ query = estimate_query.format(query=self._query)
227
+ result = client.query(query)
228
+ if result and len(result.result_rows) > 0:
229
+ estimate = result.result_rows[0][0]
230
+ return int(estimate) if estimate is not None else None
231
+ except Exception as e:
232
+ logger.warning(f"Failed to execute estimate query: {e}")
233
+ finally:
234
+ client.close()
235
+ return None
236
+
237
+ def _execute_block_query(self, query: str) -> Block:
238
+ import pyarrow as pa
239
+
240
+ client = self._init_client()
241
+ try:
242
+ with client.query_arrow_stream(query) as stream:
243
+ record_batches = list(stream) # Collect all record batches
244
+ return pa.Table.from_batches(record_batches)
245
+ except Exception as e:
246
+ raise RuntimeError(f"Failed to execute block query: {e}")
247
+ finally:
248
+ client.close()
249
+
250
+ def estimate_inmemory_data_size(self) -> Optional[int]:
251
+ """
252
+ Estimate the in-memory data size for the query.
253
+
254
+ Returns:
255
+ Estimated in-memory data size in bytes, or
256
+ None if the estimation cannot be performed.
257
+ """
258
+ return self._get_estimate_size()
259
+
260
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
261
+ """
262
+ Create read tasks for the ClickHouse query.
263
+
264
+ Args:
265
+ parallelism: The desired number of partitions to read the data into.
266
+ - If ``order_by`` is not set, parallelism will be forced to 1.
267
+ - If ``filter`` is set, parallelism will also be forced to 1
268
+ to ensure deterministic results.
269
+
270
+ Returns:
271
+ A list of read tasks to be executed.
272
+ """
273
+ num_rows_total = self._get_estimate_count()
274
+ if num_rows_total == 0 or num_rows_total is None:
275
+ return []
276
+ parallelism = min(
277
+ parallelism, math.ceil(num_rows_total / self.MIN_ROWS_PER_READ_TASK)
278
+ )
279
+ # To ensure consistent order of query results, self._order_by
280
+ # must be specified and self.filter must not be specified
281
+ # in order to support parallelism.
282
+ if self._filter is not None and parallelism > 1:
283
+ logger.warning(
284
+ "ClickHouse datasource does not currently support parallel reads "
285
+ "when a filter is set; falling back to parallelism of 1."
286
+ )
287
+ # When filter is specified and parallelism is greater than 1,
288
+ # we need to reduce parallelism to 1 to ensure consistent results.
289
+ parallelism = 1
290
+ # To ensure consistent order of query results, self._order_by
291
+ # must be specified in order to support parallelism.
292
+ if self._order_by is None and parallelism > 1:
293
+ logger.warning(
294
+ "ClickHouse datasource requires dataset to be explicitly ordered "
295
+ "to support parallelism; falling back to parallelism of 1."
296
+ )
297
+ # When order_by is not specified and parallelism is greater than 1,
298
+ # we need to reduce parallelism to 1 to ensure consistent results.
299
+ parallelism = 1
300
+ # By reducing parallelism to 1 when either of the conditions above are met,
301
+ # we ensure the downstream process is treated exactly as a non-parallelized
302
+ # (single block) process would be, thus ensuring output consistency.
303
+ num_rows_per_block = num_rows_total // parallelism
304
+ num_blocks_with_extra_row = num_rows_total % parallelism
305
+ (
306
+ estimated_size_bytes_per_row,
307
+ sample_block_schema,
308
+ ) = self._get_sampled_estimates()
309
+
310
+ def _get_read_task(
311
+ block_rows: int, offset_rows: int, parallelized: bool
312
+ ) -> ReadTask:
313
+ if parallelized:
314
+ # When parallelized, we need to build a block query with OFFSET
315
+ # and FETCH clauses.
316
+ query = self._build_block_query(block_rows, offset_rows)
317
+ else:
318
+ # When not parallelized, we can use the original query without
319
+ # OFFSET and FETCH clauses.
320
+ query = self._query
321
+ return ReadTask(
322
+ self._create_read_fn(query),
323
+ BlockMetadata(
324
+ num_rows=block_rows,
325
+ size_bytes=estimated_size_bytes_per_row * block_rows,
326
+ schema=sample_block_schema,
327
+ input_files=None,
328
+ exec_stats=None,
329
+ ),
330
+ )
331
+
332
+ if parallelism == 1:
333
+ # When parallelism is 1, we can read the entire dataset in a single task.
334
+ # We then optimize this scenario by using self._query directly without
335
+ # unnecessary OFFSET and FETCH clauses.
336
+ return [_get_read_task(num_rows_total, 0, False)]
337
+
338
+ # Otherwise we need to split the dataset into multiple tasks.
339
+ # Each task will include OFFSET and FETCH clauses to efficiently
340
+ # read a subset of the dataset.
341
+ read_tasks = []
342
+ offset = 0
343
+ for i in range(parallelism):
344
+ this_block_size = num_rows_per_block
345
+ if i < num_blocks_with_extra_row:
346
+ this_block_size += 1
347
+ read_tasks.append(_get_read_task(this_block_size, offset, True))
348
+ offset += this_block_size
349
+ return read_tasks
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/csv_datasink.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, Optional
2
+
3
+ import pyarrow
4
+
5
+ from ray.data.block import BlockAccessor
6
+ from ray.data.datasource.file_based_datasource import _resolve_kwargs
7
+ from ray.data.datasource.file_datasink import BlockBasedFileDatasink
8
+
9
+
10
+ class CSVDatasink(BlockBasedFileDatasink):
11
+ def __init__(
12
+ self,
13
+ path: str,
14
+ *,
15
+ arrow_csv_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
16
+ arrow_csv_args: Optional[Dict[str, Any]] = None,
17
+ file_format="csv",
18
+ **file_datasink_kwargs,
19
+ ):
20
+ super().__init__(path, file_format=file_format, **file_datasink_kwargs)
21
+
22
+ if arrow_csv_args_fn is None:
23
+ arrow_csv_args_fn = lambda: {} # noqa: E731
24
+
25
+ if arrow_csv_args is None:
26
+ arrow_csv_args = {}
27
+
28
+ self.arrow_csv_args_fn = arrow_csv_args_fn
29
+ self.arrow_csv_args = arrow_csv_args
30
+
31
+ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
32
+ from pyarrow import csv
33
+
34
+ writer_args = _resolve_kwargs(self.arrow_csv_args_fn, **self.arrow_csv_args)
35
+ write_options = writer_args.pop("write_options", None)
36
+ csv.write_csv(block.to_arrow(), file, write_options, **writer_args)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/databricks_uc_datasource.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import List, Optional
6
+ from urllib.parse import urljoin
7
+
8
+ import numpy as np
9
+ import pyarrow
10
+ import requests
11
+
12
+ from ray.data.block import BlockMetadata
13
+ from ray.data.datasource.datasource import Datasource, ReadTask
14
+ from ray.util.annotations import PublicAPI
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ _STATEMENT_EXEC_POLL_TIME_S = 1
20
+
21
+
22
+ @PublicAPI(stability="alpha")
23
+ class DatabricksUCDatasource(Datasource):
24
+ def __init__(
25
+ self,
26
+ host: str,
27
+ token: str,
28
+ warehouse_id: str,
29
+ catalog: str,
30
+ schema: str,
31
+ query: str,
32
+ ):
33
+ self.host = host
34
+ self.token = token
35
+ self.warehouse_id = warehouse_id
36
+ self.catalog = catalog
37
+ self.schema = schema
38
+ self.query = query
39
+
40
+ url_base = f"https://{self.host}/api/2.0/sql/statements/"
41
+
42
+ payload = json.dumps(
43
+ {
44
+ "statement": self.query,
45
+ "warehouse_id": self.warehouse_id,
46
+ "wait_timeout": "0s",
47
+ "disposition": "EXTERNAL_LINKS",
48
+ "format": "ARROW_STREAM",
49
+ "catalog": self.catalog,
50
+ "schema": self.schema,
51
+ }
52
+ )
53
+
54
+ req_headers = {
55
+ "Content-Type": "application/json",
56
+ "Authorization": "Bearer " + self.token,
57
+ }
58
+
59
+ response = requests.post(
60
+ url_base,
61
+ headers=req_headers,
62
+ data=payload,
63
+ )
64
+ response.raise_for_status()
65
+ statement_id = response.json()["statement_id"]
66
+
67
+ state = response.json()["status"]["state"]
68
+
69
+ logger.info(f"Waiting for query {query!r} execution result.")
70
+ try:
71
+ while state in ["PENDING", "RUNNING"]:
72
+ time.sleep(_STATEMENT_EXEC_POLL_TIME_S)
73
+ response = requests.get(
74
+ urljoin(url_base, statement_id) + "/",
75
+ headers=req_headers,
76
+ )
77
+ response.raise_for_status()
78
+ state = response.json()["status"]["state"]
79
+ except KeyboardInterrupt:
80
+ # User cancel the command, so we cancel query execution.
81
+ requests.post(
82
+ urljoin(url_base, f"{statement_id}/cancel"),
83
+ headers=req_headers,
84
+ )
85
+ try:
86
+ response.raise_for_status()
87
+ except Exception as e:
88
+ logger.warning(
89
+ f"Canceling query {query!r} execution failed, reason: {repr(e)}."
90
+ )
91
+ raise
92
+
93
+ if state != "SUCCEEDED":
94
+ raise RuntimeError(f"Query {self.query!r} execution failed.")
95
+
96
+ manifest = response.json()["manifest"]
97
+ is_truncated = manifest["truncated"]
98
+
99
+ if is_truncated:
100
+ logger.warning(
101
+ f"The resulting size of the dataset of '{query!r}' exceeds "
102
+ "100GiB and it is truncated."
103
+ )
104
+
105
+ chunks = manifest["chunks"]
106
+
107
+ # Make chunks metadata are ordered by index.
108
+ chunks = sorted(chunks, key=lambda x: x["chunk_index"])
109
+ num_chunks = len(chunks)
110
+ self.num_chunks = num_chunks
111
+ self._estimate_inmemory_data_size = sum(chunk["byte_count"] for chunk in chunks)
112
+
113
+ def get_read_task(task_index, parallelism):
114
+ # get chunk list to be read in this task and preserve original chunk order
115
+ chunk_index_list = list(
116
+ np.array_split(range(num_chunks), parallelism)[task_index]
117
+ )
118
+
119
+ num_rows = sum(
120
+ chunks[chunk_index]["row_count"] for chunk_index in chunk_index_list
121
+ )
122
+ size_bytes = sum(
123
+ chunks[chunk_index]["byte_count"] for chunk_index in chunk_index_list
124
+ )
125
+
126
+ metadata = BlockMetadata(
127
+ num_rows=num_rows,
128
+ size_bytes=size_bytes,
129
+ schema=None,
130
+ input_files=None,
131
+ exec_stats=None,
132
+ )
133
+
134
+ def _read_fn():
135
+ for chunk_index in chunk_index_list:
136
+ resolve_external_link_url = urljoin(
137
+ url_base, f"{statement_id}/result/chunks/{chunk_index}"
138
+ )
139
+
140
+ resolve_response = requests.get(
141
+ resolve_external_link_url, headers=req_headers
142
+ )
143
+ resolve_response.raise_for_status()
144
+ external_url = resolve_response.json()["external_links"][0][
145
+ "external_link"
146
+ ]
147
+ # NOTE: do _NOT_ send the authorization header to external urls
148
+ raw_response = requests.get(external_url, auth=None, headers=None)
149
+ raw_response.raise_for_status()
150
+
151
+ with pyarrow.ipc.open_stream(raw_response.content) as reader:
152
+ arrow_table = reader.read_all()
153
+
154
+ yield arrow_table
155
+
156
+ def read_fn():
157
+ if mock_setup_fn_path := os.environ.get(
158
+ "RAY_DATABRICKS_UC_DATASOURCE_READ_FN_MOCK_TEST_SETUP_FN_PATH"
159
+ ):
160
+ import ray.cloudpickle as pickle
161
+
162
+ # This is for testing.
163
+ with open(mock_setup_fn_path, "rb") as f:
164
+ mock_setup = pickle.load(f)
165
+ with mock_setup():
166
+ yield from _read_fn()
167
+ else:
168
+ yield from _read_fn()
169
+
170
+ return ReadTask(read_fn=read_fn, metadata=metadata)
171
+
172
+ self._get_read_task = get_read_task
173
+
174
+ def estimate_inmemory_data_size(self) -> Optional[int]:
175
+ return self._estimate_inmemory_data_size
176
+
177
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
178
+ assert parallelism > 0, f"Invalid parallelism {parallelism}"
179
+
180
+ if parallelism > self.num_chunks:
181
+ parallelism = self.num_chunks
182
+ logger.info(
183
+ "The parallelism is reduced to chunk number due to "
184
+ "insufficient chunk parallelism."
185
+ )
186
+
187
+ return [self._get_read_task(index, parallelism) for index in range(parallelism)]
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/delta_sharing_datasource.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from json import loads
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from ray.data._internal.util import _check_import
8
+ from ray.data.block import BlockMetadata
9
+ from ray.data.datasource.datasource import Datasource, ReadTask
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class DeltaSharingDatasource(Datasource):
15
+ def __init__(
16
+ self,
17
+ url: str,
18
+ json_predicate_hints: Optional[str] = None,
19
+ limit: Optional[int] = None,
20
+ version: Optional[int] = None,
21
+ timestamp: Optional[str] = None,
22
+ ):
23
+ _check_import(self, module="delta_sharing", package="delta-sharing")
24
+
25
+ if limit is not None:
26
+ assert (
27
+ isinstance(limit, int) and limit >= 0
28
+ ), "'limit' must be a non-negative int"
29
+
30
+ self._url = url
31
+ self._json_predicate_hints = json_predicate_hints
32
+ self._limit = limit
33
+ self._version = version
34
+ self._timestamp = timestamp
35
+
36
+ def estimate_inmemory_data_size(self) -> Optional[int]:
37
+ return None
38
+
39
+ def _read_files(self, files, converters):
40
+ """Read files with Delta Sharing."""
41
+ from delta_sharing.reader import DeltaSharingReader
42
+
43
+ for file in files:
44
+ yield DeltaSharingReader._to_pandas(
45
+ action=file, converters=converters, for_cdf=False, limit=None
46
+ )
47
+
48
+ def setup_delta_sharing_connections(self, url: str):
49
+ """
50
+ Set up delta sharing connections based on the url.
51
+
52
+ :param url: a url under the format "<profile>#<share>.<schema>.<table>"
53
+ :
54
+ """
55
+ from delta_sharing.protocol import DeltaSharingProfile, Table
56
+ from delta_sharing.rest_client import DataSharingRestClient
57
+
58
+ profile_str, share, schema, table_str = _parse_delta_sharing_url(url)
59
+ table = Table(name=table_str, share=share, schema=schema)
60
+
61
+ profile = DeltaSharingProfile.read_from_file(profile_str)
62
+ rest_client = DataSharingRestClient(profile)
63
+ return table, rest_client
64
+
65
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
66
+ assert parallelism > 0, f"Invalid parallelism {parallelism}"
67
+ from delta_sharing.converter import to_converters
68
+
69
+ self._table, self._rest_client = self.setup_delta_sharing_connections(self._url)
70
+ self._response = self._rest_client.list_files_in_table(
71
+ self._table,
72
+ jsonPredicateHints=self._json_predicate_hints,
73
+ limitHint=self._limit,
74
+ version=self._version,
75
+ timestamp=self._timestamp,
76
+ )
77
+
78
+ if len(self._response.add_files) == 0 or self._limit == 0:
79
+ logger.warning("No files found from the delta sharing table or limit is 0")
80
+
81
+ schema_json = loads(self._response.metadata.schema_string)
82
+ self._converters = to_converters(schema_json)
83
+
84
+ read_tasks = []
85
+ # get file list to be read in this task and preserve original chunk order
86
+ for files in np.array_split(self._response.add_files, parallelism):
87
+ files = files.tolist()
88
+ metadata = BlockMetadata(
89
+ num_rows=None,
90
+ schema=None,
91
+ input_files=files,
92
+ size_bytes=None,
93
+ exec_stats=None,
94
+ )
95
+ converters = self._converters
96
+ read_task = ReadTask(
97
+ lambda f=files: self._read_files(f, converters),
98
+ metadata,
99
+ )
100
+ read_tasks.append(read_task)
101
+
102
+ return read_tasks
103
+
104
+
105
+ def _parse_delta_sharing_url(url: str) -> Tuple[str, str, str, str]:
106
+ """
107
+ Developed from delta_sharing's _parse_url function.
108
+ https://github.com/delta-io/delta-sharing/blob/main/python/delta_sharing/delta_sharing.py#L36
109
+
110
+ Args:
111
+ url: a url under the format "<profile>#<share>.<schema>.<table>"
112
+
113
+ Returns:
114
+ a tuple with parsed (profile, share, schema, table)
115
+ """
116
+ shape_index = url.rfind("#")
117
+ if shape_index < 0:
118
+ raise ValueError(f"Invalid 'url': {url}")
119
+ profile = url[0:shape_index]
120
+ fragments = url[shape_index + 1 :].split(".")
121
+ if len(fragments) != 3:
122
+ raise ValueError(f"Invalid 'url': {url}")
123
+ share, schema, table = fragments
124
+ if len(profile) == 0 or len(share) == 0 or len(schema) == 0 or len(table) == 0:
125
+ raise ValueError(f"Invalid 'url': {url}")
126
+ return (profile, share, schema, table)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/hudi_datasource.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Dict, Iterator, List, Optional
4
+
5
+ from ray.data._internal.util import _check_import
6
+ from ray.data.block import BlockMetadata
7
+ from ray.data.datasource.datasource import Datasource, ReadTask
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class HudiDatasource(Datasource):
13
+ """Hudi datasource, for reading Apache Hudi table."""
14
+
15
+ def __init__(
16
+ self,
17
+ table_uri: str,
18
+ storage_options: Optional[Dict[str, str]] = None,
19
+ ):
20
+ _check_import(self, module="hudi", package="hudi-python")
21
+
22
+ self._table_uri = table_uri
23
+ self._storage_options = storage_options
24
+
25
+ def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
26
+ import pyarrow
27
+ from hudi import HudiTable
28
+
29
+ def _perform_read(
30
+ table_uri: str,
31
+ base_file_paths: List[str],
32
+ options: Dict[str, str],
33
+ ) -> Iterator["pyarrow.Table"]:
34
+ from hudi import HudiFileGroupReader
35
+
36
+ for p in base_file_paths:
37
+ file_group_reader = HudiFileGroupReader(table_uri, options)
38
+ batch = file_group_reader.read_file_slice_by_base_file_path(p)
39
+ yield pyarrow.Table.from_batches([batch])
40
+
41
+ hudi_table = HudiTable(self._table_uri, self._storage_options)
42
+
43
+ reader_options = {
44
+ **hudi_table.storage_options(),
45
+ **hudi_table.hudi_options(),
46
+ }
47
+
48
+ schema = hudi_table.get_schema()
49
+ read_tasks = []
50
+ for file_slices_split in hudi_table.get_file_slices_splits(parallelism):
51
+ num_rows = 0
52
+ relative_paths = []
53
+ input_files = []
54
+ size_bytes = 0
55
+ for file_slice in file_slices_split:
56
+ # A file slice in a Hudi table is a logical group of data files
57
+ # within a physical partition. Records stored in a file slice
58
+ # are associated with a commit on the Hudi table's timeline.
59
+ # For more info, see https://hudi.apache.org/docs/file_layouts
60
+ num_rows += file_slice.num_records
61
+ relative_path = file_slice.base_file_relative_path()
62
+ relative_paths.append(relative_path)
63
+ full_path = os.path.join(self._table_uri, relative_path)
64
+ input_files.append(full_path)
65
+ size_bytes += file_slice.base_file_size
66
+
67
+ metadata = BlockMetadata(
68
+ num_rows=num_rows,
69
+ schema=schema,
70
+ input_files=input_files,
71
+ size_bytes=size_bytes,
72
+ exec_stats=None,
73
+ )
74
+
75
+ read_task = ReadTask(
76
+ read_fn=lambda paths=relative_paths: _perform_read(
77
+ self._table_uri, paths, reader_options
78
+ ),
79
+ metadata=metadata,
80
+ )
81
+ read_tasks.append(read_task)
82
+
83
+ return read_tasks
84
+
85
+ def estimate_inmemory_data_size(self) -> Optional[int]:
86
+ # TODO(xushiyan) add APIs to provide estimated in-memory size
87
+ return None
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/huggingface_datasource.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import TYPE_CHECKING, Iterable, List, Optional, Union
3
+
4
+ from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
5
+ from ray.data._internal.util import _check_pyarrow_version
6
+ from ray.data.block import Block, BlockAccessor, BlockMetadata
7
+ from ray.data.dataset import Dataset
8
+ from ray.data.datasource import Datasource, ReadTask
9
+
10
+ if TYPE_CHECKING:
11
+ import datasets
12
+
13
+
14
+ TRANSFORMERS_IMPORT_ERROR: Optional[ImportError] = None
15
+
16
+ try:
17
+ # Due to HF Dataset's dynamic module system, we need to dynamically import the
18
+ # datasets_modules module on every actor when training.
19
+ # We accomplish this by simply running the following bit of code directly
20
+ # in the module you are currently viewing. This ensures that when we
21
+ # unpickle the Dataset, it runs before pickle tries to
22
+ # import datasets_modules and prevents an exception from being thrown.
23
+ # Same logic is present inside HF Transformers Ray
24
+ # integration: https://github.com/huggingface/transformers/blob/\
25
+ # 7d5fde991d598370d961be8cb7add6541e2b59ce/src/transformers/integrations.py#L271
26
+ # Also see https://github.com/ray-project/ray/issues/28084
27
+ from transformers.utils import is_datasets_available
28
+
29
+ if "datasets_modules" not in sys.modules and is_datasets_available():
30
+ import importlib
31
+ import os
32
+
33
+ import datasets.load
34
+
35
+ dynamic_modules_path = os.path.join(
36
+ datasets.load.init_dynamic_modules(), "__init__.py"
37
+ )
38
+ # load dynamic_modules from path
39
+ spec = importlib.util.spec_from_file_location(
40
+ "datasets_modules", dynamic_modules_path
41
+ )
42
+ datasets_modules = importlib.util.module_from_spec(spec)
43
+ sys.modules[spec.name] = datasets_modules
44
+ spec.loader.exec_module(datasets_modules)
45
+ except ImportError as e:
46
+ TRANSFORMERS_IMPORT_ERROR = e
47
+
48
+
49
+ class HuggingFaceDatasource(Datasource):
50
+ """Hugging Face Dataset datasource, for reading from a
51
+ `Hugging Face Datasets Dataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset/>`_.
52
+ This Datasource implements a streamed read using a
53
+ single read task, most beneficial for a
54
+ `Hugging Face Datasets IterableDataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDataset/>`_
55
+ or datasets which are too large to fit in-memory.
56
+ For an in-memory Hugging Face Dataset (`datasets.Dataset`), use :meth:`~ray.data.from_huggingface`
57
+ directly for faster performance.
58
+ """ # noqa: E501
59
+
60
+ def __init__(
61
+ self,
62
+ dataset: Union["datasets.Dataset", "datasets.IterableDataset"],
63
+ batch_size: int = 4096,
64
+ ):
65
+ if TRANSFORMERS_IMPORT_ERROR is not None:
66
+ raise TRANSFORMERS_IMPORT_ERROR
67
+
68
+ self._dataset = dataset
69
+ self._batch_size = batch_size
70
+
71
+ @classmethod
72
+ def list_parquet_urls_from_dataset(
73
+ cls, dataset: Union["datasets.Dataset", "datasets.IterableDataset"]
74
+ ) -> Dataset:
75
+ """Return list of Hugging Face hosted parquet file URLs if they
76
+ exist for the data (i.e. if the dataset is a public dataset that
77
+ has not been transformed) else return an empty list."""
78
+ import datasets
79
+
80
+ # We can use the dataset name, config name, and split name to load
81
+ # public hugging face datasets from the Hugging Face Hub. More info
82
+ # here: https://huggingface.co/docs/datasets-server/parquet
83
+ dataset_name = dataset.info.dataset_name
84
+ config_name = dataset.info.config_name
85
+ split_name = str(dataset.split)
86
+
87
+ # If a dataset is not an iterable dataset, we will check if the
88
+ # dataset with the matching dataset name, config name, and split name
89
+ # on the Hugging Face Hub has the same fingerprint as the
90
+ # dataset passed into this function. If it is not matching, transforms
91
+ # or other operations have been performed so we cannot use the parquet
92
+ # files on the Hugging Face Hub, so we return an empty list.
93
+ if not isinstance(dataset, datasets.IterableDataset):
94
+ from datasets import load_dataset
95
+
96
+ try:
97
+ ds = load_dataset(dataset_name, config_name, split=split_name)
98
+ if ds._fingerprint != dataset._fingerprint:
99
+ return []
100
+ except Exception:
101
+ # If an exception is thrown when trying to reload the dataset
102
+ # we should exit gracefully by returning an empty list.
103
+ return []
104
+
105
+ import requests
106
+
107
+ public_url = (
108
+ f"https://huggingface.co/api/datasets/{dataset_name}"
109
+ f"/parquet/{config_name}/{split_name}"
110
+ )
111
+ resp = requests.get(public_url)
112
+ if resp.status_code == requests.codes["ok"]:
113
+ # dataset corresponds to a public dataset, return list of parquet_files
114
+ return resp.json()
115
+ else:
116
+ return []
117
+
118
+ def estimate_inmemory_data_size(self) -> Optional[int]:
119
+ return self._dataset.dataset_size
120
+
121
+ def get_read_tasks(
122
+ self,
123
+ parallelism: int,
124
+ ) -> List[ReadTask]:
125
+ # Note: `parallelism` arg is currently not used by HuggingFaceDatasource.
126
+ # We always generate a single ReadTask to perform the read.
127
+ _check_pyarrow_version()
128
+ import numpy as np
129
+ import pandas as pd
130
+ import pyarrow
131
+
132
+ def _read_dataset(dataset: "datasets.IterableDataset") -> Iterable[Block]:
133
+ for batch in dataset.with_format("arrow").iter(batch_size=self._batch_size):
134
+ # HuggingFace IterableDatasets do not fully support methods like
135
+ # `set_format`, `with_format`, and `formatted_as`, so the dataset
136
+ # can return whatever is the default configured batch type, even if
137
+ # the format is manually overriden before iterating above.
138
+ # Therefore, we limit support to batch formats which have native
139
+ # block types in Ray Data (pyarrow.Table, pd.DataFrame),
140
+ # or can easily be converted to such (dict, np.array).
141
+ # See: https://github.com/huggingface/datasets/issues/3444
142
+ if not isinstance(batch, (pyarrow.Table, pd.DataFrame, dict, np.array)):
143
+ raise ValueError(
144
+ f"Batch format {type(batch)} isn't supported. Only the "
145
+ f"following batch formats are supported: "
146
+ f"dict (corresponds to `None` in `dataset.with_format()`), "
147
+ f"pyarrow.Table, np.array, pd.DataFrame."
148
+ )
149
+ # Ensure np.arrays are wrapped in a dict
150
+ # (subsequently converted to a pyarrow.Table).
151
+ if isinstance(batch, np.ndarray):
152
+ batch = {"item": batch}
153
+ if isinstance(batch, dict):
154
+ batch = pyarrow_table_from_pydict(batch)
155
+ # Ensure that we return the default block type.
156
+ block = BlockAccessor.for_block(batch).to_default()
157
+ yield block
158
+
159
+ # TODO(scottjlee): IterableDataset doesn't provide APIs
160
+ # for getting number of rows, byte size, etc., so the
161
+ # BlockMetadata is currently empty. Properly retrieve
162
+ # or calculate these so that progress bars have meaning.
163
+ meta = BlockMetadata(
164
+ num_rows=None,
165
+ size_bytes=None,
166
+ schema=None,
167
+ input_files=None,
168
+ exec_stats=None,
169
+ )
170
+ read_tasks: List[ReadTask] = [
171
+ ReadTask(
172
+ lambda hfds=self._dataset: _read_dataset(hfds),
173
+ meta,
174
+ )
175
+ ]
176
+ return read_tasks
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/iceberg_datasource.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module to read an iceberg table into a Ray Dataset, by using the Ray Datasource API.
3
+ """
4
+
5
+ import heapq
6
+ import itertools
7
+ import logging
8
+ from functools import partial
9
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
10
+
11
+ from ray.data._internal.util import _check_import
12
+ from ray.data.block import Block, BlockMetadata
13
+ from ray.data.datasource.datasource import Datasource, ReadTask
14
+ from ray.util.annotations import DeveloperAPI
15
+
16
+ if TYPE_CHECKING:
17
+ from pyiceberg.catalog import Catalog
18
+ from pyiceberg.expressions import BooleanExpression
19
+ from pyiceberg.io import FileIO
20
+ from pyiceberg.manifest import DataFile
21
+ from pyiceberg.schema import Schema
22
+ from pyiceberg.table import DataScan, FileScanTask, Table
23
+ from pyiceberg.table.metadata import TableMetadata
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def _get_read_task(
29
+ tasks: Iterable["FileScanTask"],
30
+ table_io: "FileIO",
31
+ table_metadata: "TableMetadata",
32
+ row_filter: "BooleanExpression",
33
+ case_sensitive: bool,
34
+ limit: Optional[int],
35
+ schema: "Schema",
36
+ ) -> Iterable[Block]:
37
+ from pyiceberg.io import pyarrow as pyi_pa_io
38
+
39
+ # Use the PyIceberg API to read only a single task (specifically, a
40
+ # FileScanTask) - note that this is not as simple as reading a single
41
+ # parquet file, as there might be delete files, etc. associated, so we
42
+ # must use the PyIceberg API for the projection.
43
+ yield pyi_pa_io.project_table(
44
+ tasks=tasks,
45
+ table_metadata=table_metadata,
46
+ io=table_io,
47
+ row_filter=row_filter,
48
+ projected_schema=schema,
49
+ case_sensitive=case_sensitive,
50
+ limit=limit,
51
+ )
52
+
53
+
54
+ @DeveloperAPI
55
+ class IcebergDatasource(Datasource):
56
+ """
57
+ Iceberg datasource to read Iceberg tables into a Ray Dataset. This module heavily
58
+ uses PyIceberg to read iceberg tables. All the routines in this class override
59
+ `ray.data.Datasource`.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ table_identifier: str,
65
+ row_filter: Union[str, "BooleanExpression"] = None,
66
+ selected_fields: Tuple[str, ...] = ("*",),
67
+ snapshot_id: Optional[int] = None,
68
+ scan_kwargs: Optional[Dict[str, Any]] = None,
69
+ catalog_kwargs: Optional[Dict[str, Any]] = None,
70
+ ):
71
+ """
72
+ Initialize an IcebergDatasource.
73
+
74
+ Args:
75
+ table_identifier: Fully qualified table identifier (i.e.,
76
+ "db_name.table_name")
77
+ row_filter: A PyIceberg BooleanExpression to use to filter the data *prior*
78
+ to reading
79
+ selected_fields: Which columns from the data to read, passed directly to
80
+ PyIceberg's load functions
81
+ snapshot_id: Optional snapshot ID for the Iceberg table
82
+ scan_kwargs: Optional arguments to pass to PyIceberg's Table.scan()
83
+ function
84
+ catalog_kwargs: Optional arguments to use when setting up the Iceberg
85
+ catalog
86
+ """
87
+ _check_import(self, module="pyiceberg", package="pyiceberg")
88
+ from pyiceberg.expressions import AlwaysTrue
89
+
90
+ self._scan_kwargs = scan_kwargs if scan_kwargs is not None else {}
91
+ self._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {}
92
+
93
+ if "name" in self._catalog_kwargs:
94
+ self._catalog_name = self._catalog_kwargs.pop("name")
95
+ else:
96
+ self._catalog_name = "default"
97
+
98
+ self.table_identifier = table_identifier
99
+
100
+ self._row_filter = row_filter if row_filter is not None else AlwaysTrue()
101
+ self._selected_fields = selected_fields
102
+
103
+ if snapshot_id:
104
+ self._scan_kwargs["snapshot_id"] = snapshot_id
105
+
106
+ self._plan_files = None
107
+ self._table = None
108
+
109
+ def _get_catalog(self) -> "Catalog":
110
+ from pyiceberg import catalog
111
+
112
+ return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs)
113
+
114
+ @property
115
+ def table(self) -> "Table":
116
+ """
117
+ Return the table reference from the catalog
118
+ """
119
+ if self._table is None:
120
+ catalog = self._get_catalog()
121
+ self._table = catalog.load_table(self.table_identifier)
122
+ return self._table
123
+
124
+ @property
125
+ def plan_files(self) -> List["FileScanTask"]:
126
+ """
127
+ Return the plan files specified by this query
128
+ """
129
+ # Calculate and cache the plan_files if they don't already exist
130
+ if self._plan_files is None:
131
+ data_scan = self._get_data_scan()
132
+ self._plan_files = data_scan.plan_files()
133
+
134
+ return self._plan_files
135
+
136
+ def _get_data_scan(self) -> "DataScan":
137
+
138
+ data_scan = self.table.scan(
139
+ row_filter=self._row_filter,
140
+ selected_fields=self._selected_fields,
141
+ **self._scan_kwargs,
142
+ )
143
+
144
+ return data_scan
145
+
146
+ def estimate_inmemory_data_size(self) -> Optional[int]:
147
+ # Approximate the size by using the plan files - this will not
148
+ # incorporate the deletes, but that's a reasonable approximation
149
+ # task
150
+ return sum(task.file.file_size_in_bytes for task in self.plan_files)
151
+
152
+ @staticmethod
153
+ def _distribute_tasks_into_equal_chunks(
154
+ plan_files: Iterable["FileScanTask"], n_chunks: int
155
+ ) -> List[List["FileScanTask"]]:
156
+ """
157
+ Implement a greedy knapsack algorithm to distribute the files in the scan
158
+ across tasks, based on their file size, as evenly as possible
159
+ """
160
+ chunks = [list() for _ in range(n_chunks)]
161
+
162
+ chunk_sizes = [(0, chunk_id) for chunk_id in range(n_chunks)]
163
+ heapq.heapify(chunk_sizes)
164
+
165
+ # From largest to smallest, add the plan files to the smallest chunk one at a
166
+ # time
167
+ for plan_file in sorted(
168
+ plan_files, key=lambda f: f.file.file_size_in_bytes, reverse=True
169
+ ):
170
+ smallest_chunk = heapq.heappop(chunk_sizes)
171
+ chunks[smallest_chunk[1]].append(plan_file)
172
+ heapq.heappush(
173
+ chunk_sizes,
174
+ (
175
+ smallest_chunk[0] + plan_file.file.file_size_in_bytes,
176
+ smallest_chunk[1],
177
+ ),
178
+ )
179
+
180
+ return chunks
181
+
182
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
183
+ from pyiceberg.io import pyarrow as pyi_pa_io
184
+ from pyiceberg.manifest import DataFileContent
185
+
186
+ # Get the PyIceberg scan
187
+ data_scan = self._get_data_scan()
188
+ # Get the plan files in this query
189
+ plan_files = self.plan_files
190
+
191
+ # Get the projected schema for this scan, given all the row filters,
192
+ # snapshot ID, etc.
193
+ projected_schema = data_scan.projection()
194
+ # Get the arrow schema, to set in the metadata
195
+ pya_schema = pyi_pa_io.schema_to_pyarrow(projected_schema)
196
+
197
+ # Set the n_chunks to the min of the number of plan files and the actual
198
+ # requested n_chunks, so that there are no empty tasks
199
+ if parallelism > len(list(plan_files)):
200
+ parallelism = len(list(plan_files))
201
+ logger.warning(
202
+ f"Reducing the parallelism to {parallelism}, as that is the"
203
+ "number of files"
204
+ )
205
+
206
+ # Get required properties for reading tasks - table IO, table metadata,
207
+ # row filter, case sensitivity,limit and projected schema to pass
208
+ # them directly to `_get_read_task` to avoid capture of `self` reference
209
+ # within the closure carrying substantial overhead invoking these tasks
210
+ #
211
+ # See https://github.com/ray-project/ray/issues/49107 for more context
212
+ table_io = self.table.io
213
+ table_metadata = self.table.metadata
214
+ row_filter = self._row_filter
215
+ case_sensitive = self._scan_kwargs.get("case_sensitive", True)
216
+ limit = self._scan_kwargs.get("limit")
217
+
218
+ get_read_task = partial(
219
+ _get_read_task,
220
+ table_io=table_io,
221
+ table_metadata=table_metadata,
222
+ row_filter=row_filter,
223
+ case_sensitive=case_sensitive,
224
+ limit=limit,
225
+ schema=projected_schema,
226
+ )
227
+
228
+ read_tasks = []
229
+ # Chunk the plan files based on the requested parallelism
230
+ for chunk_tasks in IcebergDatasource._distribute_tasks_into_equal_chunks(
231
+ plan_files, parallelism
232
+ ):
233
+ unique_deletes: Set[DataFile] = set(
234
+ itertools.chain.from_iterable(
235
+ [task.delete_files for task in chunk_tasks]
236
+ )
237
+ )
238
+ # Get a rough estimate of the number of deletes by just looking at
239
+ # position deletes. Equality deletes are harder to estimate, as they
240
+ # can delete multiple rows.
241
+ position_delete_count = sum(
242
+ delete.record_count
243
+ for delete in unique_deletes
244
+ if delete.content == DataFileContent.POSITION_DELETES
245
+ )
246
+ metadata = BlockMetadata(
247
+ num_rows=sum(task.file.record_count for task in chunk_tasks)
248
+ - position_delete_count,
249
+ size_bytes=sum(task.length for task in chunk_tasks),
250
+ schema=pya_schema,
251
+ input_files=[task.file.file_path for task in chunk_tasks],
252
+ exec_stats=None,
253
+ )
254
+ read_tasks.append(
255
+ ReadTask(
256
+ read_fn=lambda tasks=chunk_tasks: get_read_task(tasks),
257
+ metadata=metadata,
258
+ )
259
+ )
260
+
261
+ return read_tasks
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasink.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Any, Dict
3
+
4
+ import pyarrow
5
+
6
+ from ray.data.datasource.file_datasink import RowBasedFileDatasink
7
+
8
+
9
+ class ImageDatasink(RowBasedFileDatasink):
10
+ def __init__(
11
+ self, path: str, column: str, file_format: str, **file_datasink_kwargs
12
+ ):
13
+ super().__init__(path, file_format=file_format, **file_datasink_kwargs)
14
+
15
+ self.column = column
16
+ self.file_format = file_format
17
+
18
+ def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
19
+ from PIL import Image
20
+
21
+ image = Image.fromarray(row[self.column])
22
+ buffer = io.BytesIO()
23
+ image.save(buffer, format=self.file_format)
24
+ file.write(buffer.getvalue())
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasource.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import time
4
+ from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+
8
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
9
+ from ray.data._internal.util import _check_import
10
+ from ray.data.block import Block, BlockMetadata
11
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
12
+ from ray.data.datasource.file_meta_provider import DefaultFileMetadataProvider
13
+
14
+ if TYPE_CHECKING:
15
+ import pyarrow
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # The default size multiplier for reading image data source.
21
+ # This essentially is using image on-disk file size to estimate
22
+ # in-memory data size.
23
+ IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT = 1
24
+
25
+ # The lower bound value to estimate image encoding ratio.
26
+ IMAGE_ENCODING_RATIO_ESTIMATE_LOWER_BOUND = 0.5
27
+
28
+
29
+ class ImageDatasource(FileBasedDatasource):
30
+ """A datasource that lets you read images."""
31
+
32
+ _WRITE_FILE_PER_ROW = True
33
+ _FILE_EXTENSIONS = ["png", "jpg", "jpeg", "tif", "tiff", "bmp", "gif"]
34
+ # Use 8 threads per task to read image files.
35
+ _NUM_THREADS_PER_TASK = 8
36
+
37
+ def __init__(
38
+ self,
39
+ paths: Union[str, List[str]],
40
+ size: Optional[Tuple[int, int]] = None,
41
+ mode: Optional[str] = None,
42
+ **file_based_datasource_kwargs,
43
+ ):
44
+ super().__init__(paths, **file_based_datasource_kwargs)
45
+
46
+ _check_import(self, module="PIL", package="Pillow")
47
+
48
+ if size is not None and len(size) != 2:
49
+ raise ValueError(
50
+ "Expected `size` to contain two integers for height and width, "
51
+ f"but got {len(size)} integers instead."
52
+ )
53
+
54
+ if size is not None and (size[0] < 0 or size[1] < 0):
55
+ raise ValueError(
56
+ f"Expected `size` to contain positive integers, but got {size} instead."
57
+ )
58
+
59
+ self.size = size
60
+ self.mode = mode
61
+
62
+ meta_provider = file_based_datasource_kwargs.get("meta_provider", None)
63
+ if isinstance(meta_provider, ImageFileMetadataProvider):
64
+ self._encoding_ratio = self._estimate_files_encoding_ratio()
65
+ meta_provider._set_encoding_ratio(self._encoding_ratio)
66
+ else:
67
+ self._encoding_ratio = IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
68
+
69
+ def _read_stream(
70
+ self,
71
+ f: "pyarrow.NativeFile",
72
+ path: str,
73
+ ) -> Iterator[Block]:
74
+ from PIL import Image, UnidentifiedImageError
75
+
76
+ data = f.readall()
77
+
78
+ try:
79
+ image = Image.open(io.BytesIO(data))
80
+ except UnidentifiedImageError as e:
81
+ raise ValueError(f"PIL couldn't load image file at path '{path}'.") from e
82
+
83
+ if self.size is not None:
84
+ height, width = self.size
85
+ image = image.resize((width, height), resample=Image.BILINEAR)
86
+ if self.mode is not None:
87
+ image = image.convert(self.mode)
88
+
89
+ builder = DelegatingBlockBuilder()
90
+ array = np.array(image)
91
+ item = {"image": array}
92
+ builder.add(item)
93
+ block = builder.build()
94
+
95
+ yield block
96
+
97
+ def _rows_per_file(self):
98
+ return 1
99
+
100
+ def estimate_inmemory_data_size(self) -> Optional[int]:
101
+ total_size = 0
102
+ for file_size in self._file_sizes():
103
+ # NOTE: check if file size is not None, because some metadata provider
104
+ # such as FastFileMetadataProvider does not provide file size information.
105
+ if file_size is not None:
106
+ total_size += file_size
107
+ return total_size * self._encoding_ratio
108
+
109
+ def _estimate_files_encoding_ratio(self) -> float:
110
+ """Return an estimate of the image files encoding ratio."""
111
+ start_time = time.perf_counter()
112
+ # Filter out empty file to avoid noise.
113
+ non_empty_path_and_size = list(
114
+ filter(lambda p: p[1] > 0, zip(self._paths(), self._file_sizes()))
115
+ )
116
+ num_files = len(non_empty_path_and_size)
117
+ if num_files == 0:
118
+ logger.warn(
119
+ "All input image files are empty. "
120
+ "Use on-disk file size to estimate images in-memory size."
121
+ )
122
+ return IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
123
+
124
+ if self.size is not None and self.mode is not None:
125
+ # Use image size and mode to calculate data size for all images,
126
+ # because all images are homogeneous with same size after resizing.
127
+ # Resizing is enforced when reading every image in `ImageDatasource`
128
+ # when `size` argument is provided.
129
+ if self.mode in ["1", "L", "P"]:
130
+ dimension = 1
131
+ elif self.mode in ["RGB", "YCbCr", "LAB", "HSV"]:
132
+ dimension = 3
133
+ elif self.mode in ["RGBA", "CMYK", "I", "F"]:
134
+ dimension = 4
135
+ else:
136
+ logger.warn(f"Found unknown image mode: {self.mode}.")
137
+ return IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
138
+ height, width = self.size
139
+ single_image_size = height * width * dimension
140
+ total_estimated_size = single_image_size * num_files
141
+ total_file_size = sum(p[1] for p in non_empty_path_and_size)
142
+ ratio = total_estimated_size / total_file_size
143
+ else:
144
+ # TODO(chengsu): sample images to estimate data size
145
+ ratio = IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
146
+
147
+ sampling_duration = time.perf_counter() - start_time
148
+ if sampling_duration > 5:
149
+ logger.warn(
150
+ "Image input size estimation took "
151
+ f"{round(sampling_duration, 2)} seconds."
152
+ )
153
+ logger.debug(f"Estimated image encoding ratio from sampling is {ratio}.")
154
+ return max(ratio, IMAGE_ENCODING_RATIO_ESTIMATE_LOWER_BOUND)
155
+
156
+
157
+ class ImageFileMetadataProvider(DefaultFileMetadataProvider):
158
+ def _set_encoding_ratio(self, encoding_ratio: int):
159
+ """Set image file encoding ratio, to provide accurate size in bytes metadata."""
160
+ self._encoding_ratio = encoding_ratio
161
+
162
+ def _get_block_metadata(
163
+ self,
164
+ paths: List[str],
165
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]],
166
+ *,
167
+ rows_per_file: Optional[int],
168
+ file_sizes: List[Optional[int]],
169
+ ) -> BlockMetadata:
170
+ metadata = super()._get_block_metadata(
171
+ paths, schema, rows_per_file=rows_per_file, file_sizes=file_sizes
172
+ )
173
+ if metadata.size_bytes is not None:
174
+ metadata.size_bytes = int(metadata.size_bytes * self._encoding_ratio)
175
+ return metadata
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasink.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, Optional
2
+
3
+ import pyarrow
4
+
5
+ from ray.data.block import BlockAccessor
6
+ from ray.data.datasource.file_based_datasource import _resolve_kwargs
7
+ from ray.data.datasource.file_datasink import BlockBasedFileDatasink
8
+
9
+
10
+ class JSONDatasink(BlockBasedFileDatasink):
11
+ def __init__(
12
+ self,
13
+ path: str,
14
+ *,
15
+ pandas_json_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
16
+ pandas_json_args: Optional[Dict[str, Any]] = None,
17
+ file_format: str = "json",
18
+ **file_datasink_kwargs,
19
+ ):
20
+ super().__init__(path, file_format=file_format, **file_datasink_kwargs)
21
+
22
+ if pandas_json_args_fn is None:
23
+ pandas_json_args_fn = lambda: {} # noqa: E731
24
+
25
+ if pandas_json_args is None:
26
+ pandas_json_args = {}
27
+
28
+ self.pandas_json_args_fn = pandas_json_args_fn
29
+ self.pandas_json_args = pandas_json_args
30
+
31
+ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
32
+ writer_args = _resolve_kwargs(self.pandas_json_args_fn, **self.pandas_json_args)
33
+ orient = writer_args.pop("orient", "records")
34
+ lines = writer_args.pop("lines", True)
35
+
36
+ block.to_pandas().to_json(file, orient=orient, lines=lines, **writer_args)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasource.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from io import BytesIO
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
4
+
5
+ from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
6
+ from ray.data.context import DataContext
7
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
8
+
9
+ if TYPE_CHECKING:
10
+ import pyarrow
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class JSONDatasource(FileBasedDatasource):
16
+ """JSON datasource, for reading and writing JSON and JSONL files."""
17
+
18
+ _FILE_EXTENSIONS = [
19
+ "json",
20
+ "jsonl",
21
+ # gzip-compressed files
22
+ "json.gz",
23
+ "jsonl.gz",
24
+ # Brotli-compressed fi;es
25
+ "json.br",
26
+ "jsonl.br",
27
+ # Zstandard-compressed files
28
+ "json.zst",
29
+ "jsonl.zst",
30
+ # lz4-compressed files
31
+ "json.lz4",
32
+ "jsonl.lz4",
33
+ ]
34
+
35
+ def __init__(
36
+ self,
37
+ paths: Union[str, List[str]],
38
+ *,
39
+ arrow_json_args: Optional[Dict[str, Any]] = None,
40
+ **file_based_datasource_kwargs,
41
+ ):
42
+ from pyarrow import json
43
+
44
+ super().__init__(paths, **file_based_datasource_kwargs)
45
+
46
+ if arrow_json_args is None:
47
+ arrow_json_args = {}
48
+
49
+ self.read_options = arrow_json_args.pop(
50
+ "read_options", json.ReadOptions(use_threads=False)
51
+ )
52
+ self.arrow_json_args = arrow_json_args
53
+
54
+ def _read_with_pyarrow_read_json(self, buffer: "pyarrow.lib.Buffer"):
55
+ """Read with PyArrow JSON reader, trying to auto-increase the
56
+ read block size in the case of the read object
57
+ straddling block boundaries."""
58
+ import pyarrow as pa
59
+
60
+ # When reading large files, the default block size configured in PyArrow can be
61
+ # too small, resulting in the following error: `pyarrow.lib.ArrowInvalid:
62
+ # straddling object straddles two block boundaries (try to increase block
63
+ # size?)`. More information on this issue can be found here:
64
+ # https://github.com/apache/arrow/issues/25674
65
+ # The read will be retried with geometrically increasing block size
66
+ # until the size reaches `DataContext.get_current().target_max_block_size`.
67
+ # The initial block size will start at the PyArrow default block size
68
+ # or it can be manually set through the `read_options` parameter as follows.
69
+ # >>> import pyarrow.json as pajson
70
+ # >>> block_size = 10 << 20 # Set block size to 10MB
71
+ # >>> ray.data.read_json( # doctest: +SKIP
72
+ # ... "s3://anonymous@ray-example-data/log.json",
73
+ # ... read_options=pajson.ReadOptions(block_size=block_size)
74
+ # ... )
75
+
76
+ init_block_size = self.read_options.block_size
77
+ max_block_size = DataContext.get_current().target_max_block_size
78
+ while True:
79
+ try:
80
+ yield pa.json.read_json(
81
+ BytesIO(buffer),
82
+ read_options=self.read_options,
83
+ **self.arrow_json_args,
84
+ )
85
+ self.read_options.block_size = init_block_size
86
+ break
87
+ except pa.ArrowInvalid as e:
88
+ if "straddling object straddles two block boundaries" in str(e):
89
+ if self.read_options.block_size < max_block_size:
90
+ # Increase the block size in case it was too small.
91
+ logger.debug(
92
+ f"JSONDatasource read failed with "
93
+ f"block_size={self.read_options.block_size}. Retrying with "
94
+ f"block_size={self.read_options.block_size * 2}."
95
+ )
96
+ self.read_options.block_size *= 2
97
+ else:
98
+ raise pa.ArrowInvalid(
99
+ f"{e} - Auto-increasing block size to "
100
+ f"{self.read_options.block_size} bytes failed. "
101
+ f"Please try manually increasing the block size through "
102
+ f"the `read_options` parameter to a larger size. "
103
+ f"For example: `read_json(..., read_options="
104
+ f"pyarrow.json.ReadOptions(block_size=10 << 25))`"
105
+ f"More information on this issue can be found here: "
106
+ f"https://github.com/apache/arrow/issues/25674"
107
+ )
108
+ else:
109
+ # unrelated error, simply reraise
110
+ raise e
111
+
112
+ def _read_with_python_json(self, buffer: "pyarrow.lib.Buffer"):
113
+ """Fallback method to read JSON files with Python's native json.load(),
114
+ in case the default pyarrow json reader fails."""
115
+ import json
116
+
117
+ import pyarrow as pa
118
+
119
+ # Check if the buffer is empty
120
+ if buffer.size == 0:
121
+ return
122
+
123
+ parsed_json = json.load(BytesIO(buffer))
124
+ try:
125
+ yield pa.Table.from_pylist(parsed_json)
126
+ except AttributeError as e:
127
+ # For PyArrow < 7.0.0, `pa.Table.from_pylist()` is not available.
128
+ # Construct a dict from the list and call
129
+ # `pa.Table.from_pydict()` instead.
130
+ assert "no attribute 'from_pylist'" in str(e), str(e)
131
+ from collections import defaultdict
132
+
133
+ dct = defaultdict(list)
134
+ for row in parsed_json:
135
+ for k, v in row.items():
136
+ dct[k].append(v)
137
+ yield pyarrow_table_from_pydict(dct)
138
+
139
+ # TODO(ekl) The PyArrow JSON reader doesn't support streaming reads.
140
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str):
141
+ import pyarrow as pa
142
+
143
+ buffer: pa.lib.Buffer = f.read_buffer()
144
+
145
+ try:
146
+ yield from self._read_with_pyarrow_read_json(buffer)
147
+ except pa.ArrowInvalid as e:
148
+ # If read with PyArrow fails, try falling back to native json.load().
149
+ logger.warning(
150
+ f"Error reading with pyarrow.json.read_json(). "
151
+ f"Falling back to native json.load(), which may be slower. "
152
+ f"PyArrow error was:\n{e}"
153
+ )
154
+ yield from self._read_with_python_json(buffer)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/lance_datasource.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
3
+
4
+ import numpy as np
5
+
6
+ from ray.data._internal.util import _check_import, call_with_retry
7
+ from ray.data.block import BlockMetadata
8
+ from ray.data.context import DataContext
9
+ from ray.data.datasource.datasource import Datasource, ReadTask
10
+
11
+ if TYPE_CHECKING:
12
+ import pyarrow
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class LanceDatasource(Datasource):
19
+ """Lance datasource, for reading Lance dataset."""
20
+
21
+ # Errors to retry when reading Lance fragments.
22
+ READ_FRAGMENTS_ERRORS_TO_RETRY = ["LanceError(IO)"]
23
+ # Maximum number of attempts to read Lance fragments.
24
+ READ_FRAGMENTS_MAX_ATTEMPTS = 10
25
+ # Maximum backoff seconds between attempts to read Lance fragments.
26
+ READ_FRAGMENTS_RETRY_MAX_BACKOFF_SECONDS = 32
27
+
28
+ def __init__(
29
+ self,
30
+ uri: str,
31
+ columns: Optional[List[str]] = None,
32
+ filter: Optional[str] = None,
33
+ storage_options: Optional[Dict[str, str]] = None,
34
+ scanner_options: Optional[Dict[str, Any]] = None,
35
+ ):
36
+ _check_import(self, module="lance", package="pylance")
37
+
38
+ import lance
39
+
40
+ self.uri = uri
41
+ self.scanner_options = scanner_options or {}
42
+ if columns is not None:
43
+ self.scanner_options["columns"] = columns
44
+ if filter is not None:
45
+ self.scanner_options["filter"] = filter
46
+ self.storage_options = storage_options
47
+ self.lance_ds = lance.dataset(uri=uri, storage_options=storage_options)
48
+
49
+ match = []
50
+ match.extend(self.READ_FRAGMENTS_ERRORS_TO_RETRY)
51
+ match.extend(DataContext.get_current().retried_io_errors)
52
+ self._retry_params = {
53
+ "description": "read lance fragments",
54
+ "match": match,
55
+ "max_attempts": self.READ_FRAGMENTS_MAX_ATTEMPTS,
56
+ "max_backoff_s": self.READ_FRAGMENTS_RETRY_MAX_BACKOFF_SECONDS,
57
+ }
58
+
59
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
60
+ read_tasks = []
61
+ for fragments in np.array_split(self.lance_ds.get_fragments(), parallelism):
62
+ if len(fragments) <= 0:
63
+ continue
64
+
65
+ fragment_ids = [f.metadata.id for f in fragments]
66
+ num_rows = sum(f.count_rows() for f in fragments)
67
+ input_files = [
68
+ data_file.path() for f in fragments for data_file in f.data_files()
69
+ ]
70
+
71
+ # TODO(chengsu): Take column projection into consideration for schema.
72
+ metadata = BlockMetadata(
73
+ num_rows=num_rows,
74
+ schema=fragments[0].schema,
75
+ input_files=input_files,
76
+ size_bytes=None,
77
+ exec_stats=None,
78
+ )
79
+ scanner_options = self.scanner_options
80
+ lance_ds = self.lance_ds
81
+ retry_params = self._retry_params
82
+
83
+ read_task = ReadTask(
84
+ lambda f=fragment_ids: _read_fragments_with_retry(
85
+ f,
86
+ lance_ds,
87
+ scanner_options,
88
+ retry_params,
89
+ ),
90
+ metadata,
91
+ )
92
+ read_tasks.append(read_task)
93
+
94
+ return read_tasks
95
+
96
+ def estimate_inmemory_data_size(self) -> Optional[int]:
97
+ # TODO(chengsu): Add memory size estimation to improve auto-tune of parallelism.
98
+ return None
99
+
100
+
101
+ def _read_fragments_with_retry(
102
+ fragment_ids,
103
+ lance_ds,
104
+ scanner_options,
105
+ retry_params,
106
+ ) -> Iterator["pyarrow.Table"]:
107
+ return call_with_retry(
108
+ lambda: _read_fragments(fragment_ids, lance_ds, scanner_options),
109
+ **retry_params,
110
+ )
111
+
112
+
113
+ def _read_fragments(
114
+ fragment_ids,
115
+ lance_ds,
116
+ scanner_options,
117
+ ) -> Iterator["pyarrow.Table"]:
118
+ """Read Lance fragments in batches.
119
+
120
+ NOTE: Use fragment ids, instead of fragments as parameter, because pickling
121
+ LanceFragment is expensive.
122
+ """
123
+ import pyarrow
124
+
125
+ fragments = [lance_ds.get_fragment(id) for id in fragment_ids]
126
+ scanner_options["fragments"] = fragments
127
+ scanner = lance_ds.scanner(**scanner_options)
128
+ for batch in scanner.to_reader():
129
+ yield pyarrow.Table.from_batches([batch])
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasink.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Iterable
3
+
4
+ from ray.data._internal.datasource.mongo_datasource import (
5
+ _validate_database_collection_exist,
6
+ )
7
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
8
+ from ray.data._internal.execution.interfaces import TaskContext
9
+ from ray.data._internal.util import _check_import
10
+ from ray.data.block import Block, BlockAccessor
11
+ from ray.data.datasource.datasink import Datasink
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class MongoDatasink(Datasink[None]):
17
+ def __init__(self, uri: str, database: str, collection: str) -> None:
18
+ _check_import(self, module="pymongo", package="pymongo")
19
+ _check_import(self, module="pymongoarrow", package="pymongoarrow")
20
+
21
+ self.uri = uri
22
+ self.database = database
23
+ self.collection = collection
24
+
25
+ def write(
26
+ self,
27
+ blocks: Iterable[Block],
28
+ ctx: TaskContext,
29
+ ) -> None:
30
+ import pymongo
31
+
32
+ _validate_database_collection_exist(
33
+ pymongo.MongoClient(self.uri), self.database, self.collection
34
+ )
35
+
36
+ def write_block(uri: str, database: str, collection: str, block: Block):
37
+ from pymongoarrow.api import write
38
+
39
+ block = BlockAccessor.for_block(block).to_arrow()
40
+ client = pymongo.MongoClient(uri)
41
+ write(client[database][collection], block)
42
+
43
+ builder = DelegatingBlockBuilder()
44
+ for block in blocks:
45
+ builder.add_block(block)
46
+ block = builder.build()
47
+
48
+ write_block(self.uri, self.database, self.collection, block)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasource.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import TYPE_CHECKING, Dict, List, Optional
3
+
4
+ from ray.data.block import Block, BlockMetadata
5
+ from ray.data.datasource.datasource import Datasource, ReadTask
6
+
7
+ if TYPE_CHECKING:
8
+ import pymongoarrow.api
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class MongoDatasource(Datasource):
14
+ """Datasource for reading from and writing to MongoDB."""
15
+
16
+ def __init__(
17
+ self,
18
+ uri: str,
19
+ database: str,
20
+ collection: str,
21
+ pipeline: Optional[List[Dict]] = None,
22
+ schema: Optional["pymongoarrow.api.Schema"] = None,
23
+ **mongo_args,
24
+ ):
25
+ self._uri = uri
26
+ self._database = database
27
+ self._collection = collection
28
+ self._pipeline = pipeline
29
+ self._schema = schema
30
+ self._mongo_args = mongo_args
31
+ # If pipeline is unspecified, read the entire collection.
32
+ if not pipeline:
33
+ self._pipeline = [{"$match": {"_id": {"$exists": "true"}}}]
34
+ # Initialize Mongo client lazily later when creating read tasks.
35
+ self._client = None
36
+
37
+ def estimate_inmemory_data_size(self) -> Optional[int]:
38
+ # TODO(jian): Add memory size estimation to improve auto-tune of parallelism.
39
+ return None
40
+
41
+ def _get_match_query(self, pipeline: List[Dict]) -> Dict:
42
+ if len(pipeline) == 0 or "$match" not in pipeline[0]:
43
+ return {}
44
+ return pipeline[0]["$match"]
45
+
46
+ def _get_or_create_client(self):
47
+ import pymongo
48
+
49
+ if self._client is None:
50
+ self._client = pymongo.MongoClient(self._uri)
51
+ _validate_database_collection_exist(
52
+ self._client, self._database, self._collection
53
+ )
54
+ self._avg_obj_size = self._client[self._database].command(
55
+ "collstats", self._collection
56
+ )["avgObjSize"]
57
+
58
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
59
+ from bson.objectid import ObjectId
60
+
61
+ self._get_or_create_client()
62
+ coll = self._client[self._database][self._collection]
63
+ match_query = self._get_match_query(self._pipeline)
64
+ partitions_ids = list(
65
+ coll.aggregate(
66
+ [
67
+ {"$match": match_query},
68
+ {"$bucketAuto": {"groupBy": "$_id", "buckets": parallelism}},
69
+ ],
70
+ allowDiskUse=True,
71
+ )
72
+ )
73
+
74
+ def make_block(
75
+ uri: str,
76
+ database: str,
77
+ collection: str,
78
+ pipeline: List[Dict],
79
+ min_id: ObjectId,
80
+ max_id: ObjectId,
81
+ right_closed: bool,
82
+ schema: "pymongoarrow.api.Schema",
83
+ kwargs: dict,
84
+ ) -> Block:
85
+ import pymongo
86
+ from pymongoarrow.api import aggregate_arrow_all
87
+
88
+ # A range query over the partition.
89
+ match = [
90
+ {
91
+ "$match": {
92
+ "_id": {
93
+ "$gte": min_id,
94
+ "$lte" if right_closed else "$lt": max_id,
95
+ }
96
+ }
97
+ }
98
+ ]
99
+ client = pymongo.MongoClient(uri)
100
+ return aggregate_arrow_all(
101
+ client[database][collection], match + pipeline, schema=schema, **kwargs
102
+ )
103
+
104
+ read_tasks: List[ReadTask] = []
105
+
106
+ for i, partition in enumerate(partitions_ids):
107
+ metadata = BlockMetadata(
108
+ num_rows=partition["count"],
109
+ size_bytes=partition["count"] * self._avg_obj_size,
110
+ schema=None,
111
+ input_files=None,
112
+ exec_stats=None,
113
+ )
114
+ make_block_args = (
115
+ self._uri,
116
+ self._database,
117
+ self._collection,
118
+ self._pipeline,
119
+ partition["_id"]["min"],
120
+ partition["_id"]["max"],
121
+ i == len(partitions_ids) - 1,
122
+ self._schema,
123
+ self._mongo_args,
124
+ )
125
+ read_task = ReadTask(
126
+ lambda args=make_block_args: [make_block(*args)],
127
+ metadata,
128
+ )
129
+ read_tasks.append(read_task)
130
+
131
+ return read_tasks
132
+
133
+
134
+ def _validate_database_collection_exist(client, database: str, collection: str):
135
+ db_names = client.list_database_names()
136
+ if database not in db_names:
137
+ raise ValueError(f"The destination database {database} doesn't exist.")
138
+ collection_names = client[database].list_collection_names()
139
+ if collection not in collection_names:
140
+ raise ValueError(f"The destination collection {collection} doesn't exist.")
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasink.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pyarrow
3
+
4
+ from ray.data.block import BlockAccessor
5
+ from ray.data.datasource.file_datasink import BlockBasedFileDatasink
6
+
7
+
8
+ class NumpyDatasink(BlockBasedFileDatasink):
9
+ def __init__(
10
+ self,
11
+ path: str,
12
+ column: str,
13
+ *,
14
+ file_format: str = "npy",
15
+ **file_datasink_kwargs,
16
+ ):
17
+ super().__init__(path, file_format=file_format, **file_datasink_kwargs)
18
+
19
+ self.column = column
20
+
21
+ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
22
+ value = block.to_numpy(self.column)
23
+ np.save(file, value)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasource.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
3
+
4
+ import numpy as np
5
+
6
+ from ray.data.block import Block, BlockAccessor
7
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
8
+
9
+ if TYPE_CHECKING:
10
+ import pyarrow
11
+
12
+
13
+ class NumpyDatasource(FileBasedDatasource):
14
+ """Numpy datasource, for reading and writing Numpy files."""
15
+
16
+ _COLUMN_NAME = "data"
17
+ _FILE_EXTENSIONS = ["npy"]
18
+
19
+ def __init__(
20
+ self,
21
+ paths: Union[str, List[str]],
22
+ numpy_load_args: Optional[Dict[str, Any]] = None,
23
+ **file_based_datasource_kwargs,
24
+ ):
25
+ super().__init__(paths, **file_based_datasource_kwargs)
26
+
27
+ if numpy_load_args is None:
28
+ numpy_load_args = {}
29
+
30
+ self.numpy_load_args = numpy_load_args
31
+
32
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
33
+ # TODO(ekl) Ideally numpy can read directly from the file, but it
34
+ # seems like it requires the file to be seekable.
35
+ buf = BytesIO()
36
+ data = f.readall()
37
+ buf.write(data)
38
+ buf.seek(0)
39
+ yield BlockAccessor.batch_to_block(
40
+ {"data": np.load(buf, allow_pickle=True, **self.numpy_load_args)}
41
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_bulk_datasource.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
3
+
4
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
5
+
6
+ if TYPE_CHECKING:
7
+ import pyarrow
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ParquetBulkDatasource(FileBasedDatasource):
14
+ """Minimal Parquet datasource, for reading and writing Parquet files."""
15
+
16
+ _FILE_EXTENSIONS = ["parquet"]
17
+
18
+ def __init__(
19
+ self,
20
+ paths: Union[str, List[str]],
21
+ read_table_args: Optional[Dict[str, Any]] = None,
22
+ **file_based_datasource_kwargs,
23
+ ):
24
+ super().__init__(paths, **file_based_datasource_kwargs)
25
+
26
+ if read_table_args is None:
27
+ read_table_args = {}
28
+
29
+ self.read_table_args = read_table_args
30
+
31
+ def get_name(self):
32
+ """Return a human-readable name for this datasource.
33
+ This will be used as the names of the read tasks.
34
+ Note: overrides the base `FileBasedDatasource` method.
35
+ """
36
+ return "ParquetBulk"
37
+
38
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str):
39
+ import pyarrow.parquet as pq
40
+
41
+ use_threads = self.read_table_args.pop("use_threads", False)
42
+ yield pq.read_table(f, use_threads=use_threads, **self.read_table_args)
43
+
44
+ def _open_input_source(
45
+ self,
46
+ filesystem: "pyarrow.fs.FileSystem",
47
+ path: str,
48
+ **open_args,
49
+ ) -> "pyarrow.NativeFile":
50
+ # Parquet requires `open_input_file` due to random access reads
51
+ return filesystem.open_input_file(path, **open_args)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasink.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import posixpath
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional
4
+
5
+ from ray.data._internal.arrow_ops.transform_pyarrow import concat
6
+ from ray.data._internal.execution.interfaces import TaskContext
7
+ from ray.data._internal.util import call_with_retry
8
+ from ray.data.block import Block, BlockAccessor
9
+ from ray.data.context import DataContext
10
+ from ray.data.datasource.file_based_datasource import _resolve_kwargs
11
+ from ray.data.datasource.file_datasink import _FileDatasink
12
+ from ray.data.datasource.filename_provider import FilenameProvider
13
+
14
+ if TYPE_CHECKING:
15
+ import pyarrow
16
+
17
+ WRITE_FILE_MAX_ATTEMPTS = 10
18
+ WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ParquetDatasink(_FileDatasink):
24
+ def __init__(
25
+ self,
26
+ path: str,
27
+ *,
28
+ partition_cols: Optional[List[str]] = None,
29
+ arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
30
+ arrow_parquet_args: Optional[Dict[str, Any]] = None,
31
+ min_rows_per_file: Optional[int] = None,
32
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
33
+ try_create_dir: bool = True,
34
+ open_stream_args: Optional[Dict[str, Any]] = None,
35
+ filename_provider: Optional[FilenameProvider] = None,
36
+ dataset_uuid: Optional[str] = None,
37
+ ):
38
+ if arrow_parquet_args_fn is None:
39
+ arrow_parquet_args_fn = lambda: {} # noqa: E731
40
+
41
+ if arrow_parquet_args is None:
42
+ arrow_parquet_args = {}
43
+
44
+ self.arrow_parquet_args_fn = arrow_parquet_args_fn
45
+ self.arrow_parquet_args = arrow_parquet_args
46
+ self.min_rows_per_file = min_rows_per_file
47
+ self.partition_cols = partition_cols
48
+
49
+ super().__init__(
50
+ path,
51
+ filesystem=filesystem,
52
+ try_create_dir=try_create_dir,
53
+ open_stream_args=open_stream_args,
54
+ filename_provider=filename_provider,
55
+ dataset_uuid=dataset_uuid,
56
+ file_format="parquet",
57
+ )
58
+
59
+ def write(
60
+ self,
61
+ blocks: Iterable[Block],
62
+ ctx: TaskContext,
63
+ ) -> None:
64
+ import pyarrow as pa
65
+
66
+ blocks = list(blocks)
67
+
68
+ if all(BlockAccessor.for_block(block).num_rows() == 0 for block in blocks):
69
+ return
70
+
71
+ filename = self.filename_provider.get_filename_for_block(
72
+ blocks[0], ctx.task_idx, 0
73
+ )
74
+ write_kwargs = _resolve_kwargs(
75
+ self.arrow_parquet_args_fn, **self.arrow_parquet_args
76
+ )
77
+ user_schema = write_kwargs.pop("schema", None)
78
+
79
+ def write_blocks_to_path():
80
+ tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
81
+ if user_schema is None:
82
+ output_schema = pa.unify_schemas([table.schema for table in tables])
83
+ else:
84
+ output_schema = user_schema
85
+
86
+ if not self.partition_cols:
87
+ self._write_single_file(tables, filename, output_schema, write_kwargs)
88
+ else: # partition writes
89
+ self._write_partition_files(
90
+ tables, filename, output_schema, write_kwargs
91
+ )
92
+
93
+ logger.debug(f"Writing {filename} file to {self.path}.")
94
+
95
+ call_with_retry(
96
+ write_blocks_to_path,
97
+ description=f"write '{filename}' to '{self.path}'",
98
+ match=DataContext.get_current().retried_io_errors,
99
+ max_attempts=WRITE_FILE_MAX_ATTEMPTS,
100
+ max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
101
+ )
102
+
103
+ def _write_single_file(
104
+ self,
105
+ tables: List["pyarrow.Table"],
106
+ filename: str,
107
+ output_schema: "pyarrow.Schema",
108
+ write_kwargs: Dict[str, Any],
109
+ ) -> None:
110
+ import pyarrow.parquet as pq
111
+
112
+ write_path = posixpath.join(self.path, filename)
113
+ with self.open_output_stream(write_path) as file:
114
+ with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
115
+ for table in tables:
116
+ table = table.cast(output_schema)
117
+ writer.write_table(table)
118
+
119
+ def _write_partition_files(
120
+ self,
121
+ tables: List["pyarrow.Table"],
122
+ filename: str,
123
+ output_schema: "pyarrow.Schema",
124
+ write_kwargs: Dict[str, Any],
125
+ ) -> None:
126
+ import pyarrow as pa
127
+ import pyarrow.parquet as pq
128
+
129
+ table = concat(tables)
130
+ # Create unique combinations of the partition columns
131
+ table_fields = [
132
+ field for field in output_schema if field.name not in self.partition_cols
133
+ ]
134
+ non_partition_cols = [f.name for f in table_fields]
135
+ output_schema = pa.schema(
136
+ [field for field in output_schema if field.name not in self.partition_cols]
137
+ )
138
+ # Group the table by partition keys
139
+ # For each partition key combination fetch list of values
140
+ # for the non partition columns
141
+ # Ex: Here original table contain
142
+ # two columns (a, b). We are paritioning by column a. The schema
143
+ # of `groups` grouped Table is as follows
144
+ # b_list: [[[0,0],[1,1],[2,2]]]
145
+ # a: [[1,2,3]]
146
+ groups = table.group_by(self.partition_cols).aggregate(
147
+ [(col_name, "list") for col_name in non_partition_cols]
148
+ )
149
+ grouped_keys = [groups.column(k) for k in self.partition_cols]
150
+
151
+ for i in range(groups.num_rows):
152
+ # See https://github.com/apache/arrow/issues/14882 for recommended approach
153
+ values = [
154
+ groups.column(f"{col.name}_list")[i].values for col in table_fields
155
+ ]
156
+ group_table = pa.Table.from_arrays(values, names=non_partition_cols)
157
+ partition_path = "/".join(
158
+ [
159
+ f"{col}={values[i]}"
160
+ for col, values in zip(self.partition_cols, grouped_keys)
161
+ ]
162
+ )
163
+ write_path = posixpath.join(self.path, partition_path)
164
+ self._create_dir(write_path)
165
+ write_path = posixpath.join(write_path, filename)
166
+ with self.open_output_stream(write_path) as file:
167
+ with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
168
+ writer.write_table(group_table)
169
+
170
+ @property
171
+ def min_rows_per_write(self) -> Optional[int]:
172
+ return self.min_rows_per_file
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasource.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ Iterator,
9
+ List,
10
+ Literal,
11
+ Optional,
12
+ Union,
13
+ )
14
+
15
+ import numpy as np
16
+ from packaging.version import parse as parse_version
17
+
18
+ import ray
19
+ import ray.cloudpickle as cloudpickle
20
+ from ray._private.utils import _get_pyarrow_version
21
+ from ray.data._internal.progress_bar import ProgressBar
22
+ from ray.data._internal.remote_fn import cached_remote_fn
23
+ from ray.data._internal.util import (
24
+ _check_pyarrow_version,
25
+ _is_local_scheme,
26
+ call_with_retry,
27
+ iterate_with_retry,
28
+ )
29
+ from ray.data.block import Block
30
+ from ray.data.context import DataContext
31
+ from ray.data.datasource import Datasource
32
+ from ray.data.datasource.datasource import ReadTask
33
+ from ray.data.datasource.file_based_datasource import FileShuffleConfig
34
+ from ray.data.datasource.file_meta_provider import (
35
+ DefaultFileMetadataProvider,
36
+ _handle_read_os_error,
37
+ )
38
+ from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider
39
+ from ray.data.datasource.partitioning import (
40
+ PartitionDataType,
41
+ Partitioning,
42
+ PathPartitionFilter,
43
+ PathPartitionParser,
44
+ )
45
+ from ray.data.datasource.path_util import (
46
+ _has_file_extension,
47
+ _resolve_paths_and_filesystem,
48
+ )
49
+
50
+ if TYPE_CHECKING:
51
+ import pyarrow
52
+ from pyarrow.dataset import ParquetFileFragment
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+ # The `num_cpus` for each metadata prefetching task.
58
+ # Default to 0.5 instead of 1 because it is cheaper than normal read task.
59
+ NUM_CPUS_FOR_META_FETCH_TASK = 0.5
60
+
61
+ # The number of rows to read per batch. This is sized to generate 10MiB batches
62
+ # for rows about 1KiB in size.
63
+ PARQUET_READER_ROW_BATCH_SIZE = 10_000
64
+ FILE_READING_RETRY = 8
65
+
66
+ # The default size multiplier for reading Parquet data source in Arrow.
67
+ # Parquet data format is encoded with various encoding techniques (such as
68
+ # dictionary, RLE, delta), so Arrow in-memory representation uses much more memory
69
+ # compared to Parquet encoded representation. Parquet file statistics only record
70
+ # encoded (i.e. uncompressed) data size information.
71
+ #
72
+ # To estimate real-time in-memory data size, Datasets will try to estimate the
73
+ # correct inflation ratio from Parquet to Arrow, using this constant as the default
74
+ # value for safety. See https://github.com/ray-project/ray/pull/26516 for more context.
75
+ PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT = 5
76
+
77
+ # The lower bound size to estimate Parquet encoding ratio.
78
+ PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND = 1
79
+
80
+ # The percentage of files (1% by default) to be sampled from the dataset to estimate
81
+ # Parquet encoding ratio.
82
+ PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO = 0.01
83
+
84
+ # The minimal and maximal number of file samples to take from the dataset to estimate
85
+ # Parquet encoding ratio.
86
+ # This is to restrict `PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO` within the
87
+ # proper boundary.
88
+ PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES = 2
89
+ PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES = 10
90
+
91
+ # The number of rows to read from each file for sampling. Try to keep it low to avoid
92
+ # reading too much data into memory.
93
+ PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024
94
+
95
+
96
+ @dataclass(frozen=True)
97
+ class _SampleInfo:
98
+ actual_bytes_per_row: Optional[int]
99
+ estimated_bytes_per_row: Optional[int]
100
+
101
+
102
+ # TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a
103
+ # raw pyarrow file fragment causes S3 network calls.
104
+ class SerializedFragment:
105
+ def __init__(self, frag: "ParquetFileFragment"):
106
+ self._data = cloudpickle.dumps(
107
+ (frag.format, frag.path, frag.filesystem, frag.partition_expression)
108
+ )
109
+
110
+ def deserialize(self) -> "ParquetFileFragment":
111
+ # Implicitly trigger S3 subsystem initialization by importing
112
+ # pyarrow.fs.
113
+ import pyarrow.fs # noqa: F401
114
+
115
+ (file_format, path, filesystem, partition_expression) = cloudpickle.loads(
116
+ self._data
117
+ )
118
+ return file_format.make_fragment(path, filesystem, partition_expression)
119
+
120
+
121
+ # Visible for test mocking.
122
+ def _deserialize_fragments(
123
+ serialized_fragments: List[SerializedFragment],
124
+ ) -> List["pyarrow._dataset.ParquetFileFragment"]:
125
+ return [p.deserialize() for p in serialized_fragments]
126
+
127
+
128
+ def check_for_legacy_tensor_type(schema):
129
+ """Check for the legacy tensor extension type and raise an error if found.
130
+
131
+ Ray Data uses an extension type to represent tensors in Arrow tables. Previously,
132
+ the extension type extended `PyExtensionType`. However, this base type can expose
133
+ users to arbitrary code execution. To prevent this, we don't load the type by
134
+ default.
135
+ """
136
+ import pyarrow as pa
137
+
138
+ for name, type in zip(schema.names, schema.types):
139
+ if isinstance(type, pa.UnknownExtensionType) and isinstance(
140
+ type, pa.PyExtensionType
141
+ ):
142
+ raise RuntimeError(
143
+ f"Ray Data couldn't infer the type of column '{name}'. This might mean "
144
+ "you're trying to read data written with an older version of Ray. "
145
+ "Reading data written with older versions of Ray might expose you to "
146
+ "arbitrary code execution. To try reading the data anyway, set "
147
+ "`RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE=1` on *all* nodes."
148
+ "To learn more, see https://github.com/ray-project/ray/issues/41314."
149
+ )
150
+
151
+
152
+ class ParquetDatasource(Datasource):
153
+ """Parquet datasource, for reading and writing Parquet files.
154
+
155
+ The primary difference from ParquetBulkDatasource is that this uses
156
+ PyArrow's `ParquetDataset` abstraction for dataset reads, and thus offers
157
+ automatic Arrow dataset schema inference and row count collection at the
158
+ cost of some potential performance and/or compatibility penalties.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ paths: Union[str, List[str]],
164
+ *,
165
+ columns: Optional[List[str]] = None,
166
+ dataset_kwargs: Optional[Dict[str, Any]] = None,
167
+ to_batch_kwargs: Optional[Dict[str, Any]] = None,
168
+ _block_udf: Optional[Callable[[Block], Block]] = None,
169
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
170
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
171
+ meta_provider: ParquetMetadataProvider = ParquetMetadataProvider(),
172
+ partition_filter: PathPartitionFilter = None,
173
+ partitioning: Optional[Partitioning] = Partitioning("hive"),
174
+ shuffle: Union[Literal["files"], None] = None,
175
+ include_paths: bool = False,
176
+ file_extensions: Optional[List[str]] = None,
177
+ ):
178
+ _check_pyarrow_version()
179
+
180
+ import pyarrow as pa
181
+
182
+ self._supports_distributed_reads = not _is_local_scheme(paths)
183
+ if not self._supports_distributed_reads and ray.util.client.ray.is_connected():
184
+ raise ValueError(
185
+ "Because you're using Ray Client, read tasks scheduled on the Ray "
186
+ "cluster can't access your local files. To fix this issue, store "
187
+ "files in cloud storage or a distributed filesystem like NFS."
188
+ )
189
+
190
+ self._local_scheduling = None
191
+ if not self._supports_distributed_reads:
192
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
193
+
194
+ self._local_scheduling = NodeAffinitySchedulingStrategy(
195
+ ray.get_runtime_context().get_node_id(), soft=False
196
+ )
197
+
198
+ paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
199
+
200
+ # HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet
201
+ # files. To avoid this, we expand the input paths with the default metadata
202
+ # provider and then apply the partition filter or file extensions.
203
+ if partition_filter is not None or file_extensions is not None:
204
+ default_meta_provider = DefaultFileMetadataProvider()
205
+ expanded_paths, _ = map(
206
+ list, zip(*default_meta_provider.expand_paths(paths, filesystem))
207
+ )
208
+
209
+ paths = list(expanded_paths)
210
+ if partition_filter is not None:
211
+ paths = partition_filter(paths)
212
+ if file_extensions is not None:
213
+ paths = [
214
+ path for path in paths if _has_file_extension(path, file_extensions)
215
+ ]
216
+
217
+ filtered_paths = set(expanded_paths) - set(paths)
218
+ if filtered_paths:
219
+ logger.info(f"Filtered out {len(filtered_paths)} paths")
220
+
221
+ if dataset_kwargs is None:
222
+ dataset_kwargs = {}
223
+
224
+ if "partitioning" in dataset_kwargs:
225
+ raise ValueError(
226
+ "The 'partitioning' parameter isn't supported in 'dataset_kwargs'. "
227
+ "Use the top-level 'partitioning' parameter instead."
228
+ )
229
+
230
+ # This datasource manually adds partition data at the Ray Data-level. To avoid
231
+ # duplicating the partition data, we disable PyArrow's partitioning.
232
+ dataset_kwargs["partitioning"] = None
233
+
234
+ # `read_schema` is the schema object that will be used to perform
235
+ # read operations.
236
+ # It should be None, unless user has specified the schema or columns.
237
+ # We don't use the inferred schema for read, because the pyarrow only infers
238
+ # schema based on the first file. Thus, files with different schemas will end
239
+ # up producing blocks with wrong schema.
240
+ # See https://github.com/ray-project/ray/issues/47960 for more context.
241
+ read_schema = schema
242
+ pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs)
243
+
244
+ if schema is None:
245
+ schema = pq_ds.schema
246
+ schema = _add_partition_fields_to_schema(partitioning, schema, pq_ds)
247
+
248
+ if columns:
249
+ schema = pa.schema(
250
+ [schema.field(column) for column in columns], schema.metadata
251
+ )
252
+ read_schema = schema
253
+
254
+ check_for_legacy_tensor_type(schema)
255
+
256
+ if _block_udf is not None:
257
+ # Try to infer dataset schema by passing dummy table through UDF.
258
+ dummy_table = schema.empty_table()
259
+ try:
260
+ schema = _block_udf(dummy_table).schema.with_metadata(schema.metadata)
261
+ except Exception:
262
+ logger.debug(
263
+ "Failed to infer schema of dataset by passing dummy table "
264
+ "through UDF due to the following exception:",
265
+ exc_info=True,
266
+ )
267
+
268
+ try:
269
+ prefetch_remote_args = {}
270
+ prefetch_remote_args["num_cpus"] = NUM_CPUS_FOR_META_FETCH_TASK
271
+ if self._local_scheduling:
272
+ prefetch_remote_args["scheduling_strategy"] = self._local_scheduling
273
+ else:
274
+ # Use the scheduling strategy ("SPREAD" by default) provided in
275
+ # `DataContext``, to spread out prefetch tasks in cluster, avoid
276
+ # AWS S3 throttling error.
277
+ # Note: this is the same scheduling strategy used by read tasks.
278
+ prefetch_remote_args[
279
+ "scheduling_strategy"
280
+ ] = DataContext.get_current().scheduling_strategy
281
+
282
+ self._metadata = (
283
+ meta_provider.prefetch_file_metadata(
284
+ pq_ds.fragments, **prefetch_remote_args
285
+ )
286
+ or []
287
+ )
288
+ except OSError as e:
289
+ _handle_read_os_error(e, paths)
290
+
291
+ if to_batch_kwargs is None:
292
+ to_batch_kwargs = {}
293
+
294
+ # NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected
295
+ # network calls when `_ParquetDatasourceReader` is serialized. See
296
+ # `_SerializedFragment()` implementation for more details.
297
+ self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments]
298
+ self._pq_paths = [p.path for p in pq_ds.fragments]
299
+ self._meta_provider = meta_provider
300
+ self._block_udf = _block_udf
301
+ self._to_batches_kwargs = to_batch_kwargs
302
+ self._columns = columns
303
+ self._read_schema = read_schema
304
+ self._schema = schema
305
+ self._file_metadata_shuffler = None
306
+ self._include_paths = include_paths
307
+ self._partitioning = partitioning
308
+ if shuffle == "files":
309
+ self._file_metadata_shuffler = np.random.default_rng()
310
+ elif isinstance(shuffle, FileShuffleConfig):
311
+ self._file_metadata_shuffler = np.random.default_rng(shuffle.seed)
312
+
313
+ sample_infos = sample_fragments(
314
+ self._pq_fragments,
315
+ to_batches_kwargs=to_batch_kwargs,
316
+ columns=columns,
317
+ schema=self._read_schema,
318
+ local_scheduling=self._local_scheduling,
319
+ )
320
+ self._encoding_ratio = estimate_files_encoding_ratio(sample_infos)
321
+ self._default_read_batch_size_rows = estimate_default_read_batch_size_rows(
322
+ sample_infos
323
+ )
324
+
325
+ def estimate_inmemory_data_size(self) -> Optional[int]:
326
+ total_size = 0
327
+ for file_metadata in self._metadata:
328
+ total_size += file_metadata.total_byte_size
329
+ return total_size * self._encoding_ratio
330
+
331
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
332
+ # NOTE: We override the base class FileBasedDatasource.get_read_tasks()
333
+ # method in order to leverage pyarrow's ParquetDataset abstraction,
334
+ # which simplifies partitioning logic. We still use
335
+ # FileBasedDatasource's write side, however.
336
+ pq_metadata = self._metadata
337
+ if len(pq_metadata) < len(self._pq_fragments):
338
+ # Pad `pq_metadata` to be same length of `self._pq_fragments`.
339
+ # This can happen when no file metadata being prefetched.
340
+ pq_metadata += [None] * (len(self._pq_fragments) - len(pq_metadata))
341
+
342
+ if self._file_metadata_shuffler is not None:
343
+ files_metadata = list(zip(self._pq_fragments, self._pq_paths, pq_metadata))
344
+ shuffled_files_metadata = [
345
+ files_metadata[i]
346
+ for i in self._file_metadata_shuffler.permutation(len(files_metadata))
347
+ ]
348
+ pq_fragments, pq_paths, pq_metadata = list(
349
+ map(list, zip(*shuffled_files_metadata))
350
+ )
351
+ else:
352
+ pq_fragments, pq_paths, pq_metadata = (
353
+ self._pq_fragments,
354
+ self._pq_paths,
355
+ pq_metadata,
356
+ )
357
+
358
+ read_tasks = []
359
+ for fragments, paths, metadata in zip(
360
+ np.array_split(pq_fragments, parallelism),
361
+ np.array_split(pq_paths, parallelism),
362
+ np.array_split(pq_metadata, parallelism),
363
+ ):
364
+ if len(fragments) <= 0:
365
+ continue
366
+
367
+ meta = self._meta_provider(
368
+ paths,
369
+ self._schema,
370
+ num_fragments=len(fragments),
371
+ prefetched_metadata=metadata,
372
+ )
373
+ # If there is a filter operation, reset the calculated row count,
374
+ # since the resulting row count is unknown.
375
+ if self._to_batches_kwargs.get("filter") is not None:
376
+ meta.num_rows = None
377
+
378
+ if meta.size_bytes is not None:
379
+ meta.size_bytes = int(meta.size_bytes * self._encoding_ratio)
380
+
381
+ (
382
+ block_udf,
383
+ to_batches_kwargs,
384
+ default_read_batch_size_rows,
385
+ columns,
386
+ read_schema,
387
+ include_paths,
388
+ partitioning,
389
+ ) = (
390
+ self._block_udf,
391
+ self._to_batches_kwargs,
392
+ self._default_read_batch_size_rows,
393
+ self._columns,
394
+ self._read_schema,
395
+ self._include_paths,
396
+ self._partitioning,
397
+ )
398
+ read_tasks.append(
399
+ ReadTask(
400
+ lambda f=fragments: read_fragments(
401
+ block_udf,
402
+ to_batches_kwargs,
403
+ default_read_batch_size_rows,
404
+ columns,
405
+ read_schema,
406
+ f,
407
+ include_paths,
408
+ partitioning,
409
+ ),
410
+ meta,
411
+ )
412
+ )
413
+
414
+ return read_tasks
415
+
416
+ def get_name(self):
417
+ """Return a human-readable name for this datasource.
418
+
419
+ This will be used as the names of the read tasks.
420
+ """
421
+ return "Parquet"
422
+
423
+ @property
424
+ def supports_distributed_reads(self) -> bool:
425
+ return self._supports_distributed_reads
426
+
427
+
428
+ def read_fragments(
429
+ block_udf,
430
+ to_batches_kwargs,
431
+ default_read_batch_size_rows,
432
+ columns,
433
+ schema,
434
+ serialized_fragments: List[SerializedFragment],
435
+ include_paths: bool,
436
+ partitioning: Partitioning,
437
+ ) -> Iterator["pyarrow.Table"]:
438
+ # This import is necessary to load the tensor extension type.
439
+ from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
440
+
441
+ # Deserialize after loading the filesystem class.
442
+ fragments: List[
443
+ "pyarrow._dataset.ParquetFileFragment"
444
+ ] = _deserialize_fragments_with_retry(serialized_fragments)
445
+
446
+ # Ensure that we're reading at least one dataset fragment.
447
+ assert len(fragments) > 0
448
+
449
+ import pyarrow as pa
450
+
451
+ logger.debug(f"Reading {len(fragments)} parquet fragments")
452
+ use_threads = to_batches_kwargs.pop("use_threads", False)
453
+ batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows)
454
+ for fragment in fragments:
455
+ partitions = {}
456
+ if partitioning is not None:
457
+ parse = PathPartitionParser(partitioning)
458
+ partitions = parse(fragment.path)
459
+
460
+ # Filter out partitions that aren't in the user-specified columns list.
461
+ if columns is not None:
462
+ partitions = {
463
+ field_name: value
464
+ for field_name, value in partitions.items()
465
+ if field_name in columns
466
+ }
467
+
468
+ def get_batch_iterable():
469
+ return fragment.to_batches(
470
+ use_threads=use_threads,
471
+ columns=columns,
472
+ schema=schema,
473
+ batch_size=batch_size,
474
+ **to_batches_kwargs,
475
+ )
476
+
477
+ # S3 can raise transient errors during iteration, and PyArrow doesn't expose a
478
+ # way to retry specific batches.
479
+ ctx = ray.data.DataContext.get_current()
480
+ for batch in iterate_with_retry(
481
+ get_batch_iterable, "load batch", match=ctx.retried_io_errors
482
+ ):
483
+ table = pa.Table.from_batches([batch], schema=schema)
484
+ if include_paths:
485
+ table = table.append_column("path", [[fragment.path]] * len(table))
486
+ if partitions:
487
+ table = _add_partitions_to_table(partitions, table)
488
+
489
+ # If the table is empty, drop it.
490
+ if table.num_rows > 0:
491
+ if block_udf is not None:
492
+ yield block_udf(table)
493
+ else:
494
+ yield table
495
+
496
+
497
+ def _deserialize_fragments_with_retry(fragments):
498
+ # The deserialization retry helps when the upstream datasource is not able to
499
+ # handle overloaded read request or failed with some retriable failures.
500
+ # For example when reading data from HA hdfs service, hdfs might
501
+ # lose connection for some unknown reason expecially when
502
+ # simutaneously running many hyper parameter tuning jobs
503
+ # with ray.data parallelism setting at high value like the default 200
504
+ # Such connection failure can be restored with some waiting and retry.
505
+ return call_with_retry(
506
+ lambda: _deserialize_fragments(fragments),
507
+ description="deserialize fragments",
508
+ max_attempts=FILE_READING_RETRY,
509
+ )
510
+
511
+
512
+ def _sample_fragment(
513
+ to_batches_kwargs,
514
+ columns,
515
+ schema,
516
+ file_fragment: SerializedFragment,
517
+ ) -> _SampleInfo:
518
+ # Sample the first rows batch from file fragment `serialized_fragment`.
519
+ fragment = _deserialize_fragments_with_retry([file_fragment])[0]
520
+
521
+ # Only sample the first row group.
522
+ fragment = fragment.subset(row_group_ids=[0])
523
+ batch_size = max(
524
+ min(fragment.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1
525
+ )
526
+ # Use the batch_size calculated above, and ignore the one specified by user if set.
527
+ # This is to avoid sampling too few or too many rows.
528
+ to_batches_kwargs.pop("batch_size", None)
529
+ batches = fragment.to_batches(
530
+ columns=columns,
531
+ schema=schema,
532
+ batch_size=batch_size,
533
+ **to_batches_kwargs,
534
+ )
535
+ # Use first batch in-memory size for estimation.
536
+ try:
537
+ batch = next(batches)
538
+ except StopIteration:
539
+ sample_data = _SampleInfo(
540
+ actual_bytes_per_row=None, estimated_bytes_per_row=None
541
+ )
542
+ else:
543
+ if batch.num_rows > 0:
544
+ metadata = fragment.metadata
545
+ total_size = 0
546
+ for idx in range(metadata.num_row_groups):
547
+ total_size += metadata.row_group(idx).total_byte_size
548
+ sample_data = _SampleInfo(
549
+ actual_bytes_per_row=batch.nbytes / batch.num_rows,
550
+ estimated_bytes_per_row=total_size / metadata.num_rows,
551
+ )
552
+ else:
553
+ sample_data = _SampleInfo(
554
+ actual_bytes_per_row=None, estimated_bytes_per_row=None
555
+ )
556
+ return sample_data
557
+
558
+
559
+ def estimate_files_encoding_ratio(sample_infos: List[_SampleInfo]) -> float:
560
+ """Return an estimate of the Parquet files encoding ratio.
561
+
562
+ To avoid OOMs, it is safer to return an over-estimate than an underestimate.
563
+ """
564
+ if not DataContext.get_current().decoding_size_estimation:
565
+ return PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT
566
+
567
+ def compute_encoding_ratio(sample_info: _SampleInfo) -> float:
568
+ if (
569
+ sample_info.actual_bytes_per_row is None
570
+ or sample_info.estimated_bytes_per_row is None
571
+ ):
572
+ return PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
573
+ else:
574
+ return (
575
+ sample_info.actual_bytes_per_row / sample_info.estimated_bytes_per_row
576
+ )
577
+
578
+ ratio = np.mean(list(map(compute_encoding_ratio, sample_infos)))
579
+ logger.debug(f"Estimated Parquet encoding ratio from sampling is {ratio}.")
580
+ return max(ratio, PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND)
581
+
582
+
583
+ def estimate_default_read_batch_size_rows(sample_infos: List[_SampleInfo]) -> int:
584
+ def compute_batch_size_rows(sample_info: _SampleInfo) -> int:
585
+ # 'actual_bytes_per_row' is None if the sampled file was empty and 0 if the data
586
+ # was all null.
587
+ if not sample_info.actual_bytes_per_row:
588
+ return PARQUET_READER_ROW_BATCH_SIZE
589
+ else:
590
+ max_parquet_reader_row_batch_size_bytes = (
591
+ DataContext.get_current().target_max_block_size // 10
592
+ )
593
+ return max(
594
+ 1,
595
+ min(
596
+ PARQUET_READER_ROW_BATCH_SIZE,
597
+ max_parquet_reader_row_batch_size_bytes
598
+ // sample_info.actual_bytes_per_row,
599
+ ),
600
+ )
601
+
602
+ return np.mean(list(map(compute_batch_size_rows, sample_infos)))
603
+
604
+
605
+ def get_parquet_dataset(paths, filesystem, dataset_kwargs):
606
+ import pyarrow.parquet as pq
607
+
608
+ # If you pass a list containing a single directory path to `ParquetDataset`, PyArrow
609
+ # errors with 'IsADirectoryError: Path ... points to a directory, but only file
610
+ # paths are supported'. To avoid this, we pass the directory path directly.
611
+ if len(paths) == 1:
612
+ paths = paths[0]
613
+
614
+ try:
615
+ # The `use_legacy_dataset` parameter is deprecated in Arrow 15.
616
+ if parse_version(_get_pyarrow_version()) >= parse_version("15.0.0"):
617
+ dataset = pq.ParquetDataset(
618
+ paths,
619
+ **dataset_kwargs,
620
+ filesystem=filesystem,
621
+ )
622
+ else:
623
+ dataset = pq.ParquetDataset(
624
+ paths,
625
+ **dataset_kwargs,
626
+ filesystem=filesystem,
627
+ use_legacy_dataset=False,
628
+ )
629
+ except OSError as e:
630
+ _handle_read_os_error(e, paths)
631
+
632
+ return dataset
633
+
634
+
635
+ def sample_fragments(
636
+ serialized_fragments,
637
+ *,
638
+ to_batches_kwargs,
639
+ columns,
640
+ schema,
641
+ local_scheduling=None,
642
+ ) -> List[_SampleInfo]:
643
+ # Sample a few rows from Parquet files to estimate the encoding ratio.
644
+ # Launch tasks to sample multiple files remotely in parallel.
645
+ # Evenly distributed to sample N rows in i-th row group in i-th file.
646
+ # TODO(ekl/cheng) take into account column pruning.
647
+ num_files = len(serialized_fragments)
648
+ num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO)
649
+ min_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files)
650
+ max_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES, num_files)
651
+ num_samples = max(min(num_samples, max_num_samples), min_num_samples)
652
+
653
+ # Evenly distributed to choose which file to sample, to avoid biased prediction
654
+ # if data is skewed.
655
+ file_samples = [
656
+ serialized_fragments[idx]
657
+ for idx in np.linspace(0, num_files - 1, num_samples).astype(int).tolist()
658
+ ]
659
+
660
+ sample_fragment = cached_remote_fn(_sample_fragment)
661
+ futures = []
662
+ scheduling = local_scheduling or DataContext.get_current().scheduling_strategy
663
+ for sample in file_samples:
664
+ # Sample the first rows batch in i-th file.
665
+ # Use SPREAD scheduling strategy to avoid packing many sampling tasks on
666
+ # same machine to cause OOM issue, as sampling can be memory-intensive.
667
+ futures.append(
668
+ sample_fragment.options(
669
+ scheduling_strategy=scheduling,
670
+ # Retry in case of transient errors during sampling.
671
+ retry_exceptions=[OSError],
672
+ ).remote(
673
+ to_batches_kwargs,
674
+ columns,
675
+ schema,
676
+ sample,
677
+ )
678
+ )
679
+ sample_bar = ProgressBar("Parquet Files Sample", len(futures), unit="file")
680
+ sample_infos = sample_bar.fetch_until_complete(futures)
681
+ sample_bar.close()
682
+
683
+ return sample_infos
684
+
685
+
686
+ def _add_partitions_to_table(
687
+ partitions: Dict[str, PartitionDataType], table: "pyarrow.Table"
688
+ ) -> "pyarrow.Table":
689
+ import pyarrow as pa
690
+
691
+ for field_name, value in partitions.items():
692
+ column = pa.array([value] * len(table))
693
+ field_index = table.schema.get_field_index(field_name)
694
+ if field_index != -1:
695
+ table = table.set_column(field_index, field_name, column)
696
+ else:
697
+ table = table.append_column(field_name, column)
698
+
699
+ return table
700
+
701
+
702
+ def _add_partition_fields_to_schema(
703
+ partitioning: Partitioning,
704
+ schema: "pyarrow.Schema",
705
+ parquet_dataset: "pyarrow.dataset.Dataset",
706
+ ) -> "pyarrow.Schema":
707
+ """Return a new schema with partition fields added.
708
+
709
+ This function infers the partition fields from the first file path in the dataset.
710
+ """
711
+ import pyarrow as pa
712
+
713
+ # If the dataset is empty, we can't infer the partitioning.
714
+ if len(parquet_dataset.fragments) == 0:
715
+ return schema
716
+
717
+ # If the dataset isn't partitioned, we don't need to add any fields.
718
+ if partitioning is None:
719
+ return schema
720
+
721
+ first_path = parquet_dataset.fragments[0].path
722
+ parse = PathPartitionParser(partitioning)
723
+ partitions = parse(first_path)
724
+ for field_name in partitions:
725
+ if field_name in partitioning.field_types:
726
+ field_type = pa.from_numpy_dtype(partitioning.field_types[field_name])
727
+ else:
728
+ field_type = pa.string()
729
+ schema = schema.append(pa.field(field_name, field_type))
730
+
731
+ return schema
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/range_datasource.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import builtins
2
+ import functools
3
+ from copy import copy
4
+ from typing import Iterable, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+
8
+ from ray.data._internal.util import _check_pyarrow_version
9
+ from ray.data.block import Block, BlockAccessor, BlockMetadata
10
+ from ray.data.context import DataContext
11
+ from ray.data.datasource import Datasource, ReadTask
12
+
13
+
14
+ class RangeDatasource(Datasource):
15
+ """An example datasource that generates ranges of numbers from [0..n)."""
16
+
17
+ def __init__(
18
+ self,
19
+ n: int,
20
+ block_format: str = "arrow",
21
+ tensor_shape: Tuple = (1,),
22
+ column_name: Optional[str] = None,
23
+ ):
24
+ self._n = int(n)
25
+ self._block_format = block_format
26
+ self._tensor_shape = tensor_shape
27
+ self._column_name = column_name
28
+
29
+ def estimate_inmemory_data_size(self) -> Optional[int]:
30
+ if self._block_format == "tensor":
31
+ element_size = int(np.prod(self._tensor_shape))
32
+ else:
33
+ element_size = 1
34
+ return 8 * self._n * element_size
35
+
36
+ def get_read_tasks(
37
+ self,
38
+ parallelism: int,
39
+ ) -> List[ReadTask]:
40
+ read_tasks: List[ReadTask] = []
41
+ n = self._n
42
+ block_format = self._block_format
43
+ tensor_shape = self._tensor_shape
44
+ block_size = max(1, n // parallelism)
45
+ # TODO(swang): This target block size may not match the driver's
46
+ # context if it was overridden. Set target max block size during
47
+ # optimizer stage to fix this.
48
+ ctx = DataContext.get_current()
49
+ if self._n == 0:
50
+ target_rows_per_block = 0
51
+ else:
52
+ row_size_bytes = self.estimate_inmemory_data_size() // self._n
53
+ row_size_bytes = max(row_size_bytes, 1)
54
+ target_rows_per_block = max(1, ctx.target_max_block_size // row_size_bytes)
55
+
56
+ # Example of a read task. In a real datasource, this would pull data
57
+ # from an external system instead of generating dummy data.
58
+ def make_block(start: int, count: int) -> Block:
59
+ if block_format == "arrow":
60
+ import pyarrow as pa
61
+
62
+ return pa.Table.from_arrays(
63
+ [np.arange(start, start + count)],
64
+ names=[self._column_name or "value"],
65
+ )
66
+ elif block_format == "tensor":
67
+ import pyarrow as pa
68
+
69
+ tensor = np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
70
+ np.arange(start, start + count),
71
+ tuple(range(1, 1 + len(tensor_shape))),
72
+ )
73
+ return BlockAccessor.batch_to_block(
74
+ {self._column_name: tensor} if self._column_name else tensor
75
+ )
76
+ else:
77
+ return list(builtins.range(start, start + count))
78
+
79
+ def make_blocks(
80
+ start: int, count: int, target_rows_per_block: int
81
+ ) -> Iterable[Block]:
82
+ while count > 0:
83
+ num_rows = min(count, target_rows_per_block)
84
+ yield make_block(start, num_rows)
85
+ start += num_rows
86
+ count -= num_rows
87
+
88
+ if block_format == "tensor":
89
+ element_size = int(np.prod(tensor_shape))
90
+ else:
91
+ element_size = 1
92
+
93
+ i = 0
94
+ while i < n:
95
+ count = min(block_size, n - i)
96
+ meta = BlockMetadata(
97
+ num_rows=count,
98
+ size_bytes=8 * count * element_size,
99
+ schema=copy(self._schema),
100
+ input_files=None,
101
+ exec_stats=None,
102
+ )
103
+ read_tasks.append(
104
+ ReadTask(
105
+ lambda i=i, count=count: make_blocks(
106
+ i, count, target_rows_per_block
107
+ ),
108
+ meta,
109
+ )
110
+ )
111
+ i += block_size
112
+
113
+ return read_tasks
114
+
115
+ @functools.cached_property
116
+ def _schema(self):
117
+ if self._n == 0:
118
+ return None
119
+
120
+ if self._block_format == "arrow":
121
+ _check_pyarrow_version()
122
+ import pyarrow as pa
123
+
124
+ schema = pa.Table.from_pydict({self._column_name or "value": [0]}).schema
125
+ elif self._block_format == "tensor":
126
+ _check_pyarrow_version()
127
+ import pyarrow as pa
128
+
129
+ tensor = np.ones(self._tensor_shape, dtype=np.int64) * np.expand_dims(
130
+ np.arange(0, 10), tuple(range(1, 1 + len(self._tensor_shape)))
131
+ )
132
+ schema = BlockAccessor.batch_to_block(
133
+ {self._column_name: tensor} if self._column_name else tensor
134
+ ).schema
135
+ elif self._block_format == "list":
136
+ schema = int
137
+ else:
138
+ raise ValueError("Unsupported block type", self._block_format)
139
+ return schema
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/sql_datasink.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Iterable
2
+
3
+ from ray.data._internal.datasource.sql_datasource import Connection, _connect
4
+ from ray.data._internal.execution.interfaces import TaskContext
5
+ from ray.data.block import Block, BlockAccessor
6
+ from ray.data.datasource.datasink import Datasink
7
+
8
+
9
+ class SQLDatasink(Datasink[None]):
10
+
11
+ _MAX_ROWS_PER_WRITE = 128
12
+
13
+ def __init__(self, sql: str, connection_factory: Callable[[], Connection]):
14
+ self.sql = sql
15
+ self.connection_factory = connection_factory
16
+
17
+ def write(
18
+ self,
19
+ blocks: Iterable[Block],
20
+ ctx: TaskContext,
21
+ ) -> None:
22
+ with _connect(self.connection_factory) as cursor:
23
+ for block in blocks:
24
+ block_accessor = BlockAccessor.for_block(block)
25
+
26
+ values = []
27
+ for row in block_accessor.iter_rows(public_row_format=False):
28
+ values.append(tuple(row.values()))
29
+ assert len(values) <= self._MAX_ROWS_PER_WRITE, len(values)
30
+ if len(values) == self._MAX_ROWS_PER_WRITE:
31
+ cursor.executemany(self.sql, values)
32
+ values = []
33
+
34
+ if values:
35
+ cursor.executemany(self.sql, values)
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/text_datasource.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Iterator, List
2
+
3
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
4
+ from ray.data.block import Block
5
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
6
+
7
+ if TYPE_CHECKING:
8
+ import pyarrow
9
+
10
+
11
+ class TextDatasource(FileBasedDatasource):
12
+ """Text datasource, for reading and writing text files."""
13
+
14
+ _COLUMN_NAME = "text"
15
+
16
+ def __init__(
17
+ self,
18
+ paths: List[str],
19
+ *,
20
+ drop_empty_lines: bool = False,
21
+ encoding: str = "utf-8",
22
+ **file_based_datasource_kwargs
23
+ ):
24
+ super().__init__(paths, **file_based_datasource_kwargs)
25
+
26
+ self.drop_empty_lines = drop_empty_lines
27
+ self.encoding = encoding
28
+
29
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
30
+ data = f.readall()
31
+
32
+ builder = DelegatingBlockBuilder()
33
+
34
+ lines = data.decode(self.encoding).split("\n")
35
+ for line in lines:
36
+ if self.drop_empty_lines and line.strip() == "":
37
+ continue
38
+ item = {self._COLUMN_NAME: line}
39
+ builder.add(item)
40
+
41
+ block = builder.build()
42
+ yield block
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasink.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import struct
2
+ from typing import TYPE_CHECKING, Dict, Iterable, Optional, Union
3
+
4
+ import numpy as np
5
+
6
+ from .tfrecords_datasource import _get_single_true_type
7
+ from ray.data._internal.util import _check_import
8
+ from ray.data.block import BlockAccessor
9
+ from ray.data.datasource.file_datasink import BlockBasedFileDatasink
10
+
11
+ if TYPE_CHECKING:
12
+ import pyarrow
13
+ import tensorflow as tf
14
+ from tensorflow_metadata.proto.v0 import schema_pb2
15
+
16
+
17
+ class TFRecordDatasink(BlockBasedFileDatasink):
18
+ def __init__(
19
+ self,
20
+ path: str,
21
+ *,
22
+ tf_schema: Optional["schema_pb2.Schema"] = None,
23
+ file_format: str = "tar",
24
+ **file_datasink_kwargs,
25
+ ):
26
+ super().__init__(path, file_format=file_format, **file_datasink_kwargs)
27
+
28
+ _check_import(self, module="crc32c", package="crc32c")
29
+
30
+ self.tf_schema = tf_schema
31
+
32
+ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
33
+ arrow_table = block.to_arrow()
34
+
35
+ # It seems like TFRecords are typically row-based,
36
+ # https://www.tensorflow.org/tutorials/load_data/tfrecord#writing_a_tfrecord_file_2
37
+ # so we must iterate through the rows of the block,
38
+ # serialize to tf.train.Example proto, and write to file.
39
+
40
+ examples = _convert_arrow_table_to_examples(arrow_table, self.tf_schema)
41
+
42
+ # Write each example to the arrow file in the TFRecord format.
43
+ for example in examples:
44
+ _write_record(file, example)
45
+
46
+
47
+ def _convert_arrow_table_to_examples(
48
+ arrow_table: "pyarrow.Table",
49
+ tf_schema: Optional["schema_pb2.Schema"] = None,
50
+ ) -> Iterable["tf.train.Example"]:
51
+ import tensorflow as tf
52
+
53
+ schema_dict = {}
54
+ # Convert user-specified schema into dict for convenient mapping
55
+ if tf_schema is not None:
56
+ for schema_feature in tf_schema.feature:
57
+ schema_dict[schema_feature.name] = schema_feature.type
58
+
59
+ # Serialize each row[i] of the block to a tf.train.Example and yield it.
60
+ for i in range(arrow_table.num_rows):
61
+ # First, convert row[i] to a dictionary.
62
+ features: Dict[str, "tf.train.Feature"] = {}
63
+ for name in arrow_table.column_names:
64
+ if tf_schema is not None and name not in schema_dict:
65
+ raise ValueError(
66
+ f"Found extra unexpected feature {name} "
67
+ f"not in specified schema: {tf_schema}"
68
+ )
69
+ schema_feature_type = schema_dict.get(name)
70
+ features[name] = _value_to_feature(
71
+ arrow_table[name][i],
72
+ schema_feature_type,
73
+ )
74
+
75
+ # Convert the dictionary to an Example proto.
76
+ proto = tf.train.Example(features=tf.train.Features(feature=features))
77
+
78
+ yield proto
79
+
80
+
81
+ def _value_to_feature(
82
+ value: Union["pyarrow.Scalar", "pyarrow.Array"],
83
+ schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
84
+ ) -> "tf.train.Feature":
85
+ import pyarrow as pa
86
+ import tensorflow as tf
87
+
88
+ if isinstance(value, pa.ListScalar):
89
+ # Use the underlying type of the ListScalar's value in
90
+ # determining the output feature's data type.
91
+ value_type = value.type.value_type
92
+ value = value.as_py()
93
+ else:
94
+ value_type = value.type
95
+ value = value.as_py()
96
+ if value is None:
97
+ value = []
98
+ else:
99
+ value = [value]
100
+
101
+ underlying_value_type = {
102
+ "bytes": pa.types.is_binary(value_type),
103
+ "string": pa.types.is_string(value_type),
104
+ "float": pa.types.is_floating(value_type),
105
+ "int": pa.types.is_integer(value_type),
106
+ }
107
+ assert sum(bool(value) for value in underlying_value_type.values()) <= 1
108
+
109
+ if schema_feature_type is not None:
110
+ try:
111
+ from tensorflow_metadata.proto.v0 import schema_pb2
112
+ except ModuleNotFoundError:
113
+ raise ModuleNotFoundError(
114
+ "To use TensorFlow schemas, please install "
115
+ "the tensorflow-metadata package."
116
+ )
117
+ specified_feature_type = {
118
+ "bytes": schema_feature_type == schema_pb2.FeatureType.BYTES
119
+ and not underlying_value_type["string"],
120
+ "string": schema_feature_type == schema_pb2.FeatureType.BYTES
121
+ and underlying_value_type["string"],
122
+ "float": schema_feature_type == schema_pb2.FeatureType.FLOAT,
123
+ "int": schema_feature_type == schema_pb2.FeatureType.INT,
124
+ }
125
+
126
+ und_type = _get_single_true_type(underlying_value_type)
127
+ spec_type = _get_single_true_type(specified_feature_type)
128
+ if und_type is not None and und_type != spec_type:
129
+ raise ValueError(
130
+ "Schema field type mismatch during write: specified type is "
131
+ f"{spec_type}, but underlying type is {und_type}",
132
+ )
133
+ # Override the underlying value type with the type in the user-specified schema.
134
+ underlying_value_type = specified_feature_type
135
+
136
+ if underlying_value_type["int"]:
137
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
138
+ if underlying_value_type["float"]:
139
+ return tf.train.Feature(float_list=tf.train.FloatList(value=value))
140
+ if underlying_value_type["bytes"]:
141
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
142
+ if underlying_value_type["string"]:
143
+ value = [v.encode() for v in value] # casting to bytes
144
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
145
+ if pa.types.is_null(value_type):
146
+ raise ValueError(
147
+ "Unable to infer type from partially missing column. "
148
+ "Try setting read parallelism = 1, or use an input data source which "
149
+ "explicitly specifies the schema."
150
+ )
151
+ raise ValueError(
152
+ f"Value is of type {value_type}, "
153
+ "which we cannot convert to a supported tf.train.Feature storage type "
154
+ "(bytes, float, or int)."
155
+ )
156
+
157
+
158
+ # Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501
159
+ #
160
+ # MIT License
161
+ #
162
+ # Copyright (c) 2020 Vahid Kazemi
163
+ #
164
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
165
+ # of this software and associated documentation files (the "Software"), to deal
166
+ # in the Software without restriction, including without limitation the rights
167
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
168
+ # copies of the Software, and to permit persons to whom the Software is
169
+ # furnished to do so, subject to the following conditions:
170
+ #
171
+ # The above copyright notice and this permission notice shall be included in all
172
+ # copies or substantial portions of the Software.
173
+ #
174
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
175
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
176
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
177
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
178
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
179
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
180
+ # SOFTWARE.
181
+
182
+
183
+ def _write_record(
184
+ file: "pyarrow.NativeFile",
185
+ example: "tf.train.Example",
186
+ ) -> None:
187
+ record = example.SerializeToString()
188
+ length = len(record)
189
+ length_bytes = struct.pack("<Q", length)
190
+ file.write(length_bytes)
191
+ file.write(_masked_crc(length_bytes))
192
+ file.write(record)
193
+ file.write(_masked_crc(record))
194
+
195
+
196
+ def _masked_crc(data: bytes) -> bytes:
197
+ """CRC checksum."""
198
+ import crc32c
199
+
200
+ mask = 0xA282EAD8
201
+ crc = crc32c.crc32(data)
202
+ masked = ((crc >> 15) | (crc << 17)) + mask
203
+ masked = np.uint32(masked & np.iinfo(np.uint32).max)
204
+ masked_bytes = struct.pack("<I", masked)
205
+ return masked_bytes
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasource.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union
5
+
6
+ import pyarrow
7
+
8
+ from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
9
+ from ray.data.aggregate import AggregateFn
10
+ from ray.data.block import Block
11
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
12
+ from ray.util.annotations import PublicAPI
13
+
14
+ if TYPE_CHECKING:
15
+ import pandas as pd
16
+ import tensorflow as tf
17
+ from tensorflow_metadata.proto.v0 import schema_pb2
18
+
19
+ from ray.data.dataset import Dataset
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @PublicAPI(stability="alpha")
25
+ @dataclass
26
+ class TFXReadOptions:
27
+ """
28
+ Specifies read options when reading TFRecord files with TFX.
29
+ """
30
+
31
+ # An int representing the number of consecutive elements of
32
+ # this dataset to combine in a single batch when tfx-bsl is used to read
33
+ # the tfrecord files.
34
+ batch_size: int = 2048
35
+
36
+ # Toggles the schema inference applied; applicable
37
+ # only if tfx-bsl is used and tf_schema argument is missing.
38
+ # Defaults to True.
39
+ auto_infer_schema: bool = True
40
+
41
+
42
+ class TFRecordDatasource(FileBasedDatasource):
43
+ """TFRecord datasource, for reading and writing TFRecord files."""
44
+
45
+ _FILE_EXTENSIONS = ["tfrecords"]
46
+
47
+ def __init__(
48
+ self,
49
+ paths: Union[str, List[str]],
50
+ tf_schema: Optional["schema_pb2.Schema"] = None,
51
+ tfx_read_options: Optional["TFXReadOptions"] = None,
52
+ **file_based_datasource_kwargs,
53
+ ):
54
+ """
55
+ Args:
56
+ tf_schema: Optional TensorFlow Schema which is used to explicitly set
57
+ the schema of the underlying Dataset.
58
+ tfx_read_options: Optional options for enabling reading tfrecords
59
+ using tfx-bsl.
60
+
61
+ """
62
+ super().__init__(paths, **file_based_datasource_kwargs)
63
+
64
+ self._tf_schema = tf_schema
65
+ self._tfx_read_options = tfx_read_options
66
+
67
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
68
+ if self._tfx_read_options:
69
+ yield from self._tfx_read_stream(f, path)
70
+ else:
71
+ yield from self._default_read_stream(f, path)
72
+
73
+ def _default_read_stream(
74
+ self, f: "pyarrow.NativeFile", path: str
75
+ ) -> Iterator[Block]:
76
+ import tensorflow as tf
77
+ from google.protobuf.message import DecodeError
78
+
79
+ for record in _read_records(f, path):
80
+ example = tf.train.Example()
81
+ try:
82
+ example.ParseFromString(record)
83
+ except DecodeError as e:
84
+ raise ValueError(
85
+ "`TFRecordDatasource` failed to parse `tf.train.Example` "
86
+ f"record in '{path}'. This error can occur if your TFRecord "
87
+ f"file contains a message type other than `tf.train.Example`: {e}"
88
+ )
89
+
90
+ yield pyarrow_table_from_pydict(
91
+ _convert_example_to_dict(example, self._tf_schema)
92
+ )
93
+
94
+ def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
95
+ import tensorflow as tf
96
+ from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder
97
+
98
+ full_path = self._resolve_full_path(path)
99
+
100
+ compression = (self._open_stream_args or {}).get("compression", None)
101
+
102
+ if compression:
103
+ compression = compression.upper()
104
+
105
+ tf_schema_string = (
106
+ self._tf_schema.SerializeToString() if self._tf_schema else None
107
+ )
108
+
109
+ decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
110
+ exception_thrown = None
111
+ try:
112
+ for record in tf.data.TFRecordDataset(
113
+ full_path, compression_type=compression
114
+ ).batch(self._tfx_read_options.batch_size):
115
+ yield _cast_large_list_to_list(
116
+ pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
117
+ )
118
+ except Exception as error:
119
+ logger.exception(f"Failed to read TFRecord file {full_path}")
120
+ exception_thrown = error
121
+
122
+ # we need to do this hack were we raise an exception outside of the
123
+ # except block because tensorflow DataLossError is unpickable, and
124
+ # even if we raise a runtime error, ray keeps information about the
125
+ # original error, which makes it unpickable still.
126
+ if exception_thrown:
127
+ raise RuntimeError(f"Failed to read TFRecord file {full_path}.")
128
+
129
+ def _resolve_full_path(self, relative_path):
130
+ if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
131
+ return f"s3://{relative_path}"
132
+ if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
133
+ return f"gs://{relative_path}"
134
+ if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
135
+ return f"hdfs:///{relative_path}"
136
+ if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
137
+ protocol = self._filesystem.handler.fs.protocol
138
+ if isinstance(protocol, list) or isinstance(protocol, tuple):
139
+ protocol = protocol[0]
140
+ if protocol == "gcs":
141
+ protocol = "gs"
142
+ return f"{protocol}://{relative_path}"
143
+
144
+ return relative_path
145
+
146
+
147
+ def _convert_example_to_dict(
148
+ example: "tf.train.Example",
149
+ tf_schema: Optional["schema_pb2.Schema"],
150
+ ) -> Dict[str, pyarrow.Array]:
151
+ record = {}
152
+ schema_dict = {}
153
+ # Convert user-specified schema into dict for convenient mapping
154
+ if tf_schema is not None:
155
+ for schema_feature in tf_schema.feature:
156
+ schema_dict[schema_feature.name] = schema_feature.type
157
+
158
+ for feature_name, feature in example.features.feature.items():
159
+ if tf_schema is not None and feature_name not in schema_dict:
160
+ raise ValueError(
161
+ f"Found extra unexpected feature {feature_name} "
162
+ f"not in specified schema: {tf_schema}"
163
+ )
164
+ schema_feature_type = schema_dict.get(feature_name)
165
+ record[feature_name] = _get_feature_value(feature, schema_feature_type)
166
+ return record
167
+
168
+
169
+ def _get_single_true_type(dct) -> str:
170
+ """Utility function for getting the single key which has a `True` value in
171
+ a dict. Used to filter a dict of `{field_type: is_valid}` to get
172
+ the field type from a schema or data source."""
173
+ filtered_types = iter([_type for _type in dct if dct[_type]])
174
+ # In the case where there are no keys with a `True` value, return `None`
175
+ return next(filtered_types, None)
176
+
177
+
178
+ def _get_feature_value(
179
+ feature: "tf.train.Feature",
180
+ schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
181
+ ) -> pyarrow.Array:
182
+ import pyarrow as pa
183
+
184
+ underlying_feature_type = {
185
+ "bytes": feature.HasField("bytes_list"),
186
+ "float": feature.HasField("float_list"),
187
+ "int": feature.HasField("int64_list"),
188
+ }
189
+ # At most one of `bytes_list`, `float_list`, and `int64_list`
190
+ # should contain values. If none contain data, this indicates
191
+ # an empty feature value.
192
+ assert sum(bool(value) for value in underlying_feature_type.values()) <= 1
193
+
194
+ if schema_feature_type is not None:
195
+ try:
196
+ from tensorflow_metadata.proto.v0 import schema_pb2
197
+ except ModuleNotFoundError:
198
+ raise ModuleNotFoundError(
199
+ "To use TensorFlow schemas, please install "
200
+ "the tensorflow-metadata package."
201
+ )
202
+ # If a schema is specified, compare to the underlying type
203
+ specified_feature_type = {
204
+ "bytes": schema_feature_type == schema_pb2.FeatureType.BYTES,
205
+ "float": schema_feature_type == schema_pb2.FeatureType.FLOAT,
206
+ "int": schema_feature_type == schema_pb2.FeatureType.INT,
207
+ }
208
+ und_type = _get_single_true_type(underlying_feature_type)
209
+ spec_type = _get_single_true_type(specified_feature_type)
210
+ if und_type is not None and und_type != spec_type:
211
+ raise ValueError(
212
+ "Schema field type mismatch during read: specified type is "
213
+ f"{spec_type}, but underlying type is {und_type}",
214
+ )
215
+ # Override the underlying value type with the type in the user-specified schema.
216
+ underlying_feature_type = specified_feature_type
217
+
218
+ if underlying_feature_type["bytes"]:
219
+ value = feature.bytes_list.value
220
+ type_ = pa.binary()
221
+ elif underlying_feature_type["float"]:
222
+ value = feature.float_list.value
223
+ type_ = pa.float32()
224
+ elif underlying_feature_type["int"]:
225
+ value = feature.int64_list.value
226
+ type_ = pa.int64()
227
+ else:
228
+ value = []
229
+ type_ = pa.null()
230
+ value = list(value)
231
+ if len(value) == 1 and schema_feature_type is None:
232
+ # Use the value itself if the features contains a single value.
233
+ # This is to give better user experience when writing preprocessing UDF on
234
+ # these single-value lists.
235
+ value = value[0]
236
+ else:
237
+ # If the feature value is empty and no type is specified in the user-provided
238
+ # schema, set the type to null for now to allow pyarrow to construct a valid
239
+ # Array; later, infer the type from other records which have non-empty values
240
+ # for the feature.
241
+ if len(value) == 0:
242
+ type_ = pa.null()
243
+ type_ = pa.list_(type_)
244
+ return pa.array([value], type=type_)
245
+
246
+
247
+ # Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/reader.py#L16-L96 # noqa: E501
248
+ #
249
+ # MIT License
250
+ #
251
+ # Copyright (c) 2020 Vahid Kazemi
252
+ #
253
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
254
+ # of this software and associated documentation files (the "Software"), to deal
255
+ # in the Software without restriction, including without limitation the rights
256
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
257
+ # copies of the Software, and to permit persons to whom the Software is
258
+ # furnished to do so, subject to the following conditions:
259
+ #
260
+ # The above copyright notice and this permission notice shall be included in all
261
+ # copies or substantial portions of the Software.
262
+ #
263
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
264
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
265
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
266
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
267
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
268
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
269
+ # SOFTWARE.
270
+ def _read_records(
271
+ file: "pyarrow.NativeFile",
272
+ path: str,
273
+ ) -> Iterable[memoryview]:
274
+ """
275
+ Read records from TFRecord file.
276
+
277
+ A TFRecord file contains a sequence of records. The file can only be read
278
+ sequentially. Each record is stored in the following formats:
279
+ uint64 length
280
+ uint32 masked_crc32_of_length
281
+ byte data[length]
282
+ uint32 masked_crc32_of_data
283
+
284
+ See https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecords_format_details
285
+ for more details.
286
+ """
287
+ length_bytes = bytearray(8)
288
+ crc_bytes = bytearray(4)
289
+ datum_bytes = bytearray(1024 * 1024)
290
+ row_count = 0
291
+ while True:
292
+ try:
293
+ # Read "length" field.
294
+ num_length_bytes_read = file.readinto(length_bytes)
295
+ if num_length_bytes_read == 0:
296
+ break
297
+ elif num_length_bytes_read != 8:
298
+ raise ValueError(
299
+ "Failed to read the length of record data. Expected 8 bytes but "
300
+ "got {num_length_bytes_read} bytes."
301
+ )
302
+
303
+ # Read "masked_crc32_of_length" field.
304
+ num_length_crc_bytes_read = file.readinto(crc_bytes)
305
+ if num_length_crc_bytes_read != 4:
306
+ raise ValueError(
307
+ "Failed to read the length of CRC-32C hashes. Expected 4 bytes "
308
+ "but got {num_length_crc_bytes_read} bytes."
309
+ )
310
+
311
+ # Read "data[length]" field.
312
+ (data_length,) = struct.unpack("<Q", length_bytes)
313
+ if data_length > len(datum_bytes):
314
+ datum_bytes = datum_bytes.zfill(int(data_length * 1.5))
315
+ datum_bytes_view = memoryview(datum_bytes)[:data_length]
316
+ num_datum_bytes_read = file.readinto(datum_bytes_view)
317
+ if num_datum_bytes_read != data_length:
318
+ raise ValueError(
319
+ f"Failed to read the record. Exepcted {data_length} bytes but got "
320
+ f"{num_datum_bytes_read} bytes."
321
+ )
322
+
323
+ # Read "masked_crc32_of_data" field.
324
+ # TODO(chengsu): ideally we should check CRC-32C against the actual data.
325
+ num_crc_bytes_read = file.readinto(crc_bytes)
326
+ if num_crc_bytes_read != 4:
327
+ raise ValueError(
328
+ "Failed to read the CRC-32C hashes. Expected 4 bytes but got "
329
+ f"{num_crc_bytes_read} bytes."
330
+ )
331
+
332
+ # Return the data.
333
+ yield datum_bytes_view
334
+
335
+ row_count += 1
336
+ data_length = None
337
+ except Exception as e:
338
+ error_message = (
339
+ f"Failed to read TFRecord file {path}. Please ensure that the "
340
+ f"TFRecord file has correct format. Already read {row_count} rows."
341
+ )
342
+ if data_length is not None:
343
+ error_message += f" Byte size of current record data is {data_length}."
344
+ raise RuntimeError(error_message) from e
345
+
346
+
347
+ def _cast_large_list_to_list(batch: pyarrow.Table):
348
+ """
349
+ This function transform pyarrow.large_list into list and pyarrow.large_binary into
350
+ pyarrow.binary so that all types resulting from the tfrecord_datasource are usable
351
+ with dataset.to_tf().
352
+ """
353
+ old_schema = batch.schema
354
+ fields = {}
355
+
356
+ for column_name in old_schema.names:
357
+ field_type = old_schema.field(column_name).type
358
+ if type(field_type) is pyarrow.lib.LargeListType:
359
+ value_type = field_type.value_type
360
+
361
+ if value_type == pyarrow.large_binary():
362
+ value_type = pyarrow.binary()
363
+
364
+ fields[column_name] = pyarrow.list_(value_type)
365
+ elif field_type == pyarrow.large_binary():
366
+ fields[column_name] = pyarrow.binary()
367
+ else:
368
+ fields[column_name] = old_schema.field(column_name)
369
+
370
+ new_schema = pyarrow.schema(fields)
371
+ return batch.cast(new_schema)
372
+
373
+
374
+ def _infer_schema_and_transform(dataset: "Dataset"):
375
+ list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names))
376
+
377
+ return dataset.map_batches(
378
+ _unwrap_single_value_lists,
379
+ fn_kwargs={"col_lengths": list_sizes["max_list_size"]},
380
+ batch_format="pyarrow",
381
+ )
382
+
383
+
384
+ def _unwrap_single_value_lists(batch: pyarrow.Table, col_lengths: Dict[str, int]):
385
+ """
386
+ This function will transfrom the dataset converting list types that always
387
+ contain single values to thery underlying data type
388
+ (i.e. pyarrow.int64() and pyarrow.float64())
389
+ """
390
+ columns = {}
391
+
392
+ for col in col_lengths:
393
+ value_type = batch[col].type.value_type
394
+
395
+ if col_lengths[col] == 1:
396
+ if batch[col]:
397
+ columns[col] = pyarrow.array(
398
+ [x.as_py()[0] if x.as_py() else None for x in batch[col]],
399
+ type=value_type,
400
+ )
401
+ else:
402
+ columns[col] = batch[col]
403
+
404
+ return pyarrow.table(columns)
405
+
406
+
407
+ class _MaxListSize(AggregateFn):
408
+ def __init__(self, columns: List[str]):
409
+ self._columns = columns
410
+ super().__init__(
411
+ init=self._init,
412
+ merge=self._merge,
413
+ accumulate_row=self._accumulate_row,
414
+ finalize=lambda a: a,
415
+ name="max_list_size",
416
+ )
417
+
418
+ def _init(self, k: str):
419
+ return {col: 0 for col in self._columns}
420
+
421
+ def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]):
422
+ merged = {}
423
+ for col in self._columns:
424
+ merged[col] = max(acc1[col], acc2[col])
425
+
426
+ return merged
427
+
428
+ def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"):
429
+ for k in row:
430
+ value = row[k]
431
+ if value:
432
+ acc[k] = max(len(value), acc[k])
433
+
434
+ return acc
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/torch_datasource.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
4
+ from ray.data.block import BlockMetadata
5
+ from ray.data.datasource.datasource import Datasource, ReadTask
6
+
7
+ if TYPE_CHECKING:
8
+ import torch
9
+
10
+
11
+ TORCH_DATASOURCE_READER_BATCH_SIZE = 32
12
+
13
+
14
+ class TorchDatasource(Datasource):
15
+ """Torch datasource, for reading from `Torch
16
+ datasets <https://pytorch.org/docs/stable/data.html/>`_.
17
+ This datasource implements a streaming read using a single read task.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dataset: "torch.utils.data.Dataset",
23
+ ):
24
+ self._dataset = dataset
25
+
26
+ def get_read_tasks(self, parallelism):
27
+ assert parallelism == 1
28
+
29
+ meta = BlockMetadata(
30
+ num_rows=len(self._dataset),
31
+ size_bytes=None,
32
+ schema=None,
33
+ input_files=None,
34
+ exec_stats=None,
35
+ )
36
+ read_task = ReadTask(
37
+ lambda subset=self._dataset: _read_subset(
38
+ subset,
39
+ ),
40
+ metadata=meta,
41
+ )
42
+
43
+ return [read_task]
44
+
45
+ def estimate_inmemory_data_size(self):
46
+ return None
47
+
48
+
49
+ def _read_subset(subset: "torch.utils.data.Subset"):
50
+ batch = []
51
+ for item in subset:
52
+ batch.append(item)
53
+ if len(batch) == TORCH_DATASOURCE_READER_BATCH_SIZE:
54
+ builder = DelegatingBlockBuilder()
55
+ builder.add_batch({"item": batch})
56
+ yield builder.build()
57
+ batch.clear()
58
+
59
+ if len(batch) > 0:
60
+ builder = DelegatingBlockBuilder()
61
+ builder.add_batch({"item": batch})
62
+ yield builder.build()
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/video_datasource.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import TYPE_CHECKING, List, Union
3
+
4
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
5
+ from ray.data._internal.util import _check_import
6
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
7
+
8
+ if TYPE_CHECKING:
9
+ import pyarrow
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class VideoDatasource(FileBasedDatasource):
15
+ _FILE_EXTENSIONS = [
16
+ "mp4",
17
+ "mkv",
18
+ "mov",
19
+ "avi",
20
+ "wmv",
21
+ "flv",
22
+ "webm",
23
+ "m4v",
24
+ "3gp",
25
+ "mpeg",
26
+ "mpg",
27
+ "ts",
28
+ "ogv",
29
+ "rm",
30
+ "rmvb",
31
+ "vob",
32
+ "asf",
33
+ "f4v",
34
+ "m2ts",
35
+ "mts",
36
+ "divx",
37
+ "xvid",
38
+ "mxf",
39
+ ]
40
+
41
+ def __init__(
42
+ self,
43
+ paths: Union[str, List[str]],
44
+ **file_based_datasource_kwargs,
45
+ ):
46
+ super().__init__(paths, **file_based_datasource_kwargs)
47
+
48
+ _check_import(self, module="decord", package="decord")
49
+
50
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str):
51
+ from decord import VideoReader
52
+
53
+ reader = VideoReader(f)
54
+
55
+ for frame_index, frame in enumerate(reader):
56
+ item = {"frame": frame.asnumpy(), "frame_index": frame_index}
57
+ builder = DelegatingBlockBuilder()
58
+ builder.add(item)
59
+ yield builder.build()
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasink.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import tarfile
3
+ import time
4
+ import uuid
5
+ from typing import Optional, Union
6
+
7
+ import pyarrow
8
+
9
+ from ray.data._internal.datasource.webdataset_datasource import (
10
+ _apply_list,
11
+ _default_encoder,
12
+ _make_iterable,
13
+ )
14
+ from ray.data.block import BlockAccessor
15
+ from ray.data.datasource.file_datasink import BlockBasedFileDatasink
16
+
17
+
18
+ class WebDatasetDatasink(BlockBasedFileDatasink):
19
+ def __init__(
20
+ self,
21
+ path: str,
22
+ encoder: Optional[Union[bool, str, callable, list]] = True,
23
+ *,
24
+ file_format: str = "tar",
25
+ **file_datasink_kwargs,
26
+ ):
27
+ super().__init__(path, file_format="tar", **file_datasink_kwargs)
28
+
29
+ self.encoder = encoder
30
+
31
+ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
32
+ stream = tarfile.open(fileobj=file, mode="w|")
33
+ samples = _make_iterable(block)
34
+ for sample in samples:
35
+ if not isinstance(sample, dict):
36
+ sample = sample.as_pydict()
37
+ if self.encoder is not None:
38
+ sample = _apply_list(self.encoder, sample, default=_default_encoder)
39
+ if "__key__" not in sample:
40
+ sample["__key__"] = uuid.uuid4().hex
41
+ key = sample["__key__"]
42
+ for k, v in sample.items():
43
+ if v is None or k.startswith("__"):
44
+ continue
45
+ assert isinstance(v, bytes) or isinstance(v, str)
46
+ if not isinstance(v, bytes):
47
+ v = v.encode("utf-8")
48
+ ti = tarfile.TarInfo(f"{key}.{k}")
49
+ ti.size = len(v)
50
+ ti.mtime = time.time()
51
+ ti.mode, ti.uname, ti.gname = 0o644, "data", "data"
52
+ stream.addfile(ti, io.BytesIO(v))
53
+ stream.close()
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasource.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright NVIDIA Corporation 2023
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import fnmatch
5
+ import io
6
+ import json
7
+ import re
8
+ import tarfile
9
+ from functools import partial
10
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
11
+
12
+ import ray
13
+ from ray.data._internal.util import iterate_with_retry
14
+ from ray.data.block import BlockAccessor
15
+ from ray.data.datasource.file_based_datasource import FileBasedDatasource
16
+
17
+ if TYPE_CHECKING:
18
+ import pyarrow
19
+
20
+
21
+ def _base_plus_ext(path: str):
22
+ """Split off all file extensions.
23
+
24
+ Returns base, allext.
25
+
26
+ Args:
27
+ path: path with extensions
28
+
29
+ Returns:
30
+ str: path with all extensions removed
31
+ """
32
+ match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
33
+ if not match:
34
+ return None, None
35
+ return match.group(1), match.group(2)
36
+
37
+
38
+ def _valid_sample(sample: Dict[str, Any]):
39
+ """Check whether a sample is valid.
40
+
41
+ Args:
42
+ sample: sample to be checked
43
+ """
44
+ return (
45
+ sample is not None
46
+ and isinstance(sample, dict)
47
+ and len(list(sample.keys())) > 0
48
+ and not sample.get("__bad__", False)
49
+ )
50
+
51
+
52
+ def _apply_list(
53
+ f: Union[Callable, List[Callable]], sample: Dict[str, Any], default: Callable = None
54
+ ):
55
+ """Apply a list of functions to a sample.
56
+
57
+ Args:
58
+ f: function or list of functions
59
+ sample: sample to be modified
60
+ default: default function to be applied to all keys.
61
+ Defaults to None.
62
+
63
+ Returns:
64
+ modified sample
65
+ """
66
+ if f is None:
67
+ return sample
68
+ if not isinstance(f, list):
69
+ f = [f]
70
+ for g in f:
71
+ if default is not None and not callable(g):
72
+ g = partial(default, format=g)
73
+ sample = g(sample)
74
+ return sample
75
+
76
+
77
+ def _check_suffix(suffix: str, suffixes: Union[list, callable]):
78
+ """Check whether a suffix is valid.
79
+
80
+ Suffixes can be either None (=accept everything), a callable,
81
+ or a list of patterns. If the pattern contains */? it is treated
82
+ as a glob pattern, otherwise it is treated as a literal.
83
+
84
+ Args:
85
+ suffix: suffix to be checked
86
+ suffixes: list of valid suffixes
87
+ """
88
+ if suffixes is None:
89
+ return True
90
+ if callable(suffixes):
91
+ return suffixes(suffix)
92
+ for pattern in suffixes:
93
+ if "*" in pattern or "?" in pattern:
94
+ if fnmatch.fnmatch("." + suffix, pattern):
95
+ return True
96
+ elif suffix == pattern or "." + suffix == pattern:
97
+ return True
98
+ return False
99
+
100
+
101
+ def _tar_file_iterator(
102
+ fileobj: Any,
103
+ fileselect: Optional[Union[bool, callable, list]] = None,
104
+ filerename: Optional[Union[bool, callable, list]] = None,
105
+ verbose_open: bool = False,
106
+ meta: dict = None,
107
+ ):
108
+ """Iterate over tar file, yielding filename, content pairs for the given tar stream.
109
+
110
+ Args:
111
+ fileobj: file object
112
+ fileselect: patterns or function selecting
113
+ files to be selected
114
+ meta: metadata to be added to each sample
115
+ """
116
+ meta = meta or {}
117
+ stream = tarfile.open(fileobj=fileobj, mode="r|*")
118
+ if verbose_open:
119
+ print(f"start {meta}")
120
+ for tarinfo in stream:
121
+ fname = tarinfo.name
122
+ if not tarinfo.isreg() or fname is None:
123
+ continue
124
+ data = stream.extractfile(tarinfo).read()
125
+ fname = _apply_list(filerename, fname)
126
+ assert isinstance(fname, str)
127
+ if not _check_suffix(fname, fileselect):
128
+ continue
129
+ result = dict(fname=fname, data=data)
130
+ yield result
131
+ if verbose_open:
132
+ print(f"done {meta}")
133
+
134
+
135
+ def _group_by_keys(
136
+ data: List[Dict[str, Any]],
137
+ keys: callable = _base_plus_ext,
138
+ suffixes: Optional[Union[list, callable]] = None,
139
+ meta: dict = None,
140
+ ):
141
+ """Return function over iterator that groups key, value pairs into samples.
142
+
143
+ Args:
144
+ data: iterator over key, value pairs
145
+ keys: function that returns key, suffix for a given key
146
+ suffixes: list of suffixes to be included in the sample
147
+ meta: metadata to be added to each sample
148
+ """
149
+ meta = meta or {}
150
+ current_sample = None
151
+ for filesample in data:
152
+ assert isinstance(filesample, dict)
153
+ fname, value = filesample["fname"], filesample["data"]
154
+ prefix, suffix = keys(fname)
155
+ if prefix is None:
156
+ continue
157
+ if current_sample is None or prefix != current_sample["__key__"]:
158
+ if _valid_sample(current_sample):
159
+ current_sample.update(meta)
160
+ yield current_sample
161
+ current_sample = dict(__key__=prefix)
162
+ if "__url__" in filesample:
163
+ current_sample["__url__"] = filesample["__url__"]
164
+ if suffix in current_sample:
165
+ raise ValueError(
166
+ f"{fname}: duplicate file name in tar file "
167
+ + f"{suffix} {current_sample.keys()}, tar is {meta['__url__']}"
168
+ )
169
+ if suffixes is None or _check_suffix(suffix, suffixes):
170
+ current_sample[suffix] = value
171
+ if _valid_sample(current_sample):
172
+ current_sample.update(meta)
173
+ yield current_sample
174
+
175
+
176
+ def _default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True):
177
+ """A default decoder for webdataset.
178
+
179
+ This handles common file extensions: .txt, .cls, .cls2,
180
+ .jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl.
181
+ These are the most common extensions used in webdataset.
182
+ For other extensions, users can provide their own decoder.
183
+
184
+ Args:
185
+ sample: sample, modified in place
186
+ """
187
+ sample = dict(sample)
188
+ for key, value in sample.items():
189
+ extension = key.split(".")[-1]
190
+ if key.startswith("__"):
191
+ continue
192
+ elif extension in ["txt", "text"]:
193
+ sample[key] = value.decode("utf-8")
194
+ elif extension in ["cls", "cls2"]:
195
+ sample[key] = int(value.decode("utf-8"))
196
+ elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]:
197
+ import numpy as np
198
+ import PIL.Image
199
+
200
+ if format == "PIL":
201
+ sample[key] = PIL.Image.open(io.BytesIO(value))
202
+ else:
203
+ sample[key] = np.asarray(PIL.Image.open(io.BytesIO(value)))
204
+ elif extension == "json":
205
+ sample[key] = json.loads(value)
206
+ elif extension == "npy":
207
+ import numpy as np
208
+
209
+ sample[key] = np.load(io.BytesIO(value))
210
+ elif extension == "mp":
211
+ import msgpack
212
+
213
+ sample[key] = msgpack.unpackb(value, raw=False)
214
+ elif extension in ["pt", "pth"]:
215
+ import torch
216
+
217
+ sample[key] = torch.load(io.BytesIO(value))
218
+ elif extension in ["pickle", "pkl"]:
219
+ import pickle
220
+
221
+ sample[key] = pickle.loads(value)
222
+ return sample
223
+
224
+
225
+ extension_to_format = {"jpg": "jpeg"}
226
+
227
+
228
+ def _default_encoder(sample: Dict[str, Any], format: Optional[Union[str, bool]] = True):
229
+ """A default encoder for webdataset.
230
+
231
+ This handles common file extensions: .txt, .cls, .cls2, .jpg,
232
+ .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl
233
+ These are the most common extensions used in webdataset.
234
+ For other extensions, users can provide their own encoder.
235
+
236
+ Args:
237
+ sample (Dict[str, Any]): sample
238
+ """
239
+ sample = dict(sample)
240
+ for key, value in sample.items():
241
+ extension = key.split(".")[-1]
242
+ if key.startswith("__"):
243
+ continue
244
+ elif extension in ["txt"]:
245
+ sample[key] = value.encode("utf-8")
246
+ elif extension in ["cls", "cls2"]:
247
+ sample[key] = str(value).encode("utf-8")
248
+ elif extension in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]:
249
+ import numpy as np
250
+ import PIL.Image
251
+
252
+ if isinstance(value, np.ndarray):
253
+ value = PIL.Image.fromarray(value)
254
+ assert isinstance(value, PIL.Image.Image)
255
+ stream = io.BytesIO()
256
+ value.save(
257
+ stream, format=extension_to_format.get(extension.lower(), extension)
258
+ )
259
+ sample[key] = stream.getvalue()
260
+ elif extension == "json":
261
+ sample[key] = json.dumps(value).encode("utf-8")
262
+ elif extension == "npy":
263
+ import numpy as np
264
+
265
+ stream = io.BytesIO()
266
+ np.save(stream, value)
267
+ sample[key] = stream.getvalue()
268
+ elif extension == "mp":
269
+ import msgpack
270
+
271
+ sample[key] = msgpack.dumps(value)
272
+ elif extension in ["pt", "pth"]:
273
+ import torch
274
+
275
+ stream = io.BytesIO()
276
+ torch.save(value, stream)
277
+ sample[key] = stream.getvalue()
278
+ elif extension in ["pickle", "pkl"]:
279
+ import pickle
280
+
281
+ stream = io.BytesIO()
282
+ pickle.dump(value, stream)
283
+ sample[key] = stream.getvalue()
284
+ return sample
285
+
286
+
287
+ def _make_iterable(block: BlockAccessor):
288
+ """Make a block iterable.
289
+
290
+ This is a placeholder for dealing with more complex blocks.
291
+
292
+ Args:
293
+ block: Ray Dataset block
294
+
295
+ Returns:
296
+ Iterable[Dict[str,Any]]: Iterable of samples
297
+ """
298
+ return block.iter_rows(public_row_format=False)
299
+
300
+
301
+ class WebDatasetDatasource(FileBasedDatasource):
302
+ """A Datasource for WebDataset datasets (tar format with naming conventions)."""
303
+
304
+ _FILE_EXTENSIONS = ["tar"]
305
+
306
+ def __init__(
307
+ self,
308
+ paths: Union[str, List[str]],
309
+ decoder: Optional[Union[bool, str, callable, list]] = True,
310
+ fileselect: Optional[Union[bool, callable, list]] = None,
311
+ filerename: Optional[Union[bool, callable, list]] = None,
312
+ suffixes: Optional[Union[bool, callable, list]] = None,
313
+ verbose_open: bool = False,
314
+ expand_json: bool = False,
315
+ **file_based_datasource_kwargs,
316
+ ):
317
+ super().__init__(paths, **file_based_datasource_kwargs)
318
+
319
+ self.decoder = decoder
320
+ self.fileselect = fileselect
321
+ self.filerename = filerename
322
+ self.suffixes = suffixes
323
+ self.verbose_open = verbose_open
324
+ self.expand_json = expand_json
325
+
326
+ def _read_stream(self, stream: "pyarrow.NativeFile", path: str):
327
+ """Read and decode samples from a stream.
328
+
329
+ Note that fileselect selects files during reading, while suffixes
330
+ selects files during the grouping step.
331
+
332
+ Args:
333
+ stream: File descriptor to read from.
334
+ path: Path to the data.
335
+ decoder: decoder or list of decoders to be applied to samples
336
+ fileselect: Predicate for skipping files in tar decoder.
337
+ Defaults to lambda_:False.
338
+ suffixes: List of suffixes to be extracted. Defaults to None.
339
+ verbose_open: Print message when opening files. Defaults to False.
340
+
341
+ Yields:
342
+ List[Dict[str, Any]]: List of sample (list of length 1).
343
+ """
344
+
345
+ import pandas as pd
346
+
347
+ def get_tar_file_iterator():
348
+ return _tar_file_iterator(
349
+ stream,
350
+ fileselect=self.fileselect,
351
+ filerename=self.filerename,
352
+ verbose_open=self.verbose_open,
353
+ )
354
+
355
+ # S3 can raise transient errors during iteration
356
+ ctx = ray.data.DataContext.get_current()
357
+ files = iterate_with_retry(
358
+ get_tar_file_iterator, "iterate tar file", match=ctx.retried_io_errors
359
+ )
360
+
361
+ samples = _group_by_keys(files, meta=dict(__url__=path), suffixes=self.suffixes)
362
+ for sample in samples:
363
+ if self.decoder is not None:
364
+ sample = _apply_list(self.decoder, sample, default=_default_decoder)
365
+ if self.expand_json:
366
+ if isinstance(sample["json"], bytes):
367
+ parsed_json = json.loads(sample["json"].decode("utf-8"))
368
+ elif isinstance(sample["json"], str):
369
+ parsed_json = json.loads(sample["json"])
370
+ elif isinstance(sample["json"], dict):
371
+ parsed_json = sample["json"]
372
+ else:
373
+ raise TypeError(
374
+ f"Unsupported data type" f" {type(sample['json'])} for sample"
375
+ )
376
+ for k, v in parsed_json.items():
377
+ if k not in sample:
378
+ sample[k] = []
379
+ sample[k].append(v)
380
+ yield pd.DataFrame(
381
+ {
382
+ k: v if isinstance(v, list) and len(v) == 1 else [v]
383
+ for k, v in sample.items()
384
+ }
385
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/iterator_impl.cpython-311.pyc ADDED
Binary file (2.82 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/stream_split_iterator.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/iterator_impl.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union
2
+
3
+ from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
4
+ from ray.data._internal.stats import DatasetStats
5
+ from ray.data._internal.util import create_dataset_tag
6
+ from ray.data.iterator import DataIterator
7
+
8
+ if TYPE_CHECKING:
9
+ import pyarrow
10
+
11
+ from ray.data import Dataset
12
+
13
+
14
+ class DataIteratorImpl(DataIterator):
15
+ def __init__(
16
+ self,
17
+ base_dataset: "Dataset",
18
+ ):
19
+ self._base_dataset = base_dataset
20
+
21
+ def __repr__(self) -> str:
22
+ return f"DataIterator({self._base_dataset})"
23
+
24
+ def _to_ref_bundle_iterator(
25
+ self,
26
+ ) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]:
27
+ ds = self._base_dataset
28
+ ref_bundles_iterator, stats, executor = ds._plan.execute_to_iterator()
29
+ ds._current_executor = executor
30
+ return ref_bundles_iterator, stats, False
31
+
32
+ def stats(self) -> str:
33
+ return self._base_dataset.stats()
34
+
35
+ def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
36
+ return self._base_dataset.schema()
37
+
38
+ def _get_dataset_tag(self):
39
+ return create_dataset_tag(
40
+ self._base_dataset._plan._dataset_name, self._base_dataset._uuid
41
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/stream_split_iterator.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ import time
4
+ from dataclasses import replace
5
+ from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union
6
+
7
+ import ray
8
+ from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle
9
+ from ray.data._internal.execution.legacy_compat import execute_to_legacy_bundle_iterator
10
+ from ray.data._internal.execution.operators.output_splitter import OutputSplitter
11
+ from ray.data._internal.execution.streaming_executor import StreamingExecutor
12
+ from ray.data._internal.stats import DatasetStats
13
+ from ray.data._internal.util import create_dataset_tag
14
+ from ray.data.block import Block, BlockMetadata
15
+ from ray.data.iterator import DataIterator
16
+ from ray.types import ObjectRef
17
+ from ray.util.debug import log_once
18
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
19
+
20
+ if TYPE_CHECKING:
21
+ import pyarrow
22
+
23
+ from ray.data import Dataset
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ BLOCKED_CLIENT_WARN_TIMEOUT = 30
29
+
30
+
31
+ class StreamSplitDataIterator(DataIterator):
32
+ """Implements a collection of iterators over a shared data stream."""
33
+
34
+ @staticmethod
35
+ def create(
36
+ base_dataset: "Dataset",
37
+ n: int,
38
+ equal: bool,
39
+ locality_hints: Optional[List[NodeIdStr]],
40
+ ) -> List["StreamSplitDataIterator"]:
41
+ """Create a split iterator from the given base Dataset and options.
42
+
43
+ See also: `Dataset.streaming_split`.
44
+ """
45
+ # To avoid deadlock, the concurrency on this actor must be set to at least `n`.
46
+ coord_actor = SplitCoordinator.options(
47
+ max_concurrency=n,
48
+ scheduling_strategy=NodeAffinitySchedulingStrategy(
49
+ ray.get_runtime_context().get_node_id(), soft=False
50
+ ),
51
+ ).remote(base_dataset, n, equal, locality_hints)
52
+
53
+ return [
54
+ StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n)
55
+ ]
56
+
57
+ def __init__(
58
+ self,
59
+ base_dataset: "Dataset",
60
+ coord_actor: ray.actor.ActorHandle,
61
+ output_split_idx: int,
62
+ world_size: int,
63
+ ):
64
+ self._base_dataset = base_dataset
65
+ self._coord_actor = coord_actor
66
+ self._output_split_idx = output_split_idx
67
+ self._world_size = world_size
68
+ self._iter_stats = DatasetStats(metadata={}, parent=None)
69
+
70
+ def _to_ref_bundle_iterator(
71
+ self,
72
+ ) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]:
73
+ def gen_blocks() -> Iterator[RefBundle]:
74
+ cur_epoch = ray.get(
75
+ self._coord_actor.start_epoch.remote(self._output_split_idx)
76
+ )
77
+ future: ObjectRef[
78
+ Optional[ObjectRef[Block]]
79
+ ] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx)
80
+ while True:
81
+ block_ref_and_md: Optional[
82
+ Tuple[ObjectRef[Block], BlockMetadata]
83
+ ] = ray.get(future)
84
+ if not block_ref_and_md:
85
+ break
86
+ else:
87
+ future = self._coord_actor.get.remote(
88
+ cur_epoch, self._output_split_idx
89
+ )
90
+ yield RefBundle(blocks=(block_ref_and_md,), owns_blocks=False)
91
+
92
+ return gen_blocks(), self._iter_stats, False
93
+
94
+ def stats(self) -> str:
95
+ """Implements DataIterator."""
96
+ # Merge the locally recorded iter stats and the remotely recorded
97
+ # stream execution stats.
98
+ stats = ray.get(self._coord_actor.stats.remote())
99
+ summary = stats.to_summary()
100
+ summary.iter_stats = self._iter_stats.to_summary().iter_stats
101
+ summary.iter_stats.streaming_split_coord_time.add(
102
+ stats.streaming_split_coordinator_s.get()
103
+ )
104
+ return summary.to_string()
105
+
106
+ def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
107
+ """Implements DataIterator."""
108
+ return self._base_dataset.schema()
109
+
110
+ def world_size(self) -> int:
111
+ """Returns the number of splits total."""
112
+ return self._world_size
113
+
114
+ def _get_dataset_tag(self):
115
+ return create_dataset_tag(
116
+ self._base_dataset._plan._dataset_name,
117
+ self._base_dataset._uuid,
118
+ self._output_split_idx,
119
+ )
120
+
121
+
122
+ @ray.remote(num_cpus=0)
123
+ class SplitCoordinator:
124
+ """Coordinator actor for routing blocks to output splits.
125
+
126
+ This actor runs a streaming executor locally on its main thread. Clients can
127
+ retrieve results via actor calls running on other threads.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ dataset: "Dataset",
133
+ n: int,
134
+ equal: bool,
135
+ locality_hints: Optional[List[NodeIdStr]],
136
+ ):
137
+ # Set current DataContext.
138
+ self._data_context = dataset.context
139
+ ray.data.DataContext._set_current(self._data_context)
140
+ # Automatically set locality with output to the specified location hints.
141
+ if locality_hints:
142
+ self._data_context.execution_options.locality_with_output = locality_hints
143
+ logger.info(f"Auto configuring locality_with_output={locality_hints}")
144
+
145
+ self._base_dataset = dataset
146
+ self._n = n
147
+ self._equal = equal
148
+ self._locality_hints = locality_hints
149
+ self._lock = threading.RLock()
150
+ self._executor = None
151
+
152
+ # Guarded by self._lock.
153
+ self._next_bundle: Dict[int, RefBundle] = {}
154
+ self._unfinished_clients_in_epoch = n
155
+ self._cur_epoch = -1
156
+
157
+ def gen_epochs():
158
+ while True:
159
+ executor = StreamingExecutor(
160
+ self._data_context,
161
+ create_dataset_tag(
162
+ self._base_dataset._name, self._base_dataset._uuid
163
+ ),
164
+ )
165
+ self._executor = executor
166
+
167
+ def add_split_op(dag):
168
+ return OutputSplitter(
169
+ dag,
170
+ n,
171
+ equal,
172
+ self._data_context,
173
+ locality_hints,
174
+ )
175
+
176
+ output_iterator = execute_to_legacy_bundle_iterator(
177
+ executor,
178
+ dataset._plan,
179
+ dag_rewrite=add_split_op,
180
+ )
181
+ yield output_iterator
182
+
183
+ self._next_epoch = gen_epochs()
184
+ self._output_iterator = None
185
+ # Store the error raised from the `gen_epoch` call.
186
+ self._gen_epoch_error: Optional[Exception] = None
187
+
188
+ def stats(self) -> DatasetStats:
189
+ """Returns stats from the base dataset."""
190
+ if self._executor:
191
+ return self._executor.get_stats()
192
+ return self._base_dataset._plan.stats()
193
+
194
+ def start_epoch(self, split_idx: int) -> str:
195
+ """Called to start an epoch.
196
+
197
+ Returns:
198
+ UUID for the epoch, which must be used when accessing results via get().
199
+ """
200
+
201
+ # Wait for all clients to arrive at the barrier before starting a new epoch.
202
+ epoch_id = self._barrier(split_idx)
203
+ return epoch_id
204
+
205
+ def get(
206
+ self, epoch_id: int, output_split_idx: int
207
+ ) -> Optional[Tuple[ObjectRef[Block], BlockMetadata]]:
208
+ """Blocking get operation.
209
+
210
+ This is intended to be called concurrently from multiple clients.
211
+ """
212
+ start_time = time.perf_counter()
213
+ if epoch_id != self._cur_epoch:
214
+ raise ValueError(
215
+ "Invalid iterator: the dataset has moved on to another epoch."
216
+ )
217
+
218
+ try:
219
+ # Ensure there is at least one bundle.
220
+ with self._lock:
221
+ if output_split_idx in self._next_bundle:
222
+ next_bundle = self._next_bundle[output_split_idx]
223
+ else:
224
+ next_bundle = None
225
+
226
+ # Fetch next bundle if needed.
227
+ while next_bundle is None or not next_bundle.blocks:
228
+ # This is a BLOCKING call, so do it outside the lock.
229
+ next_bundle = self._output_iterator.get_next(output_split_idx)
230
+
231
+ block = next_bundle.blocks[-1]
232
+ next_bundle = replace(next_bundle, blocks=next_bundle.blocks[:-1])
233
+
234
+ # Accumulate any remaining blocks in next_bundle map as needed.
235
+ with self._lock:
236
+ self._next_bundle[output_split_idx] = next_bundle
237
+ if not next_bundle.blocks:
238
+ del self._next_bundle[output_split_idx]
239
+
240
+ return block
241
+ except StopIteration:
242
+ return None
243
+ finally:
244
+ stats = self.stats()
245
+ if stats and stats.streaming_split_coordinator_s:
246
+ stats.streaming_split_coordinator_s.add(
247
+ time.perf_counter() - start_time
248
+ )
249
+
250
+ def _barrier(self, split_idx: int) -> int:
251
+ """Arrive and block until the start of the given epoch."""
252
+
253
+ # Decrement and await all clients to arrive here.
254
+ with self._lock:
255
+ starting_epoch = self._cur_epoch
256
+ self._unfinished_clients_in_epoch -= 1
257
+
258
+ start_time = time.time()
259
+ while (
260
+ self._cur_epoch == starting_epoch and self._unfinished_clients_in_epoch != 0
261
+ ):
262
+ if time.time() - start_time > BLOCKED_CLIENT_WARN_TIMEOUT:
263
+ if log_once(f"stream_split_blocked_{split_idx}_{starting_epoch}"):
264
+ logger.warning(
265
+ f"StreamSplitDataIterator(epoch={starting_epoch}, "
266
+ f"split={split_idx}) blocked waiting on other clients "
267
+ f"for more than {BLOCKED_CLIENT_WARN_TIMEOUT}s. All "
268
+ "clients must read from the DataIterator splits at "
269
+ "the same time. This warning will not be printed again "
270
+ "for this epoch."
271
+ )
272
+ time.sleep(0.1)
273
+
274
+ # Advance to the next epoch.
275
+ with self._lock:
276
+ if self._cur_epoch == starting_epoch:
277
+ self._cur_epoch += 1
278
+ self._unfinished_clients_in_epoch = self._n
279
+ try:
280
+ self._output_iterator = next(self._next_epoch)
281
+ except Exception as e:
282
+ self._gen_epoch_error = e
283
+
284
+ if self._gen_epoch_error is not None:
285
+ # If there was an error when advancing to the next epoch,
286
+ # re-raise it for all threads.
287
+ raise self._gen_epoch_error
288
+
289
+ assert self._output_iterator is not None
290
+ return starting_epoch + 1
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file