diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35ee7241f7937535882c48c3f43f27c2315469dd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/aggregate.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/aggregate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8adc9f782ec6024643042efad6299ac094057ac8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/aggregate.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/arrow_block.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/arrow_block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed82531b8823906c75b6a5666bbb361799c5cf2f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/arrow_block.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/memory_tracing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/memory_tracing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc21fc28316e04bff9cbe0cb41ce94432e6b027f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/memory_tracing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/null_aggregate.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/null_aggregate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7c4bc58405b924489bbb07021544dd8d86ae54d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/null_aggregate.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/numpy_support.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/numpy_support.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b7186319f8eb726834a378cd04e79d204afa479 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/numpy_support.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/pandas_block.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/pandas_block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05c19c5c3f6e7ba35bc40e80da6430f6726b554d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/pandas_block.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/progress_bar.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/progress_bar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e47afa2f22f17951c0f16e73d77629439e1386a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/progress_bar.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/split.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/split.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2577960f0ab05ef0a0e579ba255bb9c2fd06c79 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/__pycache__/split.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/audio_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/audio_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3a28d26f45fdee6cd63ccce32cd4b8f9bb6e81 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/audio_datasource.py @@ -0,0 +1,57 @@ +import io +from typing import TYPE_CHECKING, Iterator, List, Union + +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.util import _check_import +from ray.data.block import Block +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +class AudioDatasource(FileBasedDatasource): + _FILE_EXTENSIONS = [ + "mp3", + "wav", + "aac", + "flac", + "ogg", + "m4a", + "wma", + "alac", + "aiff", + "pcm", + "amr", + "opus", + "ra", + "rm", + "au", + "mid", + "midi", + "caf", + ] + + def __init__( + self, + paths: Union[str, List[str]], + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + _check_import(self, module="soundfile", package="soundfile") + + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + import soundfile + + # `soundfile` doesn't support reading from a `pyarrow.NativeFile` directly, so + # we need to read the file into memory first. + stream = io.BytesIO(f.read()) + amplitude, sample_rate = soundfile.read(stream, always_2d=True, dtype="float32") + + # (amplitude, channels) -> (channels, amplitude) + amplitude = amplitude.transpose((1, 0)) + + builder = DelegatingBlockBuilder() + builder.add({"amplitude": amplitude, "sample_rate": sample_rate}) + yield builder.build() diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/avro_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/avro_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..288ec136ce2cd1b43d8db2e3502fe930701ba512 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/avro_datasource.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING, Iterator, List, Union + +from ray.data._internal.output_buffer import BlockOutputBuffer +from ray.data._internal.util import _check_import +from ray.data.block import Block +from ray.data.context import DataContext +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +class AvroDatasource(FileBasedDatasource): + """A datasource that reads Avro files.""" + + _FILE_EXTENSIONS = ["avro"] + + def __init__( + self, + paths: Union[str, List[str]], + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + _check_import(self, module="fastavro", package="fastavro") + + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + import fastavro + + # Read the Avro file. This assumes the Avro file includes its schema. + reader = fastavro.reader(f) + + ctx = DataContext.get_current() + output_buffer = BlockOutputBuffer(ctx.target_max_block_size) + for record in reader: + output_buffer.add(record) + while output_buffer.has_next(): + yield output_buffer.next() + + output_buffer.finalize() + while output_buffer.has_next(): + yield output_buffer.next() diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..92178996f3eaff7096f1b928145ced9418ffe83f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasink.py @@ -0,0 +1,129 @@ +import logging +import os +import tempfile +import time +import uuid +from typing import Iterable, Optional + +import pyarrow.parquet as pq + +import ray +from ray.data._internal.execution.interfaces import TaskContext +from ray.data._internal.datasource import bigquery_datasource +from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.util import _check_import +from ray.data.block import Block, BlockAccessor +from ray.data.datasource.datasink import Datasink + +logger = logging.getLogger(__name__) + +DEFAULT_MAX_RETRY_CNT = 10 +RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11 + + +class BigQueryDatasink(Datasink[None]): + def __init__( + self, + project_id: str, + dataset: str, + max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT, + overwrite_table: Optional[bool] = True, + ) -> None: + _check_import(self, module="google.cloud", package="bigquery") + _check_import(self, module="google.cloud", package="bigquery_storage") + _check_import(self, module="google.api_core", package="exceptions") + + self.project_id = project_id + self.dataset = dataset + self.max_retry_cnt = max_retry_cnt + self.overwrite_table = overwrite_table + + def on_write_start(self) -> None: + from google.api_core import exceptions + + if self.project_id is None or self.dataset is None: + raise ValueError("project_id and dataset are required args") + + # Set up datasets to write + client = bigquery_datasource._create_client(project_id=self.project_id) + dataset_id = self.dataset.split(".", 1)[0] + try: + client.get_dataset(dataset_id) + except exceptions.NotFound: + client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30) + logger.info("Created dataset " + dataset_id) + + # Delete table if overwrite_table is True + if self.overwrite_table: + logger.info( + f"Attempting to delete table {self.dataset}" + + " if it already exists since kwarg overwrite_table = True." + ) + client.delete_table(f"{self.project_id}.{self.dataset}", not_found_ok=True) + else: + logger.info( + f"The write will append to table {self.dataset}" + + " if it already exists since kwarg overwrite_table = False." + ) + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> None: + def _write_single_block(block: Block, project_id: str, dataset: str) -> None: + from google.api_core import exceptions + from google.cloud import bigquery + + block = BlockAccessor.for_block(block).to_arrow() + + client = bigquery_datasource._create_client(project=project_id) + job_config = bigquery.LoadJobConfig(autodetect=True) + job_config.source_format = bigquery.SourceFormat.PARQUET + job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND + + with tempfile.TemporaryDirectory() as temp_dir: + fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet") + pq.write_table(block, fp, compression="SNAPPY") + + retry_cnt = 0 + while retry_cnt <= self.max_retry_cnt: + with open(fp, "rb") as source_file: + job = client.load_table_from_file( + source_file, dataset, job_config=job_config + ) + try: + logger.info(job.result()) + break + except exceptions.Forbidden as e: + retry_cnt += 1 + if retry_cnt > self.max_retry_cnt: + break + logger.info( + "A block write encountered a rate limit exceeded error" + + f" {retry_cnt} time(s). Sleeping to try again." + ) + logging.debug(e) + time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME) + + # Raise exception if retry_cnt exceeds max_retry_cnt + if retry_cnt > self.max_retry_cnt: + logger.info( + f"Maximum ({self.max_retry_cnt}) retry count exceeded. Ray" + + " will attempt to retry the block write via fault tolerance." + ) + raise RuntimeError( + f"Write failed due to {retry_cnt}" + + " repeated API rate limit exceeded responses. Consider" + + " specifiying the max_retry_cnt kwarg with a higher value." + ) + + _write_single_block = cached_remote_fn(_write_single_block) + + # Launch a remote task for each block within this write task + ray.get( + [ + _write_single_block.remote(block, self.project_id, self.dataset) + for block in blocks + ] + ) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..f60fa9f5572ccf6ef45cc00e3b1a924e43bb4534 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/bigquery_datasource.py @@ -0,0 +1,156 @@ +import logging +from typing import List, Optional + +from ray.data._internal.util import _check_import +from ray.data.block import Block, BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask + +logger = logging.getLogger(__name__) + + +def _create_user_agent() -> str: + import ray + + return f"ray/{ray.__version__}" + + +def _create_client_info(): + from google.api_core.client_info import ClientInfo + + return ClientInfo( + user_agent=_create_user_agent(), + ) + + +def _create_client_info_gapic(): + from google.api_core.gapic_v1.client_info import ClientInfo + + return ClientInfo( + user_agent=_create_user_agent(), + ) + + +def _create_client(project_id: str): + from google.cloud import bigquery + + return bigquery.Client( + project=project_id, + client_info=_create_client_info(), + ) + + +def _create_read_client(): + from google.cloud import bigquery_storage + + return bigquery_storage.BigQueryReadClient( + client_info=_create_client_info_gapic(), + ) + + +class BigQueryDatasource(Datasource): + def __init__( + self, + project_id: str, + dataset: Optional[str] = None, + query: Optional[str] = None, + ): + _check_import(self, module="google.cloud", package="bigquery") + _check_import(self, module="google.cloud", package="bigquery_storage") + _check_import(self, module="google.api_core", package="exceptions") + + self._project_id = project_id + self._dataset = dataset + self._query = query + + if query is not None and dataset is not None: + raise ValueError( + "Query and dataset kwargs cannot both be provided " + + "(must be mutually exclusive)." + ) + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + from google.cloud import bigquery_storage + + def _read_single_partition(stream) -> Block: + client = _create_read_client() + reader = client.read_rows(stream.name) + return reader.to_arrow() + + if self._query: + query_client = _create_client(project_id=self._project_id) + query_job = query_client.query(self._query) + query_job.result() + destination = str(query_job.destination) + dataset_id = destination.split(".")[-2] + table_id = destination.split(".")[-1] + else: + self._validate_dataset_table_exist(self._project_id, self._dataset) + dataset_id = self._dataset.split(".")[0] + table_id = self._dataset.split(".")[1] + + bqs_client = _create_read_client() + table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}" + + if parallelism == -1: + parallelism = None + requested_session = bigquery_storage.types.ReadSession( + table=table, + data_format=bigquery_storage.types.DataFormat.ARROW, + ) + read_session = bqs_client.create_read_session( + parent=f"projects/{self._project_id}", + read_session=requested_session, + max_stream_count=parallelism, + ) + + read_tasks = [] + logger.info("Created streams: " + str(len(read_session.streams))) + if len(read_session.streams) < parallelism: + logger.info( + "The number of streams created by the " + + "BigQuery Storage Read API is less than the requested " + + "parallelism due to the size of the dataset." + ) + + for stream in read_session.streams: + # Create a metadata block object to store schema, etc. + metadata = BlockMetadata( + num_rows=None, + size_bytes=None, + schema=None, + input_files=None, + exec_stats=None, + ) + + # Create the read task and pass the no-arg wrapper and metadata in + read_task = ReadTask( + lambda stream=stream: [_read_single_partition(stream)], + metadata, + ) + read_tasks.append(read_task) + + return read_tasks + + def estimate_inmemory_data_size(self) -> Optional[int]: + return None + + def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None: + from google.api_core import exceptions + + client = _create_client(project_id=project_id) + dataset_id = dataset.split(".")[0] + try: + client.get_dataset(dataset_id) + except exceptions.NotFound: + raise ValueError( + "Dataset {} is not found. Please ensure that it exists.".format( + dataset_id + ) + ) + + try: + client.get_table(dataset) + except exceptions.NotFound: + raise ValueError( + "Table {} is not found. Please ensure that it exists.".format(dataset) + ) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/binary_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/binary_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..31e6d89969c61cc9d12b6dc289bee35bd7cc9f55 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/binary_datasource.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from ray.data._internal.arrow_block import ArrowBlockBuilder +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +class BinaryDatasource(FileBasedDatasource): + """Binary datasource, for reading and writing binary files.""" + + _COLUMN_NAME = "bytes" + + def _read_stream(self, f: "pyarrow.NativeFile", path: str): + data = f.readall() + + builder = ArrowBlockBuilder() + item = {self._COLUMN_NAME: data} + builder.add(item) + yield builder.build() + + def _rows_per_file(self): + return 1 diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/clickhouse_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/clickhouse_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..449206d3501538a43b56c6e3cf86b1e3d9ac9683 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/clickhouse_datasource.py @@ -0,0 +1,349 @@ +import logging +import math +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +from ray.data._internal.util import _check_import +from ray.data.block import Block, BlockAccessor, BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask +from ray.util.annotations import DeveloperAPI + +logger = logging.getLogger(__name__) + + +def _is_filter_string_safe(filter_str: str) -> bool: + in_string = False + escape_next = False + for c in filter_str: + if in_string: + # If we're inside a string, check if we're closing it. + if c == "'" and not escape_next: + in_string = False + escape_next = (c == "\\") and not escape_next + else: + # If we're not in a string, entering one if we see a single quote + if c == "'": + in_string = True + escape_next = False + # Disallow semicolon if we're not in a string + elif c == ";": + return False + else: + escape_next = False + # If we end inside a string, it's suspicious, but let's allow + # it to be further validated by the DB. Just return True here. + return True + + +@DeveloperAPI +class ClickHouseDatasource(Datasource): + """ + A Ray datasource for reading from ClickHouse. + + Args: + table: Fully qualified table or view identifier (e.g., + "default.table_name"). + dsn: A string in DSN (Data Source Name) HTTP format (e.g., + "clickhouse+http://username:password@host:8124/default"). + For more information, see `ClickHouse Connection String doc + `_. + columns: Optional List of columns to select from the data source. + If no columns are specified, all columns will be selected by default. + filter: Optional SQL filter string that will be used in the + WHERE statement (e.g., "label = 2 AND text IS NOT NULL"). + The filter must be valid for use in a ClickHouse SQL WHERE clause. + Note: Parallel reads are not currently supported when a filter is set. + Specifying a filter forces the parallelism to 1 to ensure deterministic + and consistent results. For more information, see + `ClickHouse SQL WHERE Clause doc + `_. + order_by: Optional Tuple containing a list of columns to order by + and a boolean indicating the order. Note: order_by is required to + support parallelism. + client_settings: Optional ClickHouse server settings to be used with the + session/every request. For more information, see + `ClickHouse Client Settings doc + `_. + client_kwargs: Optional Additional keyword arguments to pass to the + ClickHouse client. For more information, + see `ClickHouse Core Settings doc + `_. + """ + + NUM_SAMPLE_ROWS = 100 + MIN_ROWS_PER_READ_TASK = 50 + _BASE_QUERY = "SELECT {select_clause} FROM {table}" + _EXPLAIN_FILTERS_QUERY = "EXPLAIN SELECT 1 FROM {table} WHERE {filter_clause}" + _SIZE_ESTIMATE_QUERY = "SELECT SUM(byteSize(*)) AS estimate FROM ({query})" + _COUNT_ESTIMATE_QUERY = "SELECT COUNT(*) AS estimate FROM ({query})" + _SAMPLE_BLOCK_QUERY = "{query} LIMIT {limit_row_count}" + _FIRST_BLOCK_QUERY = """ + {query} + FETCH FIRST {fetch_row_count} {fetch_row_or_rows} ONLY + """ + _NEXT_BLOCK_QUERY = """ + {query} + OFFSET {offset_row_count} {offset_row_or_rows} + FETCH NEXT {fetch_row_count} {fetch_row_or_rows} ONLY + """ + + def __init__( + self, + table: str, + dsn: str, + columns: Optional[List[str]] = None, + filter: Optional[str] = None, + order_by: Optional[Tuple[List[str], bool]] = None, + client_settings: Optional[Dict[str, Any]] = None, + client_kwargs: Optional[Dict[str, Any]] = None, + ): + self._table = table + self._dsn = dsn + self._columns = columns + self._filter = filter + self._order_by = order_by + self._client_settings = client_settings or {} + self._client_kwargs = client_kwargs or {} + self._query = self._generate_query() + + def _init_client(self): + _check_import(self, module="clickhouse_connect", package="clickhouse-connect") + import clickhouse_connect + + return clickhouse_connect.get_client( + dsn=self._dsn, + settings=self._client_settings or {}, + **self._client_kwargs or {}, + ) + + def _validate_filter(self): + if not self._filter: + return + # Minimal lexical check (regex or manual approach for semicolons, etc.). + if not _is_filter_string_safe(self._filter): + raise ValueError( + f"Invalid characters outside of " + f"string literals in filter: {self._filter}" + ) + # Test "EXPLAIN" query to confirm parse-ability and catch expression errors. + client = self._init_client() + try: + test_query = self._EXPLAIN_FILTERS_QUERY.format( + table=self._table, + filter_clause=self._filter, + ) + client.query(test_query) + except Exception as e: + raise ValueError( + f"Invalid filter expression: {self._filter}. Error: {e}", + ) + finally: + client.close() + + def _generate_query(self) -> str: + query = self._BASE_QUERY.format( + select_clause=", ".join(self._columns) if self._columns else "*", + table=self._table, + ) + if self._filter: + self._validate_filter() + query += f" WHERE {self._filter}" + if self._order_by: + columns, desc = self._order_by + direction = " DESC" if desc else "" + if len(columns) == 1: + query += f" ORDER BY {columns[0]}{direction}" + elif len(columns) > 1: + columns_clause = ", ".join(columns) + query += f" ORDER BY ({columns_clause}){direction}" + return query + + def _build_block_query(self, limit_row_count: int, offset_row_count: int) -> str: + if offset_row_count == 0: + # The first block query is optimized to use FETCH FIRST clause + # with an OFFSET specified. + return self._FIRST_BLOCK_QUERY.format( + query=self._query, + fetch_row_count=limit_row_count, + fetch_row_or_rows="ROWS" if limit_row_count > 1 else "ROW", + ) + # Subsequent block queries use OFFSET and FETCH NEXT clauses to read the + # next block of data. + return self._NEXT_BLOCK_QUERY.format( + query=self._query, + offset_row_count=offset_row_count, + offset_row_or_rows="ROWS" if offset_row_count > 1 else "ROW", + fetch_row_count=limit_row_count, + fetch_row_or_rows="ROWS" if limit_row_count > 1 else "ROW", + ) + + def _create_read_fn( + self, + query: str, + ) -> Callable[[], Iterable[Block]]: + def read_fn() -> Iterable[Block]: + return [self._execute_block_query(query)] + + return read_fn + + def _get_sampled_estimates(self): + if self._order_by is not None: + # If the query is ordered, we can use a FETCH clause to get a sample. + # This reduces the CPU overhead on ClickHouse and speeds up the + # estimation query. + query = self._FIRST_BLOCK_QUERY.format( + query=self._query, + fetch_row_count=self.NUM_SAMPLE_ROWS, + fetch_row_or_rows="ROWS" if self.NUM_SAMPLE_ROWS > 1 else "ROW", + ) + else: + # If the query is not ordered, we need to use a LIMIT clause to + # get a sample. + query = self._SAMPLE_BLOCK_QUERY.format( + query=self._query, + limit_row_count=self.NUM_SAMPLE_ROWS, + ) + sample_block_accessor = BlockAccessor.for_block( + self._execute_block_query(query) + ) + estimated_size_bytes_per_row = math.ceil( + sample_block_accessor.size_bytes() / sample_block_accessor.num_rows() + ) + sample_block_schema = sample_block_accessor.schema() + return estimated_size_bytes_per_row, sample_block_schema + + def _get_estimate_count(self) -> Optional[int]: + return self._execute_estimate_query(self._COUNT_ESTIMATE_QUERY) + + def _get_estimate_size(self) -> Optional[int]: + return self._execute_estimate_query(self._SIZE_ESTIMATE_QUERY) + + def _execute_estimate_query(self, estimate_query: str) -> Optional[int]: + client = self._init_client() + try: + # Estimate queries wrap around the primary query, self._query. + # This allows us to use self._query as a sub-query to efficiently + # and accurately estimate the size or count of the result set. + query = estimate_query.format(query=self._query) + result = client.query(query) + if result and len(result.result_rows) > 0: + estimate = result.result_rows[0][0] + return int(estimate) if estimate is not None else None + except Exception as e: + logger.warning(f"Failed to execute estimate query: {e}") + finally: + client.close() + return None + + def _execute_block_query(self, query: str) -> Block: + import pyarrow as pa + + client = self._init_client() + try: + with client.query_arrow_stream(query) as stream: + record_batches = list(stream) # Collect all record batches + return pa.Table.from_batches(record_batches) + except Exception as e: + raise RuntimeError(f"Failed to execute block query: {e}") + finally: + client.close() + + def estimate_inmemory_data_size(self) -> Optional[int]: + """ + Estimate the in-memory data size for the query. + + Returns: + Estimated in-memory data size in bytes, or + None if the estimation cannot be performed. + """ + return self._get_estimate_size() + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + """ + Create read tasks for the ClickHouse query. + + Args: + parallelism: The desired number of partitions to read the data into. + - If ``order_by`` is not set, parallelism will be forced to 1. + - If ``filter`` is set, parallelism will also be forced to 1 + to ensure deterministic results. + + Returns: + A list of read tasks to be executed. + """ + num_rows_total = self._get_estimate_count() + if num_rows_total == 0 or num_rows_total is None: + return [] + parallelism = min( + parallelism, math.ceil(num_rows_total / self.MIN_ROWS_PER_READ_TASK) + ) + # To ensure consistent order of query results, self._order_by + # must be specified and self.filter must not be specified + # in order to support parallelism. + if self._filter is not None and parallelism > 1: + logger.warning( + "ClickHouse datasource does not currently support parallel reads " + "when a filter is set; falling back to parallelism of 1." + ) + # When filter is specified and parallelism is greater than 1, + # we need to reduce parallelism to 1 to ensure consistent results. + parallelism = 1 + # To ensure consistent order of query results, self._order_by + # must be specified in order to support parallelism. + if self._order_by is None and parallelism > 1: + logger.warning( + "ClickHouse datasource requires dataset to be explicitly ordered " + "to support parallelism; falling back to parallelism of 1." + ) + # When order_by is not specified and parallelism is greater than 1, + # we need to reduce parallelism to 1 to ensure consistent results. + parallelism = 1 + # By reducing parallelism to 1 when either of the conditions above are met, + # we ensure the downstream process is treated exactly as a non-parallelized + # (single block) process would be, thus ensuring output consistency. + num_rows_per_block = num_rows_total // parallelism + num_blocks_with_extra_row = num_rows_total % parallelism + ( + estimated_size_bytes_per_row, + sample_block_schema, + ) = self._get_sampled_estimates() + + def _get_read_task( + block_rows: int, offset_rows: int, parallelized: bool + ) -> ReadTask: + if parallelized: + # When parallelized, we need to build a block query with OFFSET + # and FETCH clauses. + query = self._build_block_query(block_rows, offset_rows) + else: + # When not parallelized, we can use the original query without + # OFFSET and FETCH clauses. + query = self._query + return ReadTask( + self._create_read_fn(query), + BlockMetadata( + num_rows=block_rows, + size_bytes=estimated_size_bytes_per_row * block_rows, + schema=sample_block_schema, + input_files=None, + exec_stats=None, + ), + ) + + if parallelism == 1: + # When parallelism is 1, we can read the entire dataset in a single task. + # We then optimize this scenario by using self._query directly without + # unnecessary OFFSET and FETCH clauses. + return [_get_read_task(num_rows_total, 0, False)] + + # Otherwise we need to split the dataset into multiple tasks. + # Each task will include OFFSET and FETCH clauses to efficiently + # read a subset of the dataset. + read_tasks = [] + offset = 0 + for i in range(parallelism): + this_block_size = num_rows_per_block + if i < num_blocks_with_extra_row: + this_block_size += 1 + read_tasks.append(_get_read_task(this_block_size, offset, True)) + offset += this_block_size + return read_tasks diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/csv_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/csv_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..565abf85fe2a4d15ad16e530dba3b8eec9c643d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/csv_datasink.py @@ -0,0 +1,36 @@ +from typing import Any, Callable, Dict, Optional + +import pyarrow + +from ray.data.block import BlockAccessor +from ray.data.datasource.file_based_datasource import _resolve_kwargs +from ray.data.datasource.file_datasink import BlockBasedFileDatasink + + +class CSVDatasink(BlockBasedFileDatasink): + def __init__( + self, + path: str, + *, + arrow_csv_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + arrow_csv_args: Optional[Dict[str, Any]] = None, + file_format="csv", + **file_datasink_kwargs, + ): + super().__init__(path, file_format=file_format, **file_datasink_kwargs) + + if arrow_csv_args_fn is None: + arrow_csv_args_fn = lambda: {} # noqa: E731 + + if arrow_csv_args is None: + arrow_csv_args = {} + + self.arrow_csv_args_fn = arrow_csv_args_fn + self.arrow_csv_args = arrow_csv_args + + def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): + from pyarrow import csv + + writer_args = _resolve_kwargs(self.arrow_csv_args_fn, **self.arrow_csv_args) + write_options = writer_args.pop("write_options", None) + csv.write_csv(block.to_arrow(), file, write_options, **writer_args) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/databricks_uc_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/databricks_uc_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..860cf2be9f2f143e0a4bb9509aad676f51d7fae0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/databricks_uc_datasource.py @@ -0,0 +1,187 @@ +import json +import logging +import os +import time +from typing import List, Optional +from urllib.parse import urljoin + +import numpy as np +import pyarrow +import requests + +from ray.data.block import BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask +from ray.util.annotations import PublicAPI + +logger = logging.getLogger(__name__) + + +_STATEMENT_EXEC_POLL_TIME_S = 1 + + +@PublicAPI(stability="alpha") +class DatabricksUCDatasource(Datasource): + def __init__( + self, + host: str, + token: str, + warehouse_id: str, + catalog: str, + schema: str, + query: str, + ): + self.host = host + self.token = token + self.warehouse_id = warehouse_id + self.catalog = catalog + self.schema = schema + self.query = query + + url_base = f"https://{self.host}/api/2.0/sql/statements/" + + payload = json.dumps( + { + "statement": self.query, + "warehouse_id": self.warehouse_id, + "wait_timeout": "0s", + "disposition": "EXTERNAL_LINKS", + "format": "ARROW_STREAM", + "catalog": self.catalog, + "schema": self.schema, + } + ) + + req_headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.token, + } + + response = requests.post( + url_base, + headers=req_headers, + data=payload, + ) + response.raise_for_status() + statement_id = response.json()["statement_id"] + + state = response.json()["status"]["state"] + + logger.info(f"Waiting for query {query!r} execution result.") + try: + while state in ["PENDING", "RUNNING"]: + time.sleep(_STATEMENT_EXEC_POLL_TIME_S) + response = requests.get( + urljoin(url_base, statement_id) + "/", + headers=req_headers, + ) + response.raise_for_status() + state = response.json()["status"]["state"] + except KeyboardInterrupt: + # User cancel the command, so we cancel query execution. + requests.post( + urljoin(url_base, f"{statement_id}/cancel"), + headers=req_headers, + ) + try: + response.raise_for_status() + except Exception as e: + logger.warning( + f"Canceling query {query!r} execution failed, reason: {repr(e)}." + ) + raise + + if state != "SUCCEEDED": + raise RuntimeError(f"Query {self.query!r} execution failed.") + + manifest = response.json()["manifest"] + is_truncated = manifest["truncated"] + + if is_truncated: + logger.warning( + f"The resulting size of the dataset of '{query!r}' exceeds " + "100GiB and it is truncated." + ) + + chunks = manifest["chunks"] + + # Make chunks metadata are ordered by index. + chunks = sorted(chunks, key=lambda x: x["chunk_index"]) + num_chunks = len(chunks) + self.num_chunks = num_chunks + self._estimate_inmemory_data_size = sum(chunk["byte_count"] for chunk in chunks) + + def get_read_task(task_index, parallelism): + # get chunk list to be read in this task and preserve original chunk order + chunk_index_list = list( + np.array_split(range(num_chunks), parallelism)[task_index] + ) + + num_rows = sum( + chunks[chunk_index]["row_count"] for chunk_index in chunk_index_list + ) + size_bytes = sum( + chunks[chunk_index]["byte_count"] for chunk_index in chunk_index_list + ) + + metadata = BlockMetadata( + num_rows=num_rows, + size_bytes=size_bytes, + schema=None, + input_files=None, + exec_stats=None, + ) + + def _read_fn(): + for chunk_index in chunk_index_list: + resolve_external_link_url = urljoin( + url_base, f"{statement_id}/result/chunks/{chunk_index}" + ) + + resolve_response = requests.get( + resolve_external_link_url, headers=req_headers + ) + resolve_response.raise_for_status() + external_url = resolve_response.json()["external_links"][0][ + "external_link" + ] + # NOTE: do _NOT_ send the authorization header to external urls + raw_response = requests.get(external_url, auth=None, headers=None) + raw_response.raise_for_status() + + with pyarrow.ipc.open_stream(raw_response.content) as reader: + arrow_table = reader.read_all() + + yield arrow_table + + def read_fn(): + if mock_setup_fn_path := os.environ.get( + "RAY_DATABRICKS_UC_DATASOURCE_READ_FN_MOCK_TEST_SETUP_FN_PATH" + ): + import ray.cloudpickle as pickle + + # This is for testing. + with open(mock_setup_fn_path, "rb") as f: + mock_setup = pickle.load(f) + with mock_setup(): + yield from _read_fn() + else: + yield from _read_fn() + + return ReadTask(read_fn=read_fn, metadata=metadata) + + self._get_read_task = get_read_task + + def estimate_inmemory_data_size(self) -> Optional[int]: + return self._estimate_inmemory_data_size + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + assert parallelism > 0, f"Invalid parallelism {parallelism}" + + if parallelism > self.num_chunks: + parallelism = self.num_chunks + logger.info( + "The parallelism is reduced to chunk number due to " + "insufficient chunk parallelism." + ) + + return [self._get_read_task(index, parallelism) for index in range(parallelism)] diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/delta_sharing_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/delta_sharing_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..1909c664587a51610da9b5aee5773dcdff27faa8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/delta_sharing_datasource.py @@ -0,0 +1,126 @@ +import logging +from json import loads +from typing import List, Optional, Tuple + +import numpy as np + +from ray.data._internal.util import _check_import +from ray.data.block import BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask + +logger = logging.getLogger(__name__) + + +class DeltaSharingDatasource(Datasource): + def __init__( + self, + url: str, + json_predicate_hints: Optional[str] = None, + limit: Optional[int] = None, + version: Optional[int] = None, + timestamp: Optional[str] = None, + ): + _check_import(self, module="delta_sharing", package="delta-sharing") + + if limit is not None: + assert ( + isinstance(limit, int) and limit >= 0 + ), "'limit' must be a non-negative int" + + self._url = url + self._json_predicate_hints = json_predicate_hints + self._limit = limit + self._version = version + self._timestamp = timestamp + + def estimate_inmemory_data_size(self) -> Optional[int]: + return None + + def _read_files(self, files, converters): + """Read files with Delta Sharing.""" + from delta_sharing.reader import DeltaSharingReader + + for file in files: + yield DeltaSharingReader._to_pandas( + action=file, converters=converters, for_cdf=False, limit=None + ) + + def setup_delta_sharing_connections(self, url: str): + """ + Set up delta sharing connections based on the url. + + :param url: a url under the format "#.." + : + """ + from delta_sharing.protocol import DeltaSharingProfile, Table + from delta_sharing.rest_client import DataSharingRestClient + + profile_str, share, schema, table_str = _parse_delta_sharing_url(url) + table = Table(name=table_str, share=share, schema=schema) + + profile = DeltaSharingProfile.read_from_file(profile_str) + rest_client = DataSharingRestClient(profile) + return table, rest_client + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + assert parallelism > 0, f"Invalid parallelism {parallelism}" + from delta_sharing.converter import to_converters + + self._table, self._rest_client = self.setup_delta_sharing_connections(self._url) + self._response = self._rest_client.list_files_in_table( + self._table, + jsonPredicateHints=self._json_predicate_hints, + limitHint=self._limit, + version=self._version, + timestamp=self._timestamp, + ) + + if len(self._response.add_files) == 0 or self._limit == 0: + logger.warning("No files found from the delta sharing table or limit is 0") + + schema_json = loads(self._response.metadata.schema_string) + self._converters = to_converters(schema_json) + + read_tasks = [] + # get file list to be read in this task and preserve original chunk order + for files in np.array_split(self._response.add_files, parallelism): + files = files.tolist() + metadata = BlockMetadata( + num_rows=None, + schema=None, + input_files=files, + size_bytes=None, + exec_stats=None, + ) + converters = self._converters + read_task = ReadTask( + lambda f=files: self._read_files(f, converters), + metadata, + ) + read_tasks.append(read_task) + + return read_tasks + + +def _parse_delta_sharing_url(url: str) -> Tuple[str, str, str, str]: + """ + Developed from delta_sharing's _parse_url function. + https://github.com/delta-io/delta-sharing/blob/main/python/delta_sharing/delta_sharing.py#L36 + + Args: + url: a url under the format "#..
" + + Returns: + a tuple with parsed (profile, share, schema, table) + """ + shape_index = url.rfind("#") + if shape_index < 0: + raise ValueError(f"Invalid 'url': {url}") + profile = url[0:shape_index] + fragments = url[shape_index + 1 :].split(".") + if len(fragments) != 3: + raise ValueError(f"Invalid 'url': {url}") + share, schema, table = fragments + if len(profile) == 0 or len(share) == 0 or len(schema) == 0 or len(table) == 0: + raise ValueError(f"Invalid 'url': {url}") + return (profile, share, schema, table) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/hudi_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/hudi_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1e271ca262847d634e1bcfb3d8850ef81e73f1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/hudi_datasource.py @@ -0,0 +1,87 @@ +import logging +import os +from typing import Dict, Iterator, List, Optional + +from ray.data._internal.util import _check_import +from ray.data.block import BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask + +logger = logging.getLogger(__name__) + + +class HudiDatasource(Datasource): + """Hudi datasource, for reading Apache Hudi table.""" + + def __init__( + self, + table_uri: str, + storage_options: Optional[Dict[str, str]] = None, + ): + _check_import(self, module="hudi", package="hudi-python") + + self._table_uri = table_uri + self._storage_options = storage_options + + def get_read_tasks(self, parallelism: int) -> List["ReadTask"]: + import pyarrow + from hudi import HudiTable + + def _perform_read( + table_uri: str, + base_file_paths: List[str], + options: Dict[str, str], + ) -> Iterator["pyarrow.Table"]: + from hudi import HudiFileGroupReader + + for p in base_file_paths: + file_group_reader = HudiFileGroupReader(table_uri, options) + batch = file_group_reader.read_file_slice_by_base_file_path(p) + yield pyarrow.Table.from_batches([batch]) + + hudi_table = HudiTable(self._table_uri, self._storage_options) + + reader_options = { + **hudi_table.storage_options(), + **hudi_table.hudi_options(), + } + + schema = hudi_table.get_schema() + read_tasks = [] + for file_slices_split in hudi_table.get_file_slices_splits(parallelism): + num_rows = 0 + relative_paths = [] + input_files = [] + size_bytes = 0 + for file_slice in file_slices_split: + # A file slice in a Hudi table is a logical group of data files + # within a physical partition. Records stored in a file slice + # are associated with a commit on the Hudi table's timeline. + # For more info, see https://hudi.apache.org/docs/file_layouts + num_rows += file_slice.num_records + relative_path = file_slice.base_file_relative_path() + relative_paths.append(relative_path) + full_path = os.path.join(self._table_uri, relative_path) + input_files.append(full_path) + size_bytes += file_slice.base_file_size + + metadata = BlockMetadata( + num_rows=num_rows, + schema=schema, + input_files=input_files, + size_bytes=size_bytes, + exec_stats=None, + ) + + read_task = ReadTask( + read_fn=lambda paths=relative_paths: _perform_read( + self._table_uri, paths, reader_options + ), + metadata=metadata, + ) + read_tasks.append(read_task) + + return read_tasks + + def estimate_inmemory_data_size(self) -> Optional[int]: + # TODO(xushiyan) add APIs to provide estimated in-memory size + return None diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/huggingface_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/huggingface_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb3f9368a3c049211c5e96f16f2e90c0cf66c97 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/huggingface_datasource.py @@ -0,0 +1,176 @@ +import sys +from typing import TYPE_CHECKING, Iterable, List, Optional, Union + +from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict +from ray.data._internal.util import _check_pyarrow_version +from ray.data.block import Block, BlockAccessor, BlockMetadata +from ray.data.dataset import Dataset +from ray.data.datasource import Datasource, ReadTask + +if TYPE_CHECKING: + import datasets + + +TRANSFORMERS_IMPORT_ERROR: Optional[ImportError] = None + +try: + # Due to HF Dataset's dynamic module system, we need to dynamically import the + # datasets_modules module on every actor when training. + # We accomplish this by simply running the following bit of code directly + # in the module you are currently viewing. This ensures that when we + # unpickle the Dataset, it runs before pickle tries to + # import datasets_modules and prevents an exception from being thrown. + # Same logic is present inside HF Transformers Ray + # integration: https://github.com/huggingface/transformers/blob/\ + # 7d5fde991d598370d961be8cb7add6541e2b59ce/src/transformers/integrations.py#L271 + # Also see https://github.com/ray-project/ray/issues/28084 + from transformers.utils import is_datasets_available + + if "datasets_modules" not in sys.modules and is_datasets_available(): + import importlib + import os + + import datasets.load + + dynamic_modules_path = os.path.join( + datasets.load.init_dynamic_modules(), "__init__.py" + ) + # load dynamic_modules from path + spec = importlib.util.spec_from_file_location( + "datasets_modules", dynamic_modules_path + ) + datasets_modules = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = datasets_modules + spec.loader.exec_module(datasets_modules) +except ImportError as e: + TRANSFORMERS_IMPORT_ERROR = e + + +class HuggingFaceDatasource(Datasource): + """Hugging Face Dataset datasource, for reading from a + `Hugging Face Datasets Dataset `_. + This Datasource implements a streamed read using a + single read task, most beneficial for a + `Hugging Face Datasets IterableDataset `_ + or datasets which are too large to fit in-memory. + For an in-memory Hugging Face Dataset (`datasets.Dataset`), use :meth:`~ray.data.from_huggingface` + directly for faster performance. + """ # noqa: E501 + + def __init__( + self, + dataset: Union["datasets.Dataset", "datasets.IterableDataset"], + batch_size: int = 4096, + ): + if TRANSFORMERS_IMPORT_ERROR is not None: + raise TRANSFORMERS_IMPORT_ERROR + + self._dataset = dataset + self._batch_size = batch_size + + @classmethod + def list_parquet_urls_from_dataset( + cls, dataset: Union["datasets.Dataset", "datasets.IterableDataset"] + ) -> Dataset: + """Return list of Hugging Face hosted parquet file URLs if they + exist for the data (i.e. if the dataset is a public dataset that + has not been transformed) else return an empty list.""" + import datasets + + # We can use the dataset name, config name, and split name to load + # public hugging face datasets from the Hugging Face Hub. More info + # here: https://huggingface.co/docs/datasets-server/parquet + dataset_name = dataset.info.dataset_name + config_name = dataset.info.config_name + split_name = str(dataset.split) + + # If a dataset is not an iterable dataset, we will check if the + # dataset with the matching dataset name, config name, and split name + # on the Hugging Face Hub has the same fingerprint as the + # dataset passed into this function. If it is not matching, transforms + # or other operations have been performed so we cannot use the parquet + # files on the Hugging Face Hub, so we return an empty list. + if not isinstance(dataset, datasets.IterableDataset): + from datasets import load_dataset + + try: + ds = load_dataset(dataset_name, config_name, split=split_name) + if ds._fingerprint != dataset._fingerprint: + return [] + except Exception: + # If an exception is thrown when trying to reload the dataset + # we should exit gracefully by returning an empty list. + return [] + + import requests + + public_url = ( + f"https://huggingface.co/api/datasets/{dataset_name}" + f"/parquet/{config_name}/{split_name}" + ) + resp = requests.get(public_url) + if resp.status_code == requests.codes["ok"]: + # dataset corresponds to a public dataset, return list of parquet_files + return resp.json() + else: + return [] + + def estimate_inmemory_data_size(self) -> Optional[int]: + return self._dataset.dataset_size + + def get_read_tasks( + self, + parallelism: int, + ) -> List[ReadTask]: + # Note: `parallelism` arg is currently not used by HuggingFaceDatasource. + # We always generate a single ReadTask to perform the read. + _check_pyarrow_version() + import numpy as np + import pandas as pd + import pyarrow + + def _read_dataset(dataset: "datasets.IterableDataset") -> Iterable[Block]: + for batch in dataset.with_format("arrow").iter(batch_size=self._batch_size): + # HuggingFace IterableDatasets do not fully support methods like + # `set_format`, `with_format`, and `formatted_as`, so the dataset + # can return whatever is the default configured batch type, even if + # the format is manually overriden before iterating above. + # Therefore, we limit support to batch formats which have native + # block types in Ray Data (pyarrow.Table, pd.DataFrame), + # or can easily be converted to such (dict, np.array). + # See: https://github.com/huggingface/datasets/issues/3444 + if not isinstance(batch, (pyarrow.Table, pd.DataFrame, dict, np.array)): + raise ValueError( + f"Batch format {type(batch)} isn't supported. Only the " + f"following batch formats are supported: " + f"dict (corresponds to `None` in `dataset.with_format()`), " + f"pyarrow.Table, np.array, pd.DataFrame." + ) + # Ensure np.arrays are wrapped in a dict + # (subsequently converted to a pyarrow.Table). + if isinstance(batch, np.ndarray): + batch = {"item": batch} + if isinstance(batch, dict): + batch = pyarrow_table_from_pydict(batch) + # Ensure that we return the default block type. + block = BlockAccessor.for_block(batch).to_default() + yield block + + # TODO(scottjlee): IterableDataset doesn't provide APIs + # for getting number of rows, byte size, etc., so the + # BlockMetadata is currently empty. Properly retrieve + # or calculate these so that progress bars have meaning. + meta = BlockMetadata( + num_rows=None, + size_bytes=None, + schema=None, + input_files=None, + exec_stats=None, + ) + read_tasks: List[ReadTask] = [ + ReadTask( + lambda hfds=self._dataset: _read_dataset(hfds), + meta, + ) + ] + return read_tasks diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/iceberg_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/iceberg_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..064c2ea26c789bbf09afd4f4f281a923f6bf6196 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/iceberg_datasource.py @@ -0,0 +1,261 @@ +""" +Module to read an iceberg table into a Ray Dataset, by using the Ray Datasource API. +""" + +import heapq +import itertools +import logging +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +from ray.data._internal.util import _check_import +from ray.data.block import Block, BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + from pyiceberg.catalog import Catalog + from pyiceberg.expressions import BooleanExpression + from pyiceberg.io import FileIO + from pyiceberg.manifest import DataFile + from pyiceberg.schema import Schema + from pyiceberg.table import DataScan, FileScanTask, Table + from pyiceberg.table.metadata import TableMetadata + +logger = logging.getLogger(__name__) + + +def _get_read_task( + tasks: Iterable["FileScanTask"], + table_io: "FileIO", + table_metadata: "TableMetadata", + row_filter: "BooleanExpression", + case_sensitive: bool, + limit: Optional[int], + schema: "Schema", +) -> Iterable[Block]: + from pyiceberg.io import pyarrow as pyi_pa_io + + # Use the PyIceberg API to read only a single task (specifically, a + # FileScanTask) - note that this is not as simple as reading a single + # parquet file, as there might be delete files, etc. associated, so we + # must use the PyIceberg API for the projection. + yield pyi_pa_io.project_table( + tasks=tasks, + table_metadata=table_metadata, + io=table_io, + row_filter=row_filter, + projected_schema=schema, + case_sensitive=case_sensitive, + limit=limit, + ) + + +@DeveloperAPI +class IcebergDatasource(Datasource): + """ + Iceberg datasource to read Iceberg tables into a Ray Dataset. This module heavily + uses PyIceberg to read iceberg tables. All the routines in this class override + `ray.data.Datasource`. + """ + + def __init__( + self, + table_identifier: str, + row_filter: Union[str, "BooleanExpression"] = None, + selected_fields: Tuple[str, ...] = ("*",), + snapshot_id: Optional[int] = None, + scan_kwargs: Optional[Dict[str, Any]] = None, + catalog_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Initialize an IcebergDatasource. + + Args: + table_identifier: Fully qualified table identifier (i.e., + "db_name.table_name") + row_filter: A PyIceberg BooleanExpression to use to filter the data *prior* + to reading + selected_fields: Which columns from the data to read, passed directly to + PyIceberg's load functions + snapshot_id: Optional snapshot ID for the Iceberg table + scan_kwargs: Optional arguments to pass to PyIceberg's Table.scan() + function + catalog_kwargs: Optional arguments to use when setting up the Iceberg + catalog + """ + _check_import(self, module="pyiceberg", package="pyiceberg") + from pyiceberg.expressions import AlwaysTrue + + self._scan_kwargs = scan_kwargs if scan_kwargs is not None else {} + self._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {} + + if "name" in self._catalog_kwargs: + self._catalog_name = self._catalog_kwargs.pop("name") + else: + self._catalog_name = "default" + + self.table_identifier = table_identifier + + self._row_filter = row_filter if row_filter is not None else AlwaysTrue() + self._selected_fields = selected_fields + + if snapshot_id: + self._scan_kwargs["snapshot_id"] = snapshot_id + + self._plan_files = None + self._table = None + + def _get_catalog(self) -> "Catalog": + from pyiceberg import catalog + + return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs) + + @property + def table(self) -> "Table": + """ + Return the table reference from the catalog + """ + if self._table is None: + catalog = self._get_catalog() + self._table = catalog.load_table(self.table_identifier) + return self._table + + @property + def plan_files(self) -> List["FileScanTask"]: + """ + Return the plan files specified by this query + """ + # Calculate and cache the plan_files if they don't already exist + if self._plan_files is None: + data_scan = self._get_data_scan() + self._plan_files = data_scan.plan_files() + + return self._plan_files + + def _get_data_scan(self) -> "DataScan": + + data_scan = self.table.scan( + row_filter=self._row_filter, + selected_fields=self._selected_fields, + **self._scan_kwargs, + ) + + return data_scan + + def estimate_inmemory_data_size(self) -> Optional[int]: + # Approximate the size by using the plan files - this will not + # incorporate the deletes, but that's a reasonable approximation + # task + return sum(task.file.file_size_in_bytes for task in self.plan_files) + + @staticmethod + def _distribute_tasks_into_equal_chunks( + plan_files: Iterable["FileScanTask"], n_chunks: int + ) -> List[List["FileScanTask"]]: + """ + Implement a greedy knapsack algorithm to distribute the files in the scan + across tasks, based on their file size, as evenly as possible + """ + chunks = [list() for _ in range(n_chunks)] + + chunk_sizes = [(0, chunk_id) for chunk_id in range(n_chunks)] + heapq.heapify(chunk_sizes) + + # From largest to smallest, add the plan files to the smallest chunk one at a + # time + for plan_file in sorted( + plan_files, key=lambda f: f.file.file_size_in_bytes, reverse=True + ): + smallest_chunk = heapq.heappop(chunk_sizes) + chunks[smallest_chunk[1]].append(plan_file) + heapq.heappush( + chunk_sizes, + ( + smallest_chunk[0] + plan_file.file.file_size_in_bytes, + smallest_chunk[1], + ), + ) + + return chunks + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + from pyiceberg.io import pyarrow as pyi_pa_io + from pyiceberg.manifest import DataFileContent + + # Get the PyIceberg scan + data_scan = self._get_data_scan() + # Get the plan files in this query + plan_files = self.plan_files + + # Get the projected schema for this scan, given all the row filters, + # snapshot ID, etc. + projected_schema = data_scan.projection() + # Get the arrow schema, to set in the metadata + pya_schema = pyi_pa_io.schema_to_pyarrow(projected_schema) + + # Set the n_chunks to the min of the number of plan files and the actual + # requested n_chunks, so that there are no empty tasks + if parallelism > len(list(plan_files)): + parallelism = len(list(plan_files)) + logger.warning( + f"Reducing the parallelism to {parallelism}, as that is the" + "number of files" + ) + + # Get required properties for reading tasks - table IO, table metadata, + # row filter, case sensitivity,limit and projected schema to pass + # them directly to `_get_read_task` to avoid capture of `self` reference + # within the closure carrying substantial overhead invoking these tasks + # + # See https://github.com/ray-project/ray/issues/49107 for more context + table_io = self.table.io + table_metadata = self.table.metadata + row_filter = self._row_filter + case_sensitive = self._scan_kwargs.get("case_sensitive", True) + limit = self._scan_kwargs.get("limit") + + get_read_task = partial( + _get_read_task, + table_io=table_io, + table_metadata=table_metadata, + row_filter=row_filter, + case_sensitive=case_sensitive, + limit=limit, + schema=projected_schema, + ) + + read_tasks = [] + # Chunk the plan files based on the requested parallelism + for chunk_tasks in IcebergDatasource._distribute_tasks_into_equal_chunks( + plan_files, parallelism + ): + unique_deletes: Set[DataFile] = set( + itertools.chain.from_iterable( + [task.delete_files for task in chunk_tasks] + ) + ) + # Get a rough estimate of the number of deletes by just looking at + # position deletes. Equality deletes are harder to estimate, as they + # can delete multiple rows. + position_delete_count = sum( + delete.record_count + for delete in unique_deletes + if delete.content == DataFileContent.POSITION_DELETES + ) + metadata = BlockMetadata( + num_rows=sum(task.file.record_count for task in chunk_tasks) + - position_delete_count, + size_bytes=sum(task.length for task in chunk_tasks), + schema=pya_schema, + input_files=[task.file.file_path for task in chunk_tasks], + exec_stats=None, + ) + read_tasks.append( + ReadTask( + read_fn=lambda tasks=chunk_tasks: get_read_task(tasks), + metadata=metadata, + ) + ) + + return read_tasks diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..ac561fbaaa9f40bc9e9562eb75186bdea23433d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasink.py @@ -0,0 +1,24 @@ +import io +from typing import Any, Dict + +import pyarrow + +from ray.data.datasource.file_datasink import RowBasedFileDatasink + + +class ImageDatasink(RowBasedFileDatasink): + def __init__( + self, path: str, column: str, file_format: str, **file_datasink_kwargs + ): + super().__init__(path, file_format=file_format, **file_datasink_kwargs) + + self.column = column + self.file_format = file_format + + def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"): + from PIL import Image + + image = Image.fromarray(row[self.column]) + buffer = io.BytesIO() + image.save(buffer, format=self.file_format) + file.write(buffer.getvalue()) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..bcbf481f863b9020474bc4216b8ba70c771f29f0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/image_datasource.py @@ -0,0 +1,175 @@ +import io +import logging +import time +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union + +import numpy as np + +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.util import _check_import +from ray.data.block import Block, BlockMetadata +from ray.data.datasource.file_based_datasource import FileBasedDatasource +from ray.data.datasource.file_meta_provider import DefaultFileMetadataProvider + +if TYPE_CHECKING: + import pyarrow + + +logger = logging.getLogger(__name__) + +# The default size multiplier for reading image data source. +# This essentially is using image on-disk file size to estimate +# in-memory data size. +IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT = 1 + +# The lower bound value to estimate image encoding ratio. +IMAGE_ENCODING_RATIO_ESTIMATE_LOWER_BOUND = 0.5 + + +class ImageDatasource(FileBasedDatasource): + """A datasource that lets you read images.""" + + _WRITE_FILE_PER_ROW = True + _FILE_EXTENSIONS = ["png", "jpg", "jpeg", "tif", "tiff", "bmp", "gif"] + # Use 8 threads per task to read image files. + _NUM_THREADS_PER_TASK = 8 + + def __init__( + self, + paths: Union[str, List[str]], + size: Optional[Tuple[int, int]] = None, + mode: Optional[str] = None, + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + _check_import(self, module="PIL", package="Pillow") + + if size is not None and len(size) != 2: + raise ValueError( + "Expected `size` to contain two integers for height and width, " + f"but got {len(size)} integers instead." + ) + + if size is not None and (size[0] < 0 or size[1] < 0): + raise ValueError( + f"Expected `size` to contain positive integers, but got {size} instead." + ) + + self.size = size + self.mode = mode + + meta_provider = file_based_datasource_kwargs.get("meta_provider", None) + if isinstance(meta_provider, ImageFileMetadataProvider): + self._encoding_ratio = self._estimate_files_encoding_ratio() + meta_provider._set_encoding_ratio(self._encoding_ratio) + else: + self._encoding_ratio = IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT + + def _read_stream( + self, + f: "pyarrow.NativeFile", + path: str, + ) -> Iterator[Block]: + from PIL import Image, UnidentifiedImageError + + data = f.readall() + + try: + image = Image.open(io.BytesIO(data)) + except UnidentifiedImageError as e: + raise ValueError(f"PIL couldn't load image file at path '{path}'.") from e + + if self.size is not None: + height, width = self.size + image = image.resize((width, height), resample=Image.BILINEAR) + if self.mode is not None: + image = image.convert(self.mode) + + builder = DelegatingBlockBuilder() + array = np.array(image) + item = {"image": array} + builder.add(item) + block = builder.build() + + yield block + + def _rows_per_file(self): + return 1 + + def estimate_inmemory_data_size(self) -> Optional[int]: + total_size = 0 + for file_size in self._file_sizes(): + # NOTE: check if file size is not None, because some metadata provider + # such as FastFileMetadataProvider does not provide file size information. + if file_size is not None: + total_size += file_size + return total_size * self._encoding_ratio + + def _estimate_files_encoding_ratio(self) -> float: + """Return an estimate of the image files encoding ratio.""" + start_time = time.perf_counter() + # Filter out empty file to avoid noise. + non_empty_path_and_size = list( + filter(lambda p: p[1] > 0, zip(self._paths(), self._file_sizes())) + ) + num_files = len(non_empty_path_and_size) + if num_files == 0: + logger.warn( + "All input image files are empty. " + "Use on-disk file size to estimate images in-memory size." + ) + return IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT + + if self.size is not None and self.mode is not None: + # Use image size and mode to calculate data size for all images, + # because all images are homogeneous with same size after resizing. + # Resizing is enforced when reading every image in `ImageDatasource` + # when `size` argument is provided. + if self.mode in ["1", "L", "P"]: + dimension = 1 + elif self.mode in ["RGB", "YCbCr", "LAB", "HSV"]: + dimension = 3 + elif self.mode in ["RGBA", "CMYK", "I", "F"]: + dimension = 4 + else: + logger.warn(f"Found unknown image mode: {self.mode}.") + return IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT + height, width = self.size + single_image_size = height * width * dimension + total_estimated_size = single_image_size * num_files + total_file_size = sum(p[1] for p in non_empty_path_and_size) + ratio = total_estimated_size / total_file_size + else: + # TODO(chengsu): sample images to estimate data size + ratio = IMAGE_ENCODING_RATIO_ESTIMATE_DEFAULT + + sampling_duration = time.perf_counter() - start_time + if sampling_duration > 5: + logger.warn( + "Image input size estimation took " + f"{round(sampling_duration, 2)} seconds." + ) + logger.debug(f"Estimated image encoding ratio from sampling is {ratio}.") + return max(ratio, IMAGE_ENCODING_RATIO_ESTIMATE_LOWER_BOUND) + + +class ImageFileMetadataProvider(DefaultFileMetadataProvider): + def _set_encoding_ratio(self, encoding_ratio: int): + """Set image file encoding ratio, to provide accurate size in bytes metadata.""" + self._encoding_ratio = encoding_ratio + + def _get_block_metadata( + self, + paths: List[str], + schema: Optional[Union[type, "pyarrow.lib.Schema"]], + *, + rows_per_file: Optional[int], + file_sizes: List[Optional[int]], + ) -> BlockMetadata: + metadata = super()._get_block_metadata( + paths, schema, rows_per_file=rows_per_file, file_sizes=file_sizes + ) + if metadata.size_bytes is not None: + metadata.size_bytes = int(metadata.size_bytes * self._encoding_ratio) + return metadata diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..c81cea387616a2bf48859c0401c2d1a77e6cd77c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasink.py @@ -0,0 +1,36 @@ +from typing import Any, Callable, Dict, Optional + +import pyarrow + +from ray.data.block import BlockAccessor +from ray.data.datasource.file_based_datasource import _resolve_kwargs +from ray.data.datasource.file_datasink import BlockBasedFileDatasink + + +class JSONDatasink(BlockBasedFileDatasink): + def __init__( + self, + path: str, + *, + pandas_json_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + pandas_json_args: Optional[Dict[str, Any]] = None, + file_format: str = "json", + **file_datasink_kwargs, + ): + super().__init__(path, file_format=file_format, **file_datasink_kwargs) + + if pandas_json_args_fn is None: + pandas_json_args_fn = lambda: {} # noqa: E731 + + if pandas_json_args is None: + pandas_json_args = {} + + self.pandas_json_args_fn = pandas_json_args_fn + self.pandas_json_args = pandas_json_args + + def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): + writer_args = _resolve_kwargs(self.pandas_json_args_fn, **self.pandas_json_args) + orient = writer_args.pop("orient", "records") + lines = writer_args.pop("lines", True) + + block.to_pandas().to_json(file, orient=orient, lines=lines, **writer_args) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..282a5b9a43115c458f19f9a013b7df530b34a160 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/json_datasource.py @@ -0,0 +1,154 @@ +import logging +from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict +from ray.data.context import DataContext +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + +logger = logging.getLogger(__name__) + + +class JSONDatasource(FileBasedDatasource): + """JSON datasource, for reading and writing JSON and JSONL files.""" + + _FILE_EXTENSIONS = [ + "json", + "jsonl", + # gzip-compressed files + "json.gz", + "jsonl.gz", + # Brotli-compressed fi;es + "json.br", + "jsonl.br", + # Zstandard-compressed files + "json.zst", + "jsonl.zst", + # lz4-compressed files + "json.lz4", + "jsonl.lz4", + ] + + def __init__( + self, + paths: Union[str, List[str]], + *, + arrow_json_args: Optional[Dict[str, Any]] = None, + **file_based_datasource_kwargs, + ): + from pyarrow import json + + super().__init__(paths, **file_based_datasource_kwargs) + + if arrow_json_args is None: + arrow_json_args = {} + + self.read_options = arrow_json_args.pop( + "read_options", json.ReadOptions(use_threads=False) + ) + self.arrow_json_args = arrow_json_args + + def _read_with_pyarrow_read_json(self, buffer: "pyarrow.lib.Buffer"): + """Read with PyArrow JSON reader, trying to auto-increase the + read block size in the case of the read object + straddling block boundaries.""" + import pyarrow as pa + + # When reading large files, the default block size configured in PyArrow can be + # too small, resulting in the following error: `pyarrow.lib.ArrowInvalid: + # straddling object straddles two block boundaries (try to increase block + # size?)`. More information on this issue can be found here: + # https://github.com/apache/arrow/issues/25674 + # The read will be retried with geometrically increasing block size + # until the size reaches `DataContext.get_current().target_max_block_size`. + # The initial block size will start at the PyArrow default block size + # or it can be manually set through the `read_options` parameter as follows. + # >>> import pyarrow.json as pajson + # >>> block_size = 10 << 20 # Set block size to 10MB + # >>> ray.data.read_json( # doctest: +SKIP + # ... "s3://anonymous@ray-example-data/log.json", + # ... read_options=pajson.ReadOptions(block_size=block_size) + # ... ) + + init_block_size = self.read_options.block_size + max_block_size = DataContext.get_current().target_max_block_size + while True: + try: + yield pa.json.read_json( + BytesIO(buffer), + read_options=self.read_options, + **self.arrow_json_args, + ) + self.read_options.block_size = init_block_size + break + except pa.ArrowInvalid as e: + if "straddling object straddles two block boundaries" in str(e): + if self.read_options.block_size < max_block_size: + # Increase the block size in case it was too small. + logger.debug( + f"JSONDatasource read failed with " + f"block_size={self.read_options.block_size}. Retrying with " + f"block_size={self.read_options.block_size * 2}." + ) + self.read_options.block_size *= 2 + else: + raise pa.ArrowInvalid( + f"{e} - Auto-increasing block size to " + f"{self.read_options.block_size} bytes failed. " + f"Please try manually increasing the block size through " + f"the `read_options` parameter to a larger size. " + f"For example: `read_json(..., read_options=" + f"pyarrow.json.ReadOptions(block_size=10 << 25))`" + f"More information on this issue can be found here: " + f"https://github.com/apache/arrow/issues/25674" + ) + else: + # unrelated error, simply reraise + raise e + + def _read_with_python_json(self, buffer: "pyarrow.lib.Buffer"): + """Fallback method to read JSON files with Python's native json.load(), + in case the default pyarrow json reader fails.""" + import json + + import pyarrow as pa + + # Check if the buffer is empty + if buffer.size == 0: + return + + parsed_json = json.load(BytesIO(buffer)) + try: + yield pa.Table.from_pylist(parsed_json) + except AttributeError as e: + # For PyArrow < 7.0.0, `pa.Table.from_pylist()` is not available. + # Construct a dict from the list and call + # `pa.Table.from_pydict()` instead. + assert "no attribute 'from_pylist'" in str(e), str(e) + from collections import defaultdict + + dct = defaultdict(list) + for row in parsed_json: + for k, v in row.items(): + dct[k].append(v) + yield pyarrow_table_from_pydict(dct) + + # TODO(ekl) The PyArrow JSON reader doesn't support streaming reads. + def _read_stream(self, f: "pyarrow.NativeFile", path: str): + import pyarrow as pa + + buffer: pa.lib.Buffer = f.read_buffer() + + try: + yield from self._read_with_pyarrow_read_json(buffer) + except pa.ArrowInvalid as e: + # If read with PyArrow fails, try falling back to native json.load(). + logger.warning( + f"Error reading with pyarrow.json.read_json(). " + f"Falling back to native json.load(), which may be slower. " + f"PyArrow error was:\n{e}" + ) + yield from self._read_with_python_json(buffer) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/lance_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/lance_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..2854aa0e62a551d4cab71602dbf64b17c8a356b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/lance_datasource.py @@ -0,0 +1,129 @@ +import logging +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional + +import numpy as np + +from ray.data._internal.util import _check_import, call_with_retry +from ray.data.block import BlockMetadata +from ray.data.context import DataContext +from ray.data.datasource.datasource import Datasource, ReadTask + +if TYPE_CHECKING: + import pyarrow + + +logger = logging.getLogger(__name__) + + +class LanceDatasource(Datasource): + """Lance datasource, for reading Lance dataset.""" + + # Errors to retry when reading Lance fragments. + READ_FRAGMENTS_ERRORS_TO_RETRY = ["LanceError(IO)"] + # Maximum number of attempts to read Lance fragments. + READ_FRAGMENTS_MAX_ATTEMPTS = 10 + # Maximum backoff seconds between attempts to read Lance fragments. + READ_FRAGMENTS_RETRY_MAX_BACKOFF_SECONDS = 32 + + def __init__( + self, + uri: str, + columns: Optional[List[str]] = None, + filter: Optional[str] = None, + storage_options: Optional[Dict[str, str]] = None, + scanner_options: Optional[Dict[str, Any]] = None, + ): + _check_import(self, module="lance", package="pylance") + + import lance + + self.uri = uri + self.scanner_options = scanner_options or {} + if columns is not None: + self.scanner_options["columns"] = columns + if filter is not None: + self.scanner_options["filter"] = filter + self.storage_options = storage_options + self.lance_ds = lance.dataset(uri=uri, storage_options=storage_options) + + match = [] + match.extend(self.READ_FRAGMENTS_ERRORS_TO_RETRY) + match.extend(DataContext.get_current().retried_io_errors) + self._retry_params = { + "description": "read lance fragments", + "match": match, + "max_attempts": self.READ_FRAGMENTS_MAX_ATTEMPTS, + "max_backoff_s": self.READ_FRAGMENTS_RETRY_MAX_BACKOFF_SECONDS, + } + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + read_tasks = [] + for fragments in np.array_split(self.lance_ds.get_fragments(), parallelism): + if len(fragments) <= 0: + continue + + fragment_ids = [f.metadata.id for f in fragments] + num_rows = sum(f.count_rows() for f in fragments) + input_files = [ + data_file.path() for f in fragments for data_file in f.data_files() + ] + + # TODO(chengsu): Take column projection into consideration for schema. + metadata = BlockMetadata( + num_rows=num_rows, + schema=fragments[0].schema, + input_files=input_files, + size_bytes=None, + exec_stats=None, + ) + scanner_options = self.scanner_options + lance_ds = self.lance_ds + retry_params = self._retry_params + + read_task = ReadTask( + lambda f=fragment_ids: _read_fragments_with_retry( + f, + lance_ds, + scanner_options, + retry_params, + ), + metadata, + ) + read_tasks.append(read_task) + + return read_tasks + + def estimate_inmemory_data_size(self) -> Optional[int]: + # TODO(chengsu): Add memory size estimation to improve auto-tune of parallelism. + return None + + +def _read_fragments_with_retry( + fragment_ids, + lance_ds, + scanner_options, + retry_params, +) -> Iterator["pyarrow.Table"]: + return call_with_retry( + lambda: _read_fragments(fragment_ids, lance_ds, scanner_options), + **retry_params, + ) + + +def _read_fragments( + fragment_ids, + lance_ds, + scanner_options, +) -> Iterator["pyarrow.Table"]: + """Read Lance fragments in batches. + + NOTE: Use fragment ids, instead of fragments as parameter, because pickling + LanceFragment is expensive. + """ + import pyarrow + + fragments = [lance_ds.get_fragment(id) for id in fragment_ids] + scanner_options["fragments"] = fragments + scanner = lance_ds.scanner(**scanner_options) + for batch in scanner.to_reader(): + yield pyarrow.Table.from_batches([batch]) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b61d17612244959ed803161cb1ec2634798441 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasink.py @@ -0,0 +1,48 @@ +import logging +from typing import Iterable + +from ray.data._internal.datasource.mongo_datasource import ( + _validate_database_collection_exist, +) +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import TaskContext +from ray.data._internal.util import _check_import +from ray.data.block import Block, BlockAccessor +from ray.data.datasource.datasink import Datasink + +logger = logging.getLogger(__name__) + + +class MongoDatasink(Datasink[None]): + def __init__(self, uri: str, database: str, collection: str) -> None: + _check_import(self, module="pymongo", package="pymongo") + _check_import(self, module="pymongoarrow", package="pymongoarrow") + + self.uri = uri + self.database = database + self.collection = collection + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> None: + import pymongo + + _validate_database_collection_exist( + pymongo.MongoClient(self.uri), self.database, self.collection + ) + + def write_block(uri: str, database: str, collection: str, block: Block): + from pymongoarrow.api import write + + block = BlockAccessor.for_block(block).to_arrow() + client = pymongo.MongoClient(uri) + write(client[database][collection], block) + + builder = DelegatingBlockBuilder() + for block in blocks: + builder.add_block(block) + block = builder.build() + + write_block(self.uri, self.database, self.collection, block) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c87e10e507cf8bd1570b29587d09e93059b043 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/mongo_datasource.py @@ -0,0 +1,140 @@ +import logging +from typing import TYPE_CHECKING, Dict, List, Optional + +from ray.data.block import Block, BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask + +if TYPE_CHECKING: + import pymongoarrow.api + +logger = logging.getLogger(__name__) + + +class MongoDatasource(Datasource): + """Datasource for reading from and writing to MongoDB.""" + + def __init__( + self, + uri: str, + database: str, + collection: str, + pipeline: Optional[List[Dict]] = None, + schema: Optional["pymongoarrow.api.Schema"] = None, + **mongo_args, + ): + self._uri = uri + self._database = database + self._collection = collection + self._pipeline = pipeline + self._schema = schema + self._mongo_args = mongo_args + # If pipeline is unspecified, read the entire collection. + if not pipeline: + self._pipeline = [{"$match": {"_id": {"$exists": "true"}}}] + # Initialize Mongo client lazily later when creating read tasks. + self._client = None + + def estimate_inmemory_data_size(self) -> Optional[int]: + # TODO(jian): Add memory size estimation to improve auto-tune of parallelism. + return None + + def _get_match_query(self, pipeline: List[Dict]) -> Dict: + if len(pipeline) == 0 or "$match" not in pipeline[0]: + return {} + return pipeline[0]["$match"] + + def _get_or_create_client(self): + import pymongo + + if self._client is None: + self._client = pymongo.MongoClient(self._uri) + _validate_database_collection_exist( + self._client, self._database, self._collection + ) + self._avg_obj_size = self._client[self._database].command( + "collstats", self._collection + )["avgObjSize"] + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + from bson.objectid import ObjectId + + self._get_or_create_client() + coll = self._client[self._database][self._collection] + match_query = self._get_match_query(self._pipeline) + partitions_ids = list( + coll.aggregate( + [ + {"$match": match_query}, + {"$bucketAuto": {"groupBy": "$_id", "buckets": parallelism}}, + ], + allowDiskUse=True, + ) + ) + + def make_block( + uri: str, + database: str, + collection: str, + pipeline: List[Dict], + min_id: ObjectId, + max_id: ObjectId, + right_closed: bool, + schema: "pymongoarrow.api.Schema", + kwargs: dict, + ) -> Block: + import pymongo + from pymongoarrow.api import aggregate_arrow_all + + # A range query over the partition. + match = [ + { + "$match": { + "_id": { + "$gte": min_id, + "$lte" if right_closed else "$lt": max_id, + } + } + } + ] + client = pymongo.MongoClient(uri) + return aggregate_arrow_all( + client[database][collection], match + pipeline, schema=schema, **kwargs + ) + + read_tasks: List[ReadTask] = [] + + for i, partition in enumerate(partitions_ids): + metadata = BlockMetadata( + num_rows=partition["count"], + size_bytes=partition["count"] * self._avg_obj_size, + schema=None, + input_files=None, + exec_stats=None, + ) + make_block_args = ( + self._uri, + self._database, + self._collection, + self._pipeline, + partition["_id"]["min"], + partition["_id"]["max"], + i == len(partitions_ids) - 1, + self._schema, + self._mongo_args, + ) + read_task = ReadTask( + lambda args=make_block_args: [make_block(*args)], + metadata, + ) + read_tasks.append(read_task) + + return read_tasks + + +def _validate_database_collection_exist(client, database: str, collection: str): + db_names = client.list_database_names() + if database not in db_names: + raise ValueError(f"The destination database {database} doesn't exist.") + collection_names = client[database].list_collection_names() + if collection not in collection_names: + raise ValueError(f"The destination collection {collection} doesn't exist.") diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc6f3d4df43d81b299cdc4da9a182798158a454 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasink.py @@ -0,0 +1,23 @@ +import numpy as np +import pyarrow + +from ray.data.block import BlockAccessor +from ray.data.datasource.file_datasink import BlockBasedFileDatasink + + +class NumpyDatasink(BlockBasedFileDatasink): + def __init__( + self, + path: str, + column: str, + *, + file_format: str = "npy", + **file_datasink_kwargs, + ): + super().__init__(path, file_format=file_format, **file_datasink_kwargs) + + self.column = column + + def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): + value = block.to_numpy(self.column) + np.save(file, value) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7d09d44910060be971cb9ce03d58f4a4d4c0b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/numpy_datasource.py @@ -0,0 +1,41 @@ +from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union + +import numpy as np + +from ray.data.block import Block, BlockAccessor +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +class NumpyDatasource(FileBasedDatasource): + """Numpy datasource, for reading and writing Numpy files.""" + + _COLUMN_NAME = "data" + _FILE_EXTENSIONS = ["npy"] + + def __init__( + self, + paths: Union[str, List[str]], + numpy_load_args: Optional[Dict[str, Any]] = None, + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + if numpy_load_args is None: + numpy_load_args = {} + + self.numpy_load_args = numpy_load_args + + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + # TODO(ekl) Ideally numpy can read directly from the file, but it + # seems like it requires the file to be seekable. + buf = BytesIO() + data = f.readall() + buf.write(data) + buf.seek(0) + yield BlockAccessor.batch_to_block( + {"data": np.load(buf, allow_pickle=True, **self.numpy_load_args)} + ) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_bulk_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_bulk_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..72d2015713ead6d1f16dcac07dbb27f2f5a0f421 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_bulk_datasource.py @@ -0,0 +1,51 @@ +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +logger = logging.getLogger(__name__) + + +class ParquetBulkDatasource(FileBasedDatasource): + """Minimal Parquet datasource, for reading and writing Parquet files.""" + + _FILE_EXTENSIONS = ["parquet"] + + def __init__( + self, + paths: Union[str, List[str]], + read_table_args: Optional[Dict[str, Any]] = None, + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + if read_table_args is None: + read_table_args = {} + + self.read_table_args = read_table_args + + def get_name(self): + """Return a human-readable name for this datasource. + This will be used as the names of the read tasks. + Note: overrides the base `FileBasedDatasource` method. + """ + return "ParquetBulk" + + def _read_stream(self, f: "pyarrow.NativeFile", path: str): + import pyarrow.parquet as pq + + use_threads = self.read_table_args.pop("use_threads", False) + yield pq.read_table(f, use_threads=use_threads, **self.read_table_args) + + def _open_input_source( + self, + filesystem: "pyarrow.fs.FileSystem", + path: str, + **open_args, + ) -> "pyarrow.NativeFile": + # Parquet requires `open_input_file` due to random access reads + return filesystem.open_input_file(path, **open_args) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..fdef44aff3299229ce8c49efae2a4eabcf222ce4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasink.py @@ -0,0 +1,172 @@ +import logging +import posixpath +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional + +from ray.data._internal.arrow_ops.transform_pyarrow import concat +from ray.data._internal.execution.interfaces import TaskContext +from ray.data._internal.util import call_with_retry +from ray.data.block import Block, BlockAccessor +from ray.data.context import DataContext +from ray.data.datasource.file_based_datasource import _resolve_kwargs +from ray.data.datasource.file_datasink import _FileDatasink +from ray.data.datasource.filename_provider import FilenameProvider + +if TYPE_CHECKING: + import pyarrow + +WRITE_FILE_MAX_ATTEMPTS = 10 +WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32 + +logger = logging.getLogger(__name__) + + +class ParquetDatasink(_FileDatasink): + def __init__( + self, + path: str, + *, + partition_cols: Optional[List[str]] = None, + arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + arrow_parquet_args: Optional[Dict[str, Any]] = None, + min_rows_per_file: Optional[int] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + dataset_uuid: Optional[str] = None, + ): + if arrow_parquet_args_fn is None: + arrow_parquet_args_fn = lambda: {} # noqa: E731 + + if arrow_parquet_args is None: + arrow_parquet_args = {} + + self.arrow_parquet_args_fn = arrow_parquet_args_fn + self.arrow_parquet_args = arrow_parquet_args + self.min_rows_per_file = min_rows_per_file + self.partition_cols = partition_cols + + super().__init__( + path, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=open_stream_args, + filename_provider=filename_provider, + dataset_uuid=dataset_uuid, + file_format="parquet", + ) + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> None: + import pyarrow as pa + + blocks = list(blocks) + + if all(BlockAccessor.for_block(block).num_rows() == 0 for block in blocks): + return + + filename = self.filename_provider.get_filename_for_block( + blocks[0], ctx.task_idx, 0 + ) + write_kwargs = _resolve_kwargs( + self.arrow_parquet_args_fn, **self.arrow_parquet_args + ) + user_schema = write_kwargs.pop("schema", None) + + def write_blocks_to_path(): + tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks] + if user_schema is None: + output_schema = pa.unify_schemas([table.schema for table in tables]) + else: + output_schema = user_schema + + if not self.partition_cols: + self._write_single_file(tables, filename, output_schema, write_kwargs) + else: # partition writes + self._write_partition_files( + tables, filename, output_schema, write_kwargs + ) + + logger.debug(f"Writing {filename} file to {self.path}.") + + call_with_retry( + write_blocks_to_path, + description=f"write '{filename}' to '{self.path}'", + match=DataContext.get_current().retried_io_errors, + max_attempts=WRITE_FILE_MAX_ATTEMPTS, + max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS, + ) + + def _write_single_file( + self, + tables: List["pyarrow.Table"], + filename: str, + output_schema: "pyarrow.Schema", + write_kwargs: Dict[str, Any], + ) -> None: + import pyarrow.parquet as pq + + write_path = posixpath.join(self.path, filename) + with self.open_output_stream(write_path) as file: + with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer: + for table in tables: + table = table.cast(output_schema) + writer.write_table(table) + + def _write_partition_files( + self, + tables: List["pyarrow.Table"], + filename: str, + output_schema: "pyarrow.Schema", + write_kwargs: Dict[str, Any], + ) -> None: + import pyarrow as pa + import pyarrow.parquet as pq + + table = concat(tables) + # Create unique combinations of the partition columns + table_fields = [ + field for field in output_schema if field.name not in self.partition_cols + ] + non_partition_cols = [f.name for f in table_fields] + output_schema = pa.schema( + [field for field in output_schema if field.name not in self.partition_cols] + ) + # Group the table by partition keys + # For each partition key combination fetch list of values + # for the non partition columns + # Ex: Here original table contain + # two columns (a, b). We are paritioning by column a. The schema + # of `groups` grouped Table is as follows + # b_list: [[[0,0],[1,1],[2,2]]] + # a: [[1,2,3]] + groups = table.group_by(self.partition_cols).aggregate( + [(col_name, "list") for col_name in non_partition_cols] + ) + grouped_keys = [groups.column(k) for k in self.partition_cols] + + for i in range(groups.num_rows): + # See https://github.com/apache/arrow/issues/14882 for recommended approach + values = [ + groups.column(f"{col.name}_list")[i].values for col in table_fields + ] + group_table = pa.Table.from_arrays(values, names=non_partition_cols) + partition_path = "/".join( + [ + f"{col}={values[i]}" + for col, values in zip(self.partition_cols, grouped_keys) + ] + ) + write_path = posixpath.join(self.path, partition_path) + self._create_dir(write_path) + write_path = posixpath.join(write_path, filename) + with self.open_output_stream(write_path) as file: + with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer: + writer.write_table(group_table) + + @property + def min_rows_per_write(self) -> Optional[int]: + return self.min_rows_per_file diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..b15d27baa2baf17aae55462ea21ec078756c4916 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/parquet_datasource.py @@ -0,0 +1,731 @@ +import logging +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Union, +) + +import numpy as np +from packaging.version import parse as parse_version + +import ray +import ray.cloudpickle as cloudpickle +from ray._private.utils import _get_pyarrow_version +from ray.data._internal.progress_bar import ProgressBar +from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.util import ( + _check_pyarrow_version, + _is_local_scheme, + call_with_retry, + iterate_with_retry, +) +from ray.data.block import Block +from ray.data.context import DataContext +from ray.data.datasource import Datasource +from ray.data.datasource.datasource import ReadTask +from ray.data.datasource.file_based_datasource import FileShuffleConfig +from ray.data.datasource.file_meta_provider import ( + DefaultFileMetadataProvider, + _handle_read_os_error, +) +from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider +from ray.data.datasource.partitioning import ( + PartitionDataType, + Partitioning, + PathPartitionFilter, + PathPartitionParser, +) +from ray.data.datasource.path_util import ( + _has_file_extension, + _resolve_paths_and_filesystem, +) + +if TYPE_CHECKING: + import pyarrow + from pyarrow.dataset import ParquetFileFragment + + +logger = logging.getLogger(__name__) + +# The `num_cpus` for each metadata prefetching task. +# Default to 0.5 instead of 1 because it is cheaper than normal read task. +NUM_CPUS_FOR_META_FETCH_TASK = 0.5 + +# The number of rows to read per batch. This is sized to generate 10MiB batches +# for rows about 1KiB in size. +PARQUET_READER_ROW_BATCH_SIZE = 10_000 +FILE_READING_RETRY = 8 + +# The default size multiplier for reading Parquet data source in Arrow. +# Parquet data format is encoded with various encoding techniques (such as +# dictionary, RLE, delta), so Arrow in-memory representation uses much more memory +# compared to Parquet encoded representation. Parquet file statistics only record +# encoded (i.e. uncompressed) data size information. +# +# To estimate real-time in-memory data size, Datasets will try to estimate the +# correct inflation ratio from Parquet to Arrow, using this constant as the default +# value for safety. See https://github.com/ray-project/ray/pull/26516 for more context. +PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT = 5 + +# The lower bound size to estimate Parquet encoding ratio. +PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND = 1 + +# The percentage of files (1% by default) to be sampled from the dataset to estimate +# Parquet encoding ratio. +PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO = 0.01 + +# The minimal and maximal number of file samples to take from the dataset to estimate +# Parquet encoding ratio. +# This is to restrict `PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO` within the +# proper boundary. +PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES = 2 +PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES = 10 + +# The number of rows to read from each file for sampling. Try to keep it low to avoid +# reading too much data into memory. +PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024 + + +@dataclass(frozen=True) +class _SampleInfo: + actual_bytes_per_row: Optional[int] + estimated_bytes_per_row: Optional[int] + + +# TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a +# raw pyarrow file fragment causes S3 network calls. +class SerializedFragment: + def __init__(self, frag: "ParquetFileFragment"): + self._data = cloudpickle.dumps( + (frag.format, frag.path, frag.filesystem, frag.partition_expression) + ) + + def deserialize(self) -> "ParquetFileFragment": + # Implicitly trigger S3 subsystem initialization by importing + # pyarrow.fs. + import pyarrow.fs # noqa: F401 + + (file_format, path, filesystem, partition_expression) = cloudpickle.loads( + self._data + ) + return file_format.make_fragment(path, filesystem, partition_expression) + + +# Visible for test mocking. +def _deserialize_fragments( + serialized_fragments: List[SerializedFragment], +) -> List["pyarrow._dataset.ParquetFileFragment"]: + return [p.deserialize() for p in serialized_fragments] + + +def check_for_legacy_tensor_type(schema): + """Check for the legacy tensor extension type and raise an error if found. + + Ray Data uses an extension type to represent tensors in Arrow tables. Previously, + the extension type extended `PyExtensionType`. However, this base type can expose + users to arbitrary code execution. To prevent this, we don't load the type by + default. + """ + import pyarrow as pa + + for name, type in zip(schema.names, schema.types): + if isinstance(type, pa.UnknownExtensionType) and isinstance( + type, pa.PyExtensionType + ): + raise RuntimeError( + f"Ray Data couldn't infer the type of column '{name}'. This might mean " + "you're trying to read data written with an older version of Ray. " + "Reading data written with older versions of Ray might expose you to " + "arbitrary code execution. To try reading the data anyway, set " + "`RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE=1` on *all* nodes." + "To learn more, see https://github.com/ray-project/ray/issues/41314." + ) + + +class ParquetDatasource(Datasource): + """Parquet datasource, for reading and writing Parquet files. + + The primary difference from ParquetBulkDatasource is that this uses + PyArrow's `ParquetDataset` abstraction for dataset reads, and thus offers + automatic Arrow dataset schema inference and row count collection at the + cost of some potential performance and/or compatibility penalties. + """ + + def __init__( + self, + paths: Union[str, List[str]], + *, + columns: Optional[List[str]] = None, + dataset_kwargs: Optional[Dict[str, Any]] = None, + to_batch_kwargs: Optional[Dict[str, Any]] = None, + _block_udf: Optional[Callable[[Block], Block]] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None, + meta_provider: ParquetMetadataProvider = ParquetMetadataProvider(), + partition_filter: PathPartitionFilter = None, + partitioning: Optional[Partitioning] = Partitioning("hive"), + shuffle: Union[Literal["files"], None] = None, + include_paths: bool = False, + file_extensions: Optional[List[str]] = None, + ): + _check_pyarrow_version() + + import pyarrow as pa + + self._supports_distributed_reads = not _is_local_scheme(paths) + if not self._supports_distributed_reads and ray.util.client.ray.is_connected(): + raise ValueError( + "Because you're using Ray Client, read tasks scheduled on the Ray " + "cluster can't access your local files. To fix this issue, store " + "files in cloud storage or a distributed filesystem like NFS." + ) + + self._local_scheduling = None + if not self._supports_distributed_reads: + from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + self._local_scheduling = NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), soft=False + ) + + paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem) + + # HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet + # files. To avoid this, we expand the input paths with the default metadata + # provider and then apply the partition filter or file extensions. + if partition_filter is not None or file_extensions is not None: + default_meta_provider = DefaultFileMetadataProvider() + expanded_paths, _ = map( + list, zip(*default_meta_provider.expand_paths(paths, filesystem)) + ) + + paths = list(expanded_paths) + if partition_filter is not None: + paths = partition_filter(paths) + if file_extensions is not None: + paths = [ + path for path in paths if _has_file_extension(path, file_extensions) + ] + + filtered_paths = set(expanded_paths) - set(paths) + if filtered_paths: + logger.info(f"Filtered out {len(filtered_paths)} paths") + + if dataset_kwargs is None: + dataset_kwargs = {} + + if "partitioning" in dataset_kwargs: + raise ValueError( + "The 'partitioning' parameter isn't supported in 'dataset_kwargs'. " + "Use the top-level 'partitioning' parameter instead." + ) + + # This datasource manually adds partition data at the Ray Data-level. To avoid + # duplicating the partition data, we disable PyArrow's partitioning. + dataset_kwargs["partitioning"] = None + + # `read_schema` is the schema object that will be used to perform + # read operations. + # It should be None, unless user has specified the schema or columns. + # We don't use the inferred schema for read, because the pyarrow only infers + # schema based on the first file. Thus, files with different schemas will end + # up producing blocks with wrong schema. + # See https://github.com/ray-project/ray/issues/47960 for more context. + read_schema = schema + pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs) + + if schema is None: + schema = pq_ds.schema + schema = _add_partition_fields_to_schema(partitioning, schema, pq_ds) + + if columns: + schema = pa.schema( + [schema.field(column) for column in columns], schema.metadata + ) + read_schema = schema + + check_for_legacy_tensor_type(schema) + + if _block_udf is not None: + # Try to infer dataset schema by passing dummy table through UDF. + dummy_table = schema.empty_table() + try: + schema = _block_udf(dummy_table).schema.with_metadata(schema.metadata) + except Exception: + logger.debug( + "Failed to infer schema of dataset by passing dummy table " + "through UDF due to the following exception:", + exc_info=True, + ) + + try: + prefetch_remote_args = {} + prefetch_remote_args["num_cpus"] = NUM_CPUS_FOR_META_FETCH_TASK + if self._local_scheduling: + prefetch_remote_args["scheduling_strategy"] = self._local_scheduling + else: + # Use the scheduling strategy ("SPREAD" by default) provided in + # `DataContext``, to spread out prefetch tasks in cluster, avoid + # AWS S3 throttling error. + # Note: this is the same scheduling strategy used by read tasks. + prefetch_remote_args[ + "scheduling_strategy" + ] = DataContext.get_current().scheduling_strategy + + self._metadata = ( + meta_provider.prefetch_file_metadata( + pq_ds.fragments, **prefetch_remote_args + ) + or [] + ) + except OSError as e: + _handle_read_os_error(e, paths) + + if to_batch_kwargs is None: + to_batch_kwargs = {} + + # NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected + # network calls when `_ParquetDatasourceReader` is serialized. See + # `_SerializedFragment()` implementation for more details. + self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments] + self._pq_paths = [p.path for p in pq_ds.fragments] + self._meta_provider = meta_provider + self._block_udf = _block_udf + self._to_batches_kwargs = to_batch_kwargs + self._columns = columns + self._read_schema = read_schema + self._schema = schema + self._file_metadata_shuffler = None + self._include_paths = include_paths + self._partitioning = partitioning + if shuffle == "files": + self._file_metadata_shuffler = np.random.default_rng() + elif isinstance(shuffle, FileShuffleConfig): + self._file_metadata_shuffler = np.random.default_rng(shuffle.seed) + + sample_infos = sample_fragments( + self._pq_fragments, + to_batches_kwargs=to_batch_kwargs, + columns=columns, + schema=self._read_schema, + local_scheduling=self._local_scheduling, + ) + self._encoding_ratio = estimate_files_encoding_ratio(sample_infos) + self._default_read_batch_size_rows = estimate_default_read_batch_size_rows( + sample_infos + ) + + def estimate_inmemory_data_size(self) -> Optional[int]: + total_size = 0 + for file_metadata in self._metadata: + total_size += file_metadata.total_byte_size + return total_size * self._encoding_ratio + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + # NOTE: We override the base class FileBasedDatasource.get_read_tasks() + # method in order to leverage pyarrow's ParquetDataset abstraction, + # which simplifies partitioning logic. We still use + # FileBasedDatasource's write side, however. + pq_metadata = self._metadata + if len(pq_metadata) < len(self._pq_fragments): + # Pad `pq_metadata` to be same length of `self._pq_fragments`. + # This can happen when no file metadata being prefetched. + pq_metadata += [None] * (len(self._pq_fragments) - len(pq_metadata)) + + if self._file_metadata_shuffler is not None: + files_metadata = list(zip(self._pq_fragments, self._pq_paths, pq_metadata)) + shuffled_files_metadata = [ + files_metadata[i] + for i in self._file_metadata_shuffler.permutation(len(files_metadata)) + ] + pq_fragments, pq_paths, pq_metadata = list( + map(list, zip(*shuffled_files_metadata)) + ) + else: + pq_fragments, pq_paths, pq_metadata = ( + self._pq_fragments, + self._pq_paths, + pq_metadata, + ) + + read_tasks = [] + for fragments, paths, metadata in zip( + np.array_split(pq_fragments, parallelism), + np.array_split(pq_paths, parallelism), + np.array_split(pq_metadata, parallelism), + ): + if len(fragments) <= 0: + continue + + meta = self._meta_provider( + paths, + self._schema, + num_fragments=len(fragments), + prefetched_metadata=metadata, + ) + # If there is a filter operation, reset the calculated row count, + # since the resulting row count is unknown. + if self._to_batches_kwargs.get("filter") is not None: + meta.num_rows = None + + if meta.size_bytes is not None: + meta.size_bytes = int(meta.size_bytes * self._encoding_ratio) + + ( + block_udf, + to_batches_kwargs, + default_read_batch_size_rows, + columns, + read_schema, + include_paths, + partitioning, + ) = ( + self._block_udf, + self._to_batches_kwargs, + self._default_read_batch_size_rows, + self._columns, + self._read_schema, + self._include_paths, + self._partitioning, + ) + read_tasks.append( + ReadTask( + lambda f=fragments: read_fragments( + block_udf, + to_batches_kwargs, + default_read_batch_size_rows, + columns, + read_schema, + f, + include_paths, + partitioning, + ), + meta, + ) + ) + + return read_tasks + + def get_name(self): + """Return a human-readable name for this datasource. + + This will be used as the names of the read tasks. + """ + return "Parquet" + + @property + def supports_distributed_reads(self) -> bool: + return self._supports_distributed_reads + + +def read_fragments( + block_udf, + to_batches_kwargs, + default_read_batch_size_rows, + columns, + schema, + serialized_fragments: List[SerializedFragment], + include_paths: bool, + partitioning: Partitioning, +) -> Iterator["pyarrow.Table"]: + # This import is necessary to load the tensor extension type. + from ray.data.extensions.tensor_extension import ArrowTensorType # noqa + + # Deserialize after loading the filesystem class. + fragments: List[ + "pyarrow._dataset.ParquetFileFragment" + ] = _deserialize_fragments_with_retry(serialized_fragments) + + # Ensure that we're reading at least one dataset fragment. + assert len(fragments) > 0 + + import pyarrow as pa + + logger.debug(f"Reading {len(fragments)} parquet fragments") + use_threads = to_batches_kwargs.pop("use_threads", False) + batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows) + for fragment in fragments: + partitions = {} + if partitioning is not None: + parse = PathPartitionParser(partitioning) + partitions = parse(fragment.path) + + # Filter out partitions that aren't in the user-specified columns list. + if columns is not None: + partitions = { + field_name: value + for field_name, value in partitions.items() + if field_name in columns + } + + def get_batch_iterable(): + return fragment.to_batches( + use_threads=use_threads, + columns=columns, + schema=schema, + batch_size=batch_size, + **to_batches_kwargs, + ) + + # S3 can raise transient errors during iteration, and PyArrow doesn't expose a + # way to retry specific batches. + ctx = ray.data.DataContext.get_current() + for batch in iterate_with_retry( + get_batch_iterable, "load batch", match=ctx.retried_io_errors + ): + table = pa.Table.from_batches([batch], schema=schema) + if include_paths: + table = table.append_column("path", [[fragment.path]] * len(table)) + if partitions: + table = _add_partitions_to_table(partitions, table) + + # If the table is empty, drop it. + if table.num_rows > 0: + if block_udf is not None: + yield block_udf(table) + else: + yield table + + +def _deserialize_fragments_with_retry(fragments): + # The deserialization retry helps when the upstream datasource is not able to + # handle overloaded read request or failed with some retriable failures. + # For example when reading data from HA hdfs service, hdfs might + # lose connection for some unknown reason expecially when + # simutaneously running many hyper parameter tuning jobs + # with ray.data parallelism setting at high value like the default 200 + # Such connection failure can be restored with some waiting and retry. + return call_with_retry( + lambda: _deserialize_fragments(fragments), + description="deserialize fragments", + max_attempts=FILE_READING_RETRY, + ) + + +def _sample_fragment( + to_batches_kwargs, + columns, + schema, + file_fragment: SerializedFragment, +) -> _SampleInfo: + # Sample the first rows batch from file fragment `serialized_fragment`. + fragment = _deserialize_fragments_with_retry([file_fragment])[0] + + # Only sample the first row group. + fragment = fragment.subset(row_group_ids=[0]) + batch_size = max( + min(fragment.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1 + ) + # Use the batch_size calculated above, and ignore the one specified by user if set. + # This is to avoid sampling too few or too many rows. + to_batches_kwargs.pop("batch_size", None) + batches = fragment.to_batches( + columns=columns, + schema=schema, + batch_size=batch_size, + **to_batches_kwargs, + ) + # Use first batch in-memory size for estimation. + try: + batch = next(batches) + except StopIteration: + sample_data = _SampleInfo( + actual_bytes_per_row=None, estimated_bytes_per_row=None + ) + else: + if batch.num_rows > 0: + metadata = fragment.metadata + total_size = 0 + for idx in range(metadata.num_row_groups): + total_size += metadata.row_group(idx).total_byte_size + sample_data = _SampleInfo( + actual_bytes_per_row=batch.nbytes / batch.num_rows, + estimated_bytes_per_row=total_size / metadata.num_rows, + ) + else: + sample_data = _SampleInfo( + actual_bytes_per_row=None, estimated_bytes_per_row=None + ) + return sample_data + + +def estimate_files_encoding_ratio(sample_infos: List[_SampleInfo]) -> float: + """Return an estimate of the Parquet files encoding ratio. + + To avoid OOMs, it is safer to return an over-estimate than an underestimate. + """ + if not DataContext.get_current().decoding_size_estimation: + return PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT + + def compute_encoding_ratio(sample_info: _SampleInfo) -> float: + if ( + sample_info.actual_bytes_per_row is None + or sample_info.estimated_bytes_per_row is None + ): + return PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND + else: + return ( + sample_info.actual_bytes_per_row / sample_info.estimated_bytes_per_row + ) + + ratio = np.mean(list(map(compute_encoding_ratio, sample_infos))) + logger.debug(f"Estimated Parquet encoding ratio from sampling is {ratio}.") + return max(ratio, PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND) + + +def estimate_default_read_batch_size_rows(sample_infos: List[_SampleInfo]) -> int: + def compute_batch_size_rows(sample_info: _SampleInfo) -> int: + # 'actual_bytes_per_row' is None if the sampled file was empty and 0 if the data + # was all null. + if not sample_info.actual_bytes_per_row: + return PARQUET_READER_ROW_BATCH_SIZE + else: + max_parquet_reader_row_batch_size_bytes = ( + DataContext.get_current().target_max_block_size // 10 + ) + return max( + 1, + min( + PARQUET_READER_ROW_BATCH_SIZE, + max_parquet_reader_row_batch_size_bytes + // sample_info.actual_bytes_per_row, + ), + ) + + return np.mean(list(map(compute_batch_size_rows, sample_infos))) + + +def get_parquet_dataset(paths, filesystem, dataset_kwargs): + import pyarrow.parquet as pq + + # If you pass a list containing a single directory path to `ParquetDataset`, PyArrow + # errors with 'IsADirectoryError: Path ... points to a directory, but only file + # paths are supported'. To avoid this, we pass the directory path directly. + if len(paths) == 1: + paths = paths[0] + + try: + # The `use_legacy_dataset` parameter is deprecated in Arrow 15. + if parse_version(_get_pyarrow_version()) >= parse_version("15.0.0"): + dataset = pq.ParquetDataset( + paths, + **dataset_kwargs, + filesystem=filesystem, + ) + else: + dataset = pq.ParquetDataset( + paths, + **dataset_kwargs, + filesystem=filesystem, + use_legacy_dataset=False, + ) + except OSError as e: + _handle_read_os_error(e, paths) + + return dataset + + +def sample_fragments( + serialized_fragments, + *, + to_batches_kwargs, + columns, + schema, + local_scheduling=None, +) -> List[_SampleInfo]: + # Sample a few rows from Parquet files to estimate the encoding ratio. + # Launch tasks to sample multiple files remotely in parallel. + # Evenly distributed to sample N rows in i-th row group in i-th file. + # TODO(ekl/cheng) take into account column pruning. + num_files = len(serialized_fragments) + num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO) + min_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files) + max_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES, num_files) + num_samples = max(min(num_samples, max_num_samples), min_num_samples) + + # Evenly distributed to choose which file to sample, to avoid biased prediction + # if data is skewed. + file_samples = [ + serialized_fragments[idx] + for idx in np.linspace(0, num_files - 1, num_samples).astype(int).tolist() + ] + + sample_fragment = cached_remote_fn(_sample_fragment) + futures = [] + scheduling = local_scheduling or DataContext.get_current().scheduling_strategy + for sample in file_samples: + # Sample the first rows batch in i-th file. + # Use SPREAD scheduling strategy to avoid packing many sampling tasks on + # same machine to cause OOM issue, as sampling can be memory-intensive. + futures.append( + sample_fragment.options( + scheduling_strategy=scheduling, + # Retry in case of transient errors during sampling. + retry_exceptions=[OSError], + ).remote( + to_batches_kwargs, + columns, + schema, + sample, + ) + ) + sample_bar = ProgressBar("Parquet Files Sample", len(futures), unit="file") + sample_infos = sample_bar.fetch_until_complete(futures) + sample_bar.close() + + return sample_infos + + +def _add_partitions_to_table( + partitions: Dict[str, PartitionDataType], table: "pyarrow.Table" +) -> "pyarrow.Table": + import pyarrow as pa + + for field_name, value in partitions.items(): + column = pa.array([value] * len(table)) + field_index = table.schema.get_field_index(field_name) + if field_index != -1: + table = table.set_column(field_index, field_name, column) + else: + table = table.append_column(field_name, column) + + return table + + +def _add_partition_fields_to_schema( + partitioning: Partitioning, + schema: "pyarrow.Schema", + parquet_dataset: "pyarrow.dataset.Dataset", +) -> "pyarrow.Schema": + """Return a new schema with partition fields added. + + This function infers the partition fields from the first file path in the dataset. + """ + import pyarrow as pa + + # If the dataset is empty, we can't infer the partitioning. + if len(parquet_dataset.fragments) == 0: + return schema + + # If the dataset isn't partitioned, we don't need to add any fields. + if partitioning is None: + return schema + + first_path = parquet_dataset.fragments[0].path + parse = PathPartitionParser(partitioning) + partitions = parse(first_path) + for field_name in partitions: + if field_name in partitioning.field_types: + field_type = pa.from_numpy_dtype(partitioning.field_types[field_name]) + else: + field_type = pa.string() + schema = schema.append(pa.field(field_name, field_type)) + + return schema diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/range_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/range_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..13f9dbf0015c8cde118c19b144d4884e2ad923b8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/range_datasource.py @@ -0,0 +1,139 @@ +import builtins +import functools +from copy import copy +from typing import Iterable, List, Optional, Tuple + +import numpy as np + +from ray.data._internal.util import _check_pyarrow_version +from ray.data.block import Block, BlockAccessor, BlockMetadata +from ray.data.context import DataContext +from ray.data.datasource import Datasource, ReadTask + + +class RangeDatasource(Datasource): + """An example datasource that generates ranges of numbers from [0..n).""" + + def __init__( + self, + n: int, + block_format: str = "arrow", + tensor_shape: Tuple = (1,), + column_name: Optional[str] = None, + ): + self._n = int(n) + self._block_format = block_format + self._tensor_shape = tensor_shape + self._column_name = column_name + + def estimate_inmemory_data_size(self) -> Optional[int]: + if self._block_format == "tensor": + element_size = int(np.prod(self._tensor_shape)) + else: + element_size = 1 + return 8 * self._n * element_size + + def get_read_tasks( + self, + parallelism: int, + ) -> List[ReadTask]: + read_tasks: List[ReadTask] = [] + n = self._n + block_format = self._block_format + tensor_shape = self._tensor_shape + block_size = max(1, n // parallelism) + # TODO(swang): This target block size may not match the driver's + # context if it was overridden. Set target max block size during + # optimizer stage to fix this. + ctx = DataContext.get_current() + if self._n == 0: + target_rows_per_block = 0 + else: + row_size_bytes = self.estimate_inmemory_data_size() // self._n + row_size_bytes = max(row_size_bytes, 1) + target_rows_per_block = max(1, ctx.target_max_block_size // row_size_bytes) + + # Example of a read task. In a real datasource, this would pull data + # from an external system instead of generating dummy data. + def make_block(start: int, count: int) -> Block: + if block_format == "arrow": + import pyarrow as pa + + return pa.Table.from_arrays( + [np.arange(start, start + count)], + names=[self._column_name or "value"], + ) + elif block_format == "tensor": + import pyarrow as pa + + tensor = np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( + np.arange(start, start + count), + tuple(range(1, 1 + len(tensor_shape))), + ) + return BlockAccessor.batch_to_block( + {self._column_name: tensor} if self._column_name else tensor + ) + else: + return list(builtins.range(start, start + count)) + + def make_blocks( + start: int, count: int, target_rows_per_block: int + ) -> Iterable[Block]: + while count > 0: + num_rows = min(count, target_rows_per_block) + yield make_block(start, num_rows) + start += num_rows + count -= num_rows + + if block_format == "tensor": + element_size = int(np.prod(tensor_shape)) + else: + element_size = 1 + + i = 0 + while i < n: + count = min(block_size, n - i) + meta = BlockMetadata( + num_rows=count, + size_bytes=8 * count * element_size, + schema=copy(self._schema), + input_files=None, + exec_stats=None, + ) + read_tasks.append( + ReadTask( + lambda i=i, count=count: make_blocks( + i, count, target_rows_per_block + ), + meta, + ) + ) + i += block_size + + return read_tasks + + @functools.cached_property + def _schema(self): + if self._n == 0: + return None + + if self._block_format == "arrow": + _check_pyarrow_version() + import pyarrow as pa + + schema = pa.Table.from_pydict({self._column_name or "value": [0]}).schema + elif self._block_format == "tensor": + _check_pyarrow_version() + import pyarrow as pa + + tensor = np.ones(self._tensor_shape, dtype=np.int64) * np.expand_dims( + np.arange(0, 10), tuple(range(1, 1 + len(self._tensor_shape))) + ) + schema = BlockAccessor.batch_to_block( + {self._column_name: tensor} if self._column_name else tensor + ).schema + elif self._block_format == "list": + schema = int + else: + raise ValueError("Unsupported block type", self._block_format) + return schema diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/sql_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/sql_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..f773a186c58c3facff499301a5b1bd5eff97f7c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/sql_datasink.py @@ -0,0 +1,35 @@ +from typing import Callable, Iterable + +from ray.data._internal.datasource.sql_datasource import Connection, _connect +from ray.data._internal.execution.interfaces import TaskContext +from ray.data.block import Block, BlockAccessor +from ray.data.datasource.datasink import Datasink + + +class SQLDatasink(Datasink[None]): + + _MAX_ROWS_PER_WRITE = 128 + + def __init__(self, sql: str, connection_factory: Callable[[], Connection]): + self.sql = sql + self.connection_factory = connection_factory + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> None: + with _connect(self.connection_factory) as cursor: + for block in blocks: + block_accessor = BlockAccessor.for_block(block) + + values = [] + for row in block_accessor.iter_rows(public_row_format=False): + values.append(tuple(row.values())) + assert len(values) <= self._MAX_ROWS_PER_WRITE, len(values) + if len(values) == self._MAX_ROWS_PER_WRITE: + cursor.executemany(self.sql, values) + values = [] + + if values: + cursor.executemany(self.sql, values) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/text_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/text_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..b4213a0ec85475c4f6ee3220c13447b39e3d3b75 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/text_datasource.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING, Iterator, List + +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data.block import Block +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +class TextDatasource(FileBasedDatasource): + """Text datasource, for reading and writing text files.""" + + _COLUMN_NAME = "text" + + def __init__( + self, + paths: List[str], + *, + drop_empty_lines: bool = False, + encoding: str = "utf-8", + **file_based_datasource_kwargs + ): + super().__init__(paths, **file_based_datasource_kwargs) + + self.drop_empty_lines = drop_empty_lines + self.encoding = encoding + + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + data = f.readall() + + builder = DelegatingBlockBuilder() + + lines = data.decode(self.encoding).split("\n") + for line in lines: + if self.drop_empty_lines and line.strip() == "": + continue + item = {self._COLUMN_NAME: line} + builder.add(item) + + block = builder.build() + yield block diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..ba89f26b9d66ad3a5f79cb3ecb9a12d2fc1791d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/tfrecords_datasink.py @@ -0,0 +1,205 @@ +import struct +from typing import TYPE_CHECKING, Dict, Iterable, Optional, Union + +import numpy as np + +from .tfrecords_datasource import _get_single_true_type +from ray.data._internal.util import _check_import +from ray.data.block import BlockAccessor +from ray.data.datasource.file_datasink import BlockBasedFileDatasink + +if TYPE_CHECKING: + import pyarrow + import tensorflow as tf + from tensorflow_metadata.proto.v0 import schema_pb2 + + +class TFRecordDatasink(BlockBasedFileDatasink): + def __init__( + self, + path: str, + *, + tf_schema: Optional["schema_pb2.Schema"] = None, + file_format: str = "tar", + **file_datasink_kwargs, + ): + super().__init__(path, file_format=file_format, **file_datasink_kwargs) + + _check_import(self, module="crc32c", package="crc32c") + + self.tf_schema = tf_schema + + def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): + arrow_table = block.to_arrow() + + # It seems like TFRecords are typically row-based, + # https://www.tensorflow.org/tutorials/load_data/tfrecord#writing_a_tfrecord_file_2 + # so we must iterate through the rows of the block, + # serialize to tf.train.Example proto, and write to file. + + examples = _convert_arrow_table_to_examples(arrow_table, self.tf_schema) + + # Write each example to the arrow file in the TFRecord format. + for example in examples: + _write_record(file, example) + + +def _convert_arrow_table_to_examples( + arrow_table: "pyarrow.Table", + tf_schema: Optional["schema_pb2.Schema"] = None, +) -> Iterable["tf.train.Example"]: + import tensorflow as tf + + schema_dict = {} + # Convert user-specified schema into dict for convenient mapping + if tf_schema is not None: + for schema_feature in tf_schema.feature: + schema_dict[schema_feature.name] = schema_feature.type + + # Serialize each row[i] of the block to a tf.train.Example and yield it. + for i in range(arrow_table.num_rows): + # First, convert row[i] to a dictionary. + features: Dict[str, "tf.train.Feature"] = {} + for name in arrow_table.column_names: + if tf_schema is not None and name not in schema_dict: + raise ValueError( + f"Found extra unexpected feature {name} " + f"not in specified schema: {tf_schema}" + ) + schema_feature_type = schema_dict.get(name) + features[name] = _value_to_feature( + arrow_table[name][i], + schema_feature_type, + ) + + # Convert the dictionary to an Example proto. + proto = tf.train.Example(features=tf.train.Features(feature=features)) + + yield proto + + +def _value_to_feature( + value: Union["pyarrow.Scalar", "pyarrow.Array"], + schema_feature_type: Optional["schema_pb2.FeatureType"] = None, +) -> "tf.train.Feature": + import pyarrow as pa + import tensorflow as tf + + if isinstance(value, pa.ListScalar): + # Use the underlying type of the ListScalar's value in + # determining the output feature's data type. + value_type = value.type.value_type + value = value.as_py() + else: + value_type = value.type + value = value.as_py() + if value is None: + value = [] + else: + value = [value] + + underlying_value_type = { + "bytes": pa.types.is_binary(value_type), + "string": pa.types.is_string(value_type), + "float": pa.types.is_floating(value_type), + "int": pa.types.is_integer(value_type), + } + assert sum(bool(value) for value in underlying_value_type.values()) <= 1 + + if schema_feature_type is not None: + try: + from tensorflow_metadata.proto.v0 import schema_pb2 + except ModuleNotFoundError: + raise ModuleNotFoundError( + "To use TensorFlow schemas, please install " + "the tensorflow-metadata package." + ) + specified_feature_type = { + "bytes": schema_feature_type == schema_pb2.FeatureType.BYTES + and not underlying_value_type["string"], + "string": schema_feature_type == schema_pb2.FeatureType.BYTES + and underlying_value_type["string"], + "float": schema_feature_type == schema_pb2.FeatureType.FLOAT, + "int": schema_feature_type == schema_pb2.FeatureType.INT, + } + + und_type = _get_single_true_type(underlying_value_type) + spec_type = _get_single_true_type(specified_feature_type) + if und_type is not None and und_type != spec_type: + raise ValueError( + "Schema field type mismatch during write: specified type is " + f"{spec_type}, but underlying type is {und_type}", + ) + # Override the underlying value type with the type in the user-specified schema. + underlying_value_type = specified_feature_type + + if underlying_value_type["int"]: + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + if underlying_value_type["float"]: + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + if underlying_value_type["bytes"]: + return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) + if underlying_value_type["string"]: + value = [v.encode() for v in value] # casting to bytes + return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) + if pa.types.is_null(value_type): + raise ValueError( + "Unable to infer type from partially missing column. " + "Try setting read parallelism = 1, or use an input data source which " + "explicitly specifies the schema." + ) + raise ValueError( + f"Value is of type {value_type}, " + "which we cannot convert to a supported tf.train.Feature storage type " + "(bytes, float, or int)." + ) + + +# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501 +# +# MIT License +# +# Copyright (c) 2020 Vahid Kazemi +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +def _write_record( + file: "pyarrow.NativeFile", + example: "tf.train.Example", +) -> None: + record = example.SerializeToString() + length = len(record) + length_bytes = struct.pack(" bytes: + """CRC checksum.""" + import crc32c + + mask = 0xA282EAD8 + crc = crc32c.crc32(data) + masked = ((crc >> 15) | (crc << 17)) + mask + masked = np.uint32(masked & np.iinfo(np.uint32).max) + masked_bytes = struct.pack(" Iterator[Block]: + if self._tfx_read_options: + yield from self._tfx_read_stream(f, path) + else: + yield from self._default_read_stream(f, path) + + def _default_read_stream( + self, f: "pyarrow.NativeFile", path: str + ) -> Iterator[Block]: + import tensorflow as tf + from google.protobuf.message import DecodeError + + for record in _read_records(f, path): + example = tf.train.Example() + try: + example.ParseFromString(record) + except DecodeError as e: + raise ValueError( + "`TFRecordDatasource` failed to parse `tf.train.Example` " + f"record in '{path}'. This error can occur if your TFRecord " + f"file contains a message type other than `tf.train.Example`: {e}" + ) + + yield pyarrow_table_from_pydict( + _convert_example_to_dict(example, self._tf_schema) + ) + + def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + import tensorflow as tf + from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder + + full_path = self._resolve_full_path(path) + + compression = (self._open_stream_args or {}).get("compression", None) + + if compression: + compression = compression.upper() + + tf_schema_string = ( + self._tf_schema.SerializeToString() if self._tf_schema else None + ) + + decoder = ExamplesToRecordBatchDecoder(tf_schema_string) + exception_thrown = None + try: + for record in tf.data.TFRecordDataset( + full_path, compression_type=compression + ).batch(self._tfx_read_options.batch_size): + yield _cast_large_list_to_list( + pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())]) + ) + except Exception as error: + logger.exception(f"Failed to read TFRecord file {full_path}") + exception_thrown = error + + # we need to do this hack were we raise an exception outside of the + # except block because tensorflow DataLossError is unpickable, and + # even if we raise a runtime error, ray keeps information about the + # original error, which makes it unpickable still. + if exception_thrown: + raise RuntimeError(f"Failed to read TFRecord file {full_path}.") + + def _resolve_full_path(self, relative_path): + if isinstance(self._filesystem, pyarrow.fs.S3FileSystem): + return f"s3://{relative_path}" + if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem): + return f"gs://{relative_path}" + if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem): + return f"hdfs:///{relative_path}" + if isinstance(self._filesystem, pyarrow.fs.PyFileSystem): + protocol = self._filesystem.handler.fs.protocol + if isinstance(protocol, list) or isinstance(protocol, tuple): + protocol = protocol[0] + if protocol == "gcs": + protocol = "gs" + return f"{protocol}://{relative_path}" + + return relative_path + + +def _convert_example_to_dict( + example: "tf.train.Example", + tf_schema: Optional["schema_pb2.Schema"], +) -> Dict[str, pyarrow.Array]: + record = {} + schema_dict = {} + # Convert user-specified schema into dict for convenient mapping + if tf_schema is not None: + for schema_feature in tf_schema.feature: + schema_dict[schema_feature.name] = schema_feature.type + + for feature_name, feature in example.features.feature.items(): + if tf_schema is not None and feature_name not in schema_dict: + raise ValueError( + f"Found extra unexpected feature {feature_name} " + f"not in specified schema: {tf_schema}" + ) + schema_feature_type = schema_dict.get(feature_name) + record[feature_name] = _get_feature_value(feature, schema_feature_type) + return record + + +def _get_single_true_type(dct) -> str: + """Utility function for getting the single key which has a `True` value in + a dict. Used to filter a dict of `{field_type: is_valid}` to get + the field type from a schema or data source.""" + filtered_types = iter([_type for _type in dct if dct[_type]]) + # In the case where there are no keys with a `True` value, return `None` + return next(filtered_types, None) + + +def _get_feature_value( + feature: "tf.train.Feature", + schema_feature_type: Optional["schema_pb2.FeatureType"] = None, +) -> pyarrow.Array: + import pyarrow as pa + + underlying_feature_type = { + "bytes": feature.HasField("bytes_list"), + "float": feature.HasField("float_list"), + "int": feature.HasField("int64_list"), + } + # At most one of `bytes_list`, `float_list`, and `int64_list` + # should contain values. If none contain data, this indicates + # an empty feature value. + assert sum(bool(value) for value in underlying_feature_type.values()) <= 1 + + if schema_feature_type is not None: + try: + from tensorflow_metadata.proto.v0 import schema_pb2 + except ModuleNotFoundError: + raise ModuleNotFoundError( + "To use TensorFlow schemas, please install " + "the tensorflow-metadata package." + ) + # If a schema is specified, compare to the underlying type + specified_feature_type = { + "bytes": schema_feature_type == schema_pb2.FeatureType.BYTES, + "float": schema_feature_type == schema_pb2.FeatureType.FLOAT, + "int": schema_feature_type == schema_pb2.FeatureType.INT, + } + und_type = _get_single_true_type(underlying_feature_type) + spec_type = _get_single_true_type(specified_feature_type) + if und_type is not None and und_type != spec_type: + raise ValueError( + "Schema field type mismatch during read: specified type is " + f"{spec_type}, but underlying type is {und_type}", + ) + # Override the underlying value type with the type in the user-specified schema. + underlying_feature_type = specified_feature_type + + if underlying_feature_type["bytes"]: + value = feature.bytes_list.value + type_ = pa.binary() + elif underlying_feature_type["float"]: + value = feature.float_list.value + type_ = pa.float32() + elif underlying_feature_type["int"]: + value = feature.int64_list.value + type_ = pa.int64() + else: + value = [] + type_ = pa.null() + value = list(value) + if len(value) == 1 and schema_feature_type is None: + # Use the value itself if the features contains a single value. + # This is to give better user experience when writing preprocessing UDF on + # these single-value lists. + value = value[0] + else: + # If the feature value is empty and no type is specified in the user-provided + # schema, set the type to null for now to allow pyarrow to construct a valid + # Array; later, infer the type from other records which have non-empty values + # for the feature. + if len(value) == 0: + type_ = pa.null() + type_ = pa.list_(type_) + return pa.array([value], type=type_) + + +# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/reader.py#L16-L96 # noqa: E501 +# +# MIT License +# +# Copyright (c) 2020 Vahid Kazemi +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +def _read_records( + file: "pyarrow.NativeFile", + path: str, +) -> Iterable[memoryview]: + """ + Read records from TFRecord file. + + A TFRecord file contains a sequence of records. The file can only be read + sequentially. Each record is stored in the following formats: + uint64 length + uint32 masked_crc32_of_length + byte data[length] + uint32 masked_crc32_of_data + + See https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecords_format_details + for more details. + """ + length_bytes = bytearray(8) + crc_bytes = bytearray(4) + datum_bytes = bytearray(1024 * 1024) + row_count = 0 + while True: + try: + # Read "length" field. + num_length_bytes_read = file.readinto(length_bytes) + if num_length_bytes_read == 0: + break + elif num_length_bytes_read != 8: + raise ValueError( + "Failed to read the length of record data. Expected 8 bytes but " + "got {num_length_bytes_read} bytes." + ) + + # Read "masked_crc32_of_length" field. + num_length_crc_bytes_read = file.readinto(crc_bytes) + if num_length_crc_bytes_read != 4: + raise ValueError( + "Failed to read the length of CRC-32C hashes. Expected 4 bytes " + "but got {num_length_crc_bytes_read} bytes." + ) + + # Read "data[length]" field. + (data_length,) = struct.unpack(" len(datum_bytes): + datum_bytes = datum_bytes.zfill(int(data_length * 1.5)) + datum_bytes_view = memoryview(datum_bytes)[:data_length] + num_datum_bytes_read = file.readinto(datum_bytes_view) + if num_datum_bytes_read != data_length: + raise ValueError( + f"Failed to read the record. Exepcted {data_length} bytes but got " + f"{num_datum_bytes_read} bytes." + ) + + # Read "masked_crc32_of_data" field. + # TODO(chengsu): ideally we should check CRC-32C against the actual data. + num_crc_bytes_read = file.readinto(crc_bytes) + if num_crc_bytes_read != 4: + raise ValueError( + "Failed to read the CRC-32C hashes. Expected 4 bytes but got " + f"{num_crc_bytes_read} bytes." + ) + + # Return the data. + yield datum_bytes_view + + row_count += 1 + data_length = None + except Exception as e: + error_message = ( + f"Failed to read TFRecord file {path}. Please ensure that the " + f"TFRecord file has correct format. Already read {row_count} rows." + ) + if data_length is not None: + error_message += f" Byte size of current record data is {data_length}." + raise RuntimeError(error_message) from e + + +def _cast_large_list_to_list(batch: pyarrow.Table): + """ + This function transform pyarrow.large_list into list and pyarrow.large_binary into + pyarrow.binary so that all types resulting from the tfrecord_datasource are usable + with dataset.to_tf(). + """ + old_schema = batch.schema + fields = {} + + for column_name in old_schema.names: + field_type = old_schema.field(column_name).type + if type(field_type) is pyarrow.lib.LargeListType: + value_type = field_type.value_type + + if value_type == pyarrow.large_binary(): + value_type = pyarrow.binary() + + fields[column_name] = pyarrow.list_(value_type) + elif field_type == pyarrow.large_binary(): + fields[column_name] = pyarrow.binary() + else: + fields[column_name] = old_schema.field(column_name) + + new_schema = pyarrow.schema(fields) + return batch.cast(new_schema) + + +def _infer_schema_and_transform(dataset: "Dataset"): + list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names)) + + return dataset.map_batches( + _unwrap_single_value_lists, + fn_kwargs={"col_lengths": list_sizes["max_list_size"]}, + batch_format="pyarrow", + ) + + +def _unwrap_single_value_lists(batch: pyarrow.Table, col_lengths: Dict[str, int]): + """ + This function will transfrom the dataset converting list types that always + contain single values to thery underlying data type + (i.e. pyarrow.int64() and pyarrow.float64()) + """ + columns = {} + + for col in col_lengths: + value_type = batch[col].type.value_type + + if col_lengths[col] == 1: + if batch[col]: + columns[col] = pyarrow.array( + [x.as_py()[0] if x.as_py() else None for x in batch[col]], + type=value_type, + ) + else: + columns[col] = batch[col] + + return pyarrow.table(columns) + + +class _MaxListSize(AggregateFn): + def __init__(self, columns: List[str]): + self._columns = columns + super().__init__( + init=self._init, + merge=self._merge, + accumulate_row=self._accumulate_row, + finalize=lambda a: a, + name="max_list_size", + ) + + def _init(self, k: str): + return {col: 0 for col in self._columns} + + def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]): + merged = {} + for col in self._columns: + merged[col] = max(acc1[col], acc2[col]) + + return merged + + def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"): + for k in row: + value = row[k] + if value: + acc[k] = max(len(value), acc[k]) + + return acc diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/torch_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/torch_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..3c90a6cf7107665ee48c888ce479d89e05abcdc3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/torch_datasource.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data.block import BlockMetadata +from ray.data.datasource.datasource import Datasource, ReadTask + +if TYPE_CHECKING: + import torch + + +TORCH_DATASOURCE_READER_BATCH_SIZE = 32 + + +class TorchDatasource(Datasource): + """Torch datasource, for reading from `Torch + datasets `_. + This datasource implements a streaming read using a single read task. + """ + + def __init__( + self, + dataset: "torch.utils.data.Dataset", + ): + self._dataset = dataset + + def get_read_tasks(self, parallelism): + assert parallelism == 1 + + meta = BlockMetadata( + num_rows=len(self._dataset), + size_bytes=None, + schema=None, + input_files=None, + exec_stats=None, + ) + read_task = ReadTask( + lambda subset=self._dataset: _read_subset( + subset, + ), + metadata=meta, + ) + + return [read_task] + + def estimate_inmemory_data_size(self): + return None + + +def _read_subset(subset: "torch.utils.data.Subset"): + batch = [] + for item in subset: + batch.append(item) + if len(batch) == TORCH_DATASOURCE_READER_BATCH_SIZE: + builder = DelegatingBlockBuilder() + builder.add_batch({"item": batch}) + yield builder.build() + batch.clear() + + if len(batch) > 0: + builder = DelegatingBlockBuilder() + builder.add_batch({"item": batch}) + yield builder.build() diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/video_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/video_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4a06e876e214b2f0e7820b7971b347e7eed43c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/video_datasource.py @@ -0,0 +1,59 @@ +import logging +from typing import TYPE_CHECKING, List, Union + +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.util import _check_import +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + +logger = logging.getLogger(__name__) + + +class VideoDatasource(FileBasedDatasource): + _FILE_EXTENSIONS = [ + "mp4", + "mkv", + "mov", + "avi", + "wmv", + "flv", + "webm", + "m4v", + "3gp", + "mpeg", + "mpg", + "ts", + "ogv", + "rm", + "rmvb", + "vob", + "asf", + "f4v", + "m2ts", + "mts", + "divx", + "xvid", + "mxf", + ] + + def __init__( + self, + paths: Union[str, List[str]], + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + _check_import(self, module="decord", package="decord") + + def _read_stream(self, f: "pyarrow.NativeFile", path: str): + from decord import VideoReader + + reader = VideoReader(f) + + for frame_index, frame in enumerate(reader): + item = {"frame": frame.asnumpy(), "frame_index": frame_index} + builder = DelegatingBlockBuilder() + builder.add(item) + yield builder.build() diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasink.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..57e44600c63cf931ba11a32f65a80af2615155af --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasink.py @@ -0,0 +1,53 @@ +import io +import tarfile +import time +import uuid +from typing import Optional, Union + +import pyarrow + +from ray.data._internal.datasource.webdataset_datasource import ( + _apply_list, + _default_encoder, + _make_iterable, +) +from ray.data.block import BlockAccessor +from ray.data.datasource.file_datasink import BlockBasedFileDatasink + + +class WebDatasetDatasink(BlockBasedFileDatasink): + def __init__( + self, + path: str, + encoder: Optional[Union[bool, str, callable, list]] = True, + *, + file_format: str = "tar", + **file_datasink_kwargs, + ): + super().__init__(path, file_format="tar", **file_datasink_kwargs) + + self.encoder = encoder + + def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): + stream = tarfile.open(fileobj=file, mode="w|") + samples = _make_iterable(block) + for sample in samples: + if not isinstance(sample, dict): + sample = sample.as_pydict() + if self.encoder is not None: + sample = _apply_list(self.encoder, sample, default=_default_encoder) + if "__key__" not in sample: + sample["__key__"] = uuid.uuid4().hex + key = sample["__key__"] + for k, v in sample.items(): + if v is None or k.startswith("__"): + continue + assert isinstance(v, bytes) or isinstance(v, str) + if not isinstance(v, bytes): + v = v.encode("utf-8") + ti = tarfile.TarInfo(f"{key}.{k}") + ti.size = len(v) + ti.mtime = time.time() + ti.mode, ti.uname, ti.gname = 0o644, "data", "data" + stream.addfile(ti, io.BytesIO(v)) + stream.close() diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5661647b04784e91e84c0b7e2c463d643f3059 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/datasource/webdataset_datasource.py @@ -0,0 +1,385 @@ +# Copyright NVIDIA Corporation 2023 +# SPDX-License-Identifier: Apache-2.0 + +import fnmatch +import io +import json +import re +import tarfile +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import ray +from ray.data._internal.util import iterate_with_retry +from ray.data.block import BlockAccessor +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +def _base_plus_ext(path: str): + """Split off all file extensions. + + Returns base, allext. + + Args: + path: path with extensions + + Returns: + str: path with all extensions removed + """ + match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) + if not match: + return None, None + return match.group(1), match.group(2) + + +def _valid_sample(sample: Dict[str, Any]): + """Check whether a sample is valid. + + Args: + sample: sample to be checked + """ + return ( + sample is not None + and isinstance(sample, dict) + and len(list(sample.keys())) > 0 + and not sample.get("__bad__", False) + ) + + +def _apply_list( + f: Union[Callable, List[Callable]], sample: Dict[str, Any], default: Callable = None +): + """Apply a list of functions to a sample. + + Args: + f: function or list of functions + sample: sample to be modified + default: default function to be applied to all keys. + Defaults to None. + + Returns: + modified sample + """ + if f is None: + return sample + if not isinstance(f, list): + f = [f] + for g in f: + if default is not None and not callable(g): + g = partial(default, format=g) + sample = g(sample) + return sample + + +def _check_suffix(suffix: str, suffixes: Union[list, callable]): + """Check whether a suffix is valid. + + Suffixes can be either None (=accept everything), a callable, + or a list of patterns. If the pattern contains */? it is treated + as a glob pattern, otherwise it is treated as a literal. + + Args: + suffix: suffix to be checked + suffixes: list of valid suffixes + """ + if suffixes is None: + return True + if callable(suffixes): + return suffixes(suffix) + for pattern in suffixes: + if "*" in pattern or "?" in pattern: + if fnmatch.fnmatch("." + suffix, pattern): + return True + elif suffix == pattern or "." + suffix == pattern: + return True + return False + + +def _tar_file_iterator( + fileobj: Any, + fileselect: Optional[Union[bool, callable, list]] = None, + filerename: Optional[Union[bool, callable, list]] = None, + verbose_open: bool = False, + meta: dict = None, +): + """Iterate over tar file, yielding filename, content pairs for the given tar stream. + + Args: + fileobj: file object + fileselect: patterns or function selecting + files to be selected + meta: metadata to be added to each sample + """ + meta = meta or {} + stream = tarfile.open(fileobj=fileobj, mode="r|*") + if verbose_open: + print(f"start {meta}") + for tarinfo in stream: + fname = tarinfo.name + if not tarinfo.isreg() or fname is None: + continue + data = stream.extractfile(tarinfo).read() + fname = _apply_list(filerename, fname) + assert isinstance(fname, str) + if not _check_suffix(fname, fileselect): + continue + result = dict(fname=fname, data=data) + yield result + if verbose_open: + print(f"done {meta}") + + +def _group_by_keys( + data: List[Dict[str, Any]], + keys: callable = _base_plus_ext, + suffixes: Optional[Union[list, callable]] = None, + meta: dict = None, +): + """Return function over iterator that groups key, value pairs into samples. + + Args: + data: iterator over key, value pairs + keys: function that returns key, suffix for a given key + suffixes: list of suffixes to be included in the sample + meta: metadata to be added to each sample + """ + meta = meta or {} + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if current_sample is None or prefix != current_sample["__key__"]: + if _valid_sample(current_sample): + current_sample.update(meta) + yield current_sample + current_sample = dict(__key__=prefix) + if "__url__" in filesample: + current_sample["__url__"] = filesample["__url__"] + if suffix in current_sample: + raise ValueError( + f"{fname}: duplicate file name in tar file " + + f"{suffix} {current_sample.keys()}, tar is {meta['__url__']}" + ) + if suffixes is None or _check_suffix(suffix, suffixes): + current_sample[suffix] = value + if _valid_sample(current_sample): + current_sample.update(meta) + yield current_sample + + +def _default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True): + """A default decoder for webdataset. + + This handles common file extensions: .txt, .cls, .cls2, + .jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl. + These are the most common extensions used in webdataset. + For other extensions, users can provide their own decoder. + + Args: + sample: sample, modified in place + """ + sample = dict(sample) + for key, value in sample.items(): + extension = key.split(".")[-1] + if key.startswith("__"): + continue + elif extension in ["txt", "text"]: + sample[key] = value.decode("utf-8") + elif extension in ["cls", "cls2"]: + sample[key] = int(value.decode("utf-8")) + elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]: + import numpy as np + import PIL.Image + + if format == "PIL": + sample[key] = PIL.Image.open(io.BytesIO(value)) + else: + sample[key] = np.asarray(PIL.Image.open(io.BytesIO(value))) + elif extension == "json": + sample[key] = json.loads(value) + elif extension == "npy": + import numpy as np + + sample[key] = np.load(io.BytesIO(value)) + elif extension == "mp": + import msgpack + + sample[key] = msgpack.unpackb(value, raw=False) + elif extension in ["pt", "pth"]: + import torch + + sample[key] = torch.load(io.BytesIO(value)) + elif extension in ["pickle", "pkl"]: + import pickle + + sample[key] = pickle.loads(value) + return sample + + +extension_to_format = {"jpg": "jpeg"} + + +def _default_encoder(sample: Dict[str, Any], format: Optional[Union[str, bool]] = True): + """A default encoder for webdataset. + + This handles common file extensions: .txt, .cls, .cls2, .jpg, + .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl + These are the most common extensions used in webdataset. + For other extensions, users can provide their own encoder. + + Args: + sample (Dict[str, Any]): sample + """ + sample = dict(sample) + for key, value in sample.items(): + extension = key.split(".")[-1] + if key.startswith("__"): + continue + elif extension in ["txt"]: + sample[key] = value.encode("utf-8") + elif extension in ["cls", "cls2"]: + sample[key] = str(value).encode("utf-8") + elif extension in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]: + import numpy as np + import PIL.Image + + if isinstance(value, np.ndarray): + value = PIL.Image.fromarray(value) + assert isinstance(value, PIL.Image.Image) + stream = io.BytesIO() + value.save( + stream, format=extension_to_format.get(extension.lower(), extension) + ) + sample[key] = stream.getvalue() + elif extension == "json": + sample[key] = json.dumps(value).encode("utf-8") + elif extension == "npy": + import numpy as np + + stream = io.BytesIO() + np.save(stream, value) + sample[key] = stream.getvalue() + elif extension == "mp": + import msgpack + + sample[key] = msgpack.dumps(value) + elif extension in ["pt", "pth"]: + import torch + + stream = io.BytesIO() + torch.save(value, stream) + sample[key] = stream.getvalue() + elif extension in ["pickle", "pkl"]: + import pickle + + stream = io.BytesIO() + pickle.dump(value, stream) + sample[key] = stream.getvalue() + return sample + + +def _make_iterable(block: BlockAccessor): + """Make a block iterable. + + This is a placeholder for dealing with more complex blocks. + + Args: + block: Ray Dataset block + + Returns: + Iterable[Dict[str,Any]]: Iterable of samples + """ + return block.iter_rows(public_row_format=False) + + +class WebDatasetDatasource(FileBasedDatasource): + """A Datasource for WebDataset datasets (tar format with naming conventions).""" + + _FILE_EXTENSIONS = ["tar"] + + def __init__( + self, + paths: Union[str, List[str]], + decoder: Optional[Union[bool, str, callable, list]] = True, + fileselect: Optional[Union[bool, callable, list]] = None, + filerename: Optional[Union[bool, callable, list]] = None, + suffixes: Optional[Union[bool, callable, list]] = None, + verbose_open: bool = False, + expand_json: bool = False, + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + self.decoder = decoder + self.fileselect = fileselect + self.filerename = filerename + self.suffixes = suffixes + self.verbose_open = verbose_open + self.expand_json = expand_json + + def _read_stream(self, stream: "pyarrow.NativeFile", path: str): + """Read and decode samples from a stream. + + Note that fileselect selects files during reading, while suffixes + selects files during the grouping step. + + Args: + stream: File descriptor to read from. + path: Path to the data. + decoder: decoder or list of decoders to be applied to samples + fileselect: Predicate for skipping files in tar decoder. + Defaults to lambda_:False. + suffixes: List of suffixes to be extracted. Defaults to None. + verbose_open: Print message when opening files. Defaults to False. + + Yields: + List[Dict[str, Any]]: List of sample (list of length 1). + """ + + import pandas as pd + + def get_tar_file_iterator(): + return _tar_file_iterator( + stream, + fileselect=self.fileselect, + filerename=self.filerename, + verbose_open=self.verbose_open, + ) + + # S3 can raise transient errors during iteration + ctx = ray.data.DataContext.get_current() + files = iterate_with_retry( + get_tar_file_iterator, "iterate tar file", match=ctx.retried_io_errors + ) + + samples = _group_by_keys(files, meta=dict(__url__=path), suffixes=self.suffixes) + for sample in samples: + if self.decoder is not None: + sample = _apply_list(self.decoder, sample, default=_default_decoder) + if self.expand_json: + if isinstance(sample["json"], bytes): + parsed_json = json.loads(sample["json"].decode("utf-8")) + elif isinstance(sample["json"], str): + parsed_json = json.loads(sample["json"]) + elif isinstance(sample["json"], dict): + parsed_json = sample["json"] + else: + raise TypeError( + f"Unsupported data type" f" {type(sample['json'])} for sample" + ) + for k, v in parsed_json.items(): + if k not in sample: + sample[k] = [] + sample[k].append(v) + yield pd.DataFrame( + { + k: v if isinstance(v, list) and len(v) == 1 else [v] + for k, v in sample.items() + } + ) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__init__.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..614c9f0221859c776fc12ede6b630d38152a04db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/iterator_impl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/iterator_impl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60c069819656eff8528f5106ea569e544edb5faa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/iterator_impl.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/stream_split_iterator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/stream_split_iterator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..537594eb9a3394e30577d67cda67b5f554b2a70d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/__pycache__/stream_split_iterator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/iterator_impl.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/iterator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..e08a0eb2eaafb88fee99286312a19ab17d279c74 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/iterator_impl.py @@ -0,0 +1,41 @@ +from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union + +from ray.data._internal.execution.interfaces.ref_bundle import RefBundle +from ray.data._internal.stats import DatasetStats +from ray.data._internal.util import create_dataset_tag +from ray.data.iterator import DataIterator + +if TYPE_CHECKING: + import pyarrow + + from ray.data import Dataset + + +class DataIteratorImpl(DataIterator): + def __init__( + self, + base_dataset: "Dataset", + ): + self._base_dataset = base_dataset + + def __repr__(self) -> str: + return f"DataIterator({self._base_dataset})" + + def _to_ref_bundle_iterator( + self, + ) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]: + ds = self._base_dataset + ref_bundles_iterator, stats, executor = ds._plan.execute_to_iterator() + ds._current_executor = executor + return ref_bundles_iterator, stats, False + + def stats(self) -> str: + return self._base_dataset.stats() + + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + return self._base_dataset.schema() + + def _get_dataset_tag(self): + return create_dataset_tag( + self._base_dataset._plan._dataset_name, self._base_dataset._uuid + ) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/stream_split_iterator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/stream_split_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca875905ad3b17023c522bbacc4a2795ad83c2e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/iterator/stream_split_iterator.py @@ -0,0 +1,290 @@ +import logging +import threading +import time +from dataclasses import replace +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union + +import ray +from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle +from ray.data._internal.execution.legacy_compat import execute_to_legacy_bundle_iterator +from ray.data._internal.execution.operators.output_splitter import OutputSplitter +from ray.data._internal.execution.streaming_executor import StreamingExecutor +from ray.data._internal.stats import DatasetStats +from ray.data._internal.util import create_dataset_tag +from ray.data.block import Block, BlockMetadata +from ray.data.iterator import DataIterator +from ray.types import ObjectRef +from ray.util.debug import log_once +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +if TYPE_CHECKING: + import pyarrow + + from ray.data import Dataset + +logger = logging.getLogger(__name__) + + +BLOCKED_CLIENT_WARN_TIMEOUT = 30 + + +class StreamSplitDataIterator(DataIterator): + """Implements a collection of iterators over a shared data stream.""" + + @staticmethod + def create( + base_dataset: "Dataset", + n: int, + equal: bool, + locality_hints: Optional[List[NodeIdStr]], + ) -> List["StreamSplitDataIterator"]: + """Create a split iterator from the given base Dataset and options. + + See also: `Dataset.streaming_split`. + """ + # To avoid deadlock, the concurrency on this actor must be set to at least `n`. + coord_actor = SplitCoordinator.options( + max_concurrency=n, + scheduling_strategy=NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), soft=False + ), + ).remote(base_dataset, n, equal, locality_hints) + + return [ + StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n) + ] + + def __init__( + self, + base_dataset: "Dataset", + coord_actor: ray.actor.ActorHandle, + output_split_idx: int, + world_size: int, + ): + self._base_dataset = base_dataset + self._coord_actor = coord_actor + self._output_split_idx = output_split_idx + self._world_size = world_size + self._iter_stats = DatasetStats(metadata={}, parent=None) + + def _to_ref_bundle_iterator( + self, + ) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]: + def gen_blocks() -> Iterator[RefBundle]: + cur_epoch = ray.get( + self._coord_actor.start_epoch.remote(self._output_split_idx) + ) + future: ObjectRef[ + Optional[ObjectRef[Block]] + ] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx) + while True: + block_ref_and_md: Optional[ + Tuple[ObjectRef[Block], BlockMetadata] + ] = ray.get(future) + if not block_ref_and_md: + break + else: + future = self._coord_actor.get.remote( + cur_epoch, self._output_split_idx + ) + yield RefBundle(blocks=(block_ref_and_md,), owns_blocks=False) + + return gen_blocks(), self._iter_stats, False + + def stats(self) -> str: + """Implements DataIterator.""" + # Merge the locally recorded iter stats and the remotely recorded + # stream execution stats. + stats = ray.get(self._coord_actor.stats.remote()) + summary = stats.to_summary() + summary.iter_stats = self._iter_stats.to_summary().iter_stats + summary.iter_stats.streaming_split_coord_time.add( + stats.streaming_split_coordinator_s.get() + ) + return summary.to_string() + + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + """Implements DataIterator.""" + return self._base_dataset.schema() + + def world_size(self) -> int: + """Returns the number of splits total.""" + return self._world_size + + def _get_dataset_tag(self): + return create_dataset_tag( + self._base_dataset._plan._dataset_name, + self._base_dataset._uuid, + self._output_split_idx, + ) + + +@ray.remote(num_cpus=0) +class SplitCoordinator: + """Coordinator actor for routing blocks to output splits. + + This actor runs a streaming executor locally on its main thread. Clients can + retrieve results via actor calls running on other threads. + """ + + def __init__( + self, + dataset: "Dataset", + n: int, + equal: bool, + locality_hints: Optional[List[NodeIdStr]], + ): + # Set current DataContext. + self._data_context = dataset.context + ray.data.DataContext._set_current(self._data_context) + # Automatically set locality with output to the specified location hints. + if locality_hints: + self._data_context.execution_options.locality_with_output = locality_hints + logger.info(f"Auto configuring locality_with_output={locality_hints}") + + self._base_dataset = dataset + self._n = n + self._equal = equal + self._locality_hints = locality_hints + self._lock = threading.RLock() + self._executor = None + + # Guarded by self._lock. + self._next_bundle: Dict[int, RefBundle] = {} + self._unfinished_clients_in_epoch = n + self._cur_epoch = -1 + + def gen_epochs(): + while True: + executor = StreamingExecutor( + self._data_context, + create_dataset_tag( + self._base_dataset._name, self._base_dataset._uuid + ), + ) + self._executor = executor + + def add_split_op(dag): + return OutputSplitter( + dag, + n, + equal, + self._data_context, + locality_hints, + ) + + output_iterator = execute_to_legacy_bundle_iterator( + executor, + dataset._plan, + dag_rewrite=add_split_op, + ) + yield output_iterator + + self._next_epoch = gen_epochs() + self._output_iterator = None + # Store the error raised from the `gen_epoch` call. + self._gen_epoch_error: Optional[Exception] = None + + def stats(self) -> DatasetStats: + """Returns stats from the base dataset.""" + if self._executor: + return self._executor.get_stats() + return self._base_dataset._plan.stats() + + def start_epoch(self, split_idx: int) -> str: + """Called to start an epoch. + + Returns: + UUID for the epoch, which must be used when accessing results via get(). + """ + + # Wait for all clients to arrive at the barrier before starting a new epoch. + epoch_id = self._barrier(split_idx) + return epoch_id + + def get( + self, epoch_id: int, output_split_idx: int + ) -> Optional[Tuple[ObjectRef[Block], BlockMetadata]]: + """Blocking get operation. + + This is intended to be called concurrently from multiple clients. + """ + start_time = time.perf_counter() + if epoch_id != self._cur_epoch: + raise ValueError( + "Invalid iterator: the dataset has moved on to another epoch." + ) + + try: + # Ensure there is at least one bundle. + with self._lock: + if output_split_idx in self._next_bundle: + next_bundle = self._next_bundle[output_split_idx] + else: + next_bundle = None + + # Fetch next bundle if needed. + while next_bundle is None or not next_bundle.blocks: + # This is a BLOCKING call, so do it outside the lock. + next_bundle = self._output_iterator.get_next(output_split_idx) + + block = next_bundle.blocks[-1] + next_bundle = replace(next_bundle, blocks=next_bundle.blocks[:-1]) + + # Accumulate any remaining blocks in next_bundle map as needed. + with self._lock: + self._next_bundle[output_split_idx] = next_bundle + if not next_bundle.blocks: + del self._next_bundle[output_split_idx] + + return block + except StopIteration: + return None + finally: + stats = self.stats() + if stats and stats.streaming_split_coordinator_s: + stats.streaming_split_coordinator_s.add( + time.perf_counter() - start_time + ) + + def _barrier(self, split_idx: int) -> int: + """Arrive and block until the start of the given epoch.""" + + # Decrement and await all clients to arrive here. + with self._lock: + starting_epoch = self._cur_epoch + self._unfinished_clients_in_epoch -= 1 + + start_time = time.time() + while ( + self._cur_epoch == starting_epoch and self._unfinished_clients_in_epoch != 0 + ): + if time.time() - start_time > BLOCKED_CLIENT_WARN_TIMEOUT: + if log_once(f"stream_split_blocked_{split_idx}_{starting_epoch}"): + logger.warning( + f"StreamSplitDataIterator(epoch={starting_epoch}, " + f"split={split_idx}) blocked waiting on other clients " + f"for more than {BLOCKED_CLIENT_WARN_TIMEOUT}s. All " + "clients must read from the DataIterator splits at " + "the same time. This warning will not be printed again " + "for this epoch." + ) + time.sleep(0.1) + + # Advance to the next epoch. + with self._lock: + if self._cur_epoch == starting_epoch: + self._cur_epoch += 1 + self._unfinished_clients_in_epoch = self._n + try: + self._output_iterator = next(self._next_epoch) + except Exception as e: + self._gen_epoch_error = e + + if self._gen_epoch_error is not None: + # If there was an error when advancing to the next epoch, + # re-raise it for all threads. + raise self._gen_epoch_error + + assert self._output_iterator is not None + return starting_epoch + 1 diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__init__.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1b255693265a89c211c195a0c404ddd61613369 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/optimizers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/optimizers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b9e49946ff23aa7896b9ada91a16e29db136a30 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/optimizers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__init__.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92822490b22d6e5685260b7de96a38e11a694181 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__init__.py @@ -0,0 +1,16 @@ +from .logical_operator import LogicalOperator +from .logical_plan import LogicalPlan +from .operator import Operator +from .optimizer import Optimizer, Rule +from .physical_plan import PhysicalPlan +from .plan import Plan + +__all__ = [ + "LogicalOperator", + "LogicalPlan", + "Operator", + "Optimizer", + "PhysicalPlan", + "Plan", + "Rule", +] diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41e165aff866020ef516731b6f0569f09b9cc925 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/logical_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/logical_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b230ae40f6c724aa427cad42675fec8c28c096 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/logical_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/logical_plan.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/logical_plan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a90c77172286fefb15ebefbc68574b2e51c0febc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/logical_plan.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d466a1d0829348c66ed444edcc6840d5e9abb0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/optimizer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/optimizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff705d8db9ad4128637ea9ac0ef7e0ccf1cb0956 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/optimizer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/physical_plan.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/physical_plan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89fcf57d920bff2427c2bf95b0e389636df47b1b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/physical_plan.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/plan.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/plan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..070c5c8364e6fa7e695bc9214de00453b4f1cfa8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/__pycache__/plan.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/logical_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/logical_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..84535706cd5080061b88aa675944dc80dcdc32bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/logical_operator.py @@ -0,0 +1,79 @@ +from typing import TYPE_CHECKING, Iterator, List, Optional + +from .operator import Operator +from ray.data.block import BlockMetadata + +if TYPE_CHECKING: + from ray.data._internal.execution.interfaces import RefBundle + + +class LogicalOperator(Operator): + """Abstract class for logical operators. + + A logical operator describes transformation, and later is converted into + physical operator. + """ + + def __init__( + self, + name: str, + input_dependencies: List["LogicalOperator"], + num_outputs: Optional[int] = None, + ): + super().__init__( + name, + input_dependencies, + ) + for x in input_dependencies: + assert isinstance(x, LogicalOperator), x + self._num_outputs = num_outputs + + def estimated_num_outputs(self) -> Optional[int]: + """Returns the estimated number of blocks that + would be outputted by this logical operator. + + This method does not execute the plan, so it does not take into consideration + block splitting. This method only considers high-level block constraints like + `Dataset.repartition(num_blocks=X)`. A more accurate estimation can be given by + `PhysicalOperator.num_outputs_total()` during execution. + """ + if self._num_outputs is not None: + return self._num_outputs + elif len(self._input_dependencies) == 1: + return self._input_dependencies[0].estimated_num_outputs() + return None + + # Override the following 3 methods to correct type hints. + + @property + def input_dependencies(self) -> List["LogicalOperator"]: + return super().input_dependencies # type: ignore + + @property + def output_dependencies(self) -> List["LogicalOperator"]: + return super().output_dependencies # type: ignore + + def post_order_iter(self) -> Iterator["LogicalOperator"]: + return super().post_order_iter() # type: ignore + + def output_data(self) -> Optional[List["RefBundle"]]: + """The output data of this operator, or ``None`` if not known.""" + return None + + def aggregate_output_metadata(self) -> BlockMetadata: + """A ``BlockMetadata`` that represents the aggregate metadata of the outputs. + + This method is used by methods like :meth:`~ray.data.Dataset.schema` to + efficiently return metadata. + """ + return BlockMetadata(None, None, None, None, None) + + def is_lineage_serializable(self) -> bool: + """Returns whether the lineage of this operator can be serialized. + + An operator is lineage serializable if you can serialize it on one machine and + deserialize it on another without losing information. Operators that store + object references (e.g., ``InputData``) aren't lineage serializable because the + objects aren't available on the deserialized machine. + """ + return True diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/logical_plan.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/logical_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0196bb440bbdb02ff2b112e2b68d8d2be58088 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/logical_plan.py @@ -0,0 +1,31 @@ +from typing import TYPE_CHECKING, List + +from .logical_operator import LogicalOperator +from .plan import Plan + +if TYPE_CHECKING: + from ray.data import DataContext + + +class LogicalPlan(Plan): + """The plan with a DAG of logical operators.""" + + def __init__(self, dag: LogicalOperator, context: "DataContext"): + super().__init__(context) + self._dag = dag + + @property + def dag(self) -> LogicalOperator: + """Get the DAG of logical operators.""" + return self._dag + + def sources(self) -> List[LogicalOperator]: + """List of operators that are sources for this plan's DAG.""" + # If an operator has no input dependencies, it's a source. + if not any(self._dag.input_dependencies): + return [self._dag] + + sources = [] + for op in self._dag.input_dependencies: + sources.extend(LogicalPlan(op, self._context).sources()) + return sources diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..76a320ef815a23cb319146221d32c7be10e5be52 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/operator.py @@ -0,0 +1,58 @@ +from typing import Iterator, List + + +class Operator: + """Abstract class for operators. + + Operators live on the driver side of the Dataset only. + """ + + def __init__( + self, + name: str, + input_dependencies: List["Operator"], + ): + self._name = name + self._input_dependencies = input_dependencies + self._output_dependencies = [] + for x in input_dependencies: + assert isinstance(x, Operator), x + x._output_dependencies.append(self) + + @property + def name(self) -> str: + return self._name + + @property + def input_dependencies(self) -> List["Operator"]: + """List of operators that provide inputs for this operator.""" + assert hasattr( + self, "_input_dependencies" + ), "Operator.__init__() was not called." + return self._input_dependencies + + @property + def output_dependencies(self) -> List["Operator"]: + """List of operators that consume outputs from this operator.""" + assert hasattr( + self, "_output_dependencies" + ), "Operator.__init__() was not called." + return self._output_dependencies + + def post_order_iter(self) -> Iterator["Operator"]: + """Depth-first traversal of this operator and its input dependencies.""" + for op in self.input_dependencies: + yield from op.post_order_iter() + yield self + + def __repr__(self) -> str: + if self.input_dependencies: + out_str = ", ".join([str(x) for x in self.input_dependencies]) + out_str += " -> " + else: + out_str = "" + out_str += f"{self.__class__.__name__}[{self._name}]" + return out_str + + def __str__(self) -> str: + return repr(self) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/optimizer.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6a25a44afe624b2e3ad6182739f68ed59b5cf720 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/optimizer.py @@ -0,0 +1,29 @@ +from typing import List + +from .plan import Plan + + +class Rule: + """Abstract class for optimization rule.""" + + def apply(self, plan: Plan) -> Plan: + """Apply the optimization rule to the execution plan.""" + raise NotImplementedError + + +class Optimizer: + """Abstract class for optimizers. + + An optimizers transforms a DAG of operators with a list of predefined rules. + """ + + @property + def rules(self) -> List[Rule]: + """List of predefined rules for this optimizer.""" + raise NotImplementedError + + def optimize(self, plan: Plan) -> Plan: + """Optimize operators with a list of rules.""" + for rule in self.rules: + plan = rule.apply(plan) + return plan diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/physical_plan.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/physical_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..29503831db85e7d87f1f044d7c910826fa970515 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/physical_plan.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, Dict + +from .logical_operator import LogicalOperator +from .plan import Plan + +if TYPE_CHECKING: + from ray.data import DataContext + from ray.data._internal.execution.interfaces import PhysicalOperator + + +class PhysicalPlan(Plan): + """The plan with a DAG of physical operators.""" + + def __init__( + self, + dag: "PhysicalOperator", + op_map: Dict["PhysicalOperator", LogicalOperator], + context: "DataContext", + ): + super().__init__(context) + self._dag = dag + self._op_map = op_map + + @property + def dag(self) -> "PhysicalOperator": + """Get the DAG of physical operators.""" + return self._dag + + @property + def op_map(self) -> Dict["PhysicalOperator", LogicalOperator]: + """ + Get a mapping from physical operators to their corresponding logical operator. + """ + return self._op_map diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/plan.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/plan.py new file mode 100644 index 0000000000000000000000000000000000000000..8dba60277071a690d55b9a9a166cce74dca01bbc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/interfaces/plan.py @@ -0,0 +1,25 @@ +from typing import TYPE_CHECKING + +from .operator import Operator + +if TYPE_CHECKING: + from ray.data import DataContext + + +class Plan: + """Abstract class for logical/physical execution plans. + + This plan should hold an operator representing the plan DAG and any auxiliary data + that's useful for plan optimization or execution. + """ + + def __init__(self, context: "DataContext"): + self._context = context + + @property + def dag(self) -> Operator: + raise NotImplementedError + + @property + def context(self) -> "DataContext": + return self._context diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__init__.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/count_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/count_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1115e7e69e81f5117d6b543dd5cc4574a06dbade Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/count_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/all_to_all_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/all_to_all_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..16a09f44d8209247ae107811defe55b5fe1d42a7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/all_to_all_operator.py @@ -0,0 +1,163 @@ +from typing import Any, Dict, List, Optional + +from ray.data._internal.logical.interfaces import LogicalOperator +from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec +from ray.data._internal.planner.exchange.shuffle_task_spec import ShuffleTaskSpec +from ray.data._internal.planner.exchange.sort_task_spec import SortKey, SortTaskSpec +from ray.data.aggregate import AggregateFn +from ray.data.block import BlockMetadata + + +class AbstractAllToAll(LogicalOperator): + """Abstract class for logical operators should be converted to physical + AllToAllOperator. + """ + + def __init__( + self, + name: str, + input_op: LogicalOperator, + num_outputs: Optional[int] = None, + sub_progress_bar_names: Optional[List[str]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + """ + Args: + name: Name for this operator. This is the name that will appear when + inspecting the logical plan of a Dataset. + input_op: The operator preceding this operator in the plan DAG. The outputs + of `input_op` will be the inputs to this operator. + num_outputs: The number of expected output bundles outputted by this + operator. + ray_remote_args: Args to provide to :func:`ray.remote`. + """ + super().__init__(name, [input_op], num_outputs) + self._num_outputs = num_outputs + self._ray_remote_args = ray_remote_args or {} + self._sub_progress_bar_names = sub_progress_bar_names + + +class RandomizeBlocks(AbstractAllToAll): + """Logical operator for randomize_block_order.""" + + def __init__( + self, + input_op: LogicalOperator, + seed: Optional[int] = None, + ): + super().__init__( + "RandomizeBlockOrder", + input_op, + ) + self._seed = seed + + def aggregate_output_metadata(self) -> BlockMetadata: + assert len(self._input_dependencies) == 1, len(self._input_dependencies) + return self._input_dependencies[0].aggregate_output_metadata() + + +class RandomShuffle(AbstractAllToAll): + """Logical operator for random_shuffle.""" + + def __init__( + self, + input_op: LogicalOperator, + name: str = "RandomShuffle", + seed: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + super().__init__( + name, + input_op, + sub_progress_bar_names=[ + ExchangeTaskSpec.MAP_SUB_PROGRESS_BAR_NAME, + ExchangeTaskSpec.REDUCE_SUB_PROGRESS_BAR_NAME, + ], + ray_remote_args=ray_remote_args, + ) + self._seed = seed + + def aggregate_output_metadata(self) -> BlockMetadata: + assert len(self._input_dependencies) == 1, len(self._input_dependencies) + return self._input_dependencies[0].aggregate_output_metadata() + + +class Repartition(AbstractAllToAll): + """Logical operator for repartition.""" + + def __init__( + self, + input_op: LogicalOperator, + num_outputs: int, + shuffle: bool, + ): + if shuffle: + sub_progress_bar_names = [ + ExchangeTaskSpec.MAP_SUB_PROGRESS_BAR_NAME, + ExchangeTaskSpec.REDUCE_SUB_PROGRESS_BAR_NAME, + ] + else: + sub_progress_bar_names = [ + ShuffleTaskSpec.SPLIT_REPARTITION_SUB_PROGRESS_BAR_NAME, + ] + super().__init__( + "Repartition", + input_op, + num_outputs=num_outputs, + sub_progress_bar_names=sub_progress_bar_names, + ) + self._shuffle = shuffle + + def aggregate_output_metadata(self) -> BlockMetadata: + assert len(self._input_dependencies) == 1, len(self._input_dependencies) + return self._input_dependencies[0].aggregate_output_metadata() + + +class Sort(AbstractAllToAll): + """Logical operator for sort.""" + + def __init__( + self, + input_op: LogicalOperator, + sort_key: SortKey, + batch_format: Optional[str] = "default", + ): + super().__init__( + "Sort", + input_op, + sub_progress_bar_names=[ + SortTaskSpec.SORT_SAMPLE_SUB_PROGRESS_BAR_NAME, + ExchangeTaskSpec.MAP_SUB_PROGRESS_BAR_NAME, + ExchangeTaskSpec.REDUCE_SUB_PROGRESS_BAR_NAME, + ], + ) + self._sort_key = sort_key + self._batch_format = batch_format + + def aggregate_output_metadata(self) -> BlockMetadata: + assert len(self._input_dependencies) == 1, len(self._input_dependencies) + return self._input_dependencies[0].aggregate_output_metadata() + + +class Aggregate(AbstractAllToAll): + """Logical operator for aggregate.""" + + def __init__( + self, + input_op: LogicalOperator, + key: Optional[str], + aggs: List[AggregateFn], + batch_format: Optional[str] = "default", + ): + super().__init__( + "Aggregate", + input_op, + sub_progress_bar_names=[ + SortTaskSpec.SORT_SAMPLE_SUB_PROGRESS_BAR_NAME, + ExchangeTaskSpec.MAP_SUB_PROGRESS_BAR_NAME, + ExchangeTaskSpec.REDUCE_SUB_PROGRESS_BAR_NAME, + ], + ) + self._key = key + self._aggs = aggs + self._batch_format = batch_format diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/input_data_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/input_data_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..2296b0ee315441bcdb0a3acd49791576aea3e191 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/input_data_operator.py @@ -0,0 +1,74 @@ +import functools +from typing import Callable, List, Optional + +from ray.data._internal.execution.interfaces import RefBundle +from ray.data._internal.logical.interfaces import LogicalOperator +from ray.data._internal.util import unify_block_metadata_schema +from ray.data.block import BlockMetadata + + +class InputData(LogicalOperator): + """Logical operator for input data. + + This may hold cached blocks from a previous Dataset execution, or + the arguments for read tasks. + """ + + def __init__( + self, + input_data: Optional[List[RefBundle]] = None, + input_data_factory: Optional[Callable[[int], List[RefBundle]]] = None, + ): + assert (input_data is None) != ( + input_data_factory is None + ), "Only one of input_data and input_data_factory should be set." + super().__init__( + "InputData", [], len(input_data) if input_data is not None else None + ) + self.input_data = input_data + self.input_data_factory = input_data_factory + + def output_data(self) -> Optional[List[RefBundle]]: + if self.input_data is None: + return None + return self.input_data + + def aggregate_output_metadata(self) -> BlockMetadata: + return self._cached_output_metadata + + @functools.cached_property + def _cached_output_metadata(self) -> BlockMetadata: + if self.input_data is None: + return BlockMetadata(None, None, None, None, None) + + return BlockMetadata( + num_rows=self._num_rows(), + size_bytes=self._size_bytes(), + schema=self._schema(), + input_files=None, + exec_stats=None, + ) + + def _num_rows(self): + assert self.input_data is not None + if all(bundle.num_rows() is not None for bundle in self.input_data): + return sum(bundle.num_rows() for bundle in self.input_data) + else: + return None + + def _size_bytes(self): + assert self.input_data is not None + metadata = [m for bundle in self.input_data for m in bundle.metadata] + if all(m.size_bytes is not None for m in metadata): + return sum(m.size_bytes for m in metadata) + else: + return None + + def _schema(self): + assert self.input_data is not None + metadata = [m for bundle in self.input_data for m in bundle.metadata] + return unify_block_metadata_schema(metadata) + + def is_lineage_serializable(self) -> bool: + # This operator isn't serializable because it contains ObjectRefs. + return False diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/map_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/map_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..be13cf76c7be2d0c2bf609f5ead3761e3ae4c2de --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/map_operator.py @@ -0,0 +1,315 @@ +import inspect +import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional + +from ray.data._internal.compute import ComputeStrategy, TaskPoolStrategy +from ray.data._internal.logical.interfaces import LogicalOperator +from ray.data._internal.logical.operators.one_to_one_operator import AbstractOneToOne +from ray.data.block import UserDefinedFunction +from ray.data.context import DEFAULT_BATCH_SIZE +from ray.data.preprocessor import Preprocessor + +if TYPE_CHECKING: + import pyarrow as pa + + +logger = logging.getLogger(__name__) + + +class AbstractMap(AbstractOneToOne): + """Abstract class for logical operators that should be converted to physical + MapOperator. + """ + + def __init__( + self, + name: str, + input_op: Optional[LogicalOperator] = None, + num_outputs: Optional[int] = None, + *, + min_rows_per_bundled_input: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + compute: Optional[ComputeStrategy] = None, + ): + """ + Args: + name: Name for this operator. This is the name that will appear when + inspecting the logical plan of a Dataset. + input_op: The operator preceding this operator in the plan DAG. The outputs + of `input_op` will be the inputs to this operator. + min_rows_per_bundled_input: The target number of rows to pass to + ``MapOperator._add_bundled_input()``. + ray_remote_args: Args to provide to :func:`ray.remote`. + ray_remote_args_fn: A function that returns a dictionary of remote args + passed to each map worker. The purpose of this argument is to generate + dynamic arguments for each actor/task, and will be called each time + prior to initializing the worker. Args returned from this dict + always override the args in ``ray_remote_args``. Note: this is an + advanced, experimental feature. + """ + super().__init__(name, input_op, num_outputs) + self._min_rows_per_bundled_input = min_rows_per_bundled_input + self._ray_remote_args = ray_remote_args or {} + self._ray_remote_args_fn = ray_remote_args_fn + self._compute = compute or TaskPoolStrategy() + + +class AbstractUDFMap(AbstractMap): + """Abstract class for logical operators performing a UDF that should be converted + to physical MapOperator. + """ + + def __init__( + self, + name: str, + input_op: LogicalOperator, + fn: UserDefinedFunction, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + min_rows_per_bundled_input: Optional[int] = None, + compute: Optional[ComputeStrategy] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + """ + Args: + name: Name for this operator. This is the name that will appear when + inspecting the logical plan of a Dataset. + input_op: The operator preceding this operator in the plan DAG. The outputs + of `input_op` will be the inputs to this operator. + fn: User-defined function to be called. + fn_args: Arguments to `fn`. + fn_kwargs: Keyword arguments to `fn`. + fn_constructor_args: Arguments to provide to the initializor of `fn` if + `fn` is a callable class. + fn_constructor_kwargs: Keyword Arguments to provide to the initializor of + `fn` if `fn` is a callable class. + min_rows_per_bundled_input: The target number of rows to pass to + ``MapOperator._add_bundled_input()``. + compute: The compute strategy, either ``TaskPoolStrategy`` (default) to use + Ray tasks, or ``ActorPoolStrategy`` to use an autoscaling actor pool. + ray_remote_args_fn: A function that returns a dictionary of remote args + passed to each map worker. The purpose of this argument is to generate + dynamic arguments for each actor/task, and will be called each time + prior to initializing the worker. Args returned from this dict will + always override the args in ``ray_remote_args``. Note: this is an + advanced, experimental feature. + ray_remote_args: Args to provide to :func:`ray.remote`. + """ + name = self._get_operator_name(name, fn) + super().__init__( + name, + input_op, + min_rows_per_bundled_input=min_rows_per_bundled_input, + ray_remote_args=ray_remote_args, + compute=compute, + ) + self._fn = fn + self._fn_args = fn_args + self._fn_kwargs = fn_kwargs + self._fn_constructor_args = fn_constructor_args + self._fn_constructor_kwargs = fn_constructor_kwargs + self._ray_remote_args_fn = ray_remote_args_fn + + def _get_operator_name(self, op_name: str, fn: UserDefinedFunction): + """Gets the Operator name including the map `fn` UDF name.""" + # If the input `fn` is a Preprocessor, the + # name is simply the name of the Preprocessor class. + if inspect.ismethod(fn) and isinstance(fn.__self__, Preprocessor): + return fn.__self__.__class__.__name__ + + # Otherwise, it takes the form of `()`, + # e.g. `MapBatches(my_udf)`. + try: + if inspect.isclass(fn): + # callable class + return f"{op_name}({fn.__name__})" + elif inspect.ismethod(fn): + # class method + return f"{op_name}({fn.__self__.__class__.__name__}.{fn.__name__})" + elif inspect.isfunction(fn): + # normal function or lambda function. + return f"{op_name}({fn.__name__})" + else: + # callable object. + return f"{op_name}({fn.__class__.__name__})" + except AttributeError as e: + logger.error("Failed to get name of UDF %s: %s", fn, e) + return "" + + +class MapBatches(AbstractUDFMap): + """Logical operator for map_batches.""" + + def __init__( + self, + input_op: LogicalOperator, + fn: UserDefinedFunction, + batch_size: Optional[int] = DEFAULT_BATCH_SIZE, + batch_format: str = "default", + zero_copy_batch: bool = False, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + min_rows_per_bundled_input: Optional[int] = None, + compute: Optional[ComputeStrategy] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + super().__init__( + "MapBatches", + input_op, + fn, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + min_rows_per_bundled_input=min_rows_per_bundled_input, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + self._batch_size = batch_size + self._batch_format = batch_format + self._zero_copy_batch = zero_copy_batch + + @property + def can_modify_num_rows(self) -> bool: + return False + + +class MapRows(AbstractUDFMap): + """Logical operator for map.""" + + def __init__( + self, + input_op: LogicalOperator, + fn: UserDefinedFunction, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + compute: Optional[ComputeStrategy] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + super().__init__( + "Map", + input_op, + fn, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + + @property + def can_modify_num_rows(self) -> bool: + return False + + +class Filter(AbstractUDFMap): + """Logical operator for filter.""" + + def __init__( + self, + input_op: LogicalOperator, + fn: Optional[UserDefinedFunction] = None, + filter_expr: Optional["pa.dataset.Expression"] = None, + compute: Optional[ComputeStrategy] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + # Ensure exactly one of fn or filter_expr is provided + if not ((fn is None) ^ (filter_expr is None)): + raise ValueError("Exactly one of 'fn' or 'filter_expr' must be provided") + self._filter_expr = filter_expr + + super().__init__( + "Filter", + input_op, + fn=fn, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + + @property + def can_modify_num_rows(self) -> bool: + return True + + +class Project(AbstractMap): + """Logical operator for select_columns.""" + + def __init__( + self, + input_op: LogicalOperator, + cols: Optional[List[str]] = None, + cols_rename: Optional[Dict[str, str]] = None, + compute: Optional[ComputeStrategy] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + super().__init__( + "Project", + input_op=input_op, + ray_remote_args=ray_remote_args, + compute=compute, + ) + self._batch_size = DEFAULT_BATCH_SIZE + self._cols = cols or [] + self._cols_rename = cols_rename or {} + self._batch_format = "pyarrow" + self._zero_copy_batch = True + + @property + def cols(self) -> Optional[List[str]]: + return self._cols + + @property + def cols_rename(self) -> Optional[Dict[str, str]]: + return self._cols_rename + + @property + def can_modify_num_rows(self) -> bool: + return False + + +class FlatMap(AbstractUDFMap): + """Logical operator for flat_map.""" + + def __init__( + self, + input_op: LogicalOperator, + fn: UserDefinedFunction, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + compute: Optional[ComputeStrategy] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + ): + super().__init__( + "FlatMap", + input_op, + fn, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + + @property + def can_modify_num_rows(self) -> bool: + return True diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/n_ary_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/n_ary_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f062578e2c1d1f31b42fe5f96dfbae6cd13654 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/n_ary_operator.py @@ -0,0 +1,60 @@ +from typing import Optional + +from ray.data._internal.logical.interfaces import LogicalOperator + + +class NAry(LogicalOperator): + """Base class for n-ary operators, which take multiple input operators.""" + + def __init__( + self, + *input_ops: LogicalOperator, + num_outputs: Optional[int] = None, + ): + """ + Args: + input_ops: The input operators. + """ + super().__init__(self.__class__.__name__, list(input_ops), num_outputs) + + +class Zip(NAry): + """Logical operator for zip.""" + + def __init__( + self, + left_input_op: LogicalOperator, + right_input_op: LogicalOperator, + ): + """ + Args: + left_input_ops: The input operator at left hand side. + right_input_op: The input operator at right hand side. + """ + super().__init__(left_input_op, right_input_op) + + def estimated_num_outputs(self): + left_num_outputs = self._input_dependencies[0].estimated_num_outputs() + right_num_outputs = self._input_dependencies[1].estimated_num_outputs() + if left_num_outputs is None or right_num_outputs is None: + return None + return max(left_num_outputs, right_num_outputs) + + +class Union(NAry): + """Logical operator for union.""" + + def __init__( + self, + *input_ops: LogicalOperator, + ): + super().__init__(*input_ops) + + def estimated_num_outputs(self): + total_num_outputs = 0 + for input in self._input_dependencies: + num_outputs = input.estimated_num_outputs() + if num_outputs is None: + return None + total_num_outputs += num_outputs + return total_num_outputs diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/one_to_one_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/one_to_one_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..052d0b23ecda047b1876c8ca6adfa9bf6fd8dfa5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/one_to_one_operator.py @@ -0,0 +1,80 @@ +import abc +from typing import Optional + +from ray.data._internal.logical.interfaces import LogicalOperator +from ray.data.block import BlockMetadata + + +class AbstractOneToOne(LogicalOperator): + """Abstract class for one-to-one logical operators, which + have one input and one output dependency. + """ + + def __init__( + self, + name: str, + input_op: Optional[LogicalOperator], + num_outputs: Optional[int] = None, + ): + """ + Args: + name: Name for this operator. This is the name that will appear when + inspecting the logical plan of a Dataset. + input_op: The operator preceding this operator in the plan DAG. The outputs + of `input_op` will be the inputs to this operator. + """ + super().__init__(name, [input_op] if input_op else [], num_outputs) + + @property + def input_dependency(self) -> LogicalOperator: + return self._input_dependencies[0] + + @property + @abc.abstractmethod + def can_modify_num_rows(self) -> bool: + """Whether this operator can modify the number of rows, + i.e. number of input rows != number of output rows.""" + + +class Limit(AbstractOneToOne): + """Logical operator for limit.""" + + def __init__( + self, + input_op: LogicalOperator, + limit: int, + ): + super().__init__( + f"limit={limit}", + input_op, + ) + self._limit = limit + + @property + def can_modify_num_rows(self) -> bool: + return True + + def aggregate_output_metadata(self) -> BlockMetadata: + return BlockMetadata( + num_rows=self._num_rows(), + size_bytes=None, + schema=self._schema(), + input_files=self._input_files(), + exec_stats=None, + ) + + def _schema(self): + assert len(self._input_dependencies) == 1, len(self._input_dependencies) + return self._input_dependencies[0].aggregate_output_metadata().schema + + def _num_rows(self): + assert len(self._input_dependencies) == 1, len(self._input_dependencies) + input_rows = self._input_dependencies[0].aggregate_output_metadata().num_rows + if input_rows is not None: + return min(input_rows, self._limit) + else: + return None + + def _input_files(self): + assert len(self._input_dependencies) == 1, len(self._input_dependencies) + return self._input_dependencies[0].aggregate_output_metadata().input_files diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/read_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/read_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..5d958dbc59fbd5c0c09312d2c50a59de9f689d41 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/read_operator.py @@ -0,0 +1,95 @@ +import functools +from typing import Any, Dict, Optional, Union + +from ray.data._internal.logical.operators.map_operator import AbstractMap +from ray.data._internal.util import unify_block_metadata_schema +from ray.data.block import BlockMetadata +from ray.data.datasource.datasource import Datasource, Reader + + +class Read(AbstractMap): + """Logical operator for read.""" + + def __init__( + self, + datasource: Datasource, + datasource_or_legacy_reader: Union[Datasource, Reader], + parallelism: int, + mem_size: Optional[int], + num_outputs: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + ): + super().__init__( + f"Read{datasource.get_name()}", + None, + num_outputs, + ray_remote_args=ray_remote_args, + ) + self._datasource = datasource + self._datasource_or_legacy_reader = datasource_or_legacy_reader + self._parallelism = parallelism + self._mem_size = mem_size + self._concurrency = concurrency + self._detected_parallelism = None + + def set_detected_parallelism(self, parallelism: int): + """ + Set the true parallelism that should be used during execution. This + should be specified by the user or detected by the optimizer. + """ + self._detected_parallelism = parallelism + + def get_detected_parallelism(self) -> int: + """ + Get the true parallelism that should be used during execution. + """ + return self._detected_parallelism + + def aggregate_output_metadata(self) -> BlockMetadata: + """A ``BlockMetadata`` that represents the aggregate metadata of the outputs. + + This method gets metadata from the read tasks. It doesn't trigger any actual + execution. + """ + return self._cached_output_metadata + + @functools.cached_property + def _cached_output_metadata(self) -> BlockMetadata: + # Legacy datasources might not implement `get_read_tasks`. + if self._datasource.should_create_reader: + return BlockMetadata(None, None, None, None, None) + + # HACK: Try to get a single read task to get the metadata. + read_tasks = self._datasource.get_read_tasks(1) + if len(read_tasks) == 0: + # If there are no read tasks, the dataset is probably empty. + return BlockMetadata(None, None, None, None, None) + + # `get_read_tasks` isn't guaranteed to return exactly one read task. + metadata = [read_task.metadata for read_task in read_tasks] + + if all(meta.num_rows is not None for meta in metadata): + num_rows = sum(meta.num_rows for meta in metadata) + else: + num_rows = None + + if all(meta.size_bytes is not None for meta in metadata): + size_bytes = sum(meta.size_bytes for meta in metadata) + else: + size_bytes = None + + schema = unify_block_metadata_schema(metadata) + + input_files = [] + for meta in metadata: + if meta.input_files is not None: + input_files.extend(meta.input_files) + + return BlockMetadata( + num_rows=num_rows, + size_bytes=size_bytes, + schema=schema, + input_files=input_files, + exec_stats=None, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/write_operator.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/write_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf0159fa4f36aacc1c4ed82892226d20b7ca8f9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/write_operator.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, Optional, Union + +from ray.data._internal.logical.interfaces import LogicalOperator +from ray.data._internal.logical.operators.map_operator import AbstractMap +from ray.data.datasource.datasink import Datasink +from ray.data.datasource.datasource import Datasource + + +class Write(AbstractMap): + """Logical operator for write.""" + + def __init__( + self, + input_op: LogicalOperator, + datasink_or_legacy_datasource: Union[Datasink, Datasource], + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + **write_args, + ): + if isinstance(datasink_or_legacy_datasource, Datasink): + min_rows_per_bundled_input = ( + datasink_or_legacy_datasource.min_rows_per_write + ) + else: + min_rows_per_bundled_input = None + + super().__init__( + "Write", + input_op, + min_rows_per_bundled_input=min_rows_per_bundled_input, + ray_remote_args=ray_remote_args, + ) + self._datasink_or_legacy_datasource = datasink_or_legacy_datasource + self._write_args = write_args + self._concurrency = concurrency diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/__pycache__/operator_fusion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/__pycache__/operator_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf180aa09fbabbbcdeea086c4c9a2bd29b920d89 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/__pycache__/operator_fusion.cpython-311.pyc differ