Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/aggregate.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/arrow_block.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/memory_tracing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/null_aggregate.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/numpy_support.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/pandas_block.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/progress_bar.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/split.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/audio_datasource.py +57 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/avro_datasource.py +42 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasink.py +129 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasource.py +156 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/binary_datasource.py +24 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/clickhouse_datasource.py +349 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/csv_datasink.py +36 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/databricks_uc_datasource.py +187 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/delta_sharing_datasource.py +126 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/hudi_datasource.py +87 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/huggingface_datasource.py +176 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/iceberg_datasource.py +261 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasink.py +24 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasource.py +175 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasink.py +36 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasource.py +154 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/lance_datasource.py +129 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasink.py +48 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasource.py +140 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasink.py +23 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasource.py +41 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_bulk_datasource.py +51 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasink.py +172 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasource.py +731 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/range_datasource.py +139 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/sql_datasink.py +35 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/text_datasource.py +42 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasink.py +205 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasource.py +434 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/torch_datasource.py +62 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/video_datasource.py +59 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasink.py +53 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasource.py +385 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/iterator_impl.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/stream_split_iterator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/iterator_impl.py +41 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/iterator/stream_split_iterator.py +290 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/logical/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/aggregate.cpython-311.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/arrow_block.cpython-311.pyc
ADDED
|
Binary file (34.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/memory_tracing.cpython-311.pyc
ADDED
|
Binary file (8.89 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/null_aggregate.cpython-311.pyc
ADDED
|
Binary file (9.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/numpy_support.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/pandas_block.cpython-311.pyc
ADDED
|
Binary file (38.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/progress_bar.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/split.cpython-311.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/audio_datasource.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from typing import TYPE_CHECKING, Iterator, List, Union
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 5 |
+
from ray.data._internal.util import _check_import
|
| 6 |
+
from ray.data.block import Block
|
| 7 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
import pyarrow
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AudioDatasource(FileBasedDatasource):
|
| 14 |
+
_FILE_EXTENSIONS = [
|
| 15 |
+
"mp3",
|
| 16 |
+
"wav",
|
| 17 |
+
"aac",
|
| 18 |
+
"flac",
|
| 19 |
+
"ogg",
|
| 20 |
+
"m4a",
|
| 21 |
+
"wma",
|
| 22 |
+
"alac",
|
| 23 |
+
"aiff",
|
| 24 |
+
"pcm",
|
| 25 |
+
"amr",
|
| 26 |
+
"opus",
|
| 27 |
+
"ra",
|
| 28 |
+
"rm",
|
| 29 |
+
"au",
|
| 30 |
+
"mid",
|
| 31 |
+
"midi",
|
| 32 |
+
"caf",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
paths: Union[str, List[str]],
|
| 38 |
+
**file_based_datasource_kwargs,
|
| 39 |
+
):
|
| 40 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 41 |
+
|
| 42 |
+
_check_import(self, module="soundfile", package="soundfile")
|
| 43 |
+
|
| 44 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
|
| 45 |
+
import soundfile
|
| 46 |
+
|
| 47 |
+
# `soundfile` doesn't support reading from a `pyarrow.NativeFile` directly, so
|
| 48 |
+
# we need to read the file into memory first.
|
| 49 |
+
stream = io.BytesIO(f.read())
|
| 50 |
+
amplitude, sample_rate = soundfile.read(stream, always_2d=True, dtype="float32")
|
| 51 |
+
|
| 52 |
+
# (amplitude, channels) -> (channels, amplitude)
|
| 53 |
+
amplitude = amplitude.transpose((1, 0))
|
| 54 |
+
|
| 55 |
+
builder = DelegatingBlockBuilder()
|
| 56 |
+
builder.add({"amplitude": amplitude, "sample_rate": sample_rate})
|
| 57 |
+
yield builder.build()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/avro_datasource.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Iterator, List, Union
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.output_buffer import BlockOutputBuffer
|
| 4 |
+
from ray.data._internal.util import _check_import
|
| 5 |
+
from ray.data.block import Block
|
| 6 |
+
from ray.data.context import DataContext
|
| 7 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
import pyarrow
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AvroDatasource(FileBasedDatasource):
|
| 14 |
+
"""A datasource that reads Avro files."""
|
| 15 |
+
|
| 16 |
+
_FILE_EXTENSIONS = ["avro"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
paths: Union[str, List[str]],
|
| 21 |
+
**file_based_datasource_kwargs,
|
| 22 |
+
):
|
| 23 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 24 |
+
|
| 25 |
+
_check_import(self, module="fastavro", package="fastavro")
|
| 26 |
+
|
| 27 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
|
| 28 |
+
import fastavro
|
| 29 |
+
|
| 30 |
+
# Read the Avro file. This assumes the Avro file includes its schema.
|
| 31 |
+
reader = fastavro.reader(f)
|
| 32 |
+
|
| 33 |
+
ctx = DataContext.get_current()
|
| 34 |
+
output_buffer = BlockOutputBuffer(ctx.target_max_block_size)
|
| 35 |
+
for record in reader:
|
| 36 |
+
output_buffer.add(record)
|
| 37 |
+
while output_buffer.has_next():
|
| 38 |
+
yield output_buffer.next()
|
| 39 |
+
|
| 40 |
+
output_buffer.finalize()
|
| 41 |
+
while output_buffer.has_next():
|
| 42 |
+
yield output_buffer.next()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasink.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
import time
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import Iterable, Optional
|
| 7 |
+
|
| 8 |
+
import pyarrow.parquet as pq
|
| 9 |
+
|
| 10 |
+
import ray
|
| 11 |
+
from ray.data._internal.execution.interfaces import TaskContext
|
| 12 |
+
from ray.data._internal.datasource import bigquery_datasource
|
| 13 |
+
from ray.data._internal.remote_fn import cached_remote_fn
|
| 14 |
+
from ray.data._internal.util import _check_import
|
| 15 |
+
from ray.data.block import Block, BlockAccessor
|
| 16 |
+
from ray.data.datasource.datasink import Datasink
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
DEFAULT_MAX_RETRY_CNT = 10
|
| 21 |
+
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BigQueryDatasink(Datasink[None]):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
project_id: str,
|
| 28 |
+
dataset: str,
|
| 29 |
+
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
|
| 30 |
+
overwrite_table: Optional[bool] = True,
|
| 31 |
+
) -> None:
|
| 32 |
+
_check_import(self, module="google.cloud", package="bigquery")
|
| 33 |
+
_check_import(self, module="google.cloud", package="bigquery_storage")
|
| 34 |
+
_check_import(self, module="google.api_core", package="exceptions")
|
| 35 |
+
|
| 36 |
+
self.project_id = project_id
|
| 37 |
+
self.dataset = dataset
|
| 38 |
+
self.max_retry_cnt = max_retry_cnt
|
| 39 |
+
self.overwrite_table = overwrite_table
|
| 40 |
+
|
| 41 |
+
def on_write_start(self) -> None:
|
| 42 |
+
from google.api_core import exceptions
|
| 43 |
+
|
| 44 |
+
if self.project_id is None or self.dataset is None:
|
| 45 |
+
raise ValueError("project_id and dataset are required args")
|
| 46 |
+
|
| 47 |
+
# Set up datasets to write
|
| 48 |
+
client = bigquery_datasource._create_client(project_id=self.project_id)
|
| 49 |
+
dataset_id = self.dataset.split(".", 1)[0]
|
| 50 |
+
try:
|
| 51 |
+
client.get_dataset(dataset_id)
|
| 52 |
+
except exceptions.NotFound:
|
| 53 |
+
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
|
| 54 |
+
logger.info("Created dataset " + dataset_id)
|
| 55 |
+
|
| 56 |
+
# Delete table if overwrite_table is True
|
| 57 |
+
if self.overwrite_table:
|
| 58 |
+
logger.info(
|
| 59 |
+
f"Attempting to delete table {self.dataset}"
|
| 60 |
+
+ " if it already exists since kwarg overwrite_table = True."
|
| 61 |
+
)
|
| 62 |
+
client.delete_table(f"{self.project_id}.{self.dataset}", not_found_ok=True)
|
| 63 |
+
else:
|
| 64 |
+
logger.info(
|
| 65 |
+
f"The write will append to table {self.dataset}"
|
| 66 |
+
+ " if it already exists since kwarg overwrite_table = False."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def write(
|
| 70 |
+
self,
|
| 71 |
+
blocks: Iterable[Block],
|
| 72 |
+
ctx: TaskContext,
|
| 73 |
+
) -> None:
|
| 74 |
+
def _write_single_block(block: Block, project_id: str, dataset: str) -> None:
|
| 75 |
+
from google.api_core import exceptions
|
| 76 |
+
from google.cloud import bigquery
|
| 77 |
+
|
| 78 |
+
block = BlockAccessor.for_block(block).to_arrow()
|
| 79 |
+
|
| 80 |
+
client = bigquery_datasource._create_client(project=project_id)
|
| 81 |
+
job_config = bigquery.LoadJobConfig(autodetect=True)
|
| 82 |
+
job_config.source_format = bigquery.SourceFormat.PARQUET
|
| 83 |
+
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
|
| 84 |
+
|
| 85 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 86 |
+
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
|
| 87 |
+
pq.write_table(block, fp, compression="SNAPPY")
|
| 88 |
+
|
| 89 |
+
retry_cnt = 0
|
| 90 |
+
while retry_cnt <= self.max_retry_cnt:
|
| 91 |
+
with open(fp, "rb") as source_file:
|
| 92 |
+
job = client.load_table_from_file(
|
| 93 |
+
source_file, dataset, job_config=job_config
|
| 94 |
+
)
|
| 95 |
+
try:
|
| 96 |
+
logger.info(job.result())
|
| 97 |
+
break
|
| 98 |
+
except exceptions.Forbidden as e:
|
| 99 |
+
retry_cnt += 1
|
| 100 |
+
if retry_cnt > self.max_retry_cnt:
|
| 101 |
+
break
|
| 102 |
+
logger.info(
|
| 103 |
+
"A block write encountered a rate limit exceeded error"
|
| 104 |
+
+ f" {retry_cnt} time(s). Sleeping to try again."
|
| 105 |
+
)
|
| 106 |
+
logging.debug(e)
|
| 107 |
+
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
|
| 108 |
+
|
| 109 |
+
# Raise exception if retry_cnt exceeds max_retry_cnt
|
| 110 |
+
if retry_cnt > self.max_retry_cnt:
|
| 111 |
+
logger.info(
|
| 112 |
+
f"Maximum ({self.max_retry_cnt}) retry count exceeded. Ray"
|
| 113 |
+
+ " will attempt to retry the block write via fault tolerance."
|
| 114 |
+
)
|
| 115 |
+
raise RuntimeError(
|
| 116 |
+
f"Write failed due to {retry_cnt}"
|
| 117 |
+
+ " repeated API rate limit exceeded responses. Consider"
|
| 118 |
+
+ " specifiying the max_retry_cnt kwarg with a higher value."
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
_write_single_block = cached_remote_fn(_write_single_block)
|
| 122 |
+
|
| 123 |
+
# Launch a remote task for each block within this write task
|
| 124 |
+
ray.get(
|
| 125 |
+
[
|
| 126 |
+
_write_single_block.remote(block, self.project_id, self.dataset)
|
| 127 |
+
for block in blocks
|
| 128 |
+
]
|
| 129 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasource.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.util import _check_import
|
| 5 |
+
from ray.data.block import Block, BlockMetadata
|
| 6 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _create_user_agent() -> str:
|
| 12 |
+
import ray
|
| 13 |
+
|
| 14 |
+
return f"ray/{ray.__version__}"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _create_client_info():
|
| 18 |
+
from google.api_core.client_info import ClientInfo
|
| 19 |
+
|
| 20 |
+
return ClientInfo(
|
| 21 |
+
user_agent=_create_user_agent(),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _create_client_info_gapic():
|
| 26 |
+
from google.api_core.gapic_v1.client_info import ClientInfo
|
| 27 |
+
|
| 28 |
+
return ClientInfo(
|
| 29 |
+
user_agent=_create_user_agent(),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _create_client(project_id: str):
|
| 34 |
+
from google.cloud import bigquery
|
| 35 |
+
|
| 36 |
+
return bigquery.Client(
|
| 37 |
+
project=project_id,
|
| 38 |
+
client_info=_create_client_info(),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _create_read_client():
|
| 43 |
+
from google.cloud import bigquery_storage
|
| 44 |
+
|
| 45 |
+
return bigquery_storage.BigQueryReadClient(
|
| 46 |
+
client_info=_create_client_info_gapic(),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BigQueryDatasource(Datasource):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
project_id: str,
|
| 54 |
+
dataset: Optional[str] = None,
|
| 55 |
+
query: Optional[str] = None,
|
| 56 |
+
):
|
| 57 |
+
_check_import(self, module="google.cloud", package="bigquery")
|
| 58 |
+
_check_import(self, module="google.cloud", package="bigquery_storage")
|
| 59 |
+
_check_import(self, module="google.api_core", package="exceptions")
|
| 60 |
+
|
| 61 |
+
self._project_id = project_id
|
| 62 |
+
self._dataset = dataset
|
| 63 |
+
self._query = query
|
| 64 |
+
|
| 65 |
+
if query is not None and dataset is not None:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
"Query and dataset kwargs cannot both be provided "
|
| 68 |
+
+ "(must be mutually exclusive)."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 72 |
+
from google.cloud import bigquery_storage
|
| 73 |
+
|
| 74 |
+
def _read_single_partition(stream) -> Block:
|
| 75 |
+
client = _create_read_client()
|
| 76 |
+
reader = client.read_rows(stream.name)
|
| 77 |
+
return reader.to_arrow()
|
| 78 |
+
|
| 79 |
+
if self._query:
|
| 80 |
+
query_client = _create_client(project_id=self._project_id)
|
| 81 |
+
query_job = query_client.query(self._query)
|
| 82 |
+
query_job.result()
|
| 83 |
+
destination = str(query_job.destination)
|
| 84 |
+
dataset_id = destination.split(".")[-2]
|
| 85 |
+
table_id = destination.split(".")[-1]
|
| 86 |
+
else:
|
| 87 |
+
self._validate_dataset_table_exist(self._project_id, self._dataset)
|
| 88 |
+
dataset_id = self._dataset.split(".")[0]
|
| 89 |
+
table_id = self._dataset.split(".")[1]
|
| 90 |
+
|
| 91 |
+
bqs_client = _create_read_client()
|
| 92 |
+
table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}"
|
| 93 |
+
|
| 94 |
+
if parallelism == -1:
|
| 95 |
+
parallelism = None
|
| 96 |
+
requested_session = bigquery_storage.types.ReadSession(
|
| 97 |
+
table=table,
|
| 98 |
+
data_format=bigquery_storage.types.DataFormat.ARROW,
|
| 99 |
+
)
|
| 100 |
+
read_session = bqs_client.create_read_session(
|
| 101 |
+
parent=f"projects/{self._project_id}",
|
| 102 |
+
read_session=requested_session,
|
| 103 |
+
max_stream_count=parallelism,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
read_tasks = []
|
| 107 |
+
logger.info("Created streams: " + str(len(read_session.streams)))
|
| 108 |
+
if len(read_session.streams) < parallelism:
|
| 109 |
+
logger.info(
|
| 110 |
+
"The number of streams created by the "
|
| 111 |
+
+ "BigQuery Storage Read API is less than the requested "
|
| 112 |
+
+ "parallelism due to the size of the dataset."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
for stream in read_session.streams:
|
| 116 |
+
# Create a metadata block object to store schema, etc.
|
| 117 |
+
metadata = BlockMetadata(
|
| 118 |
+
num_rows=None,
|
| 119 |
+
size_bytes=None,
|
| 120 |
+
schema=None,
|
| 121 |
+
input_files=None,
|
| 122 |
+
exec_stats=None,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Create the read task and pass the no-arg wrapper and metadata in
|
| 126 |
+
read_task = ReadTask(
|
| 127 |
+
lambda stream=stream: [_read_single_partition(stream)],
|
| 128 |
+
metadata,
|
| 129 |
+
)
|
| 130 |
+
read_tasks.append(read_task)
|
| 131 |
+
|
| 132 |
+
return read_tasks
|
| 133 |
+
|
| 134 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None:
|
| 138 |
+
from google.api_core import exceptions
|
| 139 |
+
|
| 140 |
+
client = _create_client(project_id=project_id)
|
| 141 |
+
dataset_id = dataset.split(".")[0]
|
| 142 |
+
try:
|
| 143 |
+
client.get_dataset(dataset_id)
|
| 144 |
+
except exceptions.NotFound:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
"Dataset {} is not found. Please ensure that it exists.".format(
|
| 147 |
+
dataset_id
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
client.get_table(dataset)
|
| 153 |
+
except exceptions.NotFound:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
"Table {} is not found. Please ensure that it exists.".format(dataset)
|
| 156 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/binary_datasource.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.arrow_block import ArrowBlockBuilder
|
| 4 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 5 |
+
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
import pyarrow
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BinaryDatasource(FileBasedDatasource):
|
| 11 |
+
"""Binary datasource, for reading and writing binary files."""
|
| 12 |
+
|
| 13 |
+
_COLUMN_NAME = "bytes"
|
| 14 |
+
|
| 15 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str):
|
| 16 |
+
data = f.readall()
|
| 17 |
+
|
| 18 |
+
builder = ArrowBlockBuilder()
|
| 19 |
+
item = {self._COLUMN_NAME: data}
|
| 20 |
+
builder.add(item)
|
| 21 |
+
yield builder.build()
|
| 22 |
+
|
| 23 |
+
def _rows_per_file(self):
|
| 24 |
+
return 1
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/clickhouse_datasource.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from ray.data._internal.util import _check_import
|
| 6 |
+
from ray.data.block import Block, BlockAccessor, BlockMetadata
|
| 7 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 8 |
+
from ray.util.annotations import DeveloperAPI
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _is_filter_string_safe(filter_str: str) -> bool:
|
| 14 |
+
in_string = False
|
| 15 |
+
escape_next = False
|
| 16 |
+
for c in filter_str:
|
| 17 |
+
if in_string:
|
| 18 |
+
# If we're inside a string, check if we're closing it.
|
| 19 |
+
if c == "'" and not escape_next:
|
| 20 |
+
in_string = False
|
| 21 |
+
escape_next = (c == "\\") and not escape_next
|
| 22 |
+
else:
|
| 23 |
+
# If we're not in a string, entering one if we see a single quote
|
| 24 |
+
if c == "'":
|
| 25 |
+
in_string = True
|
| 26 |
+
escape_next = False
|
| 27 |
+
# Disallow semicolon if we're not in a string
|
| 28 |
+
elif c == ";":
|
| 29 |
+
return False
|
| 30 |
+
else:
|
| 31 |
+
escape_next = False
|
| 32 |
+
# If we end inside a string, it's suspicious, but let's allow
|
| 33 |
+
# it to be further validated by the DB. Just return True here.
|
| 34 |
+
return True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@DeveloperAPI
|
| 38 |
+
class ClickHouseDatasource(Datasource):
|
| 39 |
+
"""
|
| 40 |
+
A Ray datasource for reading from ClickHouse.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
table: Fully qualified table or view identifier (e.g.,
|
| 44 |
+
"default.table_name").
|
| 45 |
+
dsn: A string in DSN (Data Source Name) HTTP format (e.g.,
|
| 46 |
+
"clickhouse+http://username:password@host:8124/default").
|
| 47 |
+
For more information, see `ClickHouse Connection String doc
|
| 48 |
+
<https://clickhouse.com/docs/en/integrations/sql-clients/cli#connection_string>`_.
|
| 49 |
+
columns: Optional List of columns to select from the data source.
|
| 50 |
+
If no columns are specified, all columns will be selected by default.
|
| 51 |
+
filter: Optional SQL filter string that will be used in the
|
| 52 |
+
WHERE statement (e.g., "label = 2 AND text IS NOT NULL").
|
| 53 |
+
The filter must be valid for use in a ClickHouse SQL WHERE clause.
|
| 54 |
+
Note: Parallel reads are not currently supported when a filter is set.
|
| 55 |
+
Specifying a filter forces the parallelism to 1 to ensure deterministic
|
| 56 |
+
and consistent results. For more information, see
|
| 57 |
+
`ClickHouse SQL WHERE Clause doc
|
| 58 |
+
<https://clickhouse.com/docs/en/sql-reference/statements/select/where>`_.
|
| 59 |
+
order_by: Optional Tuple containing a list of columns to order by
|
| 60 |
+
and a boolean indicating the order. Note: order_by is required to
|
| 61 |
+
support parallelism.
|
| 62 |
+
client_settings: Optional ClickHouse server settings to be used with the
|
| 63 |
+
session/every request. For more information, see
|
| 64 |
+
`ClickHouse Client Settings doc
|
| 65 |
+
<https://clickhouse.com/docs/en/integrations/python#settings-argument>`_.
|
| 66 |
+
client_kwargs: Optional Additional keyword arguments to pass to the
|
| 67 |
+
ClickHouse client. For more information,
|
| 68 |
+
see `ClickHouse Core Settings doc
|
| 69 |
+
<https://clickhouse.com/docs/en/integrations/python#additional-options>`_.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
NUM_SAMPLE_ROWS = 100
|
| 73 |
+
MIN_ROWS_PER_READ_TASK = 50
|
| 74 |
+
_BASE_QUERY = "SELECT {select_clause} FROM {table}"
|
| 75 |
+
_EXPLAIN_FILTERS_QUERY = "EXPLAIN SELECT 1 FROM {table} WHERE {filter_clause}"
|
| 76 |
+
_SIZE_ESTIMATE_QUERY = "SELECT SUM(byteSize(*)) AS estimate FROM ({query})"
|
| 77 |
+
_COUNT_ESTIMATE_QUERY = "SELECT COUNT(*) AS estimate FROM ({query})"
|
| 78 |
+
_SAMPLE_BLOCK_QUERY = "{query} LIMIT {limit_row_count}"
|
| 79 |
+
_FIRST_BLOCK_QUERY = """
|
| 80 |
+
{query}
|
| 81 |
+
FETCH FIRST {fetch_row_count} {fetch_row_or_rows} ONLY
|
| 82 |
+
"""
|
| 83 |
+
_NEXT_BLOCK_QUERY = """
|
| 84 |
+
{query}
|
| 85 |
+
OFFSET {offset_row_count} {offset_row_or_rows}
|
| 86 |
+
FETCH NEXT {fetch_row_count} {fetch_row_or_rows} ONLY
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
table: str,
|
| 92 |
+
dsn: str,
|
| 93 |
+
columns: Optional[List[str]] = None,
|
| 94 |
+
filter: Optional[str] = None,
|
| 95 |
+
order_by: Optional[Tuple[List[str], bool]] = None,
|
| 96 |
+
client_settings: Optional[Dict[str, Any]] = None,
|
| 97 |
+
client_kwargs: Optional[Dict[str, Any]] = None,
|
| 98 |
+
):
|
| 99 |
+
self._table = table
|
| 100 |
+
self._dsn = dsn
|
| 101 |
+
self._columns = columns
|
| 102 |
+
self._filter = filter
|
| 103 |
+
self._order_by = order_by
|
| 104 |
+
self._client_settings = client_settings or {}
|
| 105 |
+
self._client_kwargs = client_kwargs or {}
|
| 106 |
+
self._query = self._generate_query()
|
| 107 |
+
|
| 108 |
+
def _init_client(self):
|
| 109 |
+
_check_import(self, module="clickhouse_connect", package="clickhouse-connect")
|
| 110 |
+
import clickhouse_connect
|
| 111 |
+
|
| 112 |
+
return clickhouse_connect.get_client(
|
| 113 |
+
dsn=self._dsn,
|
| 114 |
+
settings=self._client_settings or {},
|
| 115 |
+
**self._client_kwargs or {},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _validate_filter(self):
|
| 119 |
+
if not self._filter:
|
| 120 |
+
return
|
| 121 |
+
# Minimal lexical check (regex or manual approach for semicolons, etc.).
|
| 122 |
+
if not _is_filter_string_safe(self._filter):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Invalid characters outside of "
|
| 125 |
+
f"string literals in filter: {self._filter}"
|
| 126 |
+
)
|
| 127 |
+
# Test "EXPLAIN" query to confirm parse-ability and catch expression errors.
|
| 128 |
+
client = self._init_client()
|
| 129 |
+
try:
|
| 130 |
+
test_query = self._EXPLAIN_FILTERS_QUERY.format(
|
| 131 |
+
table=self._table,
|
| 132 |
+
filter_clause=self._filter,
|
| 133 |
+
)
|
| 134 |
+
client.query(test_query)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"Invalid filter expression: {self._filter}. Error: {e}",
|
| 138 |
+
)
|
| 139 |
+
finally:
|
| 140 |
+
client.close()
|
| 141 |
+
|
| 142 |
+
def _generate_query(self) -> str:
|
| 143 |
+
query = self._BASE_QUERY.format(
|
| 144 |
+
select_clause=", ".join(self._columns) if self._columns else "*",
|
| 145 |
+
table=self._table,
|
| 146 |
+
)
|
| 147 |
+
if self._filter:
|
| 148 |
+
self._validate_filter()
|
| 149 |
+
query += f" WHERE {self._filter}"
|
| 150 |
+
if self._order_by:
|
| 151 |
+
columns, desc = self._order_by
|
| 152 |
+
direction = " DESC" if desc else ""
|
| 153 |
+
if len(columns) == 1:
|
| 154 |
+
query += f" ORDER BY {columns[0]}{direction}"
|
| 155 |
+
elif len(columns) > 1:
|
| 156 |
+
columns_clause = ", ".join(columns)
|
| 157 |
+
query += f" ORDER BY ({columns_clause}){direction}"
|
| 158 |
+
return query
|
| 159 |
+
|
| 160 |
+
def _build_block_query(self, limit_row_count: int, offset_row_count: int) -> str:
|
| 161 |
+
if offset_row_count == 0:
|
| 162 |
+
# The first block query is optimized to use FETCH FIRST clause
|
| 163 |
+
# with an OFFSET specified.
|
| 164 |
+
return self._FIRST_BLOCK_QUERY.format(
|
| 165 |
+
query=self._query,
|
| 166 |
+
fetch_row_count=limit_row_count,
|
| 167 |
+
fetch_row_or_rows="ROWS" if limit_row_count > 1 else "ROW",
|
| 168 |
+
)
|
| 169 |
+
# Subsequent block queries use OFFSET and FETCH NEXT clauses to read the
|
| 170 |
+
# next block of data.
|
| 171 |
+
return self._NEXT_BLOCK_QUERY.format(
|
| 172 |
+
query=self._query,
|
| 173 |
+
offset_row_count=offset_row_count,
|
| 174 |
+
offset_row_or_rows="ROWS" if offset_row_count > 1 else "ROW",
|
| 175 |
+
fetch_row_count=limit_row_count,
|
| 176 |
+
fetch_row_or_rows="ROWS" if limit_row_count > 1 else "ROW",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _create_read_fn(
|
| 180 |
+
self,
|
| 181 |
+
query: str,
|
| 182 |
+
) -> Callable[[], Iterable[Block]]:
|
| 183 |
+
def read_fn() -> Iterable[Block]:
|
| 184 |
+
return [self._execute_block_query(query)]
|
| 185 |
+
|
| 186 |
+
return read_fn
|
| 187 |
+
|
| 188 |
+
def _get_sampled_estimates(self):
|
| 189 |
+
if self._order_by is not None:
|
| 190 |
+
# If the query is ordered, we can use a FETCH clause to get a sample.
|
| 191 |
+
# This reduces the CPU overhead on ClickHouse and speeds up the
|
| 192 |
+
# estimation query.
|
| 193 |
+
query = self._FIRST_BLOCK_QUERY.format(
|
| 194 |
+
query=self._query,
|
| 195 |
+
fetch_row_count=self.NUM_SAMPLE_ROWS,
|
| 196 |
+
fetch_row_or_rows="ROWS" if self.NUM_SAMPLE_ROWS > 1 else "ROW",
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
# If the query is not ordered, we need to use a LIMIT clause to
|
| 200 |
+
# get a sample.
|
| 201 |
+
query = self._SAMPLE_BLOCK_QUERY.format(
|
| 202 |
+
query=self._query,
|
| 203 |
+
limit_row_count=self.NUM_SAMPLE_ROWS,
|
| 204 |
+
)
|
| 205 |
+
sample_block_accessor = BlockAccessor.for_block(
|
| 206 |
+
self._execute_block_query(query)
|
| 207 |
+
)
|
| 208 |
+
estimated_size_bytes_per_row = math.ceil(
|
| 209 |
+
sample_block_accessor.size_bytes() / sample_block_accessor.num_rows()
|
| 210 |
+
)
|
| 211 |
+
sample_block_schema = sample_block_accessor.schema()
|
| 212 |
+
return estimated_size_bytes_per_row, sample_block_schema
|
| 213 |
+
|
| 214 |
+
def _get_estimate_count(self) -> Optional[int]:
|
| 215 |
+
return self._execute_estimate_query(self._COUNT_ESTIMATE_QUERY)
|
| 216 |
+
|
| 217 |
+
def _get_estimate_size(self) -> Optional[int]:
|
| 218 |
+
return self._execute_estimate_query(self._SIZE_ESTIMATE_QUERY)
|
| 219 |
+
|
| 220 |
+
def _execute_estimate_query(self, estimate_query: str) -> Optional[int]:
|
| 221 |
+
client = self._init_client()
|
| 222 |
+
try:
|
| 223 |
+
# Estimate queries wrap around the primary query, self._query.
|
| 224 |
+
# This allows us to use self._query as a sub-query to efficiently
|
| 225 |
+
# and accurately estimate the size or count of the result set.
|
| 226 |
+
query = estimate_query.format(query=self._query)
|
| 227 |
+
result = client.query(query)
|
| 228 |
+
if result and len(result.result_rows) > 0:
|
| 229 |
+
estimate = result.result_rows[0][0]
|
| 230 |
+
return int(estimate) if estimate is not None else None
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.warning(f"Failed to execute estimate query: {e}")
|
| 233 |
+
finally:
|
| 234 |
+
client.close()
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
def _execute_block_query(self, query: str) -> Block:
|
| 238 |
+
import pyarrow as pa
|
| 239 |
+
|
| 240 |
+
client = self._init_client()
|
| 241 |
+
try:
|
| 242 |
+
with client.query_arrow_stream(query) as stream:
|
| 243 |
+
record_batches = list(stream) # Collect all record batches
|
| 244 |
+
return pa.Table.from_batches(record_batches)
|
| 245 |
+
except Exception as e:
|
| 246 |
+
raise RuntimeError(f"Failed to execute block query: {e}")
|
| 247 |
+
finally:
|
| 248 |
+
client.close()
|
| 249 |
+
|
| 250 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 251 |
+
"""
|
| 252 |
+
Estimate the in-memory data size for the query.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Estimated in-memory data size in bytes, or
|
| 256 |
+
None if the estimation cannot be performed.
|
| 257 |
+
"""
|
| 258 |
+
return self._get_estimate_size()
|
| 259 |
+
|
| 260 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 261 |
+
"""
|
| 262 |
+
Create read tasks for the ClickHouse query.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
parallelism: The desired number of partitions to read the data into.
|
| 266 |
+
- If ``order_by`` is not set, parallelism will be forced to 1.
|
| 267 |
+
- If ``filter`` is set, parallelism will also be forced to 1
|
| 268 |
+
to ensure deterministic results.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
A list of read tasks to be executed.
|
| 272 |
+
"""
|
| 273 |
+
num_rows_total = self._get_estimate_count()
|
| 274 |
+
if num_rows_total == 0 or num_rows_total is None:
|
| 275 |
+
return []
|
| 276 |
+
parallelism = min(
|
| 277 |
+
parallelism, math.ceil(num_rows_total / self.MIN_ROWS_PER_READ_TASK)
|
| 278 |
+
)
|
| 279 |
+
# To ensure consistent order of query results, self._order_by
|
| 280 |
+
# must be specified and self.filter must not be specified
|
| 281 |
+
# in order to support parallelism.
|
| 282 |
+
if self._filter is not None and parallelism > 1:
|
| 283 |
+
logger.warning(
|
| 284 |
+
"ClickHouse datasource does not currently support parallel reads "
|
| 285 |
+
"when a filter is set; falling back to parallelism of 1."
|
| 286 |
+
)
|
| 287 |
+
# When filter is specified and parallelism is greater than 1,
|
| 288 |
+
# we need to reduce parallelism to 1 to ensure consistent results.
|
| 289 |
+
parallelism = 1
|
| 290 |
+
# To ensure consistent order of query results, self._order_by
|
| 291 |
+
# must be specified in order to support parallelism.
|
| 292 |
+
if self._order_by is None and parallelism > 1:
|
| 293 |
+
logger.warning(
|
| 294 |
+
"ClickHouse datasource requires dataset to be explicitly ordered "
|
| 295 |
+
"to support parallelism; falling back to parallelism of 1."
|
| 296 |
+
)
|
| 297 |
+
# When order_by is not specified and parallelism is greater than 1,
|
| 298 |
+
# we need to reduce parallelism to 1 to ensure consistent results.
|
| 299 |
+
parallelism = 1
|
| 300 |
+
# By reducing parallelism to 1 when either of the conditions above are met,
|
| 301 |
+
# we ensure the downstream process is treated exactly as a non-parallelized
|
| 302 |
+
# (single block) process would be, thus ensuring output consistency.
|
| 303 |
+
num_rows_per_block = num_rows_total // parallelism
|
| 304 |
+
num_blocks_with_extra_row = num_rows_total % parallelism
|
| 305 |
+
(
|
| 306 |
+
estimated_size_bytes_per_row,
|
| 307 |
+
sample_block_schema,
|
| 308 |
+
) = self._get_sampled_estimates()
|
| 309 |
+
|
| 310 |
+
def _get_read_task(
|
| 311 |
+
block_rows: int, offset_rows: int, parallelized: bool
|
| 312 |
+
) -> ReadTask:
|
| 313 |
+
if parallelized:
|
| 314 |
+
# When parallelized, we need to build a block query with OFFSET
|
| 315 |
+
# and FETCH clauses.
|
| 316 |
+
query = self._build_block_query(block_rows, offset_rows)
|
| 317 |
+
else:
|
| 318 |
+
# When not parallelized, we can use the original query without
|
| 319 |
+
# OFFSET and FETCH clauses.
|
| 320 |
+
query = self._query
|
| 321 |
+
return ReadTask(
|
| 322 |
+
self._create_read_fn(query),
|
| 323 |
+
BlockMetadata(
|
| 324 |
+
num_rows=block_rows,
|
| 325 |
+
size_bytes=estimated_size_bytes_per_row * block_rows,
|
| 326 |
+
schema=sample_block_schema,
|
| 327 |
+
input_files=None,
|
| 328 |
+
exec_stats=None,
|
| 329 |
+
),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if parallelism == 1:
|
| 333 |
+
# When parallelism is 1, we can read the entire dataset in a single task.
|
| 334 |
+
# We then optimize this scenario by using self._query directly without
|
| 335 |
+
# unnecessary OFFSET and FETCH clauses.
|
| 336 |
+
return [_get_read_task(num_rows_total, 0, False)]
|
| 337 |
+
|
| 338 |
+
# Otherwise we need to split the dataset into multiple tasks.
|
| 339 |
+
# Each task will include OFFSET and FETCH clauses to efficiently
|
| 340 |
+
# read a subset of the dataset.
|
| 341 |
+
read_tasks = []
|
| 342 |
+
offset = 0
|
| 343 |
+
for i in range(parallelism):
|
| 344 |
+
this_block_size = num_rows_per_block
|
| 345 |
+
if i < num_blocks_with_extra_row:
|
| 346 |
+
this_block_size += 1
|
| 347 |
+
read_tasks.append(_get_read_task(this_block_size, offset, True))
|
| 348 |
+
offset += this_block_size
|
| 349 |
+
return read_tasks
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/csv_datasink.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import pyarrow
|
| 4 |
+
|
| 5 |
+
from ray.data.block import BlockAccessor
|
| 6 |
+
from ray.data.datasource.file_based_datasource import _resolve_kwargs
|
| 7 |
+
from ray.data.datasource.file_datasink import BlockBasedFileDatasink
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CSVDatasink(BlockBasedFileDatasink):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
path: str,
|
| 14 |
+
*,
|
| 15 |
+
arrow_csv_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
|
| 16 |
+
arrow_csv_args: Optional[Dict[str, Any]] = None,
|
| 17 |
+
file_format="csv",
|
| 18 |
+
**file_datasink_kwargs,
|
| 19 |
+
):
|
| 20 |
+
super().__init__(path, file_format=file_format, **file_datasink_kwargs)
|
| 21 |
+
|
| 22 |
+
if arrow_csv_args_fn is None:
|
| 23 |
+
arrow_csv_args_fn = lambda: {} # noqa: E731
|
| 24 |
+
|
| 25 |
+
if arrow_csv_args is None:
|
| 26 |
+
arrow_csv_args = {}
|
| 27 |
+
|
| 28 |
+
self.arrow_csv_args_fn = arrow_csv_args_fn
|
| 29 |
+
self.arrow_csv_args = arrow_csv_args
|
| 30 |
+
|
| 31 |
+
def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
|
| 32 |
+
from pyarrow import csv
|
| 33 |
+
|
| 34 |
+
writer_args = _resolve_kwargs(self.arrow_csv_args_fn, **self.arrow_csv_args)
|
| 35 |
+
write_options = writer_args.pop("write_options", None)
|
| 36 |
+
csv.write_csv(block.to_arrow(), file, write_options, **writer_args)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/databricks_uc_datasource.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
from urllib.parse import urljoin
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pyarrow
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
from ray.data.block import BlockMetadata
|
| 13 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 14 |
+
from ray.util.annotations import PublicAPI
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_STATEMENT_EXEC_POLL_TIME_S = 1
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@PublicAPI(stability="alpha")
|
| 23 |
+
class DatabricksUCDatasource(Datasource):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
host: str,
|
| 27 |
+
token: str,
|
| 28 |
+
warehouse_id: str,
|
| 29 |
+
catalog: str,
|
| 30 |
+
schema: str,
|
| 31 |
+
query: str,
|
| 32 |
+
):
|
| 33 |
+
self.host = host
|
| 34 |
+
self.token = token
|
| 35 |
+
self.warehouse_id = warehouse_id
|
| 36 |
+
self.catalog = catalog
|
| 37 |
+
self.schema = schema
|
| 38 |
+
self.query = query
|
| 39 |
+
|
| 40 |
+
url_base = f"https://{self.host}/api/2.0/sql/statements/"
|
| 41 |
+
|
| 42 |
+
payload = json.dumps(
|
| 43 |
+
{
|
| 44 |
+
"statement": self.query,
|
| 45 |
+
"warehouse_id": self.warehouse_id,
|
| 46 |
+
"wait_timeout": "0s",
|
| 47 |
+
"disposition": "EXTERNAL_LINKS",
|
| 48 |
+
"format": "ARROW_STREAM",
|
| 49 |
+
"catalog": self.catalog,
|
| 50 |
+
"schema": self.schema,
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
req_headers = {
|
| 55 |
+
"Content-Type": "application/json",
|
| 56 |
+
"Authorization": "Bearer " + self.token,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
response = requests.post(
|
| 60 |
+
url_base,
|
| 61 |
+
headers=req_headers,
|
| 62 |
+
data=payload,
|
| 63 |
+
)
|
| 64 |
+
response.raise_for_status()
|
| 65 |
+
statement_id = response.json()["statement_id"]
|
| 66 |
+
|
| 67 |
+
state = response.json()["status"]["state"]
|
| 68 |
+
|
| 69 |
+
logger.info(f"Waiting for query {query!r} execution result.")
|
| 70 |
+
try:
|
| 71 |
+
while state in ["PENDING", "RUNNING"]:
|
| 72 |
+
time.sleep(_STATEMENT_EXEC_POLL_TIME_S)
|
| 73 |
+
response = requests.get(
|
| 74 |
+
urljoin(url_base, statement_id) + "/",
|
| 75 |
+
headers=req_headers,
|
| 76 |
+
)
|
| 77 |
+
response.raise_for_status()
|
| 78 |
+
state = response.json()["status"]["state"]
|
| 79 |
+
except KeyboardInterrupt:
|
| 80 |
+
# User cancel the command, so we cancel query execution.
|
| 81 |
+
requests.post(
|
| 82 |
+
urljoin(url_base, f"{statement_id}/cancel"),
|
| 83 |
+
headers=req_headers,
|
| 84 |
+
)
|
| 85 |
+
try:
|
| 86 |
+
response.raise_for_status()
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.warning(
|
| 89 |
+
f"Canceling query {query!r} execution failed, reason: {repr(e)}."
|
| 90 |
+
)
|
| 91 |
+
raise
|
| 92 |
+
|
| 93 |
+
if state != "SUCCEEDED":
|
| 94 |
+
raise RuntimeError(f"Query {self.query!r} execution failed.")
|
| 95 |
+
|
| 96 |
+
manifest = response.json()["manifest"]
|
| 97 |
+
is_truncated = manifest["truncated"]
|
| 98 |
+
|
| 99 |
+
if is_truncated:
|
| 100 |
+
logger.warning(
|
| 101 |
+
f"The resulting size of the dataset of '{query!r}' exceeds "
|
| 102 |
+
"100GiB and it is truncated."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
chunks = manifest["chunks"]
|
| 106 |
+
|
| 107 |
+
# Make chunks metadata are ordered by index.
|
| 108 |
+
chunks = sorted(chunks, key=lambda x: x["chunk_index"])
|
| 109 |
+
num_chunks = len(chunks)
|
| 110 |
+
self.num_chunks = num_chunks
|
| 111 |
+
self._estimate_inmemory_data_size = sum(chunk["byte_count"] for chunk in chunks)
|
| 112 |
+
|
| 113 |
+
def get_read_task(task_index, parallelism):
|
| 114 |
+
# get chunk list to be read in this task and preserve original chunk order
|
| 115 |
+
chunk_index_list = list(
|
| 116 |
+
np.array_split(range(num_chunks), parallelism)[task_index]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
num_rows = sum(
|
| 120 |
+
chunks[chunk_index]["row_count"] for chunk_index in chunk_index_list
|
| 121 |
+
)
|
| 122 |
+
size_bytes = sum(
|
| 123 |
+
chunks[chunk_index]["byte_count"] for chunk_index in chunk_index_list
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
metadata = BlockMetadata(
|
| 127 |
+
num_rows=num_rows,
|
| 128 |
+
size_bytes=size_bytes,
|
| 129 |
+
schema=None,
|
| 130 |
+
input_files=None,
|
| 131 |
+
exec_stats=None,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def _read_fn():
|
| 135 |
+
for chunk_index in chunk_index_list:
|
| 136 |
+
resolve_external_link_url = urljoin(
|
| 137 |
+
url_base, f"{statement_id}/result/chunks/{chunk_index}"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
resolve_response = requests.get(
|
| 141 |
+
resolve_external_link_url, headers=req_headers
|
| 142 |
+
)
|
| 143 |
+
resolve_response.raise_for_status()
|
| 144 |
+
external_url = resolve_response.json()["external_links"][0][
|
| 145 |
+
"external_link"
|
| 146 |
+
]
|
| 147 |
+
# NOTE: do _NOT_ send the authorization header to external urls
|
| 148 |
+
raw_response = requests.get(external_url, auth=None, headers=None)
|
| 149 |
+
raw_response.raise_for_status()
|
| 150 |
+
|
| 151 |
+
with pyarrow.ipc.open_stream(raw_response.content) as reader:
|
| 152 |
+
arrow_table = reader.read_all()
|
| 153 |
+
|
| 154 |
+
yield arrow_table
|
| 155 |
+
|
| 156 |
+
def read_fn():
|
| 157 |
+
if mock_setup_fn_path := os.environ.get(
|
| 158 |
+
"RAY_DATABRICKS_UC_DATASOURCE_READ_FN_MOCK_TEST_SETUP_FN_PATH"
|
| 159 |
+
):
|
| 160 |
+
import ray.cloudpickle as pickle
|
| 161 |
+
|
| 162 |
+
# This is for testing.
|
| 163 |
+
with open(mock_setup_fn_path, "rb") as f:
|
| 164 |
+
mock_setup = pickle.load(f)
|
| 165 |
+
with mock_setup():
|
| 166 |
+
yield from _read_fn()
|
| 167 |
+
else:
|
| 168 |
+
yield from _read_fn()
|
| 169 |
+
|
| 170 |
+
return ReadTask(read_fn=read_fn, metadata=metadata)
|
| 171 |
+
|
| 172 |
+
self._get_read_task = get_read_task
|
| 173 |
+
|
| 174 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 175 |
+
return self._estimate_inmemory_data_size
|
| 176 |
+
|
| 177 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 178 |
+
assert parallelism > 0, f"Invalid parallelism {parallelism}"
|
| 179 |
+
|
| 180 |
+
if parallelism > self.num_chunks:
|
| 181 |
+
parallelism = self.num_chunks
|
| 182 |
+
logger.info(
|
| 183 |
+
"The parallelism is reduced to chunk number due to "
|
| 184 |
+
"insufficient chunk parallelism."
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return [self._get_read_task(index, parallelism) for index in range(parallelism)]
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/delta_sharing_datasource.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from json import loads
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from ray.data._internal.util import _check_import
|
| 8 |
+
from ray.data.block import BlockMetadata
|
| 9 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DeltaSharingDatasource(Datasource):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
url: str,
|
| 18 |
+
json_predicate_hints: Optional[str] = None,
|
| 19 |
+
limit: Optional[int] = None,
|
| 20 |
+
version: Optional[int] = None,
|
| 21 |
+
timestamp: Optional[str] = None,
|
| 22 |
+
):
|
| 23 |
+
_check_import(self, module="delta_sharing", package="delta-sharing")
|
| 24 |
+
|
| 25 |
+
if limit is not None:
|
| 26 |
+
assert (
|
| 27 |
+
isinstance(limit, int) and limit >= 0
|
| 28 |
+
), "'limit' must be a non-negative int"
|
| 29 |
+
|
| 30 |
+
self._url = url
|
| 31 |
+
self._json_predicate_hints = json_predicate_hints
|
| 32 |
+
self._limit = limit
|
| 33 |
+
self._version = version
|
| 34 |
+
self._timestamp = timestamp
|
| 35 |
+
|
| 36 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
def _read_files(self, files, converters):
|
| 40 |
+
"""Read files with Delta Sharing."""
|
| 41 |
+
from delta_sharing.reader import DeltaSharingReader
|
| 42 |
+
|
| 43 |
+
for file in files:
|
| 44 |
+
yield DeltaSharingReader._to_pandas(
|
| 45 |
+
action=file, converters=converters, for_cdf=False, limit=None
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def setup_delta_sharing_connections(self, url: str):
|
| 49 |
+
"""
|
| 50 |
+
Set up delta sharing connections based on the url.
|
| 51 |
+
|
| 52 |
+
:param url: a url under the format "<profile>#<share>.<schema>.<table>"
|
| 53 |
+
:
|
| 54 |
+
"""
|
| 55 |
+
from delta_sharing.protocol import DeltaSharingProfile, Table
|
| 56 |
+
from delta_sharing.rest_client import DataSharingRestClient
|
| 57 |
+
|
| 58 |
+
profile_str, share, schema, table_str = _parse_delta_sharing_url(url)
|
| 59 |
+
table = Table(name=table_str, share=share, schema=schema)
|
| 60 |
+
|
| 61 |
+
profile = DeltaSharingProfile.read_from_file(profile_str)
|
| 62 |
+
rest_client = DataSharingRestClient(profile)
|
| 63 |
+
return table, rest_client
|
| 64 |
+
|
| 65 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 66 |
+
assert parallelism > 0, f"Invalid parallelism {parallelism}"
|
| 67 |
+
from delta_sharing.converter import to_converters
|
| 68 |
+
|
| 69 |
+
self._table, self._rest_client = self.setup_delta_sharing_connections(self._url)
|
| 70 |
+
self._response = self._rest_client.list_files_in_table(
|
| 71 |
+
self._table,
|
| 72 |
+
jsonPredicateHints=self._json_predicate_hints,
|
| 73 |
+
limitHint=self._limit,
|
| 74 |
+
version=self._version,
|
| 75 |
+
timestamp=self._timestamp,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if len(self._response.add_files) == 0 or self._limit == 0:
|
| 79 |
+
logger.warning("No files found from the delta sharing table or limit is 0")
|
| 80 |
+
|
| 81 |
+
schema_json = loads(self._response.metadata.schema_string)
|
| 82 |
+
self._converters = to_converters(schema_json)
|
| 83 |
+
|
| 84 |
+
read_tasks = []
|
| 85 |
+
# get file list to be read in this task and preserve original chunk order
|
| 86 |
+
for files in np.array_split(self._response.add_files, parallelism):
|
| 87 |
+
files = files.tolist()
|
| 88 |
+
metadata = BlockMetadata(
|
| 89 |
+
num_rows=None,
|
| 90 |
+
schema=None,
|
| 91 |
+
input_files=files,
|
| 92 |
+
size_bytes=None,
|
| 93 |
+
exec_stats=None,
|
| 94 |
+
)
|
| 95 |
+
converters = self._converters
|
| 96 |
+
read_task = ReadTask(
|
| 97 |
+
lambda f=files: self._read_files(f, converters),
|
| 98 |
+
metadata,
|
| 99 |
+
)
|
| 100 |
+
read_tasks.append(read_task)
|
| 101 |
+
|
| 102 |
+
return read_tasks
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _parse_delta_sharing_url(url: str) -> Tuple[str, str, str, str]:
|
| 106 |
+
"""
|
| 107 |
+
Developed from delta_sharing's _parse_url function.
|
| 108 |
+
https://github.com/delta-io/delta-sharing/blob/main/python/delta_sharing/delta_sharing.py#L36
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
url: a url under the format "<profile>#<share>.<schema>.<table>"
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
a tuple with parsed (profile, share, schema, table)
|
| 115 |
+
"""
|
| 116 |
+
shape_index = url.rfind("#")
|
| 117 |
+
if shape_index < 0:
|
| 118 |
+
raise ValueError(f"Invalid 'url': {url}")
|
| 119 |
+
profile = url[0:shape_index]
|
| 120 |
+
fragments = url[shape_index + 1 :].split(".")
|
| 121 |
+
if len(fragments) != 3:
|
| 122 |
+
raise ValueError(f"Invalid 'url': {url}")
|
| 123 |
+
share, schema, table = fragments
|
| 124 |
+
if len(profile) == 0 or len(share) == 0 or len(schema) == 0 or len(table) == 0:
|
| 125 |
+
raise ValueError(f"Invalid 'url': {url}")
|
| 126 |
+
return (profile, share, schema, table)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/hudi_datasource.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, Iterator, List, Optional
|
| 4 |
+
|
| 5 |
+
from ray.data._internal.util import _check_import
|
| 6 |
+
from ray.data.block import BlockMetadata
|
| 7 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HudiDatasource(Datasource):
|
| 13 |
+
"""Hudi datasource, for reading Apache Hudi table."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
table_uri: str,
|
| 18 |
+
storage_options: Optional[Dict[str, str]] = None,
|
| 19 |
+
):
|
| 20 |
+
_check_import(self, module="hudi", package="hudi-python")
|
| 21 |
+
|
| 22 |
+
self._table_uri = table_uri
|
| 23 |
+
self._storage_options = storage_options
|
| 24 |
+
|
| 25 |
+
def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
|
| 26 |
+
import pyarrow
|
| 27 |
+
from hudi import HudiTable
|
| 28 |
+
|
| 29 |
+
def _perform_read(
|
| 30 |
+
table_uri: str,
|
| 31 |
+
base_file_paths: List[str],
|
| 32 |
+
options: Dict[str, str],
|
| 33 |
+
) -> Iterator["pyarrow.Table"]:
|
| 34 |
+
from hudi import HudiFileGroupReader
|
| 35 |
+
|
| 36 |
+
for p in base_file_paths:
|
| 37 |
+
file_group_reader = HudiFileGroupReader(table_uri, options)
|
| 38 |
+
batch = file_group_reader.read_file_slice_by_base_file_path(p)
|
| 39 |
+
yield pyarrow.Table.from_batches([batch])
|
| 40 |
+
|
| 41 |
+
hudi_table = HudiTable(self._table_uri, self._storage_options)
|
| 42 |
+
|
| 43 |
+
reader_options = {
|
| 44 |
+
**hudi_table.storage_options(),
|
| 45 |
+
**hudi_table.hudi_options(),
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
schema = hudi_table.get_schema()
|
| 49 |
+
read_tasks = []
|
| 50 |
+
for file_slices_split in hudi_table.get_file_slices_splits(parallelism):
|
| 51 |
+
num_rows = 0
|
| 52 |
+
relative_paths = []
|
| 53 |
+
input_files = []
|
| 54 |
+
size_bytes = 0
|
| 55 |
+
for file_slice in file_slices_split:
|
| 56 |
+
# A file slice in a Hudi table is a logical group of data files
|
| 57 |
+
# within a physical partition. Records stored in a file slice
|
| 58 |
+
# are associated with a commit on the Hudi table's timeline.
|
| 59 |
+
# For more info, see https://hudi.apache.org/docs/file_layouts
|
| 60 |
+
num_rows += file_slice.num_records
|
| 61 |
+
relative_path = file_slice.base_file_relative_path()
|
| 62 |
+
relative_paths.append(relative_path)
|
| 63 |
+
full_path = os.path.join(self._table_uri, relative_path)
|
| 64 |
+
input_files.append(full_path)
|
| 65 |
+
size_bytes += file_slice.base_file_size
|
| 66 |
+
|
| 67 |
+
metadata = BlockMetadata(
|
| 68 |
+
num_rows=num_rows,
|
| 69 |
+
schema=schema,
|
| 70 |
+
input_files=input_files,
|
| 71 |
+
size_bytes=size_bytes,
|
| 72 |
+
exec_stats=None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
read_task = ReadTask(
|
| 76 |
+
read_fn=lambda paths=relative_paths: _perform_read(
|
| 77 |
+
self._table_uri, paths, reader_options
|
| 78 |
+
),
|
| 79 |
+
metadata=metadata,
|
| 80 |
+
)
|
| 81 |
+
read_tasks.append(read_task)
|
| 82 |
+
|
| 83 |
+
return read_tasks
|
| 84 |
+
|
| 85 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 86 |
+
# TODO(xushiyan) add APIs to provide estimated in-memory size
|
| 87 |
+
return None
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/huggingface_datasource.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
|
| 5 |
+
from ray.data._internal.util import _check_pyarrow_version
|
| 6 |
+
from ray.data.block import Block, BlockAccessor, BlockMetadata
|
| 7 |
+
from ray.data.dataset import Dataset
|
| 8 |
+
from ray.data.datasource import Datasource, ReadTask
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
import datasets
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
TRANSFORMERS_IMPORT_ERROR: Optional[ImportError] = None
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
# Due to HF Dataset's dynamic module system, we need to dynamically import the
|
| 18 |
+
# datasets_modules module on every actor when training.
|
| 19 |
+
# We accomplish this by simply running the following bit of code directly
|
| 20 |
+
# in the module you are currently viewing. This ensures that when we
|
| 21 |
+
# unpickle the Dataset, it runs before pickle tries to
|
| 22 |
+
# import datasets_modules and prevents an exception from being thrown.
|
| 23 |
+
# Same logic is present inside HF Transformers Ray
|
| 24 |
+
# integration: https://github.com/huggingface/transformers/blob/\
|
| 25 |
+
# 7d5fde991d598370d961be8cb7add6541e2b59ce/src/transformers/integrations.py#L271
|
| 26 |
+
# Also see https://github.com/ray-project/ray/issues/28084
|
| 27 |
+
from transformers.utils import is_datasets_available
|
| 28 |
+
|
| 29 |
+
if "datasets_modules" not in sys.modules and is_datasets_available():
|
| 30 |
+
import importlib
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
import datasets.load
|
| 34 |
+
|
| 35 |
+
dynamic_modules_path = os.path.join(
|
| 36 |
+
datasets.load.init_dynamic_modules(), "__init__.py"
|
| 37 |
+
)
|
| 38 |
+
# load dynamic_modules from path
|
| 39 |
+
spec = importlib.util.spec_from_file_location(
|
| 40 |
+
"datasets_modules", dynamic_modules_path
|
| 41 |
+
)
|
| 42 |
+
datasets_modules = importlib.util.module_from_spec(spec)
|
| 43 |
+
sys.modules[spec.name] = datasets_modules
|
| 44 |
+
spec.loader.exec_module(datasets_modules)
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
TRANSFORMERS_IMPORT_ERROR = e
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class HuggingFaceDatasource(Datasource):
|
| 50 |
+
"""Hugging Face Dataset datasource, for reading from a
|
| 51 |
+
`Hugging Face Datasets Dataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset/>`_.
|
| 52 |
+
This Datasource implements a streamed read using a
|
| 53 |
+
single read task, most beneficial for a
|
| 54 |
+
`Hugging Face Datasets IterableDataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDataset/>`_
|
| 55 |
+
or datasets which are too large to fit in-memory.
|
| 56 |
+
For an in-memory Hugging Face Dataset (`datasets.Dataset`), use :meth:`~ray.data.from_huggingface`
|
| 57 |
+
directly for faster performance.
|
| 58 |
+
""" # noqa: E501
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
dataset: Union["datasets.Dataset", "datasets.IterableDataset"],
|
| 63 |
+
batch_size: int = 4096,
|
| 64 |
+
):
|
| 65 |
+
if TRANSFORMERS_IMPORT_ERROR is not None:
|
| 66 |
+
raise TRANSFORMERS_IMPORT_ERROR
|
| 67 |
+
|
| 68 |
+
self._dataset = dataset
|
| 69 |
+
self._batch_size = batch_size
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def list_parquet_urls_from_dataset(
|
| 73 |
+
cls, dataset: Union["datasets.Dataset", "datasets.IterableDataset"]
|
| 74 |
+
) -> Dataset:
|
| 75 |
+
"""Return list of Hugging Face hosted parquet file URLs if they
|
| 76 |
+
exist for the data (i.e. if the dataset is a public dataset that
|
| 77 |
+
has not been transformed) else return an empty list."""
|
| 78 |
+
import datasets
|
| 79 |
+
|
| 80 |
+
# We can use the dataset name, config name, and split name to load
|
| 81 |
+
# public hugging face datasets from the Hugging Face Hub. More info
|
| 82 |
+
# here: https://huggingface.co/docs/datasets-server/parquet
|
| 83 |
+
dataset_name = dataset.info.dataset_name
|
| 84 |
+
config_name = dataset.info.config_name
|
| 85 |
+
split_name = str(dataset.split)
|
| 86 |
+
|
| 87 |
+
# If a dataset is not an iterable dataset, we will check if the
|
| 88 |
+
# dataset with the matching dataset name, config name, and split name
|
| 89 |
+
# on the Hugging Face Hub has the same fingerprint as the
|
| 90 |
+
# dataset passed into this function. If it is not matching, transforms
|
| 91 |
+
# or other operations have been performed so we cannot use the parquet
|
| 92 |
+
# files on the Hugging Face Hub, so we return an empty list.
|
| 93 |
+
if not isinstance(dataset, datasets.IterableDataset):
|
| 94 |
+
from datasets import load_dataset
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
ds = load_dataset(dataset_name, config_name, split=split_name)
|
| 98 |
+
if ds._fingerprint != dataset._fingerprint:
|
| 99 |
+
return []
|
| 100 |
+
except Exception:
|
| 101 |
+
# If an exception is thrown when trying to reload the dataset
|
| 102 |
+
# we should exit gracefully by returning an empty list.
|
| 103 |
+
return []
|
| 104 |
+
|
| 105 |
+
import requests
|
| 106 |
+
|
| 107 |
+
public_url = (
|
| 108 |
+
f"https://huggingface.co/api/datasets/{dataset_name}"
|
| 109 |
+
f"/parquet/{config_name}/{split_name}"
|
| 110 |
+
)
|
| 111 |
+
resp = requests.get(public_url)
|
| 112 |
+
if resp.status_code == requests.codes["ok"]:
|
| 113 |
+
# dataset corresponds to a public dataset, return list of parquet_files
|
| 114 |
+
return resp.json()
|
| 115 |
+
else:
|
| 116 |
+
return []
|
| 117 |
+
|
| 118 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 119 |
+
return self._dataset.dataset_size
|
| 120 |
+
|
| 121 |
+
def get_read_tasks(
|
| 122 |
+
self,
|
| 123 |
+
parallelism: int,
|
| 124 |
+
) -> List[ReadTask]:
|
| 125 |
+
# Note: `parallelism` arg is currently not used by HuggingFaceDatasource.
|
| 126 |
+
# We always generate a single ReadTask to perform the read.
|
| 127 |
+
_check_pyarrow_version()
|
| 128 |
+
import numpy as np
|
| 129 |
+
import pandas as pd
|
| 130 |
+
import pyarrow
|
| 131 |
+
|
| 132 |
+
def _read_dataset(dataset: "datasets.IterableDataset") -> Iterable[Block]:
|
| 133 |
+
for batch in dataset.with_format("arrow").iter(batch_size=self._batch_size):
|
| 134 |
+
# HuggingFace IterableDatasets do not fully support methods like
|
| 135 |
+
# `set_format`, `with_format`, and `formatted_as`, so the dataset
|
| 136 |
+
# can return whatever is the default configured batch type, even if
|
| 137 |
+
# the format is manually overriden before iterating above.
|
| 138 |
+
# Therefore, we limit support to batch formats which have native
|
| 139 |
+
# block types in Ray Data (pyarrow.Table, pd.DataFrame),
|
| 140 |
+
# or can easily be converted to such (dict, np.array).
|
| 141 |
+
# See: https://github.com/huggingface/datasets/issues/3444
|
| 142 |
+
if not isinstance(batch, (pyarrow.Table, pd.DataFrame, dict, np.array)):
|
| 143 |
+
raise ValueError(
|
| 144 |
+
f"Batch format {type(batch)} isn't supported. Only the "
|
| 145 |
+
f"following batch formats are supported: "
|
| 146 |
+
f"dict (corresponds to `None` in `dataset.with_format()`), "
|
| 147 |
+
f"pyarrow.Table, np.array, pd.DataFrame."
|
| 148 |
+
)
|
| 149 |
+
# Ensure np.arrays are wrapped in a dict
|
| 150 |
+
# (subsequently converted to a pyarrow.Table).
|
| 151 |
+
if isinstance(batch, np.ndarray):
|
| 152 |
+
batch = {"item": batch}
|
| 153 |
+
if isinstance(batch, dict):
|
| 154 |
+
batch = pyarrow_table_from_pydict(batch)
|
| 155 |
+
# Ensure that we return the default block type.
|
| 156 |
+
block = BlockAccessor.for_block(batch).to_default()
|
| 157 |
+
yield block
|
| 158 |
+
|
| 159 |
+
# TODO(scottjlee): IterableDataset doesn't provide APIs
|
| 160 |
+
# for getting number of rows, byte size, etc., so the
|
| 161 |
+
# BlockMetadata is currently empty. Properly retrieve
|
| 162 |
+
# or calculate these so that progress bars have meaning.
|
| 163 |
+
meta = BlockMetadata(
|
| 164 |
+
num_rows=None,
|
| 165 |
+
size_bytes=None,
|
| 166 |
+
schema=None,
|
| 167 |
+
input_files=None,
|
| 168 |
+
exec_stats=None,
|
| 169 |
+
)
|
| 170 |
+
read_tasks: List[ReadTask] = [
|
| 171 |
+
ReadTask(
|
| 172 |
+
lambda hfds=self._dataset: _read_dataset(hfds),
|
| 173 |
+
meta,
|
| 174 |
+
)
|
| 175 |
+
]
|
| 176 |
+
return read_tasks
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/iceberg_datasource.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module to read an iceberg table into a Ray Dataset, by using the Ray Datasource API.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import heapq
|
| 6 |
+
import itertools
|
| 7 |
+
import logging
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
| 10 |
+
|
| 11 |
+
from ray.data._internal.util import _check_import
|
| 12 |
+
from ray.data.block import Block, BlockMetadata
|
| 13 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 14 |
+
from ray.util.annotations import DeveloperAPI
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from pyiceberg.catalog import Catalog
|
| 18 |
+
from pyiceberg.expressions import BooleanExpression
|
| 19 |
+
from pyiceberg.io import FileIO
|
| 20 |
+
from pyiceberg.manifest import DataFile
|
| 21 |
+
from pyiceberg.schema import Schema
|
| 22 |
+
from pyiceberg.table import DataScan, FileScanTask, Table
|
| 23 |
+
from pyiceberg.table.metadata import TableMetadata
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _get_read_task(
|
| 29 |
+
tasks: Iterable["FileScanTask"],
|
| 30 |
+
table_io: "FileIO",
|
| 31 |
+
table_metadata: "TableMetadata",
|
| 32 |
+
row_filter: "BooleanExpression",
|
| 33 |
+
case_sensitive: bool,
|
| 34 |
+
limit: Optional[int],
|
| 35 |
+
schema: "Schema",
|
| 36 |
+
) -> Iterable[Block]:
|
| 37 |
+
from pyiceberg.io import pyarrow as pyi_pa_io
|
| 38 |
+
|
| 39 |
+
# Use the PyIceberg API to read only a single task (specifically, a
|
| 40 |
+
# FileScanTask) - note that this is not as simple as reading a single
|
| 41 |
+
# parquet file, as there might be delete files, etc. associated, so we
|
| 42 |
+
# must use the PyIceberg API for the projection.
|
| 43 |
+
yield pyi_pa_io.project_table(
|
| 44 |
+
tasks=tasks,
|
| 45 |
+
table_metadata=table_metadata,
|
| 46 |
+
io=table_io,
|
| 47 |
+
row_filter=row_filter,
|
| 48 |
+
projected_schema=schema,
|
| 49 |
+
case_sensitive=case_sensitive,
|
| 50 |
+
limit=limit,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@DeveloperAPI
|
| 55 |
+
class IcebergDatasource(Datasource):
|
| 56 |
+
"""
|
| 57 |
+
Iceberg datasource to read Iceberg tables into a Ray Dataset. This module heavily
|
| 58 |
+
uses PyIceberg to read iceberg tables. All the routines in this class override
|
| 59 |
+
`ray.data.Datasource`.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
table_identifier: str,
|
| 65 |
+
row_filter: Union[str, "BooleanExpression"] = None,
|
| 66 |
+
selected_fields: Tuple[str, ...] = ("*",),
|
| 67 |
+
snapshot_id: Optional[int] = None,
|
| 68 |
+
scan_kwargs: Optional[Dict[str, Any]] = None,
|
| 69 |
+
catalog_kwargs: Optional[Dict[str, Any]] = None,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Initialize an IcebergDatasource.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
table_identifier: Fully qualified table identifier (i.e.,
|
| 76 |
+
"db_name.table_name")
|
| 77 |
+
row_filter: A PyIceberg BooleanExpression to use to filter the data *prior*
|
| 78 |
+
to reading
|
| 79 |
+
selected_fields: Which columns from the data to read, passed directly to
|
| 80 |
+
PyIceberg's load functions
|
| 81 |
+
snapshot_id: Optional snapshot ID for the Iceberg table
|
| 82 |
+
scan_kwargs: Optional arguments to pass to PyIceberg's Table.scan()
|
| 83 |
+
function
|
| 84 |
+
catalog_kwargs: Optional arguments to use when setting up the Iceberg
|
| 85 |
+
catalog
|
| 86 |
+
"""
|
| 87 |
+
_check_import(self, module="pyiceberg", package="pyiceberg")
|
| 88 |
+
from pyiceberg.expressions import AlwaysTrue
|
| 89 |
+
|
| 90 |
+
self._scan_kwargs = scan_kwargs if scan_kwargs is not None else {}
|
| 91 |
+
self._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {}
|
| 92 |
+
|
| 93 |
+
if "name" in self._catalog_kwargs:
|
| 94 |
+
self._catalog_name = self._catalog_kwargs.pop("name")
|
| 95 |
+
else:
|
| 96 |
+
self._catalog_name = "default"
|
| 97 |
+
|
| 98 |
+
self.table_identifier = table_identifier
|
| 99 |
+
|
| 100 |
+
self._row_filter = row_filter if row_filter is not None else AlwaysTrue()
|
| 101 |
+
self._selected_fields = selected_fields
|
| 102 |
+
|
| 103 |
+
if snapshot_id:
|
| 104 |
+
self._scan_kwargs["snapshot_id"] = snapshot_id
|
| 105 |
+
|
| 106 |
+
self._plan_files = None
|
| 107 |
+
self._table = None
|
| 108 |
+
|
| 109 |
+
def _get_catalog(self) -> "Catalog":
|
| 110 |
+
from pyiceberg import catalog
|
| 111 |
+
|
| 112 |
+
return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs)
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def table(self) -> "Table":
|
| 116 |
+
"""
|
| 117 |
+
Return the table reference from the catalog
|
| 118 |
+
"""
|
| 119 |
+
if self._table is None:
|
| 120 |
+
catalog = self._get_catalog()
|
| 121 |
+
self._table = catalog.load_table(self.table_identifier)
|
| 122 |
+
return self._table
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def plan_files(self) -> List["FileScanTask"]:
|
| 126 |
+
"""
|
| 127 |
+
Return the plan files specified by this query
|
| 128 |
+
"""
|
| 129 |
+
# Calculate and cache the plan_files if they don't already exist
|
| 130 |
+
if self._plan_files is None:
|
| 131 |
+
data_scan = self._get_data_scan()
|
| 132 |
+
self._plan_files = data_scan.plan_files()
|
| 133 |
+
|
| 134 |
+
return self._plan_files
|
| 135 |
+
|
| 136 |
+
def _get_data_scan(self) -> "DataScan":
|
| 137 |
+
|
| 138 |
+
data_scan = self.table.scan(
|
| 139 |
+
row_filter=self._row_filter,
|
| 140 |
+
selected_fields=self._selected_fields,
|
| 141 |
+
**self._scan_kwargs,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return data_scan
|
| 145 |
+
|
| 146 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 147 |
+
# Approximate the size by using the plan files - this will not
|
| 148 |
+
# incorporate the deletes, but that's a reasonable approximation
|
| 149 |
+
# task
|
| 150 |
+
return sum(task.file.file_size_in_bytes for task in self.plan_files)
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def _distribute_tasks_into_equal_chunks(
|
| 154 |
+
plan_files: Iterable["FileScanTask"], n_chunks: int
|
| 155 |
+
) -> List[List["FileScanTask"]]:
|
| 156 |
+
"""
|
| 157 |
+
Implement a greedy knapsack algorithm to distribute the files in the scan
|
| 158 |
+
across tasks, based on their file size, as evenly as possible
|
| 159 |
+
"""
|
| 160 |
+
chunks = [list() for _ in range(n_chunks)]
|
| 161 |
+
|
| 162 |
+
chunk_sizes = [(0, chunk_id) for chunk_id in range(n_chunks)]
|
| 163 |
+
heapq.heapify(chunk_sizes)
|
| 164 |
+
|
| 165 |
+
# From largest to smallest, add the plan files to the smallest chunk one at a
|
| 166 |
+
# time
|
| 167 |
+
for plan_file in sorted(
|
| 168 |
+
plan_files, key=lambda f: f.file.file_size_in_bytes, reverse=True
|
| 169 |
+
):
|
| 170 |
+
smallest_chunk = heapq.heappop(chunk_sizes)
|
| 171 |
+
chunks[smallest_chunk[1]].append(plan_file)
|
| 172 |
+
heapq.heappush(
|
| 173 |
+
chunk_sizes,
|
| 174 |
+
(
|
| 175 |
+
smallest_chunk[0] + plan_file.file.file_size_in_bytes,
|
| 176 |
+
smallest_chunk[1],
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return chunks
|
| 181 |
+
|
| 182 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 183 |
+
from pyiceberg.io import pyarrow as pyi_pa_io
|
| 184 |
+
from pyiceberg.manifest import DataFileContent
|
| 185 |
+
|
| 186 |
+
# Get the PyIceberg scan
|
| 187 |
+
data_scan = self._get_data_scan()
|
| 188 |
+
# Get the plan files in this query
|
| 189 |
+
plan_files = self.plan_files
|
| 190 |
+
|
| 191 |
+
# Get the projected schema for this scan, given all the row filters,
|
| 192 |
+
# snapshot ID, etc.
|
| 193 |
+
projected_schema = data_scan.projection()
|
| 194 |
+
# Get the arrow schema, to set in the metadata
|
| 195 |
+
pya_schema = pyi_pa_io.schema_to_pyarrow(projected_schema)
|
| 196 |
+
|
| 197 |
+
# Set the n_chunks to the min of the number of plan files and the actual
|
| 198 |
+
# requested n_chunks, so that there are no empty tasks
|
| 199 |
+
if parallelism > len(list(plan_files)):
|
| 200 |
+
parallelism = len(list(plan_files))
|
| 201 |
+
logger.warning(
|
| 202 |
+
f"Reducing the parallelism to {parallelism}, as that is the"
|
| 203 |
+
"number of files"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Get required properties for reading tasks - table IO, table metadata,
|
| 207 |
+
# row filter, case sensitivity,limit and projected schema to pass
|
| 208 |
+
# them directly to `_get_read_task` to avoid capture of `self` reference
|
| 209 |
+
# within the closure carrying substantial overhead invoking these tasks
|
| 210 |
+
#
|
| 211 |
+
# See https://github.com/ray-project/ray/issues/49107 for more context
|
| 212 |
+
table_io = self.table.io
|
| 213 |
+
table_metadata = self.table.metadata
|
| 214 |
+
row_filter = self._row_filter
|
| 215 |
+
case_sensitive = self._scan_kwargs.get("case_sensitive", True)
|
| 216 |
+
limit = self._scan_kwargs.get("limit")
|
| 217 |
+
|
| 218 |
+
get_read_task = partial(
|
| 219 |
+
_get_read_task,
|
| 220 |
+
table_io=table_io,
|
| 221 |
+
table_metadata=table_metadata,
|
| 222 |
+
row_filter=row_filter,
|
| 223 |
+
case_sensitive=case_sensitive,
|
| 224 |
+
limit=limit,
|
| 225 |
+
schema=projected_schema,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
read_tasks = []
|
| 229 |
+
# Chunk the plan files based on the requested parallelism
|
| 230 |
+
for chunk_tasks in IcebergDatasource._distribute_tasks_into_equal_chunks(
|
| 231 |
+
plan_files, parallelism
|
| 232 |
+
):
|
| 233 |
+
unique_deletes: Set[DataFile] = set(
|
| 234 |
+
itertools.chain.from_iterable(
|
| 235 |
+
[task.delete_files for task in chunk_tasks]
|
| 236 |
+
)
|
| 237 |
+
)
|
| 238 |
+
# Get a rough estimate of the number of deletes by just looking at
|
| 239 |
+
# position deletes. Equality deletes are harder to estimate, as they
|
| 240 |
+
# can delete multiple rows.
|
| 241 |
+
position_delete_count = sum(
|
| 242 |
+
delete.record_count
|
| 243 |
+
for delete in unique_deletes
|
| 244 |
+
if delete.content == DataFileContent.POSITION_DELETES
|
| 245 |
+
)
|
| 246 |
+
metadata = BlockMetadata(
|
| 247 |
+
num_rows=sum(task.file.record_count for task in chunk_tasks)
|
| 248 |
+
- position_delete_count,
|
| 249 |
+
size_bytes=sum(task.length for task in chunk_tasks),
|
| 250 |
+
schema=pya_schema,
|
| 251 |
+
input_files=[task.file.file_path for task in chunk_tasks],
|
| 252 |
+
exec_stats=None,
|
| 253 |
+
)
|
| 254 |
+
read_tasks.append(
|
| 255 |
+
ReadTask(
|
| 256 |
+
read_fn=lambda tasks=chunk_tasks: get_read_task(tasks),
|
| 257 |
+
metadata=metadata,
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return read_tasks
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasink.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
import pyarrow
|
| 5 |
+
|
| 6 |
+
from ray.data.datasource.file_datasink import RowBasedFileDatasink
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ImageDatasink(RowBasedFileDatasink):
|
| 10 |
+
def __init__(
|
| 11 |
+
self, path: str, column: str, file_format: str, **file_datasink_kwargs
|
| 12 |
+
):
|
| 13 |
+
super().__init__(path, file_format=file_format, **file_datasink_kwargs)
|
| 14 |
+
|
| 15 |
+
self.column = column
|
| 16 |
+
self.file_format = file_format
|
| 17 |
+
|
| 18 |
+
def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
image = Image.fromarray(row[self.column])
|
| 22 |
+
buffer = io.BytesIO()
|
| 23 |
+
image.save(buffer, format=self.file_format)
|
| 24 |
+
file.write(buffer.getvalue())
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasource.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 9 |
+
from ray.data._internal.util import _check_import
|
| 10 |
+
from ray.data.block import Block, BlockMetadata
|
| 11 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 12 |
+
from ray.data.datasource.file_meta_provider import DefaultFileMetadataProvider
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
import pyarrow
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# The default size multiplier for reading image data source.
|
| 21 |
+
# This essentially is using image on-disk file size to estimate
|
| 22 |
+
# in-memory data size.
|
| 23 |
+
IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT = 1
|
| 24 |
+
|
| 25 |
+
# The lower bound value to estimate image encoding ratio.
|
| 26 |
+
IMAGE_ENCODING_RATIO_ESTIMATE_LOWER_BOUND = 0.5
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ImageDatasource(FileBasedDatasource):
|
| 30 |
+
"""A datasource that lets you read images."""
|
| 31 |
+
|
| 32 |
+
_WRITE_FILE_PER_ROW = True
|
| 33 |
+
_FILE_EXTENSIONS = ["png", "jpg", "jpeg", "tif", "tiff", "bmp", "gif"]
|
| 34 |
+
# Use 8 threads per task to read image files.
|
| 35 |
+
_NUM_THREADS_PER_TASK = 8
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
paths: Union[str, List[str]],
|
| 40 |
+
size: Optional[Tuple[int, int]] = None,
|
| 41 |
+
mode: Optional[str] = None,
|
| 42 |
+
**file_based_datasource_kwargs,
|
| 43 |
+
):
|
| 44 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 45 |
+
|
| 46 |
+
_check_import(self, module="PIL", package="Pillow")
|
| 47 |
+
|
| 48 |
+
if size is not None and len(size) != 2:
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"Expected `size` to contain two integers for height and width, "
|
| 51 |
+
f"but got {len(size)} integers instead."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if size is not None and (size[0] < 0 or size[1] < 0):
|
| 55 |
+
raise ValueError(
|
| 56 |
+
f"Expected `size` to contain positive integers, but got {size} instead."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.size = size
|
| 60 |
+
self.mode = mode
|
| 61 |
+
|
| 62 |
+
meta_provider = file_based_datasource_kwargs.get("meta_provider", None)
|
| 63 |
+
if isinstance(meta_provider, ImageFileMetadataProvider):
|
| 64 |
+
self._encoding_ratio = self._estimate_files_encoding_ratio()
|
| 65 |
+
meta_provider._set_encoding_ratio(self._encoding_ratio)
|
| 66 |
+
else:
|
| 67 |
+
self._encoding_ratio = IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
|
| 68 |
+
|
| 69 |
+
def _read_stream(
|
| 70 |
+
self,
|
| 71 |
+
f: "pyarrow.NativeFile",
|
| 72 |
+
path: str,
|
| 73 |
+
) -> Iterator[Block]:
|
| 74 |
+
from PIL import Image, UnidentifiedImageError
|
| 75 |
+
|
| 76 |
+
data = f.readall()
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
image = Image.open(io.BytesIO(data))
|
| 80 |
+
except UnidentifiedImageError as e:
|
| 81 |
+
raise ValueError(f"PIL couldn't load image file at path '{path}'.") from e
|
| 82 |
+
|
| 83 |
+
if self.size is not None:
|
| 84 |
+
height, width = self.size
|
| 85 |
+
image = image.resize((width, height), resample=Image.BILINEAR)
|
| 86 |
+
if self.mode is not None:
|
| 87 |
+
image = image.convert(self.mode)
|
| 88 |
+
|
| 89 |
+
builder = DelegatingBlockBuilder()
|
| 90 |
+
array = np.array(image)
|
| 91 |
+
item = {"image": array}
|
| 92 |
+
builder.add(item)
|
| 93 |
+
block = builder.build()
|
| 94 |
+
|
| 95 |
+
yield block
|
| 96 |
+
|
| 97 |
+
def _rows_per_file(self):
|
| 98 |
+
return 1
|
| 99 |
+
|
| 100 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 101 |
+
total_size = 0
|
| 102 |
+
for file_size in self._file_sizes():
|
| 103 |
+
# NOTE: check if file size is not None, because some metadata provider
|
| 104 |
+
# such as FastFileMetadataProvider does not provide file size information.
|
| 105 |
+
if file_size is not None:
|
| 106 |
+
total_size += file_size
|
| 107 |
+
return total_size * self._encoding_ratio
|
| 108 |
+
|
| 109 |
+
def _estimate_files_encoding_ratio(self) -> float:
|
| 110 |
+
"""Return an estimate of the image files encoding ratio."""
|
| 111 |
+
start_time = time.perf_counter()
|
| 112 |
+
# Filter out empty file to avoid noise.
|
| 113 |
+
non_empty_path_and_size = list(
|
| 114 |
+
filter(lambda p: p[1] > 0, zip(self._paths(), self._file_sizes()))
|
| 115 |
+
)
|
| 116 |
+
num_files = len(non_empty_path_and_size)
|
| 117 |
+
if num_files == 0:
|
| 118 |
+
logger.warn(
|
| 119 |
+
"All input image files are empty. "
|
| 120 |
+
"Use on-disk file size to estimate images in-memory size."
|
| 121 |
+
)
|
| 122 |
+
return IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
|
| 123 |
+
|
| 124 |
+
if self.size is not None and self.mode is not None:
|
| 125 |
+
# Use image size and mode to calculate data size for all images,
|
| 126 |
+
# because all images are homogeneous with same size after resizing.
|
| 127 |
+
# Resizing is enforced when reading every image in `ImageDatasource`
|
| 128 |
+
# when `size` argument is provided.
|
| 129 |
+
if self.mode in ["1", "L", "P"]:
|
| 130 |
+
dimension = 1
|
| 131 |
+
elif self.mode in ["RGB", "YCbCr", "LAB", "HSV"]:
|
| 132 |
+
dimension = 3
|
| 133 |
+
elif self.mode in ["RGBA", "CMYK", "I", "F"]:
|
| 134 |
+
dimension = 4
|
| 135 |
+
else:
|
| 136 |
+
logger.warn(f"Found unknown image mode: {self.mode}.")
|
| 137 |
+
return IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
|
| 138 |
+
height, width = self.size
|
| 139 |
+
single_image_size = height * width * dimension
|
| 140 |
+
total_estimated_size = single_image_size * num_files
|
| 141 |
+
total_file_size = sum(p[1] for p in non_empty_path_and_size)
|
| 142 |
+
ratio = total_estimated_size / total_file_size
|
| 143 |
+
else:
|
| 144 |
+
# TODO(chengsu): sample images to estimate data size
|
| 145 |
+
ratio = IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT
|
| 146 |
+
|
| 147 |
+
sampling_duration = time.perf_counter() - start_time
|
| 148 |
+
if sampling_duration > 5:
|
| 149 |
+
logger.warn(
|
| 150 |
+
"Image input size estimation took "
|
| 151 |
+
f"{round(sampling_duration, 2)} seconds."
|
| 152 |
+
)
|
| 153 |
+
logger.debug(f"Estimated image encoding ratio from sampling is {ratio}.")
|
| 154 |
+
return max(ratio, IMAGE_ENCODING_RATIO_ESTIMATE_LOWER_BOUND)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class ImageFileMetadataProvider(DefaultFileMetadataProvider):
|
| 158 |
+
def _set_encoding_ratio(self, encoding_ratio: int):
|
| 159 |
+
"""Set image file encoding ratio, to provide accurate size in bytes metadata."""
|
| 160 |
+
self._encoding_ratio = encoding_ratio
|
| 161 |
+
|
| 162 |
+
def _get_block_metadata(
|
| 163 |
+
self,
|
| 164 |
+
paths: List[str],
|
| 165 |
+
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
|
| 166 |
+
*,
|
| 167 |
+
rows_per_file: Optional[int],
|
| 168 |
+
file_sizes: List[Optional[int]],
|
| 169 |
+
) -> BlockMetadata:
|
| 170 |
+
metadata = super()._get_block_metadata(
|
| 171 |
+
paths, schema, rows_per_file=rows_per_file, file_sizes=file_sizes
|
| 172 |
+
)
|
| 173 |
+
if metadata.size_bytes is not None:
|
| 174 |
+
metadata.size_bytes = int(metadata.size_bytes * self._encoding_ratio)
|
| 175 |
+
return metadata
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasink.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import pyarrow
|
| 4 |
+
|
| 5 |
+
from ray.data.block import BlockAccessor
|
| 6 |
+
from ray.data.datasource.file_based_datasource import _resolve_kwargs
|
| 7 |
+
from ray.data.datasource.file_datasink import BlockBasedFileDatasink
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class JSONDatasink(BlockBasedFileDatasink):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
path: str,
|
| 14 |
+
*,
|
| 15 |
+
pandas_json_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
|
| 16 |
+
pandas_json_args: Optional[Dict[str, Any]] = None,
|
| 17 |
+
file_format: str = "json",
|
| 18 |
+
**file_datasink_kwargs,
|
| 19 |
+
):
|
| 20 |
+
super().__init__(path, file_format=file_format, **file_datasink_kwargs)
|
| 21 |
+
|
| 22 |
+
if pandas_json_args_fn is None:
|
| 23 |
+
pandas_json_args_fn = lambda: {} # noqa: E731
|
| 24 |
+
|
| 25 |
+
if pandas_json_args is None:
|
| 26 |
+
pandas_json_args = {}
|
| 27 |
+
|
| 28 |
+
self.pandas_json_args_fn = pandas_json_args_fn
|
| 29 |
+
self.pandas_json_args = pandas_json_args
|
| 30 |
+
|
| 31 |
+
def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
|
| 32 |
+
writer_args = _resolve_kwargs(self.pandas_json_args_fn, **self.pandas_json_args)
|
| 33 |
+
orient = writer_args.pop("orient", "records")
|
| 34 |
+
lines = writer_args.pop("lines", True)
|
| 35 |
+
|
| 36 |
+
block.to_pandas().to_json(file, orient=orient, lines=lines, **writer_args)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasource.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
| 4 |
+
|
| 5 |
+
from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
|
| 6 |
+
from ray.data.context import DataContext
|
| 7 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
import pyarrow
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class JSONDatasource(FileBasedDatasource):
|
| 16 |
+
"""JSON datasource, for reading and writing JSON and JSONL files."""
|
| 17 |
+
|
| 18 |
+
_FILE_EXTENSIONS = [
|
| 19 |
+
"json",
|
| 20 |
+
"jsonl",
|
| 21 |
+
# gzip-compressed files
|
| 22 |
+
"json.gz",
|
| 23 |
+
"jsonl.gz",
|
| 24 |
+
# Brotli-compressed fi;es
|
| 25 |
+
"json.br",
|
| 26 |
+
"jsonl.br",
|
| 27 |
+
# Zstandard-compressed files
|
| 28 |
+
"json.zst",
|
| 29 |
+
"jsonl.zst",
|
| 30 |
+
# lz4-compressed files
|
| 31 |
+
"json.lz4",
|
| 32 |
+
"jsonl.lz4",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
paths: Union[str, List[str]],
|
| 38 |
+
*,
|
| 39 |
+
arrow_json_args: Optional[Dict[str, Any]] = None,
|
| 40 |
+
**file_based_datasource_kwargs,
|
| 41 |
+
):
|
| 42 |
+
from pyarrow import json
|
| 43 |
+
|
| 44 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 45 |
+
|
| 46 |
+
if arrow_json_args is None:
|
| 47 |
+
arrow_json_args = {}
|
| 48 |
+
|
| 49 |
+
self.read_options = arrow_json_args.pop(
|
| 50 |
+
"read_options", json.ReadOptions(use_threads=False)
|
| 51 |
+
)
|
| 52 |
+
self.arrow_json_args = arrow_json_args
|
| 53 |
+
|
| 54 |
+
def _read_with_pyarrow_read_json(self, buffer: "pyarrow.lib.Buffer"):
|
| 55 |
+
"""Read with PyArrow JSON reader, trying to auto-increase the
|
| 56 |
+
read block size in the case of the read object
|
| 57 |
+
straddling block boundaries."""
|
| 58 |
+
import pyarrow as pa
|
| 59 |
+
|
| 60 |
+
# When reading large files, the default block size configured in PyArrow can be
|
| 61 |
+
# too small, resulting in the following error: `pyarrow.lib.ArrowInvalid:
|
| 62 |
+
# straddling object straddles two block boundaries (try to increase block
|
| 63 |
+
# size?)`. More information on this issue can be found here:
|
| 64 |
+
# https://github.com/apache/arrow/issues/25674
|
| 65 |
+
# The read will be retried with geometrically increasing block size
|
| 66 |
+
# until the size reaches `DataContext.get_current().target_max_block_size`.
|
| 67 |
+
# The initial block size will start at the PyArrow default block size
|
| 68 |
+
# or it can be manually set through the `read_options` parameter as follows.
|
| 69 |
+
# >>> import pyarrow.json as pajson
|
| 70 |
+
# >>> block_size = 10 << 20 # Set block size to 10MB
|
| 71 |
+
# >>> ray.data.read_json( # doctest: +SKIP
|
| 72 |
+
# ... "s3://anonymous@ray-example-data/log.json",
|
| 73 |
+
# ... read_options=pajson.ReadOptions(block_size=block_size)
|
| 74 |
+
# ... )
|
| 75 |
+
|
| 76 |
+
init_block_size = self.read_options.block_size
|
| 77 |
+
max_block_size = DataContext.get_current().target_max_block_size
|
| 78 |
+
while True:
|
| 79 |
+
try:
|
| 80 |
+
yield pa.json.read_json(
|
| 81 |
+
BytesIO(buffer),
|
| 82 |
+
read_options=self.read_options,
|
| 83 |
+
**self.arrow_json_args,
|
| 84 |
+
)
|
| 85 |
+
self.read_options.block_size = init_block_size
|
| 86 |
+
break
|
| 87 |
+
except pa.ArrowInvalid as e:
|
| 88 |
+
if "straddling object straddles two block boundaries" in str(e):
|
| 89 |
+
if self.read_options.block_size < max_block_size:
|
| 90 |
+
# Increase the block size in case it was too small.
|
| 91 |
+
logger.debug(
|
| 92 |
+
f"JSONDatasource read failed with "
|
| 93 |
+
f"block_size={self.read_options.block_size}. Retrying with "
|
| 94 |
+
f"block_size={self.read_options.block_size * 2}."
|
| 95 |
+
)
|
| 96 |
+
self.read_options.block_size *= 2
|
| 97 |
+
else:
|
| 98 |
+
raise pa.ArrowInvalid(
|
| 99 |
+
f"{e} - Auto-increasing block size to "
|
| 100 |
+
f"{self.read_options.block_size} bytes failed. "
|
| 101 |
+
f"Please try manually increasing the block size through "
|
| 102 |
+
f"the `read_options` parameter to a larger size. "
|
| 103 |
+
f"For example: `read_json(..., read_options="
|
| 104 |
+
f"pyarrow.json.ReadOptions(block_size=10 << 25))`"
|
| 105 |
+
f"More information on this issue can be found here: "
|
| 106 |
+
f"https://github.com/apache/arrow/issues/25674"
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
# unrelated error, simply reraise
|
| 110 |
+
raise e
|
| 111 |
+
|
| 112 |
+
def _read_with_python_json(self, buffer: "pyarrow.lib.Buffer"):
|
| 113 |
+
"""Fallback method to read JSON files with Python's native json.load(),
|
| 114 |
+
in case the default pyarrow json reader fails."""
|
| 115 |
+
import json
|
| 116 |
+
|
| 117 |
+
import pyarrow as pa
|
| 118 |
+
|
| 119 |
+
# Check if the buffer is empty
|
| 120 |
+
if buffer.size == 0:
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
parsed_json = json.load(BytesIO(buffer))
|
| 124 |
+
try:
|
| 125 |
+
yield pa.Table.from_pylist(parsed_json)
|
| 126 |
+
except AttributeError as e:
|
| 127 |
+
# For PyArrow < 7.0.0, `pa.Table.from_pylist()` is not available.
|
| 128 |
+
# Construct a dict from the list and call
|
| 129 |
+
# `pa.Table.from_pydict()` instead.
|
| 130 |
+
assert "no attribute 'from_pylist'" in str(e), str(e)
|
| 131 |
+
from collections import defaultdict
|
| 132 |
+
|
| 133 |
+
dct = defaultdict(list)
|
| 134 |
+
for row in parsed_json:
|
| 135 |
+
for k, v in row.items():
|
| 136 |
+
dct[k].append(v)
|
| 137 |
+
yield pyarrow_table_from_pydict(dct)
|
| 138 |
+
|
| 139 |
+
# TODO(ekl) The PyArrow JSON reader doesn't support streaming reads.
|
| 140 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str):
|
| 141 |
+
import pyarrow as pa
|
| 142 |
+
|
| 143 |
+
buffer: pa.lib.Buffer = f.read_buffer()
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
yield from self._read_with_pyarrow_read_json(buffer)
|
| 147 |
+
except pa.ArrowInvalid as e:
|
| 148 |
+
# If read with PyArrow fails, try falling back to native json.load().
|
| 149 |
+
logger.warning(
|
| 150 |
+
f"Error reading with pyarrow.json.read_json(). "
|
| 151 |
+
f"Falling back to native json.load(), which may be slower. "
|
| 152 |
+
f"PyArrow error was:\n{e}"
|
| 153 |
+
)
|
| 154 |
+
yield from self._read_with_python_json(buffer)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/lance_datasource.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ray.data._internal.util import _check_import, call_with_retry
|
| 7 |
+
from ray.data.block import BlockMetadata
|
| 8 |
+
from ray.data.context import DataContext
|
| 9 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
import pyarrow
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LanceDatasource(Datasource):
|
| 19 |
+
"""Lance datasource, for reading Lance dataset."""
|
| 20 |
+
|
| 21 |
+
# Errors to retry when reading Lance fragments.
|
| 22 |
+
READ_FRAGMENTS_ERRORS_TO_RETRY = ["LanceError(IO)"]
|
| 23 |
+
# Maximum number of attempts to read Lance fragments.
|
| 24 |
+
READ_FRAGMENTS_MAX_ATTEMPTS = 10
|
| 25 |
+
# Maximum backoff seconds between attempts to read Lance fragments.
|
| 26 |
+
READ_FRAGMENTS_RETRY_MAX_BACKOFF_SECONDS = 32
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
uri: str,
|
| 31 |
+
columns: Optional[List[str]] = None,
|
| 32 |
+
filter: Optional[str] = None,
|
| 33 |
+
storage_options: Optional[Dict[str, str]] = None,
|
| 34 |
+
scanner_options: Optional[Dict[str, Any]] = None,
|
| 35 |
+
):
|
| 36 |
+
_check_import(self, module="lance", package="pylance")
|
| 37 |
+
|
| 38 |
+
import lance
|
| 39 |
+
|
| 40 |
+
self.uri = uri
|
| 41 |
+
self.scanner_options = scanner_options or {}
|
| 42 |
+
if columns is not None:
|
| 43 |
+
self.scanner_options["columns"] = columns
|
| 44 |
+
if filter is not None:
|
| 45 |
+
self.scanner_options["filter"] = filter
|
| 46 |
+
self.storage_options = storage_options
|
| 47 |
+
self.lance_ds = lance.dataset(uri=uri, storage_options=storage_options)
|
| 48 |
+
|
| 49 |
+
match = []
|
| 50 |
+
match.extend(self.READ_FRAGMENTS_ERRORS_TO_RETRY)
|
| 51 |
+
match.extend(DataContext.get_current().retried_io_errors)
|
| 52 |
+
self._retry_params = {
|
| 53 |
+
"description": "read lance fragments",
|
| 54 |
+
"match": match,
|
| 55 |
+
"max_attempts": self.READ_FRAGMENTS_MAX_ATTEMPTS,
|
| 56 |
+
"max_backoff_s": self.READ_FRAGMENTS_RETRY_MAX_BACKOFF_SECONDS,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 60 |
+
read_tasks = []
|
| 61 |
+
for fragments in np.array_split(self.lance_ds.get_fragments(), parallelism):
|
| 62 |
+
if len(fragments) <= 0:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
fragment_ids = [f.metadata.id for f in fragments]
|
| 66 |
+
num_rows = sum(f.count_rows() for f in fragments)
|
| 67 |
+
input_files = [
|
| 68 |
+
data_file.path() for f in fragments for data_file in f.data_files()
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# TODO(chengsu): Take column projection into consideration for schema.
|
| 72 |
+
metadata = BlockMetadata(
|
| 73 |
+
num_rows=num_rows,
|
| 74 |
+
schema=fragments[0].schema,
|
| 75 |
+
input_files=input_files,
|
| 76 |
+
size_bytes=None,
|
| 77 |
+
exec_stats=None,
|
| 78 |
+
)
|
| 79 |
+
scanner_options = self.scanner_options
|
| 80 |
+
lance_ds = self.lance_ds
|
| 81 |
+
retry_params = self._retry_params
|
| 82 |
+
|
| 83 |
+
read_task = ReadTask(
|
| 84 |
+
lambda f=fragment_ids: _read_fragments_with_retry(
|
| 85 |
+
f,
|
| 86 |
+
lance_ds,
|
| 87 |
+
scanner_options,
|
| 88 |
+
retry_params,
|
| 89 |
+
),
|
| 90 |
+
metadata,
|
| 91 |
+
)
|
| 92 |
+
read_tasks.append(read_task)
|
| 93 |
+
|
| 94 |
+
return read_tasks
|
| 95 |
+
|
| 96 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 97 |
+
# TODO(chengsu): Add memory size estimation to improve auto-tune of parallelism.
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _read_fragments_with_retry(
|
| 102 |
+
fragment_ids,
|
| 103 |
+
lance_ds,
|
| 104 |
+
scanner_options,
|
| 105 |
+
retry_params,
|
| 106 |
+
) -> Iterator["pyarrow.Table"]:
|
| 107 |
+
return call_with_retry(
|
| 108 |
+
lambda: _read_fragments(fragment_ids, lance_ds, scanner_options),
|
| 109 |
+
**retry_params,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _read_fragments(
|
| 114 |
+
fragment_ids,
|
| 115 |
+
lance_ds,
|
| 116 |
+
scanner_options,
|
| 117 |
+
) -> Iterator["pyarrow.Table"]:
|
| 118 |
+
"""Read Lance fragments in batches.
|
| 119 |
+
|
| 120 |
+
NOTE: Use fragment ids, instead of fragments as parameter, because pickling
|
| 121 |
+
LanceFragment is expensive.
|
| 122 |
+
"""
|
| 123 |
+
import pyarrow
|
| 124 |
+
|
| 125 |
+
fragments = [lance_ds.get_fragment(id) for id in fragment_ids]
|
| 126 |
+
scanner_options["fragments"] = fragments
|
| 127 |
+
scanner = lance_ds.scanner(**scanner_options)
|
| 128 |
+
for batch in scanner.to_reader():
|
| 129 |
+
yield pyarrow.Table.from_batches([batch])
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasink.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.datasource.mongo_datasource import (
|
| 5 |
+
_validate_database_collection_exist,
|
| 6 |
+
)
|
| 7 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 8 |
+
from ray.data._internal.execution.interfaces import TaskContext
|
| 9 |
+
from ray.data._internal.util import _check_import
|
| 10 |
+
from ray.data.block import Block, BlockAccessor
|
| 11 |
+
from ray.data.datasource.datasink import Datasink
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MongoDatasink(Datasink[None]):
|
| 17 |
+
def __init__(self, uri: str, database: str, collection: str) -> None:
|
| 18 |
+
_check_import(self, module="pymongo", package="pymongo")
|
| 19 |
+
_check_import(self, module="pymongoarrow", package="pymongoarrow")
|
| 20 |
+
|
| 21 |
+
self.uri = uri
|
| 22 |
+
self.database = database
|
| 23 |
+
self.collection = collection
|
| 24 |
+
|
| 25 |
+
def write(
|
| 26 |
+
self,
|
| 27 |
+
blocks: Iterable[Block],
|
| 28 |
+
ctx: TaskContext,
|
| 29 |
+
) -> None:
|
| 30 |
+
import pymongo
|
| 31 |
+
|
| 32 |
+
_validate_database_collection_exist(
|
| 33 |
+
pymongo.MongoClient(self.uri), self.database, self.collection
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def write_block(uri: str, database: str, collection: str, block: Block):
|
| 37 |
+
from pymongoarrow.api import write
|
| 38 |
+
|
| 39 |
+
block = BlockAccessor.for_block(block).to_arrow()
|
| 40 |
+
client = pymongo.MongoClient(uri)
|
| 41 |
+
write(client[database][collection], block)
|
| 42 |
+
|
| 43 |
+
builder = DelegatingBlockBuilder()
|
| 44 |
+
for block in blocks:
|
| 45 |
+
builder.add_block(block)
|
| 46 |
+
block = builder.build()
|
| 47 |
+
|
| 48 |
+
write_block(self.uri, self.database, self.collection, block)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasource.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from ray.data.block import Block, BlockMetadata
|
| 5 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
import pymongoarrow.api
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MongoDatasource(Datasource):
|
| 14 |
+
"""Datasource for reading from and writing to MongoDB."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
uri: str,
|
| 19 |
+
database: str,
|
| 20 |
+
collection: str,
|
| 21 |
+
pipeline: Optional[List[Dict]] = None,
|
| 22 |
+
schema: Optional["pymongoarrow.api.Schema"] = None,
|
| 23 |
+
**mongo_args,
|
| 24 |
+
):
|
| 25 |
+
self._uri = uri
|
| 26 |
+
self._database = database
|
| 27 |
+
self._collection = collection
|
| 28 |
+
self._pipeline = pipeline
|
| 29 |
+
self._schema = schema
|
| 30 |
+
self._mongo_args = mongo_args
|
| 31 |
+
# If pipeline is unspecified, read the entire collection.
|
| 32 |
+
if not pipeline:
|
| 33 |
+
self._pipeline = [{"$match": {"_id": {"$exists": "true"}}}]
|
| 34 |
+
# Initialize Mongo client lazily later when creating read tasks.
|
| 35 |
+
self._client = None
|
| 36 |
+
|
| 37 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 38 |
+
# TODO(jian): Add memory size estimation to improve auto-tune of parallelism.
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
def _get_match_query(self, pipeline: List[Dict]) -> Dict:
|
| 42 |
+
if len(pipeline) == 0 or "$match" not in pipeline[0]:
|
| 43 |
+
return {}
|
| 44 |
+
return pipeline[0]["$match"]
|
| 45 |
+
|
| 46 |
+
def _get_or_create_client(self):
|
| 47 |
+
import pymongo
|
| 48 |
+
|
| 49 |
+
if self._client is None:
|
| 50 |
+
self._client = pymongo.MongoClient(self._uri)
|
| 51 |
+
_validate_database_collection_exist(
|
| 52 |
+
self._client, self._database, self._collection
|
| 53 |
+
)
|
| 54 |
+
self._avg_obj_size = self._client[self._database].command(
|
| 55 |
+
"collstats", self._collection
|
| 56 |
+
)["avgObjSize"]
|
| 57 |
+
|
| 58 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 59 |
+
from bson.objectid import ObjectId
|
| 60 |
+
|
| 61 |
+
self._get_or_create_client()
|
| 62 |
+
coll = self._client[self._database][self._collection]
|
| 63 |
+
match_query = self._get_match_query(self._pipeline)
|
| 64 |
+
partitions_ids = list(
|
| 65 |
+
coll.aggregate(
|
| 66 |
+
[
|
| 67 |
+
{"$match": match_query},
|
| 68 |
+
{"$bucketAuto": {"groupBy": "$_id", "buckets": parallelism}},
|
| 69 |
+
],
|
| 70 |
+
allowDiskUse=True,
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def make_block(
|
| 75 |
+
uri: str,
|
| 76 |
+
database: str,
|
| 77 |
+
collection: str,
|
| 78 |
+
pipeline: List[Dict],
|
| 79 |
+
min_id: ObjectId,
|
| 80 |
+
max_id: ObjectId,
|
| 81 |
+
right_closed: bool,
|
| 82 |
+
schema: "pymongoarrow.api.Schema",
|
| 83 |
+
kwargs: dict,
|
| 84 |
+
) -> Block:
|
| 85 |
+
import pymongo
|
| 86 |
+
from pymongoarrow.api import aggregate_arrow_all
|
| 87 |
+
|
| 88 |
+
# A range query over the partition.
|
| 89 |
+
match = [
|
| 90 |
+
{
|
| 91 |
+
"$match": {
|
| 92 |
+
"_id": {
|
| 93 |
+
"$gte": min_id,
|
| 94 |
+
"$lte" if right_closed else "$lt": max_id,
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
]
|
| 99 |
+
client = pymongo.MongoClient(uri)
|
| 100 |
+
return aggregate_arrow_all(
|
| 101 |
+
client[database][collection], match + pipeline, schema=schema, **kwargs
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
read_tasks: List[ReadTask] = []
|
| 105 |
+
|
| 106 |
+
for i, partition in enumerate(partitions_ids):
|
| 107 |
+
metadata = BlockMetadata(
|
| 108 |
+
num_rows=partition["count"],
|
| 109 |
+
size_bytes=partition["count"] * self._avg_obj_size,
|
| 110 |
+
schema=None,
|
| 111 |
+
input_files=None,
|
| 112 |
+
exec_stats=None,
|
| 113 |
+
)
|
| 114 |
+
make_block_args = (
|
| 115 |
+
self._uri,
|
| 116 |
+
self._database,
|
| 117 |
+
self._collection,
|
| 118 |
+
self._pipeline,
|
| 119 |
+
partition["_id"]["min"],
|
| 120 |
+
partition["_id"]["max"],
|
| 121 |
+
i == len(partitions_ids) - 1,
|
| 122 |
+
self._schema,
|
| 123 |
+
self._mongo_args,
|
| 124 |
+
)
|
| 125 |
+
read_task = ReadTask(
|
| 126 |
+
lambda args=make_block_args: [make_block(*args)],
|
| 127 |
+
metadata,
|
| 128 |
+
)
|
| 129 |
+
read_tasks.append(read_task)
|
| 130 |
+
|
| 131 |
+
return read_tasks
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _validate_database_collection_exist(client, database: str, collection: str):
|
| 135 |
+
db_names = client.list_database_names()
|
| 136 |
+
if database not in db_names:
|
| 137 |
+
raise ValueError(f"The destination database {database} doesn't exist.")
|
| 138 |
+
collection_names = client[database].list_collection_names()
|
| 139 |
+
if collection not in collection_names:
|
| 140 |
+
raise ValueError(f"The destination collection {collection} doesn't exist.")
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasink.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pyarrow
|
| 3 |
+
|
| 4 |
+
from ray.data.block import BlockAccessor
|
| 5 |
+
from ray.data.datasource.file_datasink import BlockBasedFileDatasink
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NumpyDatasink(BlockBasedFileDatasink):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
path: str,
|
| 12 |
+
column: str,
|
| 13 |
+
*,
|
| 14 |
+
file_format: str = "npy",
|
| 15 |
+
**file_datasink_kwargs,
|
| 16 |
+
):
|
| 17 |
+
super().__init__(path, file_format=file_format, **file_datasink_kwargs)
|
| 18 |
+
|
| 19 |
+
self.column = column
|
| 20 |
+
|
| 21 |
+
def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
|
| 22 |
+
value = block.to_numpy(self.column)
|
| 23 |
+
np.save(file, value)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasource.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ray.data.block import Block, BlockAccessor
|
| 7 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
import pyarrow
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NumpyDatasource(FileBasedDatasource):
|
| 14 |
+
"""Numpy datasource, for reading and writing Numpy files."""
|
| 15 |
+
|
| 16 |
+
_COLUMN_NAME = "data"
|
| 17 |
+
_FILE_EXTENSIONS = ["npy"]
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
paths: Union[str, List[str]],
|
| 22 |
+
numpy_load_args: Optional[Dict[str, Any]] = None,
|
| 23 |
+
**file_based_datasource_kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 26 |
+
|
| 27 |
+
if numpy_load_args is None:
|
| 28 |
+
numpy_load_args = {}
|
| 29 |
+
|
| 30 |
+
self.numpy_load_args = numpy_load_args
|
| 31 |
+
|
| 32 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
|
| 33 |
+
# TODO(ekl) Ideally numpy can read directly from the file, but it
|
| 34 |
+
# seems like it requires the file to be seekable.
|
| 35 |
+
buf = BytesIO()
|
| 36 |
+
data = f.readall()
|
| 37 |
+
buf.write(data)
|
| 38 |
+
buf.seek(0)
|
| 39 |
+
yield BlockAccessor.batch_to_block(
|
| 40 |
+
{"data": np.load(buf, allow_pickle=True, **self.numpy_load_args)}
|
| 41 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_bulk_datasource.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 5 |
+
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
import pyarrow
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ParquetBulkDatasource(FileBasedDatasource):
|
| 14 |
+
"""Minimal Parquet datasource, for reading and writing Parquet files."""
|
| 15 |
+
|
| 16 |
+
_FILE_EXTENSIONS = ["parquet"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
paths: Union[str, List[str]],
|
| 21 |
+
read_table_args: Optional[Dict[str, Any]] = None,
|
| 22 |
+
**file_based_datasource_kwargs,
|
| 23 |
+
):
|
| 24 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 25 |
+
|
| 26 |
+
if read_table_args is None:
|
| 27 |
+
read_table_args = {}
|
| 28 |
+
|
| 29 |
+
self.read_table_args = read_table_args
|
| 30 |
+
|
| 31 |
+
def get_name(self):
|
| 32 |
+
"""Return a human-readable name for this datasource.
|
| 33 |
+
This will be used as the names of the read tasks.
|
| 34 |
+
Note: overrides the base `FileBasedDatasource` method.
|
| 35 |
+
"""
|
| 36 |
+
return "ParquetBulk"
|
| 37 |
+
|
| 38 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str):
|
| 39 |
+
import pyarrow.parquet as pq
|
| 40 |
+
|
| 41 |
+
use_threads = self.read_table_args.pop("use_threads", False)
|
| 42 |
+
yield pq.read_table(f, use_threads=use_threads, **self.read_table_args)
|
| 43 |
+
|
| 44 |
+
def _open_input_source(
|
| 45 |
+
self,
|
| 46 |
+
filesystem: "pyarrow.fs.FileSystem",
|
| 47 |
+
path: str,
|
| 48 |
+
**open_args,
|
| 49 |
+
) -> "pyarrow.NativeFile":
|
| 50 |
+
# Parquet requires `open_input_file` due to random access reads
|
| 51 |
+
return filesystem.open_input_file(path, **open_args)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasink.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import posixpath
|
| 3 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional
|
| 4 |
+
|
| 5 |
+
from ray.data._internal.arrow_ops.transform_pyarrow import concat
|
| 6 |
+
from ray.data._internal.execution.interfaces import TaskContext
|
| 7 |
+
from ray.data._internal.util import call_with_retry
|
| 8 |
+
from ray.data.block import Block, BlockAccessor
|
| 9 |
+
from ray.data.context import DataContext
|
| 10 |
+
from ray.data.datasource.file_based_datasource import _resolve_kwargs
|
| 11 |
+
from ray.data.datasource.file_datasink import _FileDatasink
|
| 12 |
+
from ray.data.datasource.filename_provider import FilenameProvider
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
import pyarrow
|
| 16 |
+
|
| 17 |
+
WRITE_FILE_MAX_ATTEMPTS = 10
|
| 18 |
+
WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ParquetDatasink(_FileDatasink):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
path: str,
|
| 27 |
+
*,
|
| 28 |
+
partition_cols: Optional[List[str]] = None,
|
| 29 |
+
arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
|
| 30 |
+
arrow_parquet_args: Optional[Dict[str, Any]] = None,
|
| 31 |
+
min_rows_per_file: Optional[int] = None,
|
| 32 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 33 |
+
try_create_dir: bool = True,
|
| 34 |
+
open_stream_args: Optional[Dict[str, Any]] = None,
|
| 35 |
+
filename_provider: Optional[FilenameProvider] = None,
|
| 36 |
+
dataset_uuid: Optional[str] = None,
|
| 37 |
+
):
|
| 38 |
+
if arrow_parquet_args_fn is None:
|
| 39 |
+
arrow_parquet_args_fn = lambda: {} # noqa: E731
|
| 40 |
+
|
| 41 |
+
if arrow_parquet_args is None:
|
| 42 |
+
arrow_parquet_args = {}
|
| 43 |
+
|
| 44 |
+
self.arrow_parquet_args_fn = arrow_parquet_args_fn
|
| 45 |
+
self.arrow_parquet_args = arrow_parquet_args
|
| 46 |
+
self.min_rows_per_file = min_rows_per_file
|
| 47 |
+
self.partition_cols = partition_cols
|
| 48 |
+
|
| 49 |
+
super().__init__(
|
| 50 |
+
path,
|
| 51 |
+
filesystem=filesystem,
|
| 52 |
+
try_create_dir=try_create_dir,
|
| 53 |
+
open_stream_args=open_stream_args,
|
| 54 |
+
filename_provider=filename_provider,
|
| 55 |
+
dataset_uuid=dataset_uuid,
|
| 56 |
+
file_format="parquet",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def write(
|
| 60 |
+
self,
|
| 61 |
+
blocks: Iterable[Block],
|
| 62 |
+
ctx: TaskContext,
|
| 63 |
+
) -> None:
|
| 64 |
+
import pyarrow as pa
|
| 65 |
+
|
| 66 |
+
blocks = list(blocks)
|
| 67 |
+
|
| 68 |
+
if all(BlockAccessor.for_block(block).num_rows() == 0 for block in blocks):
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
filename = self.filename_provider.get_filename_for_block(
|
| 72 |
+
blocks[0], ctx.task_idx, 0
|
| 73 |
+
)
|
| 74 |
+
write_kwargs = _resolve_kwargs(
|
| 75 |
+
self.arrow_parquet_args_fn, **self.arrow_parquet_args
|
| 76 |
+
)
|
| 77 |
+
user_schema = write_kwargs.pop("schema", None)
|
| 78 |
+
|
| 79 |
+
def write_blocks_to_path():
|
| 80 |
+
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
|
| 81 |
+
if user_schema is None:
|
| 82 |
+
output_schema = pa.unify_schemas([table.schema for table in tables])
|
| 83 |
+
else:
|
| 84 |
+
output_schema = user_schema
|
| 85 |
+
|
| 86 |
+
if not self.partition_cols:
|
| 87 |
+
self._write_single_file(tables, filename, output_schema, write_kwargs)
|
| 88 |
+
else: # partition writes
|
| 89 |
+
self._write_partition_files(
|
| 90 |
+
tables, filename, output_schema, write_kwargs
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
logger.debug(f"Writing {filename} file to {self.path}.")
|
| 94 |
+
|
| 95 |
+
call_with_retry(
|
| 96 |
+
write_blocks_to_path,
|
| 97 |
+
description=f"write '{filename}' to '{self.path}'",
|
| 98 |
+
match=DataContext.get_current().retried_io_errors,
|
| 99 |
+
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
|
| 100 |
+
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def _write_single_file(
|
| 104 |
+
self,
|
| 105 |
+
tables: List["pyarrow.Table"],
|
| 106 |
+
filename: str,
|
| 107 |
+
output_schema: "pyarrow.Schema",
|
| 108 |
+
write_kwargs: Dict[str, Any],
|
| 109 |
+
) -> None:
|
| 110 |
+
import pyarrow.parquet as pq
|
| 111 |
+
|
| 112 |
+
write_path = posixpath.join(self.path, filename)
|
| 113 |
+
with self.open_output_stream(write_path) as file:
|
| 114 |
+
with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
|
| 115 |
+
for table in tables:
|
| 116 |
+
table = table.cast(output_schema)
|
| 117 |
+
writer.write_table(table)
|
| 118 |
+
|
| 119 |
+
def _write_partition_files(
|
| 120 |
+
self,
|
| 121 |
+
tables: List["pyarrow.Table"],
|
| 122 |
+
filename: str,
|
| 123 |
+
output_schema: "pyarrow.Schema",
|
| 124 |
+
write_kwargs: Dict[str, Any],
|
| 125 |
+
) -> None:
|
| 126 |
+
import pyarrow as pa
|
| 127 |
+
import pyarrow.parquet as pq
|
| 128 |
+
|
| 129 |
+
table = concat(tables)
|
| 130 |
+
# Create unique combinations of the partition columns
|
| 131 |
+
table_fields = [
|
| 132 |
+
field for field in output_schema if field.name not in self.partition_cols
|
| 133 |
+
]
|
| 134 |
+
non_partition_cols = [f.name for f in table_fields]
|
| 135 |
+
output_schema = pa.schema(
|
| 136 |
+
[field for field in output_schema if field.name not in self.partition_cols]
|
| 137 |
+
)
|
| 138 |
+
# Group the table by partition keys
|
| 139 |
+
# For each partition key combination fetch list of values
|
| 140 |
+
# for the non partition columns
|
| 141 |
+
# Ex: Here original table contain
|
| 142 |
+
# two columns (a, b). We are paritioning by column a. The schema
|
| 143 |
+
# of `groups` grouped Table is as follows
|
| 144 |
+
# b_list: [[[0,0],[1,1],[2,2]]]
|
| 145 |
+
# a: [[1,2,3]]
|
| 146 |
+
groups = table.group_by(self.partition_cols).aggregate(
|
| 147 |
+
[(col_name, "list") for col_name in non_partition_cols]
|
| 148 |
+
)
|
| 149 |
+
grouped_keys = [groups.column(k) for k in self.partition_cols]
|
| 150 |
+
|
| 151 |
+
for i in range(groups.num_rows):
|
| 152 |
+
# See https://github.com/apache/arrow/issues/14882 for recommended approach
|
| 153 |
+
values = [
|
| 154 |
+
groups.column(f"{col.name}_list")[i].values for col in table_fields
|
| 155 |
+
]
|
| 156 |
+
group_table = pa.Table.from_arrays(values, names=non_partition_cols)
|
| 157 |
+
partition_path = "/".join(
|
| 158 |
+
[
|
| 159 |
+
f"{col}={values[i]}"
|
| 160 |
+
for col, values in zip(self.partition_cols, grouped_keys)
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
write_path = posixpath.join(self.path, partition_path)
|
| 164 |
+
self._create_dir(write_path)
|
| 165 |
+
write_path = posixpath.join(write_path, filename)
|
| 166 |
+
with self.open_output_stream(write_path) as file:
|
| 167 |
+
with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
|
| 168 |
+
writer.write_table(group_table)
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def min_rows_per_write(self) -> Optional[int]:
|
| 172 |
+
return self.min_rows_per_file
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasource.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import (
|
| 4 |
+
TYPE_CHECKING,
|
| 5 |
+
Any,
|
| 6 |
+
Callable,
|
| 7 |
+
Dict,
|
| 8 |
+
Iterator,
|
| 9 |
+
List,
|
| 10 |
+
Literal,
|
| 11 |
+
Optional,
|
| 12 |
+
Union,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
from packaging.version import parse as parse_version
|
| 17 |
+
|
| 18 |
+
import ray
|
| 19 |
+
import ray.cloudpickle as cloudpickle
|
| 20 |
+
from ray._private.utils import _get_pyarrow_version
|
| 21 |
+
from ray.data._internal.progress_bar import ProgressBar
|
| 22 |
+
from ray.data._internal.remote_fn import cached_remote_fn
|
| 23 |
+
from ray.data._internal.util import (
|
| 24 |
+
_check_pyarrow_version,
|
| 25 |
+
_is_local_scheme,
|
| 26 |
+
call_with_retry,
|
| 27 |
+
iterate_with_retry,
|
| 28 |
+
)
|
| 29 |
+
from ray.data.block import Block
|
| 30 |
+
from ray.data.context import DataContext
|
| 31 |
+
from ray.data.datasource import Datasource
|
| 32 |
+
from ray.data.datasource.datasource import ReadTask
|
| 33 |
+
from ray.data.datasource.file_based_datasource import FileShuffleConfig
|
| 34 |
+
from ray.data.datasource.file_meta_provider import (
|
| 35 |
+
DefaultFileMetadataProvider,
|
| 36 |
+
_handle_read_os_error,
|
| 37 |
+
)
|
| 38 |
+
from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider
|
| 39 |
+
from ray.data.datasource.partitioning import (
|
| 40 |
+
PartitionDataType,
|
| 41 |
+
Partitioning,
|
| 42 |
+
PathPartitionFilter,
|
| 43 |
+
PathPartitionParser,
|
| 44 |
+
)
|
| 45 |
+
from ray.data.datasource.path_util import (
|
| 46 |
+
_has_file_extension,
|
| 47 |
+
_resolve_paths_and_filesystem,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if TYPE_CHECKING:
|
| 51 |
+
import pyarrow
|
| 52 |
+
from pyarrow.dataset import ParquetFileFragment
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
# The `num_cpus` for each metadata prefetching task.
|
| 58 |
+
# Default to 0.5 instead of 1 because it is cheaper than normal read task.
|
| 59 |
+
NUM_CPUS_FOR_META_FETCH_TASK = 0.5
|
| 60 |
+
|
| 61 |
+
# The number of rows to read per batch. This is sized to generate 10MiB batches
|
| 62 |
+
# for rows about 1KiB in size.
|
| 63 |
+
PARQUET_READER_ROW_BATCH_SIZE = 10_000
|
| 64 |
+
FILE_READING_RETRY = 8
|
| 65 |
+
|
| 66 |
+
# The default size multiplier for reading Parquet data source in Arrow.
|
| 67 |
+
# Parquet data format is encoded with various encoding techniques (such as
|
| 68 |
+
# dictionary, RLE, delta), so Arrow in-memory representation uses much more memory
|
| 69 |
+
# compared to Parquet encoded representation. Parquet file statistics only record
|
| 70 |
+
# encoded (i.e. uncompressed) data size information.
|
| 71 |
+
#
|
| 72 |
+
# To estimate real-time in-memory data size, Datasets will try to estimate the
|
| 73 |
+
# correct inflation ratio from Parquet to Arrow, using this constant as the default
|
| 74 |
+
# value for safety. See https://github.com/ray-project/ray/pull/26516 for more context.
|
| 75 |
+
PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT = 5
|
| 76 |
+
|
| 77 |
+
# The lower bound size to estimate Parquet encoding ratio.
|
| 78 |
+
PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND = 1
|
| 79 |
+
|
| 80 |
+
# The percentage of files (1% by default) to be sampled from the dataset to estimate
|
| 81 |
+
# Parquet encoding ratio.
|
| 82 |
+
PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO = 0.01
|
| 83 |
+
|
| 84 |
+
# The minimal and maximal number of file samples to take from the dataset to estimate
|
| 85 |
+
# Parquet encoding ratio.
|
| 86 |
+
# This is to restrict `PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO` within the
|
| 87 |
+
# proper boundary.
|
| 88 |
+
PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES = 2
|
| 89 |
+
PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES = 10
|
| 90 |
+
|
| 91 |
+
# The number of rows to read from each file for sampling. Try to keep it low to avoid
|
| 92 |
+
# reading too much data into memory.
|
| 93 |
+
PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass(frozen=True)
|
| 97 |
+
class _SampleInfo:
|
| 98 |
+
actual_bytes_per_row: Optional[int]
|
| 99 |
+
estimated_bytes_per_row: Optional[int]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a
|
| 103 |
+
# raw pyarrow file fragment causes S3 network calls.
|
| 104 |
+
class SerializedFragment:
|
| 105 |
+
def __init__(self, frag: "ParquetFileFragment"):
|
| 106 |
+
self._data = cloudpickle.dumps(
|
| 107 |
+
(frag.format, frag.path, frag.filesystem, frag.partition_expression)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def deserialize(self) -> "ParquetFileFragment":
|
| 111 |
+
# Implicitly trigger S3 subsystem initialization by importing
|
| 112 |
+
# pyarrow.fs.
|
| 113 |
+
import pyarrow.fs # noqa: F401
|
| 114 |
+
|
| 115 |
+
(file_format, path, filesystem, partition_expression) = cloudpickle.loads(
|
| 116 |
+
self._data
|
| 117 |
+
)
|
| 118 |
+
return file_format.make_fragment(path, filesystem, partition_expression)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# Visible for test mocking.
|
| 122 |
+
def _deserialize_fragments(
|
| 123 |
+
serialized_fragments: List[SerializedFragment],
|
| 124 |
+
) -> List["pyarrow._dataset.ParquetFileFragment"]:
|
| 125 |
+
return [p.deserialize() for p in serialized_fragments]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def check_for_legacy_tensor_type(schema):
|
| 129 |
+
"""Check for the legacy tensor extension type and raise an error if found.
|
| 130 |
+
|
| 131 |
+
Ray Data uses an extension type to represent tensors in Arrow tables. Previously,
|
| 132 |
+
the extension type extended `PyExtensionType`. However, this base type can expose
|
| 133 |
+
users to arbitrary code execution. To prevent this, we don't load the type by
|
| 134 |
+
default.
|
| 135 |
+
"""
|
| 136 |
+
import pyarrow as pa
|
| 137 |
+
|
| 138 |
+
for name, type in zip(schema.names, schema.types):
|
| 139 |
+
if isinstance(type, pa.UnknownExtensionType) and isinstance(
|
| 140 |
+
type, pa.PyExtensionType
|
| 141 |
+
):
|
| 142 |
+
raise RuntimeError(
|
| 143 |
+
f"Ray Data couldn't infer the type of column '{name}'. This might mean "
|
| 144 |
+
"you're trying to read data written with an older version of Ray. "
|
| 145 |
+
"Reading data written with older versions of Ray might expose you to "
|
| 146 |
+
"arbitrary code execution. To try reading the data anyway, set "
|
| 147 |
+
"`RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE=1` on *all* nodes."
|
| 148 |
+
"To learn more, see https://github.com/ray-project/ray/issues/41314."
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ParquetDatasource(Datasource):
|
| 153 |
+
"""Parquet datasource, for reading and writing Parquet files.
|
| 154 |
+
|
| 155 |
+
The primary difference from ParquetBulkDatasource is that this uses
|
| 156 |
+
PyArrow's `ParquetDataset` abstraction for dataset reads, and thus offers
|
| 157 |
+
automatic Arrow dataset schema inference and row count collection at the
|
| 158 |
+
cost of some potential performance and/or compatibility penalties.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
paths: Union[str, List[str]],
|
| 164 |
+
*,
|
| 165 |
+
columns: Optional[List[str]] = None,
|
| 166 |
+
dataset_kwargs: Optional[Dict[str, Any]] = None,
|
| 167 |
+
to_batch_kwargs: Optional[Dict[str, Any]] = None,
|
| 168 |
+
_block_udf: Optional[Callable[[Block], Block]] = None,
|
| 169 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 170 |
+
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
|
| 171 |
+
meta_provider: ParquetMetadataProvider = ParquetMetadataProvider(),
|
| 172 |
+
partition_filter: PathPartitionFilter = None,
|
| 173 |
+
partitioning: Optional[Partitioning] = Partitioning("hive"),
|
| 174 |
+
shuffle: Union[Literal["files"], None] = None,
|
| 175 |
+
include_paths: bool = False,
|
| 176 |
+
file_extensions: Optional[List[str]] = None,
|
| 177 |
+
):
|
| 178 |
+
_check_pyarrow_version()
|
| 179 |
+
|
| 180 |
+
import pyarrow as pa
|
| 181 |
+
|
| 182 |
+
self._supports_distributed_reads = not _is_local_scheme(paths)
|
| 183 |
+
if not self._supports_distributed_reads and ray.util.client.ray.is_connected():
|
| 184 |
+
raise ValueError(
|
| 185 |
+
"Because you're using Ray Client, read tasks scheduled on the Ray "
|
| 186 |
+
"cluster can't access your local files. To fix this issue, store "
|
| 187 |
+
"files in cloud storage or a distributed filesystem like NFS."
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
self._local_scheduling = None
|
| 191 |
+
if not self._supports_distributed_reads:
|
| 192 |
+
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
| 193 |
+
|
| 194 |
+
self._local_scheduling = NodeAffinitySchedulingStrategy(
|
| 195 |
+
ray.get_runtime_context().get_node_id(), soft=False
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
|
| 199 |
+
|
| 200 |
+
# HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet
|
| 201 |
+
# files. To avoid this, we expand the input paths with the default metadata
|
| 202 |
+
# provider and then apply the partition filter or file extensions.
|
| 203 |
+
if partition_filter is not None or file_extensions is not None:
|
| 204 |
+
default_meta_provider = DefaultFileMetadataProvider()
|
| 205 |
+
expanded_paths, _ = map(
|
| 206 |
+
list, zip(*default_meta_provider.expand_paths(paths, filesystem))
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
paths = list(expanded_paths)
|
| 210 |
+
if partition_filter is not None:
|
| 211 |
+
paths = partition_filter(paths)
|
| 212 |
+
if file_extensions is not None:
|
| 213 |
+
paths = [
|
| 214 |
+
path for path in paths if _has_file_extension(path, file_extensions)
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
filtered_paths = set(expanded_paths) - set(paths)
|
| 218 |
+
if filtered_paths:
|
| 219 |
+
logger.info(f"Filtered out {len(filtered_paths)} paths")
|
| 220 |
+
|
| 221 |
+
if dataset_kwargs is None:
|
| 222 |
+
dataset_kwargs = {}
|
| 223 |
+
|
| 224 |
+
if "partitioning" in dataset_kwargs:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"The 'partitioning' parameter isn't supported in 'dataset_kwargs'. "
|
| 227 |
+
"Use the top-level 'partitioning' parameter instead."
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# This datasource manually adds partition data at the Ray Data-level. To avoid
|
| 231 |
+
# duplicating the partition data, we disable PyArrow's partitioning.
|
| 232 |
+
dataset_kwargs["partitioning"] = None
|
| 233 |
+
|
| 234 |
+
# `read_schema` is the schema object that will be used to perform
|
| 235 |
+
# read operations.
|
| 236 |
+
# It should be None, unless user has specified the schema or columns.
|
| 237 |
+
# We don't use the inferred schema for read, because the pyarrow only infers
|
| 238 |
+
# schema based on the first file. Thus, files with different schemas will end
|
| 239 |
+
# up producing blocks with wrong schema.
|
| 240 |
+
# See https://github.com/ray-project/ray/issues/47960 for more context.
|
| 241 |
+
read_schema = schema
|
| 242 |
+
pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs)
|
| 243 |
+
|
| 244 |
+
if schema is None:
|
| 245 |
+
schema = pq_ds.schema
|
| 246 |
+
schema = _add_partition_fields_to_schema(partitioning, schema, pq_ds)
|
| 247 |
+
|
| 248 |
+
if columns:
|
| 249 |
+
schema = pa.schema(
|
| 250 |
+
[schema.field(column) for column in columns], schema.metadata
|
| 251 |
+
)
|
| 252 |
+
read_schema = schema
|
| 253 |
+
|
| 254 |
+
check_for_legacy_tensor_type(schema)
|
| 255 |
+
|
| 256 |
+
if _block_udf is not None:
|
| 257 |
+
# Try to infer dataset schema by passing dummy table through UDF.
|
| 258 |
+
dummy_table = schema.empty_table()
|
| 259 |
+
try:
|
| 260 |
+
schema = _block_udf(dummy_table).schema.with_metadata(schema.metadata)
|
| 261 |
+
except Exception:
|
| 262 |
+
logger.debug(
|
| 263 |
+
"Failed to infer schema of dataset by passing dummy table "
|
| 264 |
+
"through UDF due to the following exception:",
|
| 265 |
+
exc_info=True,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
try:
|
| 269 |
+
prefetch_remote_args = {}
|
| 270 |
+
prefetch_remote_args["num_cpus"] = NUM_CPUS_FOR_META_FETCH_TASK
|
| 271 |
+
if self._local_scheduling:
|
| 272 |
+
prefetch_remote_args["scheduling_strategy"] = self._local_scheduling
|
| 273 |
+
else:
|
| 274 |
+
# Use the scheduling strategy ("SPREAD" by default) provided in
|
| 275 |
+
# `DataContext``, to spread out prefetch tasks in cluster, avoid
|
| 276 |
+
# AWS S3 throttling error.
|
| 277 |
+
# Note: this is the same scheduling strategy used by read tasks.
|
| 278 |
+
prefetch_remote_args[
|
| 279 |
+
"scheduling_strategy"
|
| 280 |
+
] = DataContext.get_current().scheduling_strategy
|
| 281 |
+
|
| 282 |
+
self._metadata = (
|
| 283 |
+
meta_provider.prefetch_file_metadata(
|
| 284 |
+
pq_ds.fragments, **prefetch_remote_args
|
| 285 |
+
)
|
| 286 |
+
or []
|
| 287 |
+
)
|
| 288 |
+
except OSError as e:
|
| 289 |
+
_handle_read_os_error(e, paths)
|
| 290 |
+
|
| 291 |
+
if to_batch_kwargs is None:
|
| 292 |
+
to_batch_kwargs = {}
|
| 293 |
+
|
| 294 |
+
# NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected
|
| 295 |
+
# network calls when `_ParquetDatasourceReader` is serialized. See
|
| 296 |
+
# `_SerializedFragment()` implementation for more details.
|
| 297 |
+
self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments]
|
| 298 |
+
self._pq_paths = [p.path for p in pq_ds.fragments]
|
| 299 |
+
self._meta_provider = meta_provider
|
| 300 |
+
self._block_udf = _block_udf
|
| 301 |
+
self._to_batches_kwargs = to_batch_kwargs
|
| 302 |
+
self._columns = columns
|
| 303 |
+
self._read_schema = read_schema
|
| 304 |
+
self._schema = schema
|
| 305 |
+
self._file_metadata_shuffler = None
|
| 306 |
+
self._include_paths = include_paths
|
| 307 |
+
self._partitioning = partitioning
|
| 308 |
+
if shuffle == "files":
|
| 309 |
+
self._file_metadata_shuffler = np.random.default_rng()
|
| 310 |
+
elif isinstance(shuffle, FileShuffleConfig):
|
| 311 |
+
self._file_metadata_shuffler = np.random.default_rng(shuffle.seed)
|
| 312 |
+
|
| 313 |
+
sample_infos = sample_fragments(
|
| 314 |
+
self._pq_fragments,
|
| 315 |
+
to_batches_kwargs=to_batch_kwargs,
|
| 316 |
+
columns=columns,
|
| 317 |
+
schema=self._read_schema,
|
| 318 |
+
local_scheduling=self._local_scheduling,
|
| 319 |
+
)
|
| 320 |
+
self._encoding_ratio = estimate_files_encoding_ratio(sample_infos)
|
| 321 |
+
self._default_read_batch_size_rows = estimate_default_read_batch_size_rows(
|
| 322 |
+
sample_infos
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 326 |
+
total_size = 0
|
| 327 |
+
for file_metadata in self._metadata:
|
| 328 |
+
total_size += file_metadata.total_byte_size
|
| 329 |
+
return total_size * self._encoding_ratio
|
| 330 |
+
|
| 331 |
+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
| 332 |
+
# NOTE: We override the base class FileBasedDatasource.get_read_tasks()
|
| 333 |
+
# method in order to leverage pyarrow's ParquetDataset abstraction,
|
| 334 |
+
# which simplifies partitioning logic. We still use
|
| 335 |
+
# FileBasedDatasource's write side, however.
|
| 336 |
+
pq_metadata = self._metadata
|
| 337 |
+
if len(pq_metadata) < len(self._pq_fragments):
|
| 338 |
+
# Pad `pq_metadata` to be same length of `self._pq_fragments`.
|
| 339 |
+
# This can happen when no file metadata being prefetched.
|
| 340 |
+
pq_metadata += [None] * (len(self._pq_fragments) - len(pq_metadata))
|
| 341 |
+
|
| 342 |
+
if self._file_metadata_shuffler is not None:
|
| 343 |
+
files_metadata = list(zip(self._pq_fragments, self._pq_paths, pq_metadata))
|
| 344 |
+
shuffled_files_metadata = [
|
| 345 |
+
files_metadata[i]
|
| 346 |
+
for i in self._file_metadata_shuffler.permutation(len(files_metadata))
|
| 347 |
+
]
|
| 348 |
+
pq_fragments, pq_paths, pq_metadata = list(
|
| 349 |
+
map(list, zip(*shuffled_files_metadata))
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
pq_fragments, pq_paths, pq_metadata = (
|
| 353 |
+
self._pq_fragments,
|
| 354 |
+
self._pq_paths,
|
| 355 |
+
pq_metadata,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
read_tasks = []
|
| 359 |
+
for fragments, paths, metadata in zip(
|
| 360 |
+
np.array_split(pq_fragments, parallelism),
|
| 361 |
+
np.array_split(pq_paths, parallelism),
|
| 362 |
+
np.array_split(pq_metadata, parallelism),
|
| 363 |
+
):
|
| 364 |
+
if len(fragments) <= 0:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
meta = self._meta_provider(
|
| 368 |
+
paths,
|
| 369 |
+
self._schema,
|
| 370 |
+
num_fragments=len(fragments),
|
| 371 |
+
prefetched_metadata=metadata,
|
| 372 |
+
)
|
| 373 |
+
# If there is a filter operation, reset the calculated row count,
|
| 374 |
+
# since the resulting row count is unknown.
|
| 375 |
+
if self._to_batches_kwargs.get("filter") is not None:
|
| 376 |
+
meta.num_rows = None
|
| 377 |
+
|
| 378 |
+
if meta.size_bytes is not None:
|
| 379 |
+
meta.size_bytes = int(meta.size_bytes * self._encoding_ratio)
|
| 380 |
+
|
| 381 |
+
(
|
| 382 |
+
block_udf,
|
| 383 |
+
to_batches_kwargs,
|
| 384 |
+
default_read_batch_size_rows,
|
| 385 |
+
columns,
|
| 386 |
+
read_schema,
|
| 387 |
+
include_paths,
|
| 388 |
+
partitioning,
|
| 389 |
+
) = (
|
| 390 |
+
self._block_udf,
|
| 391 |
+
self._to_batches_kwargs,
|
| 392 |
+
self._default_read_batch_size_rows,
|
| 393 |
+
self._columns,
|
| 394 |
+
self._read_schema,
|
| 395 |
+
self._include_paths,
|
| 396 |
+
self._partitioning,
|
| 397 |
+
)
|
| 398 |
+
read_tasks.append(
|
| 399 |
+
ReadTask(
|
| 400 |
+
lambda f=fragments: read_fragments(
|
| 401 |
+
block_udf,
|
| 402 |
+
to_batches_kwargs,
|
| 403 |
+
default_read_batch_size_rows,
|
| 404 |
+
columns,
|
| 405 |
+
read_schema,
|
| 406 |
+
f,
|
| 407 |
+
include_paths,
|
| 408 |
+
partitioning,
|
| 409 |
+
),
|
| 410 |
+
meta,
|
| 411 |
+
)
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
return read_tasks
|
| 415 |
+
|
| 416 |
+
def get_name(self):
|
| 417 |
+
"""Return a human-readable name for this datasource.
|
| 418 |
+
|
| 419 |
+
This will be used as the names of the read tasks.
|
| 420 |
+
"""
|
| 421 |
+
return "Parquet"
|
| 422 |
+
|
| 423 |
+
@property
|
| 424 |
+
def supports_distributed_reads(self) -> bool:
|
| 425 |
+
return self._supports_distributed_reads
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def read_fragments(
|
| 429 |
+
block_udf,
|
| 430 |
+
to_batches_kwargs,
|
| 431 |
+
default_read_batch_size_rows,
|
| 432 |
+
columns,
|
| 433 |
+
schema,
|
| 434 |
+
serialized_fragments: List[SerializedFragment],
|
| 435 |
+
include_paths: bool,
|
| 436 |
+
partitioning: Partitioning,
|
| 437 |
+
) -> Iterator["pyarrow.Table"]:
|
| 438 |
+
# This import is necessary to load the tensor extension type.
|
| 439 |
+
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
|
| 440 |
+
|
| 441 |
+
# Deserialize after loading the filesystem class.
|
| 442 |
+
fragments: List[
|
| 443 |
+
"pyarrow._dataset.ParquetFileFragment"
|
| 444 |
+
] = _deserialize_fragments_with_retry(serialized_fragments)
|
| 445 |
+
|
| 446 |
+
# Ensure that we're reading at least one dataset fragment.
|
| 447 |
+
assert len(fragments) > 0
|
| 448 |
+
|
| 449 |
+
import pyarrow as pa
|
| 450 |
+
|
| 451 |
+
logger.debug(f"Reading {len(fragments)} parquet fragments")
|
| 452 |
+
use_threads = to_batches_kwargs.pop("use_threads", False)
|
| 453 |
+
batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows)
|
| 454 |
+
for fragment in fragments:
|
| 455 |
+
partitions = {}
|
| 456 |
+
if partitioning is not None:
|
| 457 |
+
parse = PathPartitionParser(partitioning)
|
| 458 |
+
partitions = parse(fragment.path)
|
| 459 |
+
|
| 460 |
+
# Filter out partitions that aren't in the user-specified columns list.
|
| 461 |
+
if columns is not None:
|
| 462 |
+
partitions = {
|
| 463 |
+
field_name: value
|
| 464 |
+
for field_name, value in partitions.items()
|
| 465 |
+
if field_name in columns
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
def get_batch_iterable():
|
| 469 |
+
return fragment.to_batches(
|
| 470 |
+
use_threads=use_threads,
|
| 471 |
+
columns=columns,
|
| 472 |
+
schema=schema,
|
| 473 |
+
batch_size=batch_size,
|
| 474 |
+
**to_batches_kwargs,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# S3 can raise transient errors during iteration, and PyArrow doesn't expose a
|
| 478 |
+
# way to retry specific batches.
|
| 479 |
+
ctx = ray.data.DataContext.get_current()
|
| 480 |
+
for batch in iterate_with_retry(
|
| 481 |
+
get_batch_iterable, "load batch", match=ctx.retried_io_errors
|
| 482 |
+
):
|
| 483 |
+
table = pa.Table.from_batches([batch], schema=schema)
|
| 484 |
+
if include_paths:
|
| 485 |
+
table = table.append_column("path", [[fragment.path]] * len(table))
|
| 486 |
+
if partitions:
|
| 487 |
+
table = _add_partitions_to_table(partitions, table)
|
| 488 |
+
|
| 489 |
+
# If the table is empty, drop it.
|
| 490 |
+
if table.num_rows > 0:
|
| 491 |
+
if block_udf is not None:
|
| 492 |
+
yield block_udf(table)
|
| 493 |
+
else:
|
| 494 |
+
yield table
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def _deserialize_fragments_with_retry(fragments):
|
| 498 |
+
# The deserialization retry helps when the upstream datasource is not able to
|
| 499 |
+
# handle overloaded read request or failed with some retriable failures.
|
| 500 |
+
# For example when reading data from HA hdfs service, hdfs might
|
| 501 |
+
# lose connection for some unknown reason expecially when
|
| 502 |
+
# simutaneously running many hyper parameter tuning jobs
|
| 503 |
+
# with ray.data parallelism setting at high value like the default 200
|
| 504 |
+
# Such connection failure can be restored with some waiting and retry.
|
| 505 |
+
return call_with_retry(
|
| 506 |
+
lambda: _deserialize_fragments(fragments),
|
| 507 |
+
description="deserialize fragments",
|
| 508 |
+
max_attempts=FILE_READING_RETRY,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def _sample_fragment(
|
| 513 |
+
to_batches_kwargs,
|
| 514 |
+
columns,
|
| 515 |
+
schema,
|
| 516 |
+
file_fragment: SerializedFragment,
|
| 517 |
+
) -> _SampleInfo:
|
| 518 |
+
# Sample the first rows batch from file fragment `serialized_fragment`.
|
| 519 |
+
fragment = _deserialize_fragments_with_retry([file_fragment])[0]
|
| 520 |
+
|
| 521 |
+
# Only sample the first row group.
|
| 522 |
+
fragment = fragment.subset(row_group_ids=[0])
|
| 523 |
+
batch_size = max(
|
| 524 |
+
min(fragment.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1
|
| 525 |
+
)
|
| 526 |
+
# Use the batch_size calculated above, and ignore the one specified by user if set.
|
| 527 |
+
# This is to avoid sampling too few or too many rows.
|
| 528 |
+
to_batches_kwargs.pop("batch_size", None)
|
| 529 |
+
batches = fragment.to_batches(
|
| 530 |
+
columns=columns,
|
| 531 |
+
schema=schema,
|
| 532 |
+
batch_size=batch_size,
|
| 533 |
+
**to_batches_kwargs,
|
| 534 |
+
)
|
| 535 |
+
# Use first batch in-memory size for estimation.
|
| 536 |
+
try:
|
| 537 |
+
batch = next(batches)
|
| 538 |
+
except StopIteration:
|
| 539 |
+
sample_data = _SampleInfo(
|
| 540 |
+
actual_bytes_per_row=None, estimated_bytes_per_row=None
|
| 541 |
+
)
|
| 542 |
+
else:
|
| 543 |
+
if batch.num_rows > 0:
|
| 544 |
+
metadata = fragment.metadata
|
| 545 |
+
total_size = 0
|
| 546 |
+
for idx in range(metadata.num_row_groups):
|
| 547 |
+
total_size += metadata.row_group(idx).total_byte_size
|
| 548 |
+
sample_data = _SampleInfo(
|
| 549 |
+
actual_bytes_per_row=batch.nbytes / batch.num_rows,
|
| 550 |
+
estimated_bytes_per_row=total_size / metadata.num_rows,
|
| 551 |
+
)
|
| 552 |
+
else:
|
| 553 |
+
sample_data = _SampleInfo(
|
| 554 |
+
actual_bytes_per_row=None, estimated_bytes_per_row=None
|
| 555 |
+
)
|
| 556 |
+
return sample_data
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def estimate_files_encoding_ratio(sample_infos: List[_SampleInfo]) -> float:
|
| 560 |
+
"""Return an estimate of the Parquet files encoding ratio.
|
| 561 |
+
|
| 562 |
+
To avoid OOMs, it is safer to return an over-estimate than an underestimate.
|
| 563 |
+
"""
|
| 564 |
+
if not DataContext.get_current().decoding_size_estimation:
|
| 565 |
+
return PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT
|
| 566 |
+
|
| 567 |
+
def compute_encoding_ratio(sample_info: _SampleInfo) -> float:
|
| 568 |
+
if (
|
| 569 |
+
sample_info.actual_bytes_per_row is None
|
| 570 |
+
or sample_info.estimated_bytes_per_row is None
|
| 571 |
+
):
|
| 572 |
+
return PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
|
| 573 |
+
else:
|
| 574 |
+
return (
|
| 575 |
+
sample_info.actual_bytes_per_row / sample_info.estimated_bytes_per_row
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
ratio = np.mean(list(map(compute_encoding_ratio, sample_infos)))
|
| 579 |
+
logger.debug(f"Estimated Parquet encoding ratio from sampling is {ratio}.")
|
| 580 |
+
return max(ratio, PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def estimate_default_read_batch_size_rows(sample_infos: List[_SampleInfo]) -> int:
|
| 584 |
+
def compute_batch_size_rows(sample_info: _SampleInfo) -> int:
|
| 585 |
+
# 'actual_bytes_per_row' is None if the sampled file was empty and 0 if the data
|
| 586 |
+
# was all null.
|
| 587 |
+
if not sample_info.actual_bytes_per_row:
|
| 588 |
+
return PARQUET_READER_ROW_BATCH_SIZE
|
| 589 |
+
else:
|
| 590 |
+
max_parquet_reader_row_batch_size_bytes = (
|
| 591 |
+
DataContext.get_current().target_max_block_size // 10
|
| 592 |
+
)
|
| 593 |
+
return max(
|
| 594 |
+
1,
|
| 595 |
+
min(
|
| 596 |
+
PARQUET_READER_ROW_BATCH_SIZE,
|
| 597 |
+
max_parquet_reader_row_batch_size_bytes
|
| 598 |
+
// sample_info.actual_bytes_per_row,
|
| 599 |
+
),
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
return np.mean(list(map(compute_batch_size_rows, sample_infos)))
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def get_parquet_dataset(paths, filesystem, dataset_kwargs):
|
| 606 |
+
import pyarrow.parquet as pq
|
| 607 |
+
|
| 608 |
+
# If you pass a list containing a single directory path to `ParquetDataset`, PyArrow
|
| 609 |
+
# errors with 'IsADirectoryError: Path ... points to a directory, but only file
|
| 610 |
+
# paths are supported'. To avoid this, we pass the directory path directly.
|
| 611 |
+
if len(paths) == 1:
|
| 612 |
+
paths = paths[0]
|
| 613 |
+
|
| 614 |
+
try:
|
| 615 |
+
# The `use_legacy_dataset` parameter is deprecated in Arrow 15.
|
| 616 |
+
if parse_version(_get_pyarrow_version()) >= parse_version("15.0.0"):
|
| 617 |
+
dataset = pq.ParquetDataset(
|
| 618 |
+
paths,
|
| 619 |
+
**dataset_kwargs,
|
| 620 |
+
filesystem=filesystem,
|
| 621 |
+
)
|
| 622 |
+
else:
|
| 623 |
+
dataset = pq.ParquetDataset(
|
| 624 |
+
paths,
|
| 625 |
+
**dataset_kwargs,
|
| 626 |
+
filesystem=filesystem,
|
| 627 |
+
use_legacy_dataset=False,
|
| 628 |
+
)
|
| 629 |
+
except OSError as e:
|
| 630 |
+
_handle_read_os_error(e, paths)
|
| 631 |
+
|
| 632 |
+
return dataset
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def sample_fragments(
|
| 636 |
+
serialized_fragments,
|
| 637 |
+
*,
|
| 638 |
+
to_batches_kwargs,
|
| 639 |
+
columns,
|
| 640 |
+
schema,
|
| 641 |
+
local_scheduling=None,
|
| 642 |
+
) -> List[_SampleInfo]:
|
| 643 |
+
# Sample a few rows from Parquet files to estimate the encoding ratio.
|
| 644 |
+
# Launch tasks to sample multiple files remotely in parallel.
|
| 645 |
+
# Evenly distributed to sample N rows in i-th row group in i-th file.
|
| 646 |
+
# TODO(ekl/cheng) take into account column pruning.
|
| 647 |
+
num_files = len(serialized_fragments)
|
| 648 |
+
num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO)
|
| 649 |
+
min_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files)
|
| 650 |
+
max_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES, num_files)
|
| 651 |
+
num_samples = max(min(num_samples, max_num_samples), min_num_samples)
|
| 652 |
+
|
| 653 |
+
# Evenly distributed to choose which file to sample, to avoid biased prediction
|
| 654 |
+
# if data is skewed.
|
| 655 |
+
file_samples = [
|
| 656 |
+
serialized_fragments[idx]
|
| 657 |
+
for idx in np.linspace(0, num_files - 1, num_samples).astype(int).tolist()
|
| 658 |
+
]
|
| 659 |
+
|
| 660 |
+
sample_fragment = cached_remote_fn(_sample_fragment)
|
| 661 |
+
futures = []
|
| 662 |
+
scheduling = local_scheduling or DataContext.get_current().scheduling_strategy
|
| 663 |
+
for sample in file_samples:
|
| 664 |
+
# Sample the first rows batch in i-th file.
|
| 665 |
+
# Use SPREAD scheduling strategy to avoid packing many sampling tasks on
|
| 666 |
+
# same machine to cause OOM issue, as sampling can be memory-intensive.
|
| 667 |
+
futures.append(
|
| 668 |
+
sample_fragment.options(
|
| 669 |
+
scheduling_strategy=scheduling,
|
| 670 |
+
# Retry in case of transient errors during sampling.
|
| 671 |
+
retry_exceptions=[OSError],
|
| 672 |
+
).remote(
|
| 673 |
+
to_batches_kwargs,
|
| 674 |
+
columns,
|
| 675 |
+
schema,
|
| 676 |
+
sample,
|
| 677 |
+
)
|
| 678 |
+
)
|
| 679 |
+
sample_bar = ProgressBar("Parquet Files Sample", len(futures), unit="file")
|
| 680 |
+
sample_infos = sample_bar.fetch_until_complete(futures)
|
| 681 |
+
sample_bar.close()
|
| 682 |
+
|
| 683 |
+
return sample_infos
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def _add_partitions_to_table(
|
| 687 |
+
partitions: Dict[str, PartitionDataType], table: "pyarrow.Table"
|
| 688 |
+
) -> "pyarrow.Table":
|
| 689 |
+
import pyarrow as pa
|
| 690 |
+
|
| 691 |
+
for field_name, value in partitions.items():
|
| 692 |
+
column = pa.array([value] * len(table))
|
| 693 |
+
field_index = table.schema.get_field_index(field_name)
|
| 694 |
+
if field_index != -1:
|
| 695 |
+
table = table.set_column(field_index, field_name, column)
|
| 696 |
+
else:
|
| 697 |
+
table = table.append_column(field_name, column)
|
| 698 |
+
|
| 699 |
+
return table
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def _add_partition_fields_to_schema(
|
| 703 |
+
partitioning: Partitioning,
|
| 704 |
+
schema: "pyarrow.Schema",
|
| 705 |
+
parquet_dataset: "pyarrow.dataset.Dataset",
|
| 706 |
+
) -> "pyarrow.Schema":
|
| 707 |
+
"""Return a new schema with partition fields added.
|
| 708 |
+
|
| 709 |
+
This function infers the partition fields from the first file path in the dataset.
|
| 710 |
+
"""
|
| 711 |
+
import pyarrow as pa
|
| 712 |
+
|
| 713 |
+
# If the dataset is empty, we can't infer the partitioning.
|
| 714 |
+
if len(parquet_dataset.fragments) == 0:
|
| 715 |
+
return schema
|
| 716 |
+
|
| 717 |
+
# If the dataset isn't partitioned, we don't need to add any fields.
|
| 718 |
+
if partitioning is None:
|
| 719 |
+
return schema
|
| 720 |
+
|
| 721 |
+
first_path = parquet_dataset.fragments[0].path
|
| 722 |
+
parse = PathPartitionParser(partitioning)
|
| 723 |
+
partitions = parse(first_path)
|
| 724 |
+
for field_name in partitions:
|
| 725 |
+
if field_name in partitioning.field_types:
|
| 726 |
+
field_type = pa.from_numpy_dtype(partitioning.field_types[field_name])
|
| 727 |
+
else:
|
| 728 |
+
field_type = pa.string()
|
| 729 |
+
schema = schema.append(pa.field(field_name, field_type))
|
| 730 |
+
|
| 731 |
+
return schema
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/range_datasource.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import builtins
|
| 2 |
+
import functools
|
| 3 |
+
from copy import copy
|
| 4 |
+
from typing import Iterable, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ray.data._internal.util import _check_pyarrow_version
|
| 9 |
+
from ray.data.block import Block, BlockAccessor, BlockMetadata
|
| 10 |
+
from ray.data.context import DataContext
|
| 11 |
+
from ray.data.datasource import Datasource, ReadTask
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RangeDatasource(Datasource):
|
| 15 |
+
"""An example datasource that generates ranges of numbers from [0..n)."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
n: int,
|
| 20 |
+
block_format: str = "arrow",
|
| 21 |
+
tensor_shape: Tuple = (1,),
|
| 22 |
+
column_name: Optional[str] = None,
|
| 23 |
+
):
|
| 24 |
+
self._n = int(n)
|
| 25 |
+
self._block_format = block_format
|
| 26 |
+
self._tensor_shape = tensor_shape
|
| 27 |
+
self._column_name = column_name
|
| 28 |
+
|
| 29 |
+
def estimate_inmemory_data_size(self) -> Optional[int]:
|
| 30 |
+
if self._block_format == "tensor":
|
| 31 |
+
element_size = int(np.prod(self._tensor_shape))
|
| 32 |
+
else:
|
| 33 |
+
element_size = 1
|
| 34 |
+
return 8 * self._n * element_size
|
| 35 |
+
|
| 36 |
+
def get_read_tasks(
|
| 37 |
+
self,
|
| 38 |
+
parallelism: int,
|
| 39 |
+
) -> List[ReadTask]:
|
| 40 |
+
read_tasks: List[ReadTask] = []
|
| 41 |
+
n = self._n
|
| 42 |
+
block_format = self._block_format
|
| 43 |
+
tensor_shape = self._tensor_shape
|
| 44 |
+
block_size = max(1, n // parallelism)
|
| 45 |
+
# TODO(swang): This target block size may not match the driver's
|
| 46 |
+
# context if it was overridden. Set target max block size during
|
| 47 |
+
# optimizer stage to fix this.
|
| 48 |
+
ctx = DataContext.get_current()
|
| 49 |
+
if self._n == 0:
|
| 50 |
+
target_rows_per_block = 0
|
| 51 |
+
else:
|
| 52 |
+
row_size_bytes = self.estimate_inmemory_data_size() // self._n
|
| 53 |
+
row_size_bytes = max(row_size_bytes, 1)
|
| 54 |
+
target_rows_per_block = max(1, ctx.target_max_block_size // row_size_bytes)
|
| 55 |
+
|
| 56 |
+
# Example of a read task. In a real datasource, this would pull data
|
| 57 |
+
# from an external system instead of generating dummy data.
|
| 58 |
+
def make_block(start: int, count: int) -> Block:
|
| 59 |
+
if block_format == "arrow":
|
| 60 |
+
import pyarrow as pa
|
| 61 |
+
|
| 62 |
+
return pa.Table.from_arrays(
|
| 63 |
+
[np.arange(start, start + count)],
|
| 64 |
+
names=[self._column_name or "value"],
|
| 65 |
+
)
|
| 66 |
+
elif block_format == "tensor":
|
| 67 |
+
import pyarrow as pa
|
| 68 |
+
|
| 69 |
+
tensor = np.ones(tensor_shape, dtype=np.int64) * np.expand_dims(
|
| 70 |
+
np.arange(start, start + count),
|
| 71 |
+
tuple(range(1, 1 + len(tensor_shape))),
|
| 72 |
+
)
|
| 73 |
+
return BlockAccessor.batch_to_block(
|
| 74 |
+
{self._column_name: tensor} if self._column_name else tensor
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
return list(builtins.range(start, start + count))
|
| 78 |
+
|
| 79 |
+
def make_blocks(
|
| 80 |
+
start: int, count: int, target_rows_per_block: int
|
| 81 |
+
) -> Iterable[Block]:
|
| 82 |
+
while count > 0:
|
| 83 |
+
num_rows = min(count, target_rows_per_block)
|
| 84 |
+
yield make_block(start, num_rows)
|
| 85 |
+
start += num_rows
|
| 86 |
+
count -= num_rows
|
| 87 |
+
|
| 88 |
+
if block_format == "tensor":
|
| 89 |
+
element_size = int(np.prod(tensor_shape))
|
| 90 |
+
else:
|
| 91 |
+
element_size = 1
|
| 92 |
+
|
| 93 |
+
i = 0
|
| 94 |
+
while i < n:
|
| 95 |
+
count = min(block_size, n - i)
|
| 96 |
+
meta = BlockMetadata(
|
| 97 |
+
num_rows=count,
|
| 98 |
+
size_bytes=8 * count * element_size,
|
| 99 |
+
schema=copy(self._schema),
|
| 100 |
+
input_files=None,
|
| 101 |
+
exec_stats=None,
|
| 102 |
+
)
|
| 103 |
+
read_tasks.append(
|
| 104 |
+
ReadTask(
|
| 105 |
+
lambda i=i, count=count: make_blocks(
|
| 106 |
+
i, count, target_rows_per_block
|
| 107 |
+
),
|
| 108 |
+
meta,
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
i += block_size
|
| 112 |
+
|
| 113 |
+
return read_tasks
|
| 114 |
+
|
| 115 |
+
@functools.cached_property
|
| 116 |
+
def _schema(self):
|
| 117 |
+
if self._n == 0:
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
if self._block_format == "arrow":
|
| 121 |
+
_check_pyarrow_version()
|
| 122 |
+
import pyarrow as pa
|
| 123 |
+
|
| 124 |
+
schema = pa.Table.from_pydict({self._column_name or "value": [0]}).schema
|
| 125 |
+
elif self._block_format == "tensor":
|
| 126 |
+
_check_pyarrow_version()
|
| 127 |
+
import pyarrow as pa
|
| 128 |
+
|
| 129 |
+
tensor = np.ones(self._tensor_shape, dtype=np.int64) * np.expand_dims(
|
| 130 |
+
np.arange(0, 10), tuple(range(1, 1 + len(self._tensor_shape)))
|
| 131 |
+
)
|
| 132 |
+
schema = BlockAccessor.batch_to_block(
|
| 133 |
+
{self._column_name: tensor} if self._column_name else tensor
|
| 134 |
+
).schema
|
| 135 |
+
elif self._block_format == "list":
|
| 136 |
+
schema = int
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError("Unsupported block type", self._block_format)
|
| 139 |
+
return schema
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/sql_datasink.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Iterable
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.datasource.sql_datasource import Connection, _connect
|
| 4 |
+
from ray.data._internal.execution.interfaces import TaskContext
|
| 5 |
+
from ray.data.block import Block, BlockAccessor
|
| 6 |
+
from ray.data.datasource.datasink import Datasink
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SQLDatasink(Datasink[None]):
|
| 10 |
+
|
| 11 |
+
_MAX_ROWS_PER_WRITE = 128
|
| 12 |
+
|
| 13 |
+
def __init__(self, sql: str, connection_factory: Callable[[], Connection]):
|
| 14 |
+
self.sql = sql
|
| 15 |
+
self.connection_factory = connection_factory
|
| 16 |
+
|
| 17 |
+
def write(
|
| 18 |
+
self,
|
| 19 |
+
blocks: Iterable[Block],
|
| 20 |
+
ctx: TaskContext,
|
| 21 |
+
) -> None:
|
| 22 |
+
with _connect(self.connection_factory) as cursor:
|
| 23 |
+
for block in blocks:
|
| 24 |
+
block_accessor = BlockAccessor.for_block(block)
|
| 25 |
+
|
| 26 |
+
values = []
|
| 27 |
+
for row in block_accessor.iter_rows(public_row_format=False):
|
| 28 |
+
values.append(tuple(row.values()))
|
| 29 |
+
assert len(values) <= self._MAX_ROWS_PER_WRITE, len(values)
|
| 30 |
+
if len(values) == self._MAX_ROWS_PER_WRITE:
|
| 31 |
+
cursor.executemany(self.sql, values)
|
| 32 |
+
values = []
|
| 33 |
+
|
| 34 |
+
if values:
|
| 35 |
+
cursor.executemany(self.sql, values)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/text_datasource.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Iterator, List
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 4 |
+
from ray.data.block import Block
|
| 5 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
import pyarrow
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TextDatasource(FileBasedDatasource):
|
| 12 |
+
"""Text datasource, for reading and writing text files."""
|
| 13 |
+
|
| 14 |
+
_COLUMN_NAME = "text"
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
paths: List[str],
|
| 19 |
+
*,
|
| 20 |
+
drop_empty_lines: bool = False,
|
| 21 |
+
encoding: str = "utf-8",
|
| 22 |
+
**file_based_datasource_kwargs
|
| 23 |
+
):
|
| 24 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 25 |
+
|
| 26 |
+
self.drop_empty_lines = drop_empty_lines
|
| 27 |
+
self.encoding = encoding
|
| 28 |
+
|
| 29 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
|
| 30 |
+
data = f.readall()
|
| 31 |
+
|
| 32 |
+
builder = DelegatingBlockBuilder()
|
| 33 |
+
|
| 34 |
+
lines = data.decode(self.encoding).split("\n")
|
| 35 |
+
for line in lines:
|
| 36 |
+
if self.drop_empty_lines and line.strip() == "":
|
| 37 |
+
continue
|
| 38 |
+
item = {self._COLUMN_NAME: line}
|
| 39 |
+
builder.add(item)
|
| 40 |
+
|
| 41 |
+
block = builder.build()
|
| 42 |
+
yield block
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasink.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import struct
|
| 2 |
+
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .tfrecords_datasource import _get_single_true_type
|
| 7 |
+
from ray.data._internal.util import _check_import
|
| 8 |
+
from ray.data.block import BlockAccessor
|
| 9 |
+
from ray.data.datasource.file_datasink import BlockBasedFileDatasink
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
import pyarrow
|
| 13 |
+
import tensorflow as tf
|
| 14 |
+
from tensorflow_metadata.proto.v0 import schema_pb2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TFRecordDatasink(BlockBasedFileDatasink):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
path: str,
|
| 21 |
+
*,
|
| 22 |
+
tf_schema: Optional["schema_pb2.Schema"] = None,
|
| 23 |
+
file_format: str = "tar",
|
| 24 |
+
**file_datasink_kwargs,
|
| 25 |
+
):
|
| 26 |
+
super().__init__(path, file_format=file_format, **file_datasink_kwargs)
|
| 27 |
+
|
| 28 |
+
_check_import(self, module="crc32c", package="crc32c")
|
| 29 |
+
|
| 30 |
+
self.tf_schema = tf_schema
|
| 31 |
+
|
| 32 |
+
def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
|
| 33 |
+
arrow_table = block.to_arrow()
|
| 34 |
+
|
| 35 |
+
# It seems like TFRecords are typically row-based,
|
| 36 |
+
# https://www.tensorflow.org/tutorials/load_data/tfrecord#writing_a_tfrecord_file_2
|
| 37 |
+
# so we must iterate through the rows of the block,
|
| 38 |
+
# serialize to tf.train.Example proto, and write to file.
|
| 39 |
+
|
| 40 |
+
examples = _convert_arrow_table_to_examples(arrow_table, self.tf_schema)
|
| 41 |
+
|
| 42 |
+
# Write each example to the arrow file in the TFRecord format.
|
| 43 |
+
for example in examples:
|
| 44 |
+
_write_record(file, example)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _convert_arrow_table_to_examples(
|
| 48 |
+
arrow_table: "pyarrow.Table",
|
| 49 |
+
tf_schema: Optional["schema_pb2.Schema"] = None,
|
| 50 |
+
) -> Iterable["tf.train.Example"]:
|
| 51 |
+
import tensorflow as tf
|
| 52 |
+
|
| 53 |
+
schema_dict = {}
|
| 54 |
+
# Convert user-specified schema into dict for convenient mapping
|
| 55 |
+
if tf_schema is not None:
|
| 56 |
+
for schema_feature in tf_schema.feature:
|
| 57 |
+
schema_dict[schema_feature.name] = schema_feature.type
|
| 58 |
+
|
| 59 |
+
# Serialize each row[i] of the block to a tf.train.Example and yield it.
|
| 60 |
+
for i in range(arrow_table.num_rows):
|
| 61 |
+
# First, convert row[i] to a dictionary.
|
| 62 |
+
features: Dict[str, "tf.train.Feature"] = {}
|
| 63 |
+
for name in arrow_table.column_names:
|
| 64 |
+
if tf_schema is not None and name not in schema_dict:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
f"Found extra unexpected feature {name} "
|
| 67 |
+
f"not in specified schema: {tf_schema}"
|
| 68 |
+
)
|
| 69 |
+
schema_feature_type = schema_dict.get(name)
|
| 70 |
+
features[name] = _value_to_feature(
|
| 71 |
+
arrow_table[name][i],
|
| 72 |
+
schema_feature_type,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Convert the dictionary to an Example proto.
|
| 76 |
+
proto = tf.train.Example(features=tf.train.Features(feature=features))
|
| 77 |
+
|
| 78 |
+
yield proto
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _value_to_feature(
|
| 82 |
+
value: Union["pyarrow.Scalar", "pyarrow.Array"],
|
| 83 |
+
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
|
| 84 |
+
) -> "tf.train.Feature":
|
| 85 |
+
import pyarrow as pa
|
| 86 |
+
import tensorflow as tf
|
| 87 |
+
|
| 88 |
+
if isinstance(value, pa.ListScalar):
|
| 89 |
+
# Use the underlying type of the ListScalar's value in
|
| 90 |
+
# determining the output feature's data type.
|
| 91 |
+
value_type = value.type.value_type
|
| 92 |
+
value = value.as_py()
|
| 93 |
+
else:
|
| 94 |
+
value_type = value.type
|
| 95 |
+
value = value.as_py()
|
| 96 |
+
if value is None:
|
| 97 |
+
value = []
|
| 98 |
+
else:
|
| 99 |
+
value = [value]
|
| 100 |
+
|
| 101 |
+
underlying_value_type = {
|
| 102 |
+
"bytes": pa.types.is_binary(value_type),
|
| 103 |
+
"string": pa.types.is_string(value_type),
|
| 104 |
+
"float": pa.types.is_floating(value_type),
|
| 105 |
+
"int": pa.types.is_integer(value_type),
|
| 106 |
+
}
|
| 107 |
+
assert sum(bool(value) for value in underlying_value_type.values()) <= 1
|
| 108 |
+
|
| 109 |
+
if schema_feature_type is not None:
|
| 110 |
+
try:
|
| 111 |
+
from tensorflow_metadata.proto.v0 import schema_pb2
|
| 112 |
+
except ModuleNotFoundError:
|
| 113 |
+
raise ModuleNotFoundError(
|
| 114 |
+
"To use TensorFlow schemas, please install "
|
| 115 |
+
"the tensorflow-metadata package."
|
| 116 |
+
)
|
| 117 |
+
specified_feature_type = {
|
| 118 |
+
"bytes": schema_feature_type == schema_pb2.FeatureType.BYTES
|
| 119 |
+
and not underlying_value_type["string"],
|
| 120 |
+
"string": schema_feature_type == schema_pb2.FeatureType.BYTES
|
| 121 |
+
and underlying_value_type["string"],
|
| 122 |
+
"float": schema_feature_type == schema_pb2.FeatureType.FLOAT,
|
| 123 |
+
"int": schema_feature_type == schema_pb2.FeatureType.INT,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
und_type = _get_single_true_type(underlying_value_type)
|
| 127 |
+
spec_type = _get_single_true_type(specified_feature_type)
|
| 128 |
+
if und_type is not None and und_type != spec_type:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
"Schema field type mismatch during write: specified type is "
|
| 131 |
+
f"{spec_type}, but underlying type is {und_type}",
|
| 132 |
+
)
|
| 133 |
+
# Override the underlying value type with the type in the user-specified schema.
|
| 134 |
+
underlying_value_type = specified_feature_type
|
| 135 |
+
|
| 136 |
+
if underlying_value_type["int"]:
|
| 137 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
| 138 |
+
if underlying_value_type["float"]:
|
| 139 |
+
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
| 140 |
+
if underlying_value_type["bytes"]:
|
| 141 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
| 142 |
+
if underlying_value_type["string"]:
|
| 143 |
+
value = [v.encode() for v in value] # casting to bytes
|
| 144 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
| 145 |
+
if pa.types.is_null(value_type):
|
| 146 |
+
raise ValueError(
|
| 147 |
+
"Unable to infer type from partially missing column. "
|
| 148 |
+
"Try setting read parallelism = 1, or use an input data source which "
|
| 149 |
+
"explicitly specifies the schema."
|
| 150 |
+
)
|
| 151 |
+
raise ValueError(
|
| 152 |
+
f"Value is of type {value_type}, "
|
| 153 |
+
"which we cannot convert to a supported tf.train.Feature storage type "
|
| 154 |
+
"(bytes, float, or int)."
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501
|
| 159 |
+
#
|
| 160 |
+
# MIT License
|
| 161 |
+
#
|
| 162 |
+
# Copyright (c) 2020 Vahid Kazemi
|
| 163 |
+
#
|
| 164 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 165 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 166 |
+
# in the Software without restriction, including without limitation the rights
|
| 167 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 168 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 169 |
+
# furnished to do so, subject to the following conditions:
|
| 170 |
+
#
|
| 171 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 172 |
+
# copies or substantial portions of the Software.
|
| 173 |
+
#
|
| 174 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 175 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 176 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 177 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 178 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 179 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 180 |
+
# SOFTWARE.
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _write_record(
|
| 184 |
+
file: "pyarrow.NativeFile",
|
| 185 |
+
example: "tf.train.Example",
|
| 186 |
+
) -> None:
|
| 187 |
+
record = example.SerializeToString()
|
| 188 |
+
length = len(record)
|
| 189 |
+
length_bytes = struct.pack("<Q", length)
|
| 190 |
+
file.write(length_bytes)
|
| 191 |
+
file.write(_masked_crc(length_bytes))
|
| 192 |
+
file.write(record)
|
| 193 |
+
file.write(_masked_crc(record))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _masked_crc(data: bytes) -> bytes:
|
| 197 |
+
"""CRC checksum."""
|
| 198 |
+
import crc32c
|
| 199 |
+
|
| 200 |
+
mask = 0xA282EAD8
|
| 201 |
+
crc = crc32c.crc32(data)
|
| 202 |
+
masked = ((crc >> 15) | (crc << 17)) + mask
|
| 203 |
+
masked = np.uint32(masked & np.iinfo(np.uint32).max)
|
| 204 |
+
masked_bytes = struct.pack("<I", masked)
|
| 205 |
+
return masked_bytes
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasource.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import struct
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import pyarrow
|
| 7 |
+
|
| 8 |
+
from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
|
| 9 |
+
from ray.data.aggregate import AggregateFn
|
| 10 |
+
from ray.data.block import Block
|
| 11 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 12 |
+
from ray.util.annotations import PublicAPI
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
from tensorflow_metadata.proto.v0 import schema_pb2
|
| 18 |
+
|
| 19 |
+
from ray.data.dataset import Dataset
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@PublicAPI(stability="alpha")
|
| 25 |
+
@dataclass
|
| 26 |
+
class TFXReadOptions:
|
| 27 |
+
"""
|
| 28 |
+
Specifies read options when reading TFRecord files with TFX.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# An int representing the number of consecutive elements of
|
| 32 |
+
# this dataset to combine in a single batch when tfx-bsl is used to read
|
| 33 |
+
# the tfrecord files.
|
| 34 |
+
batch_size: int = 2048
|
| 35 |
+
|
| 36 |
+
# Toggles the schema inference applied; applicable
|
| 37 |
+
# only if tfx-bsl is used and tf_schema argument is missing.
|
| 38 |
+
# Defaults to True.
|
| 39 |
+
auto_infer_schema: bool = True
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TFRecordDatasource(FileBasedDatasource):
|
| 43 |
+
"""TFRecord datasource, for reading and writing TFRecord files."""
|
| 44 |
+
|
| 45 |
+
_FILE_EXTENSIONS = ["tfrecords"]
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
paths: Union[str, List[str]],
|
| 50 |
+
tf_schema: Optional["schema_pb2.Schema"] = None,
|
| 51 |
+
tfx_read_options: Optional["TFXReadOptions"] = None,
|
| 52 |
+
**file_based_datasource_kwargs,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
tf_schema: Optional TensorFlow Schema which is used to explicitly set
|
| 57 |
+
the schema of the underlying Dataset.
|
| 58 |
+
tfx_read_options: Optional options for enabling reading tfrecords
|
| 59 |
+
using tfx-bsl.
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 63 |
+
|
| 64 |
+
self._tf_schema = tf_schema
|
| 65 |
+
self._tfx_read_options = tfx_read_options
|
| 66 |
+
|
| 67 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
|
| 68 |
+
if self._tfx_read_options:
|
| 69 |
+
yield from self._tfx_read_stream(f, path)
|
| 70 |
+
else:
|
| 71 |
+
yield from self._default_read_stream(f, path)
|
| 72 |
+
|
| 73 |
+
def _default_read_stream(
|
| 74 |
+
self, f: "pyarrow.NativeFile", path: str
|
| 75 |
+
) -> Iterator[Block]:
|
| 76 |
+
import tensorflow as tf
|
| 77 |
+
from google.protobuf.message import DecodeError
|
| 78 |
+
|
| 79 |
+
for record in _read_records(f, path):
|
| 80 |
+
example = tf.train.Example()
|
| 81 |
+
try:
|
| 82 |
+
example.ParseFromString(record)
|
| 83 |
+
except DecodeError as e:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
"`TFRecordDatasource` failed to parse `tf.train.Example` "
|
| 86 |
+
f"record in '{path}'. This error can occur if your TFRecord "
|
| 87 |
+
f"file contains a message type other than `tf.train.Example`: {e}"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
yield pyarrow_table_from_pydict(
|
| 91 |
+
_convert_example_to_dict(example, self._tf_schema)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
|
| 95 |
+
import tensorflow as tf
|
| 96 |
+
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder
|
| 97 |
+
|
| 98 |
+
full_path = self._resolve_full_path(path)
|
| 99 |
+
|
| 100 |
+
compression = (self._open_stream_args or {}).get("compression", None)
|
| 101 |
+
|
| 102 |
+
if compression:
|
| 103 |
+
compression = compression.upper()
|
| 104 |
+
|
| 105 |
+
tf_schema_string = (
|
| 106 |
+
self._tf_schema.SerializeToString() if self._tf_schema else None
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
|
| 110 |
+
exception_thrown = None
|
| 111 |
+
try:
|
| 112 |
+
for record in tf.data.TFRecordDataset(
|
| 113 |
+
full_path, compression_type=compression
|
| 114 |
+
).batch(self._tfx_read_options.batch_size):
|
| 115 |
+
yield _cast_large_list_to_list(
|
| 116 |
+
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
|
| 117 |
+
)
|
| 118 |
+
except Exception as error:
|
| 119 |
+
logger.exception(f"Failed to read TFRecord file {full_path}")
|
| 120 |
+
exception_thrown = error
|
| 121 |
+
|
| 122 |
+
# we need to do this hack were we raise an exception outside of the
|
| 123 |
+
# except block because tensorflow DataLossError is unpickable, and
|
| 124 |
+
# even if we raise a runtime error, ray keeps information about the
|
| 125 |
+
# original error, which makes it unpickable still.
|
| 126 |
+
if exception_thrown:
|
| 127 |
+
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")
|
| 128 |
+
|
| 129 |
+
def _resolve_full_path(self, relative_path):
|
| 130 |
+
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
|
| 131 |
+
return f"s3://{relative_path}"
|
| 132 |
+
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
|
| 133 |
+
return f"gs://{relative_path}"
|
| 134 |
+
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
|
| 135 |
+
return f"hdfs:///{relative_path}"
|
| 136 |
+
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
|
| 137 |
+
protocol = self._filesystem.handler.fs.protocol
|
| 138 |
+
if isinstance(protocol, list) or isinstance(protocol, tuple):
|
| 139 |
+
protocol = protocol[0]
|
| 140 |
+
if protocol == "gcs":
|
| 141 |
+
protocol = "gs"
|
| 142 |
+
return f"{protocol}://{relative_path}"
|
| 143 |
+
|
| 144 |
+
return relative_path
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _convert_example_to_dict(
|
| 148 |
+
example: "tf.train.Example",
|
| 149 |
+
tf_schema: Optional["schema_pb2.Schema"],
|
| 150 |
+
) -> Dict[str, pyarrow.Array]:
|
| 151 |
+
record = {}
|
| 152 |
+
schema_dict = {}
|
| 153 |
+
# Convert user-specified schema into dict for convenient mapping
|
| 154 |
+
if tf_schema is not None:
|
| 155 |
+
for schema_feature in tf_schema.feature:
|
| 156 |
+
schema_dict[schema_feature.name] = schema_feature.type
|
| 157 |
+
|
| 158 |
+
for feature_name, feature in example.features.feature.items():
|
| 159 |
+
if tf_schema is not None and feature_name not in schema_dict:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"Found extra unexpected feature {feature_name} "
|
| 162 |
+
f"not in specified schema: {tf_schema}"
|
| 163 |
+
)
|
| 164 |
+
schema_feature_type = schema_dict.get(feature_name)
|
| 165 |
+
record[feature_name] = _get_feature_value(feature, schema_feature_type)
|
| 166 |
+
return record
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _get_single_true_type(dct) -> str:
|
| 170 |
+
"""Utility function for getting the single key which has a `True` value in
|
| 171 |
+
a dict. Used to filter a dict of `{field_type: is_valid}` to get
|
| 172 |
+
the field type from a schema or data source."""
|
| 173 |
+
filtered_types = iter([_type for _type in dct if dct[_type]])
|
| 174 |
+
# In the case where there are no keys with a `True` value, return `None`
|
| 175 |
+
return next(filtered_types, None)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _get_feature_value(
|
| 179 |
+
feature: "tf.train.Feature",
|
| 180 |
+
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
|
| 181 |
+
) -> pyarrow.Array:
|
| 182 |
+
import pyarrow as pa
|
| 183 |
+
|
| 184 |
+
underlying_feature_type = {
|
| 185 |
+
"bytes": feature.HasField("bytes_list"),
|
| 186 |
+
"float": feature.HasField("float_list"),
|
| 187 |
+
"int": feature.HasField("int64_list"),
|
| 188 |
+
}
|
| 189 |
+
# At most one of `bytes_list`, `float_list`, and `int64_list`
|
| 190 |
+
# should contain values. If none contain data, this indicates
|
| 191 |
+
# an empty feature value.
|
| 192 |
+
assert sum(bool(value) for value in underlying_feature_type.values()) <= 1
|
| 193 |
+
|
| 194 |
+
if schema_feature_type is not None:
|
| 195 |
+
try:
|
| 196 |
+
from tensorflow_metadata.proto.v0 import schema_pb2
|
| 197 |
+
except ModuleNotFoundError:
|
| 198 |
+
raise ModuleNotFoundError(
|
| 199 |
+
"To use TensorFlow schemas, please install "
|
| 200 |
+
"the tensorflow-metadata package."
|
| 201 |
+
)
|
| 202 |
+
# If a schema is specified, compare to the underlying type
|
| 203 |
+
specified_feature_type = {
|
| 204 |
+
"bytes": schema_feature_type == schema_pb2.FeatureType.BYTES,
|
| 205 |
+
"float": schema_feature_type == schema_pb2.FeatureType.FLOAT,
|
| 206 |
+
"int": schema_feature_type == schema_pb2.FeatureType.INT,
|
| 207 |
+
}
|
| 208 |
+
und_type = _get_single_true_type(underlying_feature_type)
|
| 209 |
+
spec_type = _get_single_true_type(specified_feature_type)
|
| 210 |
+
if und_type is not None and und_type != spec_type:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
"Schema field type mismatch during read: specified type is "
|
| 213 |
+
f"{spec_type}, but underlying type is {und_type}",
|
| 214 |
+
)
|
| 215 |
+
# Override the underlying value type with the type in the user-specified schema.
|
| 216 |
+
underlying_feature_type = specified_feature_type
|
| 217 |
+
|
| 218 |
+
if underlying_feature_type["bytes"]:
|
| 219 |
+
value = feature.bytes_list.value
|
| 220 |
+
type_ = pa.binary()
|
| 221 |
+
elif underlying_feature_type["float"]:
|
| 222 |
+
value = feature.float_list.value
|
| 223 |
+
type_ = pa.float32()
|
| 224 |
+
elif underlying_feature_type["int"]:
|
| 225 |
+
value = feature.int64_list.value
|
| 226 |
+
type_ = pa.int64()
|
| 227 |
+
else:
|
| 228 |
+
value = []
|
| 229 |
+
type_ = pa.null()
|
| 230 |
+
value = list(value)
|
| 231 |
+
if len(value) == 1 and schema_feature_type is None:
|
| 232 |
+
# Use the value itself if the features contains a single value.
|
| 233 |
+
# This is to give better user experience when writing preprocessing UDF on
|
| 234 |
+
# these single-value lists.
|
| 235 |
+
value = value[0]
|
| 236 |
+
else:
|
| 237 |
+
# If the feature value is empty and no type is specified in the user-provided
|
| 238 |
+
# schema, set the type to null for now to allow pyarrow to construct a valid
|
| 239 |
+
# Array; later, infer the type from other records which have non-empty values
|
| 240 |
+
# for the feature.
|
| 241 |
+
if len(value) == 0:
|
| 242 |
+
type_ = pa.null()
|
| 243 |
+
type_ = pa.list_(type_)
|
| 244 |
+
return pa.array([value], type=type_)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/reader.py#L16-L96 # noqa: E501
|
| 248 |
+
#
|
| 249 |
+
# MIT License
|
| 250 |
+
#
|
| 251 |
+
# Copyright (c) 2020 Vahid Kazemi
|
| 252 |
+
#
|
| 253 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 254 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 255 |
+
# in the Software without restriction, including without limitation the rights
|
| 256 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 257 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 258 |
+
# furnished to do so, subject to the following conditions:
|
| 259 |
+
#
|
| 260 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 261 |
+
# copies or substantial portions of the Software.
|
| 262 |
+
#
|
| 263 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 264 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 265 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 266 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 267 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 268 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 269 |
+
# SOFTWARE.
|
| 270 |
+
def _read_records(
|
| 271 |
+
file: "pyarrow.NativeFile",
|
| 272 |
+
path: str,
|
| 273 |
+
) -> Iterable[memoryview]:
|
| 274 |
+
"""
|
| 275 |
+
Read records from TFRecord file.
|
| 276 |
+
|
| 277 |
+
A TFRecord file contains a sequence of records. The file can only be read
|
| 278 |
+
sequentially. Each record is stored in the following formats:
|
| 279 |
+
uint64 length
|
| 280 |
+
uint32 masked_crc32_of_length
|
| 281 |
+
byte data[length]
|
| 282 |
+
uint32 masked_crc32_of_data
|
| 283 |
+
|
| 284 |
+
See https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecords_format_details
|
| 285 |
+
for more details.
|
| 286 |
+
"""
|
| 287 |
+
length_bytes = bytearray(8)
|
| 288 |
+
crc_bytes = bytearray(4)
|
| 289 |
+
datum_bytes = bytearray(1024 * 1024)
|
| 290 |
+
row_count = 0
|
| 291 |
+
while True:
|
| 292 |
+
try:
|
| 293 |
+
# Read "length" field.
|
| 294 |
+
num_length_bytes_read = file.readinto(length_bytes)
|
| 295 |
+
if num_length_bytes_read == 0:
|
| 296 |
+
break
|
| 297 |
+
elif num_length_bytes_read != 8:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
"Failed to read the length of record data. Expected 8 bytes but "
|
| 300 |
+
"got {num_length_bytes_read} bytes."
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Read "masked_crc32_of_length" field.
|
| 304 |
+
num_length_crc_bytes_read = file.readinto(crc_bytes)
|
| 305 |
+
if num_length_crc_bytes_read != 4:
|
| 306 |
+
raise ValueError(
|
| 307 |
+
"Failed to read the length of CRC-32C hashes. Expected 4 bytes "
|
| 308 |
+
"but got {num_length_crc_bytes_read} bytes."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Read "data[length]" field.
|
| 312 |
+
(data_length,) = struct.unpack("<Q", length_bytes)
|
| 313 |
+
if data_length > len(datum_bytes):
|
| 314 |
+
datum_bytes = datum_bytes.zfill(int(data_length * 1.5))
|
| 315 |
+
datum_bytes_view = memoryview(datum_bytes)[:data_length]
|
| 316 |
+
num_datum_bytes_read = file.readinto(datum_bytes_view)
|
| 317 |
+
if num_datum_bytes_read != data_length:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"Failed to read the record. Exepcted {data_length} bytes but got "
|
| 320 |
+
f"{num_datum_bytes_read} bytes."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Read "masked_crc32_of_data" field.
|
| 324 |
+
# TODO(chengsu): ideally we should check CRC-32C against the actual data.
|
| 325 |
+
num_crc_bytes_read = file.readinto(crc_bytes)
|
| 326 |
+
if num_crc_bytes_read != 4:
|
| 327 |
+
raise ValueError(
|
| 328 |
+
"Failed to read the CRC-32C hashes. Expected 4 bytes but got "
|
| 329 |
+
f"{num_crc_bytes_read} bytes."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Return the data.
|
| 333 |
+
yield datum_bytes_view
|
| 334 |
+
|
| 335 |
+
row_count += 1
|
| 336 |
+
data_length = None
|
| 337 |
+
except Exception as e:
|
| 338 |
+
error_message = (
|
| 339 |
+
f"Failed to read TFRecord file {path}. Please ensure that the "
|
| 340 |
+
f"TFRecord file has correct format. Already read {row_count} rows."
|
| 341 |
+
)
|
| 342 |
+
if data_length is not None:
|
| 343 |
+
error_message += f" Byte size of current record data is {data_length}."
|
| 344 |
+
raise RuntimeError(error_message) from e
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def _cast_large_list_to_list(batch: pyarrow.Table):
|
| 348 |
+
"""
|
| 349 |
+
This function transform pyarrow.large_list into list and pyarrow.large_binary into
|
| 350 |
+
pyarrow.binary so that all types resulting from the tfrecord_datasource are usable
|
| 351 |
+
with dataset.to_tf().
|
| 352 |
+
"""
|
| 353 |
+
old_schema = batch.schema
|
| 354 |
+
fields = {}
|
| 355 |
+
|
| 356 |
+
for column_name in old_schema.names:
|
| 357 |
+
field_type = old_schema.field(column_name).type
|
| 358 |
+
if type(field_type) is pyarrow.lib.LargeListType:
|
| 359 |
+
value_type = field_type.value_type
|
| 360 |
+
|
| 361 |
+
if value_type == pyarrow.large_binary():
|
| 362 |
+
value_type = pyarrow.binary()
|
| 363 |
+
|
| 364 |
+
fields[column_name] = pyarrow.list_(value_type)
|
| 365 |
+
elif field_type == pyarrow.large_binary():
|
| 366 |
+
fields[column_name] = pyarrow.binary()
|
| 367 |
+
else:
|
| 368 |
+
fields[column_name] = old_schema.field(column_name)
|
| 369 |
+
|
| 370 |
+
new_schema = pyarrow.schema(fields)
|
| 371 |
+
return batch.cast(new_schema)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _infer_schema_and_transform(dataset: "Dataset"):
|
| 375 |
+
list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names))
|
| 376 |
+
|
| 377 |
+
return dataset.map_batches(
|
| 378 |
+
_unwrap_single_value_lists,
|
| 379 |
+
fn_kwargs={"col_lengths": list_sizes["max_list_size"]},
|
| 380 |
+
batch_format="pyarrow",
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _unwrap_single_value_lists(batch: pyarrow.Table, col_lengths: Dict[str, int]):
|
| 385 |
+
"""
|
| 386 |
+
This function will transfrom the dataset converting list types that always
|
| 387 |
+
contain single values to thery underlying data type
|
| 388 |
+
(i.e. pyarrow.int64() and pyarrow.float64())
|
| 389 |
+
"""
|
| 390 |
+
columns = {}
|
| 391 |
+
|
| 392 |
+
for col in col_lengths:
|
| 393 |
+
value_type = batch[col].type.value_type
|
| 394 |
+
|
| 395 |
+
if col_lengths[col] == 1:
|
| 396 |
+
if batch[col]:
|
| 397 |
+
columns[col] = pyarrow.array(
|
| 398 |
+
[x.as_py()[0] if x.as_py() else None for x in batch[col]],
|
| 399 |
+
type=value_type,
|
| 400 |
+
)
|
| 401 |
+
else:
|
| 402 |
+
columns[col] = batch[col]
|
| 403 |
+
|
| 404 |
+
return pyarrow.table(columns)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class _MaxListSize(AggregateFn):
|
| 408 |
+
def __init__(self, columns: List[str]):
|
| 409 |
+
self._columns = columns
|
| 410 |
+
super().__init__(
|
| 411 |
+
init=self._init,
|
| 412 |
+
merge=self._merge,
|
| 413 |
+
accumulate_row=self._accumulate_row,
|
| 414 |
+
finalize=lambda a: a,
|
| 415 |
+
name="max_list_size",
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def _init(self, k: str):
|
| 419 |
+
return {col: 0 for col in self._columns}
|
| 420 |
+
|
| 421 |
+
def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]):
|
| 422 |
+
merged = {}
|
| 423 |
+
for col in self._columns:
|
| 424 |
+
merged[col] = max(acc1[col], acc2[col])
|
| 425 |
+
|
| 426 |
+
return merged
|
| 427 |
+
|
| 428 |
+
def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"):
|
| 429 |
+
for k in row:
|
| 430 |
+
value = row[k]
|
| 431 |
+
if value:
|
| 432 |
+
acc[k] = max(len(value), acc[k])
|
| 433 |
+
|
| 434 |
+
return acc
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/torch_datasource.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 4 |
+
from ray.data.block import BlockMetadata
|
| 5 |
+
from ray.data.datasource.datasource import Datasource, ReadTask
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
TORCH_DATASOURCE_READER_BATCH_SIZE = 32
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TorchDatasource(Datasource):
|
| 15 |
+
"""Torch datasource, for reading from `Torch
|
| 16 |
+
datasets <https://pytorch.org/docs/stable/data.html/>`_.
|
| 17 |
+
This datasource implements a streaming read using a single read task.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dataset: "torch.utils.data.Dataset",
|
| 23 |
+
):
|
| 24 |
+
self._dataset = dataset
|
| 25 |
+
|
| 26 |
+
def get_read_tasks(self, parallelism):
|
| 27 |
+
assert parallelism == 1
|
| 28 |
+
|
| 29 |
+
meta = BlockMetadata(
|
| 30 |
+
num_rows=len(self._dataset),
|
| 31 |
+
size_bytes=None,
|
| 32 |
+
schema=None,
|
| 33 |
+
input_files=None,
|
| 34 |
+
exec_stats=None,
|
| 35 |
+
)
|
| 36 |
+
read_task = ReadTask(
|
| 37 |
+
lambda subset=self._dataset: _read_subset(
|
| 38 |
+
subset,
|
| 39 |
+
),
|
| 40 |
+
metadata=meta,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return [read_task]
|
| 44 |
+
|
| 45 |
+
def estimate_inmemory_data_size(self):
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _read_subset(subset: "torch.utils.data.Subset"):
|
| 50 |
+
batch = []
|
| 51 |
+
for item in subset:
|
| 52 |
+
batch.append(item)
|
| 53 |
+
if len(batch) == TORCH_DATASOURCE_READER_BATCH_SIZE:
|
| 54 |
+
builder = DelegatingBlockBuilder()
|
| 55 |
+
builder.add_batch({"item": batch})
|
| 56 |
+
yield builder.build()
|
| 57 |
+
batch.clear()
|
| 58 |
+
|
| 59 |
+
if len(batch) > 0:
|
| 60 |
+
builder = DelegatingBlockBuilder()
|
| 61 |
+
builder.add_batch({"item": batch})
|
| 62 |
+
yield builder.build()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/video_datasource.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import TYPE_CHECKING, List, Union
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 5 |
+
from ray.data._internal.util import _check_import
|
| 6 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
import pyarrow
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class VideoDatasource(FileBasedDatasource):
|
| 15 |
+
_FILE_EXTENSIONS = [
|
| 16 |
+
"mp4",
|
| 17 |
+
"mkv",
|
| 18 |
+
"mov",
|
| 19 |
+
"avi",
|
| 20 |
+
"wmv",
|
| 21 |
+
"flv",
|
| 22 |
+
"webm",
|
| 23 |
+
"m4v",
|
| 24 |
+
"3gp",
|
| 25 |
+
"mpeg",
|
| 26 |
+
"mpg",
|
| 27 |
+
"ts",
|
| 28 |
+
"ogv",
|
| 29 |
+
"rm",
|
| 30 |
+
"rmvb",
|
| 31 |
+
"vob",
|
| 32 |
+
"asf",
|
| 33 |
+
"f4v",
|
| 34 |
+
"m2ts",
|
| 35 |
+
"mts",
|
| 36 |
+
"divx",
|
| 37 |
+
"xvid",
|
| 38 |
+
"mxf",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
paths: Union[str, List[str]],
|
| 44 |
+
**file_based_datasource_kwargs,
|
| 45 |
+
):
|
| 46 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 47 |
+
|
| 48 |
+
_check_import(self, module="decord", package="decord")
|
| 49 |
+
|
| 50 |
+
def _read_stream(self, f: "pyarrow.NativeFile", path: str):
|
| 51 |
+
from decord import VideoReader
|
| 52 |
+
|
| 53 |
+
reader = VideoReader(f)
|
| 54 |
+
|
| 55 |
+
for frame_index, frame in enumerate(reader):
|
| 56 |
+
item = {"frame": frame.asnumpy(), "frame_index": frame_index}
|
| 57 |
+
builder = DelegatingBlockBuilder()
|
| 58 |
+
builder.add(item)
|
| 59 |
+
yield builder.build()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasink.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import tarfile
|
| 3 |
+
import time
|
| 4 |
+
import uuid
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
|
| 7 |
+
import pyarrow
|
| 8 |
+
|
| 9 |
+
from ray.data._internal.datasource.webdataset_datasource import (
|
| 10 |
+
_apply_list,
|
| 11 |
+
_default_encoder,
|
| 12 |
+
_make_iterable,
|
| 13 |
+
)
|
| 14 |
+
from ray.data.block import BlockAccessor
|
| 15 |
+
from ray.data.datasource.file_datasink import BlockBasedFileDatasink
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class WebDatasetDatasink(BlockBasedFileDatasink):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
path: str,
|
| 22 |
+
encoder: Optional[Union[bool, str, callable, list]] = True,
|
| 23 |
+
*,
|
| 24 |
+
file_format: str = "tar",
|
| 25 |
+
**file_datasink_kwargs,
|
| 26 |
+
):
|
| 27 |
+
super().__init__(path, file_format="tar", **file_datasink_kwargs)
|
| 28 |
+
|
| 29 |
+
self.encoder = encoder
|
| 30 |
+
|
| 31 |
+
def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
|
| 32 |
+
stream = tarfile.open(fileobj=file, mode="w|")
|
| 33 |
+
samples = _make_iterable(block)
|
| 34 |
+
for sample in samples:
|
| 35 |
+
if not isinstance(sample, dict):
|
| 36 |
+
sample = sample.as_pydict()
|
| 37 |
+
if self.encoder is not None:
|
| 38 |
+
sample = _apply_list(self.encoder, sample, default=_default_encoder)
|
| 39 |
+
if "__key__" not in sample:
|
| 40 |
+
sample["__key__"] = uuid.uuid4().hex
|
| 41 |
+
key = sample["__key__"]
|
| 42 |
+
for k, v in sample.items():
|
| 43 |
+
if v is None or k.startswith("__"):
|
| 44 |
+
continue
|
| 45 |
+
assert isinstance(v, bytes) or isinstance(v, str)
|
| 46 |
+
if not isinstance(v, bytes):
|
| 47 |
+
v = v.encode("utf-8")
|
| 48 |
+
ti = tarfile.TarInfo(f"{key}.{k}")
|
| 49 |
+
ti.size = len(v)
|
| 50 |
+
ti.mtime = time.time()
|
| 51 |
+
ti.mode, ti.uname, ti.gname = 0o644, "data", "data"
|
| 52 |
+
stream.addfile(ti, io.BytesIO(v))
|
| 53 |
+
stream.close()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasource.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright NVIDIA Corporation 2023
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import fnmatch
|
| 5 |
+
import io
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
import tarfile
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
import ray
|
| 13 |
+
from ray.data._internal.util import iterate_with_retry
|
| 14 |
+
from ray.data.block import BlockAccessor
|
| 15 |
+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
import pyarrow
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _base_plus_ext(path: str):
|
| 22 |
+
"""Split off all file extensions.
|
| 23 |
+
|
| 24 |
+
Returns base, allext.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
path: path with extensions
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
str: path with all extensions removed
|
| 31 |
+
"""
|
| 32 |
+
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
|
| 33 |
+
if not match:
|
| 34 |
+
return None, None
|
| 35 |
+
return match.group(1), match.group(2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _valid_sample(sample: Dict[str, Any]):
|
| 39 |
+
"""Check whether a sample is valid.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
sample: sample to be checked
|
| 43 |
+
"""
|
| 44 |
+
return (
|
| 45 |
+
sample is not None
|
| 46 |
+
and isinstance(sample, dict)
|
| 47 |
+
and len(list(sample.keys())) > 0
|
| 48 |
+
and not sample.get("__bad__", False)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _apply_list(
|
| 53 |
+
f: Union[Callable, List[Callable]], sample: Dict[str, Any], default: Callable = None
|
| 54 |
+
):
|
| 55 |
+
"""Apply a list of functions to a sample.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
f: function or list of functions
|
| 59 |
+
sample: sample to be modified
|
| 60 |
+
default: default function to be applied to all keys.
|
| 61 |
+
Defaults to None.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
modified sample
|
| 65 |
+
"""
|
| 66 |
+
if f is None:
|
| 67 |
+
return sample
|
| 68 |
+
if not isinstance(f, list):
|
| 69 |
+
f = [f]
|
| 70 |
+
for g in f:
|
| 71 |
+
if default is not None and not callable(g):
|
| 72 |
+
g = partial(default, format=g)
|
| 73 |
+
sample = g(sample)
|
| 74 |
+
return sample
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _check_suffix(suffix: str, suffixes: Union[list, callable]):
|
| 78 |
+
"""Check whether a suffix is valid.
|
| 79 |
+
|
| 80 |
+
Suffixes can be either None (=accept everything), a callable,
|
| 81 |
+
or a list of patterns. If the pattern contains */? it is treated
|
| 82 |
+
as a glob pattern, otherwise it is treated as a literal.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
suffix: suffix to be checked
|
| 86 |
+
suffixes: list of valid suffixes
|
| 87 |
+
"""
|
| 88 |
+
if suffixes is None:
|
| 89 |
+
return True
|
| 90 |
+
if callable(suffixes):
|
| 91 |
+
return suffixes(suffix)
|
| 92 |
+
for pattern in suffixes:
|
| 93 |
+
if "*" in pattern or "?" in pattern:
|
| 94 |
+
if fnmatch.fnmatch("." + suffix, pattern):
|
| 95 |
+
return True
|
| 96 |
+
elif suffix == pattern or "." + suffix == pattern:
|
| 97 |
+
return True
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _tar_file_iterator(
|
| 102 |
+
fileobj: Any,
|
| 103 |
+
fileselect: Optional[Union[bool, callable, list]] = None,
|
| 104 |
+
filerename: Optional[Union[bool, callable, list]] = None,
|
| 105 |
+
verbose_open: bool = False,
|
| 106 |
+
meta: dict = None,
|
| 107 |
+
):
|
| 108 |
+
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
fileobj: file object
|
| 112 |
+
fileselect: patterns or function selecting
|
| 113 |
+
files to be selected
|
| 114 |
+
meta: metadata to be added to each sample
|
| 115 |
+
"""
|
| 116 |
+
meta = meta or {}
|
| 117 |
+
stream = tarfile.open(fileobj=fileobj, mode="r|*")
|
| 118 |
+
if verbose_open:
|
| 119 |
+
print(f"start {meta}")
|
| 120 |
+
for tarinfo in stream:
|
| 121 |
+
fname = tarinfo.name
|
| 122 |
+
if not tarinfo.isreg() or fname is None:
|
| 123 |
+
continue
|
| 124 |
+
data = stream.extractfile(tarinfo).read()
|
| 125 |
+
fname = _apply_list(filerename, fname)
|
| 126 |
+
assert isinstance(fname, str)
|
| 127 |
+
if not _check_suffix(fname, fileselect):
|
| 128 |
+
continue
|
| 129 |
+
result = dict(fname=fname, data=data)
|
| 130 |
+
yield result
|
| 131 |
+
if verbose_open:
|
| 132 |
+
print(f"done {meta}")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _group_by_keys(
|
| 136 |
+
data: List[Dict[str, Any]],
|
| 137 |
+
keys: callable = _base_plus_ext,
|
| 138 |
+
suffixes: Optional[Union[list, callable]] = None,
|
| 139 |
+
meta: dict = None,
|
| 140 |
+
):
|
| 141 |
+
"""Return function over iterator that groups key, value pairs into samples.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
data: iterator over key, value pairs
|
| 145 |
+
keys: function that returns key, suffix for a given key
|
| 146 |
+
suffixes: list of suffixes to be included in the sample
|
| 147 |
+
meta: metadata to be added to each sample
|
| 148 |
+
"""
|
| 149 |
+
meta = meta or {}
|
| 150 |
+
current_sample = None
|
| 151 |
+
for filesample in data:
|
| 152 |
+
assert isinstance(filesample, dict)
|
| 153 |
+
fname, value = filesample["fname"], filesample["data"]
|
| 154 |
+
prefix, suffix = keys(fname)
|
| 155 |
+
if prefix is None:
|
| 156 |
+
continue
|
| 157 |
+
if current_sample is None or prefix != current_sample["__key__"]:
|
| 158 |
+
if _valid_sample(current_sample):
|
| 159 |
+
current_sample.update(meta)
|
| 160 |
+
yield current_sample
|
| 161 |
+
current_sample = dict(__key__=prefix)
|
| 162 |
+
if "__url__" in filesample:
|
| 163 |
+
current_sample["__url__"] = filesample["__url__"]
|
| 164 |
+
if suffix in current_sample:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
f"{fname}: duplicate file name in tar file "
|
| 167 |
+
+ f"{suffix} {current_sample.keys()}, tar is {meta['__url__']}"
|
| 168 |
+
)
|
| 169 |
+
if suffixes is None or _check_suffix(suffix, suffixes):
|
| 170 |
+
current_sample[suffix] = value
|
| 171 |
+
if _valid_sample(current_sample):
|
| 172 |
+
current_sample.update(meta)
|
| 173 |
+
yield current_sample
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True):
|
| 177 |
+
"""A default decoder for webdataset.
|
| 178 |
+
|
| 179 |
+
This handles common file extensions: .txt, .cls, .cls2,
|
| 180 |
+
.jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl.
|
| 181 |
+
These are the most common extensions used in webdataset.
|
| 182 |
+
For other extensions, users can provide their own decoder.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
sample: sample, modified in place
|
| 186 |
+
"""
|
| 187 |
+
sample = dict(sample)
|
| 188 |
+
for key, value in sample.items():
|
| 189 |
+
extension = key.split(".")[-1]
|
| 190 |
+
if key.startswith("__"):
|
| 191 |
+
continue
|
| 192 |
+
elif extension in ["txt", "text"]:
|
| 193 |
+
sample[key] = value.decode("utf-8")
|
| 194 |
+
elif extension in ["cls", "cls2"]:
|
| 195 |
+
sample[key] = int(value.decode("utf-8"))
|
| 196 |
+
elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]:
|
| 197 |
+
import numpy as np
|
| 198 |
+
import PIL.Image
|
| 199 |
+
|
| 200 |
+
if format == "PIL":
|
| 201 |
+
sample[key] = PIL.Image.open(io.BytesIO(value))
|
| 202 |
+
else:
|
| 203 |
+
sample[key] = np.asarray(PIL.Image.open(io.BytesIO(value)))
|
| 204 |
+
elif extension == "json":
|
| 205 |
+
sample[key] = json.loads(value)
|
| 206 |
+
elif extension == "npy":
|
| 207 |
+
import numpy as np
|
| 208 |
+
|
| 209 |
+
sample[key] = np.load(io.BytesIO(value))
|
| 210 |
+
elif extension == "mp":
|
| 211 |
+
import msgpack
|
| 212 |
+
|
| 213 |
+
sample[key] = msgpack.unpackb(value, raw=False)
|
| 214 |
+
elif extension in ["pt", "pth"]:
|
| 215 |
+
import torch
|
| 216 |
+
|
| 217 |
+
sample[key] = torch.load(io.BytesIO(value))
|
| 218 |
+
elif extension in ["pickle", "pkl"]:
|
| 219 |
+
import pickle
|
| 220 |
+
|
| 221 |
+
sample[key] = pickle.loads(value)
|
| 222 |
+
return sample
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
extension_to_format = {"jpg": "jpeg"}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _default_encoder(sample: Dict[str, Any], format: Optional[Union[str, bool]] = True):
|
| 229 |
+
"""A default encoder for webdataset.
|
| 230 |
+
|
| 231 |
+
This handles common file extensions: .txt, .cls, .cls2, .jpg,
|
| 232 |
+
.png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl
|
| 233 |
+
These are the most common extensions used in webdataset.
|
| 234 |
+
For other extensions, users can provide their own encoder.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
sample (Dict[str, Any]): sample
|
| 238 |
+
"""
|
| 239 |
+
sample = dict(sample)
|
| 240 |
+
for key, value in sample.items():
|
| 241 |
+
extension = key.split(".")[-1]
|
| 242 |
+
if key.startswith("__"):
|
| 243 |
+
continue
|
| 244 |
+
elif extension in ["txt"]:
|
| 245 |
+
sample[key] = value.encode("utf-8")
|
| 246 |
+
elif extension in ["cls", "cls2"]:
|
| 247 |
+
sample[key] = str(value).encode("utf-8")
|
| 248 |
+
elif extension in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]:
|
| 249 |
+
import numpy as np
|
| 250 |
+
import PIL.Image
|
| 251 |
+
|
| 252 |
+
if isinstance(value, np.ndarray):
|
| 253 |
+
value = PIL.Image.fromarray(value)
|
| 254 |
+
assert isinstance(value, PIL.Image.Image)
|
| 255 |
+
stream = io.BytesIO()
|
| 256 |
+
value.save(
|
| 257 |
+
stream, format=extension_to_format.get(extension.lower(), extension)
|
| 258 |
+
)
|
| 259 |
+
sample[key] = stream.getvalue()
|
| 260 |
+
elif extension == "json":
|
| 261 |
+
sample[key] = json.dumps(value).encode("utf-8")
|
| 262 |
+
elif extension == "npy":
|
| 263 |
+
import numpy as np
|
| 264 |
+
|
| 265 |
+
stream = io.BytesIO()
|
| 266 |
+
np.save(stream, value)
|
| 267 |
+
sample[key] = stream.getvalue()
|
| 268 |
+
elif extension == "mp":
|
| 269 |
+
import msgpack
|
| 270 |
+
|
| 271 |
+
sample[key] = msgpack.dumps(value)
|
| 272 |
+
elif extension in ["pt", "pth"]:
|
| 273 |
+
import torch
|
| 274 |
+
|
| 275 |
+
stream = io.BytesIO()
|
| 276 |
+
torch.save(value, stream)
|
| 277 |
+
sample[key] = stream.getvalue()
|
| 278 |
+
elif extension in ["pickle", "pkl"]:
|
| 279 |
+
import pickle
|
| 280 |
+
|
| 281 |
+
stream = io.BytesIO()
|
| 282 |
+
pickle.dump(value, stream)
|
| 283 |
+
sample[key] = stream.getvalue()
|
| 284 |
+
return sample
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _make_iterable(block: BlockAccessor):
|
| 288 |
+
"""Make a block iterable.
|
| 289 |
+
|
| 290 |
+
This is a placeholder for dealing with more complex blocks.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
block: Ray Dataset block
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
Iterable[Dict[str,Any]]: Iterable of samples
|
| 297 |
+
"""
|
| 298 |
+
return block.iter_rows(public_row_format=False)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class WebDatasetDatasource(FileBasedDatasource):
|
| 302 |
+
"""A Datasource for WebDataset datasets (tar format with naming conventions)."""
|
| 303 |
+
|
| 304 |
+
_FILE_EXTENSIONS = ["tar"]
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
paths: Union[str, List[str]],
|
| 309 |
+
decoder: Optional[Union[bool, str, callable, list]] = True,
|
| 310 |
+
fileselect: Optional[Union[bool, callable, list]] = None,
|
| 311 |
+
filerename: Optional[Union[bool, callable, list]] = None,
|
| 312 |
+
suffixes: Optional[Union[bool, callable, list]] = None,
|
| 313 |
+
verbose_open: bool = False,
|
| 314 |
+
expand_json: bool = False,
|
| 315 |
+
**file_based_datasource_kwargs,
|
| 316 |
+
):
|
| 317 |
+
super().__init__(paths, **file_based_datasource_kwargs)
|
| 318 |
+
|
| 319 |
+
self.decoder = decoder
|
| 320 |
+
self.fileselect = fileselect
|
| 321 |
+
self.filerename = filerename
|
| 322 |
+
self.suffixes = suffixes
|
| 323 |
+
self.verbose_open = verbose_open
|
| 324 |
+
self.expand_json = expand_json
|
| 325 |
+
|
| 326 |
+
def _read_stream(self, stream: "pyarrow.NativeFile", path: str):
|
| 327 |
+
"""Read and decode samples from a stream.
|
| 328 |
+
|
| 329 |
+
Note that fileselect selects files during reading, while suffixes
|
| 330 |
+
selects files during the grouping step.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
stream: File descriptor to read from.
|
| 334 |
+
path: Path to the data.
|
| 335 |
+
decoder: decoder or list of decoders to be applied to samples
|
| 336 |
+
fileselect: Predicate for skipping files in tar decoder.
|
| 337 |
+
Defaults to lambda_:False.
|
| 338 |
+
suffixes: List of suffixes to be extracted. Defaults to None.
|
| 339 |
+
verbose_open: Print message when opening files. Defaults to False.
|
| 340 |
+
|
| 341 |
+
Yields:
|
| 342 |
+
List[Dict[str, Any]]: List of sample (list of length 1).
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
import pandas as pd
|
| 346 |
+
|
| 347 |
+
def get_tar_file_iterator():
|
| 348 |
+
return _tar_file_iterator(
|
| 349 |
+
stream,
|
| 350 |
+
fileselect=self.fileselect,
|
| 351 |
+
filerename=self.filerename,
|
| 352 |
+
verbose_open=self.verbose_open,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# S3 can raise transient errors during iteration
|
| 356 |
+
ctx = ray.data.DataContext.get_current()
|
| 357 |
+
files = iterate_with_retry(
|
| 358 |
+
get_tar_file_iterator, "iterate tar file", match=ctx.retried_io_errors
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
samples = _group_by_keys(files, meta=dict(__url__=path), suffixes=self.suffixes)
|
| 362 |
+
for sample in samples:
|
| 363 |
+
if self.decoder is not None:
|
| 364 |
+
sample = _apply_list(self.decoder, sample, default=_default_decoder)
|
| 365 |
+
if self.expand_json:
|
| 366 |
+
if isinstance(sample["json"], bytes):
|
| 367 |
+
parsed_json = json.loads(sample["json"].decode("utf-8"))
|
| 368 |
+
elif isinstance(sample["json"], str):
|
| 369 |
+
parsed_json = json.loads(sample["json"])
|
| 370 |
+
elif isinstance(sample["json"], dict):
|
| 371 |
+
parsed_json = sample["json"]
|
| 372 |
+
else:
|
| 373 |
+
raise TypeError(
|
| 374 |
+
f"Unsupported data type" f" {type(sample['json'])} for sample"
|
| 375 |
+
)
|
| 376 |
+
for k, v in parsed_json.items():
|
| 377 |
+
if k not in sample:
|
| 378 |
+
sample[k] = []
|
| 379 |
+
sample[k].append(v)
|
| 380 |
+
yield pd.DataFrame(
|
| 381 |
+
{
|
| 382 |
+
k: v if isinstance(v, list) and len(v) == 1 else [v]
|
| 383 |
+
for k, v in sample.items()
|
| 384 |
+
}
|
| 385 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/iterator_impl.cpython-311.pyc
ADDED
|
Binary file (2.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/stream_split_iterator.cpython-311.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/iterator_impl.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
|
| 4 |
+
from ray.data._internal.stats import DatasetStats
|
| 5 |
+
from ray.data._internal.util import create_dataset_tag
|
| 6 |
+
from ray.data.iterator import DataIterator
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
import pyarrow
|
| 10 |
+
|
| 11 |
+
from ray.data import Dataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DataIteratorImpl(DataIterator):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
base_dataset: "Dataset",
|
| 18 |
+
):
|
| 19 |
+
self._base_dataset = base_dataset
|
| 20 |
+
|
| 21 |
+
def __repr__(self) -> str:
|
| 22 |
+
return f"DataIterator({self._base_dataset})"
|
| 23 |
+
|
| 24 |
+
def _to_ref_bundle_iterator(
|
| 25 |
+
self,
|
| 26 |
+
) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]:
|
| 27 |
+
ds = self._base_dataset
|
| 28 |
+
ref_bundles_iterator, stats, executor = ds._plan.execute_to_iterator()
|
| 29 |
+
ds._current_executor = executor
|
| 30 |
+
return ref_bundles_iterator, stats, False
|
| 31 |
+
|
| 32 |
+
def stats(self) -> str:
|
| 33 |
+
return self._base_dataset.stats()
|
| 34 |
+
|
| 35 |
+
def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
|
| 36 |
+
return self._base_dataset.schema()
|
| 37 |
+
|
| 38 |
+
def _get_dataset_tag(self):
|
| 39 |
+
return create_dataset_tag(
|
| 40 |
+
self._base_dataset._plan._dataset_name, self._base_dataset._uuid
|
| 41 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/stream_split_iterator.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import threading
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import replace
|
| 5 |
+
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import ray
|
| 8 |
+
from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle
|
| 9 |
+
from ray.data._internal.execution.legacy_compat import execute_to_legacy_bundle_iterator
|
| 10 |
+
from ray.data._internal.execution.operators.output_splitter import OutputSplitter
|
| 11 |
+
from ray.data._internal.execution.streaming_executor import StreamingExecutor
|
| 12 |
+
from ray.data._internal.stats import DatasetStats
|
| 13 |
+
from ray.data._internal.util import create_dataset_tag
|
| 14 |
+
from ray.data.block import Block, BlockMetadata
|
| 15 |
+
from ray.data.iterator import DataIterator
|
| 16 |
+
from ray.types import ObjectRef
|
| 17 |
+
from ray.util.debug import log_once
|
| 18 |
+
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
import pyarrow
|
| 22 |
+
|
| 23 |
+
from ray.data import Dataset
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
BLOCKED_CLIENT_WARN_TIMEOUT = 30
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class StreamSplitDataIterator(DataIterator):
|
| 32 |
+
"""Implements a collection of iterators over a shared data stream."""
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def create(
|
| 36 |
+
base_dataset: "Dataset",
|
| 37 |
+
n: int,
|
| 38 |
+
equal: bool,
|
| 39 |
+
locality_hints: Optional[List[NodeIdStr]],
|
| 40 |
+
) -> List["StreamSplitDataIterator"]:
|
| 41 |
+
"""Create a split iterator from the given base Dataset and options.
|
| 42 |
+
|
| 43 |
+
See also: `Dataset.streaming_split`.
|
| 44 |
+
"""
|
| 45 |
+
# To avoid deadlock, the concurrency on this actor must be set to at least `n`.
|
| 46 |
+
coord_actor = SplitCoordinator.options(
|
| 47 |
+
max_concurrency=n,
|
| 48 |
+
scheduling_strategy=NodeAffinitySchedulingStrategy(
|
| 49 |
+
ray.get_runtime_context().get_node_id(), soft=False
|
| 50 |
+
),
|
| 51 |
+
).remote(base_dataset, n, equal, locality_hints)
|
| 52 |
+
|
| 53 |
+
return [
|
| 54 |
+
StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n)
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
base_dataset: "Dataset",
|
| 60 |
+
coord_actor: ray.actor.ActorHandle,
|
| 61 |
+
output_split_idx: int,
|
| 62 |
+
world_size: int,
|
| 63 |
+
):
|
| 64 |
+
self._base_dataset = base_dataset
|
| 65 |
+
self._coord_actor = coord_actor
|
| 66 |
+
self._output_split_idx = output_split_idx
|
| 67 |
+
self._world_size = world_size
|
| 68 |
+
self._iter_stats = DatasetStats(metadata={}, parent=None)
|
| 69 |
+
|
| 70 |
+
def _to_ref_bundle_iterator(
|
| 71 |
+
self,
|
| 72 |
+
) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]:
|
| 73 |
+
def gen_blocks() -> Iterator[RefBundle]:
|
| 74 |
+
cur_epoch = ray.get(
|
| 75 |
+
self._coord_actor.start_epoch.remote(self._output_split_idx)
|
| 76 |
+
)
|
| 77 |
+
future: ObjectRef[
|
| 78 |
+
Optional[ObjectRef[Block]]
|
| 79 |
+
] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx)
|
| 80 |
+
while True:
|
| 81 |
+
block_ref_and_md: Optional[
|
| 82 |
+
Tuple[ObjectRef[Block], BlockMetadata]
|
| 83 |
+
] = ray.get(future)
|
| 84 |
+
if not block_ref_and_md:
|
| 85 |
+
break
|
| 86 |
+
else:
|
| 87 |
+
future = self._coord_actor.get.remote(
|
| 88 |
+
cur_epoch, self._output_split_idx
|
| 89 |
+
)
|
| 90 |
+
yield RefBundle(blocks=(block_ref_and_md,), owns_blocks=False)
|
| 91 |
+
|
| 92 |
+
return gen_blocks(), self._iter_stats, False
|
| 93 |
+
|
| 94 |
+
def stats(self) -> str:
|
| 95 |
+
"""Implements DataIterator."""
|
| 96 |
+
# Merge the locally recorded iter stats and the remotely recorded
|
| 97 |
+
# stream execution stats.
|
| 98 |
+
stats = ray.get(self._coord_actor.stats.remote())
|
| 99 |
+
summary = stats.to_summary()
|
| 100 |
+
summary.iter_stats = self._iter_stats.to_summary().iter_stats
|
| 101 |
+
summary.iter_stats.streaming_split_coord_time.add(
|
| 102 |
+
stats.streaming_split_coordinator_s.get()
|
| 103 |
+
)
|
| 104 |
+
return summary.to_string()
|
| 105 |
+
|
| 106 |
+
def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
|
| 107 |
+
"""Implements DataIterator."""
|
| 108 |
+
return self._base_dataset.schema()
|
| 109 |
+
|
| 110 |
+
def world_size(self) -> int:
|
| 111 |
+
"""Returns the number of splits total."""
|
| 112 |
+
return self._world_size
|
| 113 |
+
|
| 114 |
+
def _get_dataset_tag(self):
|
| 115 |
+
return create_dataset_tag(
|
| 116 |
+
self._base_dataset._plan._dataset_name,
|
| 117 |
+
self._base_dataset._uuid,
|
| 118 |
+
self._output_split_idx,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@ray.remote(num_cpus=0)
|
| 123 |
+
class SplitCoordinator:
|
| 124 |
+
"""Coordinator actor for routing blocks to output splits.
|
| 125 |
+
|
| 126 |
+
This actor runs a streaming executor locally on its main thread. Clients can
|
| 127 |
+
retrieve results via actor calls running on other threads.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
dataset: "Dataset",
|
| 133 |
+
n: int,
|
| 134 |
+
equal: bool,
|
| 135 |
+
locality_hints: Optional[List[NodeIdStr]],
|
| 136 |
+
):
|
| 137 |
+
# Set current DataContext.
|
| 138 |
+
self._data_context = dataset.context
|
| 139 |
+
ray.data.DataContext._set_current(self._data_context)
|
| 140 |
+
# Automatically set locality with output to the specified location hints.
|
| 141 |
+
if locality_hints:
|
| 142 |
+
self._data_context.execution_options.locality_with_output = locality_hints
|
| 143 |
+
logger.info(f"Auto configuring locality_with_output={locality_hints}")
|
| 144 |
+
|
| 145 |
+
self._base_dataset = dataset
|
| 146 |
+
self._n = n
|
| 147 |
+
self._equal = equal
|
| 148 |
+
self._locality_hints = locality_hints
|
| 149 |
+
self._lock = threading.RLock()
|
| 150 |
+
self._executor = None
|
| 151 |
+
|
| 152 |
+
# Guarded by self._lock.
|
| 153 |
+
self._next_bundle: Dict[int, RefBundle] = {}
|
| 154 |
+
self._unfinished_clients_in_epoch = n
|
| 155 |
+
self._cur_epoch = -1
|
| 156 |
+
|
| 157 |
+
def gen_epochs():
|
| 158 |
+
while True:
|
| 159 |
+
executor = StreamingExecutor(
|
| 160 |
+
self._data_context,
|
| 161 |
+
create_dataset_tag(
|
| 162 |
+
self._base_dataset._name, self._base_dataset._uuid
|
| 163 |
+
),
|
| 164 |
+
)
|
| 165 |
+
self._executor = executor
|
| 166 |
+
|
| 167 |
+
def add_split_op(dag):
|
| 168 |
+
return OutputSplitter(
|
| 169 |
+
dag,
|
| 170 |
+
n,
|
| 171 |
+
equal,
|
| 172 |
+
self._data_context,
|
| 173 |
+
locality_hints,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
output_iterator = execute_to_legacy_bundle_iterator(
|
| 177 |
+
executor,
|
| 178 |
+
dataset._plan,
|
| 179 |
+
dag_rewrite=add_split_op,
|
| 180 |
+
)
|
| 181 |
+
yield output_iterator
|
| 182 |
+
|
| 183 |
+
self._next_epoch = gen_epochs()
|
| 184 |
+
self._output_iterator = None
|
| 185 |
+
# Store the error raised from the `gen_epoch` call.
|
| 186 |
+
self._gen_epoch_error: Optional[Exception] = None
|
| 187 |
+
|
| 188 |
+
def stats(self) -> DatasetStats:
|
| 189 |
+
"""Returns stats from the base dataset."""
|
| 190 |
+
if self._executor:
|
| 191 |
+
return self._executor.get_stats()
|
| 192 |
+
return self._base_dataset._plan.stats()
|
| 193 |
+
|
| 194 |
+
def start_epoch(self, split_idx: int) -> str:
|
| 195 |
+
"""Called to start an epoch.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
UUID for the epoch, which must be used when accessing results via get().
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
# Wait for all clients to arrive at the barrier before starting a new epoch.
|
| 202 |
+
epoch_id = self._barrier(split_idx)
|
| 203 |
+
return epoch_id
|
| 204 |
+
|
| 205 |
+
def get(
|
| 206 |
+
self, epoch_id: int, output_split_idx: int
|
| 207 |
+
) -> Optional[Tuple[ObjectRef[Block], BlockMetadata]]:
|
| 208 |
+
"""Blocking get operation.
|
| 209 |
+
|
| 210 |
+
This is intended to be called concurrently from multiple clients.
|
| 211 |
+
"""
|
| 212 |
+
start_time = time.perf_counter()
|
| 213 |
+
if epoch_id != self._cur_epoch:
|
| 214 |
+
raise ValueError(
|
| 215 |
+
"Invalid iterator: the dataset has moved on to another epoch."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
# Ensure there is at least one bundle.
|
| 220 |
+
with self._lock:
|
| 221 |
+
if output_split_idx in self._next_bundle:
|
| 222 |
+
next_bundle = self._next_bundle[output_split_idx]
|
| 223 |
+
else:
|
| 224 |
+
next_bundle = None
|
| 225 |
+
|
| 226 |
+
# Fetch next bundle if needed.
|
| 227 |
+
while next_bundle is None or not next_bundle.blocks:
|
| 228 |
+
# This is a BLOCKING call, so do it outside the lock.
|
| 229 |
+
next_bundle = self._output_iterator.get_next(output_split_idx)
|
| 230 |
+
|
| 231 |
+
block = next_bundle.blocks[-1]
|
| 232 |
+
next_bundle = replace(next_bundle, blocks=next_bundle.blocks[:-1])
|
| 233 |
+
|
| 234 |
+
# Accumulate any remaining blocks in next_bundle map as needed.
|
| 235 |
+
with self._lock:
|
| 236 |
+
self._next_bundle[output_split_idx] = next_bundle
|
| 237 |
+
if not next_bundle.blocks:
|
| 238 |
+
del self._next_bundle[output_split_idx]
|
| 239 |
+
|
| 240 |
+
return block
|
| 241 |
+
except StopIteration:
|
| 242 |
+
return None
|
| 243 |
+
finally:
|
| 244 |
+
stats = self.stats()
|
| 245 |
+
if stats and stats.streaming_split_coordinator_s:
|
| 246 |
+
stats.streaming_split_coordinator_s.add(
|
| 247 |
+
time.perf_counter() - start_time
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
def _barrier(self, split_idx: int) -> int:
|
| 251 |
+
"""Arrive and block until the start of the given epoch."""
|
| 252 |
+
|
| 253 |
+
# Decrement and await all clients to arrive here.
|
| 254 |
+
with self._lock:
|
| 255 |
+
starting_epoch = self._cur_epoch
|
| 256 |
+
self._unfinished_clients_in_epoch -= 1
|
| 257 |
+
|
| 258 |
+
start_time = time.time()
|
| 259 |
+
while (
|
| 260 |
+
self._cur_epoch == starting_epoch and self._unfinished_clients_in_epoch != 0
|
| 261 |
+
):
|
| 262 |
+
if time.time() - start_time > BLOCKED_CLIENT_WARN_TIMEOUT:
|
| 263 |
+
if log_once(f"stream_split_blocked_{split_idx}_{starting_epoch}"):
|
| 264 |
+
logger.warning(
|
| 265 |
+
f"StreamSplitDataIterator(epoch={starting_epoch}, "
|
| 266 |
+
f"split={split_idx}) blocked waiting on other clients "
|
| 267 |
+
f"for more than {BLOCKED_CLIENT_WARN_TIMEOUT}s. All "
|
| 268 |
+
"clients must read from the DataIterator splits at "
|
| 269 |
+
"the same time. This warning will not be printed again "
|
| 270 |
+
"for this epoch."
|
| 271 |
+
)
|
| 272 |
+
time.sleep(0.1)
|
| 273 |
+
|
| 274 |
+
# Advance to the next epoch.
|
| 275 |
+
with self._lock:
|
| 276 |
+
if self._cur_epoch == starting_epoch:
|
| 277 |
+
self._cur_epoch += 1
|
| 278 |
+
self._unfinished_clients_in_epoch = self._n
|
| 279 |
+
try:
|
| 280 |
+
self._output_iterator = next(self._next_epoch)
|
| 281 |
+
except Exception as e:
|
| 282 |
+
self._gen_epoch_error = e
|
| 283 |
+
|
| 284 |
+
if self._gen_epoch_error is not None:
|
| 285 |
+
# If there was an error when advancing to the next epoch,
|
| 286 |
+
# re-raise it for all threads.
|
| 287 |
+
raise self._gen_epoch_error
|
| 288 |
+
|
| 289 |
+
assert self._output_iterator is not None
|
| 290 |
+
return starting_epoch + 1
|
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|