diff --git a/.venv/lib/python3.11/site-packages/ray/data/__init__.py b/.venv/lib/python3.11/site-packages/ray/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..087d41ab38d6478ac189c99b6b0e2446b633094d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/__init__.py @@ -0,0 +1,165 @@ +# Short term workaround for https://github.com/ray-project/ray/issues/32435 +# Dataset has a hard dependency on pandas, so it doesn't need to be delayed. +import pandas # noqa +from packaging.version import parse as parse_version + +from ray._private.utils import _get_pyarrow_version +from ray.data._internal.compute import ActorPoolStrategy +from ray.data._internal.datasource.tfrecords_datasource import TFXReadOptions +from ray.data._internal.execution.interfaces import ( + ExecutionOptions, + ExecutionResources, + NodeIdStr, +) +from ray.data._internal.logging import configure_logging +from ray.data.context import DataContext, DatasetContext +from ray.data.dataset import Dataset, Schema +from ray.data.datasource import ( + BlockBasedFileDatasink, + Datasink, + Datasource, + FileShuffleConfig, + ReadTask, + RowBasedFileDatasink, +) +from ray.data.iterator import DataIterator, DatasetIterator +from ray.data.preprocessor import Preprocessor +from ray.data.read_api import ( # noqa: F401 + from_arrow, + from_arrow_refs, + from_blocks, + from_dask, + from_huggingface, + from_items, + from_mars, + from_modin, + from_numpy, + from_numpy_refs, + from_pandas, + from_pandas_refs, + from_spark, + from_tf, + from_torch, + range, + range_tensor, + read_audio, + read_avro, + read_bigquery, + read_binary_files, + read_clickhouse, + read_csv, + read_databricks_tables, + read_datasource, + read_delta_sharing_tables, + read_hudi, + read_iceberg, + read_images, + read_json, + read_lance, + read_mongo, + read_numpy, + read_parquet, + read_parquet_bulk, + read_sql, + read_text, + read_tfrecords, + read_videos, + read_webdataset, +) + +# Module-level cached global functions for callable classes. It needs to be defined here +# since it has to be process-global across cloudpickled funcs. +_map_actor_context = None + +configure_logging() + +try: + import pyarrow as pa + + # https://github.com/apache/arrow/pull/38608 deprecated `PyExtensionType`, and + # disabled it's deserialization by default. To ensure that users can load data + # written with earlier version of Ray Data, we enable auto-loading of serialized + # tensor extensions. + pyarrow_version = _get_pyarrow_version() + if not isinstance(pyarrow_version, str): + # PyArrow is mocked in documentation builds. In this case, we don't need to do + # anything. + pass + else: + from ray._private.ray_constants import env_bool + + RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE = env_bool( + "RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE", False + ) + + if ( + parse_version(pyarrow_version) >= parse_version("14.0.1") + and RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE + ): + pa.PyExtensionType.set_auto_load(True) + # Import these arrow extension types to ensure that they are registered. + from ray.air.util.tensor_extensions.arrow import ( # noqa + ArrowTensorType, + ArrowVariableShapedTensorType, + ) +except ModuleNotFoundError: + pass + + +__all__ = [ + "ActorPoolStrategy", + "BlockBasedFileDatasink", + "Dataset", + "DataContext", + "DatasetContext", # Backwards compatibility alias. + "DataIterator", + "DatasetIterator", # Backwards compatibility alias. + "Datasink", + "Datasource", + "ExecutionOptions", + "ExecutionResources", + "FileShuffleConfig", + "NodeIdStr", + "ReadTask", + "RowBasedFileDatasink", + "Schema", + "from_dask", + "from_items", + "from_arrow", + "from_arrow_refs", + "from_mars", + "from_modin", + "from_numpy", + "from_numpy_refs", + "from_pandas", + "from_pandas_refs", + "from_spark", + "from_tf", + "from_torch", + "from_huggingface", + "range", + "range_tensor", + "read_audio", + "read_avro", + "read_text", + "read_binary_files", + "read_clickhouse", + "read_csv", + "read_datasource", + "read_delta_sharing_tables", + "read_hudi", + "read_iceberg", + "read_images", + "read_json", + "read_lance", + "read_numpy", + "read_mongo", + "read_parquet", + "read_parquet_bulk", + "read_sql", + "read_tfrecords", + "read_videos", + "read_webdataset", + "Preprocessor", + "TFXReadOptions", +] diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16de397460c3dce21db2debfddf845daef418944 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/aggregate.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/aggregate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6995257375cba8098778c549c4249835054b9f82 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/aggregate.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/block.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..844fdcbc3fd5858ae1be78598d30c2130a1176c3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/block.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27658a734d00f8fe9c70e1887232bfdff9196d87 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c6d591858c0eba8b9a7f6ece37f53bf1914f7d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/grouped_data.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/grouped_data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfbfcceb246c8b34543ba70aa9efcdbbe8674a1f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/grouped_data.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/iterator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/iterator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41c12b550c5aefc02f08de45f32546304cabb1f4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/iterator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/preprocessor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/preprocessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26994e13823b28539bc85983158ca4ce0f23936b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/preprocessor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/__pycache__/random_access_dataset.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/random_access_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e01433e204de5aab41945642324d2f8dbc15e40 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/__pycache__/random_access_dataset.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/aggregate.py b/.venv/lib/python3.11/site-packages/ray/data/aggregate.py new file mode 100644 index 0000000000000000000000000000000000000000..86d8b7cb603c303ffc22f73dfd19bcf9cfa0ac9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/aggregate.py @@ -0,0 +1,76 @@ +from typing import TYPE_CHECKING, Callable, Optional, Union + +from ray.data.block import AggType, Block, BlockAccessor, KeyType, T, U +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + import pyarrow as pa + + +@PublicAPI +class AggregateFn: + """Defines an aggregate function in the accumulator style. + + Aggregates a collection of inputs of type T into + a single output value of type U. + See https://www.sigops.org/s/conferences/sosp/2009/papers/yu-sosp09.pdf + for more details about accumulator-based aggregation. + + Args: + init: This is called once for each group to return the empty accumulator. + For example, an empty accumulator for a sum would be 0. + merge: This may be called multiple times, each time to merge + two accumulators into one. + name: The name of the aggregation. This will be used as the column name + in the output Dataset. + accumulate_row: This is called once per row of the same group. + This combines the accumulator and the row, returns the updated + accumulator. Exactly one of accumulate_row and accumulate_block must + be provided. + accumulate_block: This is used to calculate the aggregation for a + single block, and is vectorized alternative to accumulate_row. This will + be given a base accumulator and the entire block, allowing for + vectorized accumulation of the block. Exactly one of accumulate_row and + accumulate_block must be provided. + finalize: This is called once to compute the final aggregation + result from the fully merged accumulator. + """ + + def __init__( + self, + init: Callable[[KeyType], AggType], + merge: Callable[[AggType, AggType], AggType], + name: str, + accumulate_row: Callable[[AggType, T], AggType] = None, + accumulate_block: Callable[[AggType, Block], AggType] = None, + finalize: Optional[Callable[[AggType], U]] = None, + ): + if (accumulate_row is None and accumulate_block is None) or ( + accumulate_row is not None and accumulate_block is not None + ): + raise ValueError( + "Exactly one of accumulate_row or accumulate_block must be provided." + ) + if accumulate_block is None: + + def accumulate_block(a: AggType, block: Block) -> AggType: + block_acc = BlockAccessor.for_block(block) + for r in block_acc.iter_rows(public_row_format=False): + a = accumulate_row(a, r) + return a + + if not isinstance(name, str): + raise TypeError("`name` must be provided.") + + if finalize is None: + finalize = lambda a: a # noqa: E731 + + self.init = init + self.merge = merge + self.name = name + self.accumulate_block = accumulate_block + self.finalize = finalize + + def _validate(self, schema: Optional[Union[type, "pa.lib.Schema"]]) -> None: + """Raise an error if this cannot be applied to the given schema.""" + pass diff --git a/.venv/lib/python3.11/site-packages/ray/data/block.py b/.venv/lib/python3.11/site-packages/ray/data/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8e78b0736e4fc9b338029160aab25444652d18 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/block.py @@ -0,0 +1,561 @@ +import collections +import logging +import os +import time +from dataclasses import dataclass +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Protocol, + Tuple, + TypeVar, + Union, +) + +import numpy as np + +import ray +from ray import DynamicObjectRefGenerator +from ray.air.util.tensor_extensions.arrow import ArrowConversionError +from ray.data._internal.util import _check_pyarrow_version, _truncated_repr +from ray.types import ObjectRef +from ray.util import log_once +from ray.util.annotations import DeveloperAPI + +import psutil + +try: + import resource +except ImportError: + resource = None + +if TYPE_CHECKING: + import pandas + import pyarrow + + from ray.data._internal.block_builder import BlockBuilder + from ray.data._internal.planner.exchange.sort_task_spec import SortKey + from ray.data.aggregate import AggregateFn + + +T = TypeVar("T", contravariant=True) +U = TypeVar("U", covariant=True) + +KeyType = TypeVar("KeyType") +AggType = TypeVar("AggType") + + +# Represents a batch of records to be stored in the Ray object store. +# +# Block data can be accessed in a uniform way via ``BlockAccessors`` like` +# ``ArrowBlockAccessor``. +Block = Union["pyarrow.Table", "pandas.DataFrame"] + + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class BlockType(Enum): + ARROW = "arrow" + PANDAS = "pandas" + + +# User-facing data batch type. This is the data type for data that is supplied to and +# returned from batch UDFs. +DataBatch = Union["pyarrow.Table", "pandas.DataFrame", Dict[str, np.ndarray]] + +# User-facing data column type. This is the data type for data that is supplied to and +# returned from column UDFs. +DataBatchColumn = Union[ + "pyarrow.ChunkedArray", "pyarrow.Array", "pandas.Series", np.ndarray +] + + +# A class type that implements __call__. +CallableClass = type + + +class _CallableClassProtocol(Protocol[T, U]): + def __call__(self, __arg: T) -> Union[U, Iterator[U]]: + ... + + +# A user defined function passed to map, map_batches, ec. +UserDefinedFunction = Union[ + Callable[[T], U], + Callable[[T], Iterator[U]], + "_CallableClassProtocol", +] + +# A list of block references pending computation by a single task. For example, +# this may be the output of a task reading a file. +BlockPartition = List[Tuple[ObjectRef[Block], "BlockMetadata"]] + +# The metadata that describes the output of a BlockPartition. This has the +# same type as the metadata that describes each block in the partition. +BlockPartitionMetadata = List["BlockMetadata"] + +# TODO(ekl/chengsu): replace this with just +# `DynamicObjectRefGenerator` once block splitting +# is on by default. When block splitting is off, the type is a plain block. +MaybeBlockPartition = Union[Block, DynamicObjectRefGenerator] + +VALID_BATCH_FORMATS = ["pandas", "pyarrow", "numpy", None] +DEFAULT_BATCH_FORMAT = "numpy" + + +def _apply_batch_format(given_batch_format: Optional[str]) -> str: + if given_batch_format == "default": + given_batch_format = DEFAULT_BATCH_FORMAT + if given_batch_format not in VALID_BATCH_FORMATS: + raise ValueError( + f"The given batch format {given_batch_format} isn't allowed (must be one of" + f" {VALID_BATCH_FORMATS})." + ) + return given_batch_format + + +def _apply_batch_size( + given_batch_size: Optional[Union[int, Literal["default"]]] +) -> Optional[int]: + if given_batch_size == "default": + return ray.data.context.DEFAULT_BATCH_SIZE + else: + return given_batch_size + + +@DeveloperAPI +class BlockExecStats: + """Execution stats for this block. + + Attributes: + wall_time_s: The wall-clock time it took to compute this block. + cpu_time_s: The CPU time it took to compute this block. + node_id: A unique id for the node that computed this block. + """ + + def __init__(self): + self.start_time_s: Optional[float] = None + self.end_time_s: Optional[float] = None + self.wall_time_s: Optional[float] = None + self.udf_time_s: Optional[float] = 0 + self.cpu_time_s: Optional[float] = None + self.node_id = ray.runtime_context.get_runtime_context().get_node_id() + # Max memory usage. May be an overestimate since we do not + # differentiate from previous tasks on the same worker. + self.max_rss_bytes: int = 0 + self.task_idx: Optional[int] = None + + @staticmethod + def builder() -> "_BlockExecStatsBuilder": + return _BlockExecStatsBuilder() + + def __repr__(self): + return repr( + { + "wall_time_s": self.wall_time_s, + "cpu_time_s": self.cpu_time_s, + "udf_time_s": self.udf_time_s, + "node_id": self.node_id, + } + ) + + +class _BlockExecStatsBuilder: + """Helper class for building block stats. + + When this class is created, we record the start time. When build() is + called, the time delta is saved as part of the stats. + """ + + def __init__(self): + self.start_time = time.perf_counter() + self.start_cpu = time.process_time() + + def build(self) -> "BlockExecStats": + self.end_time = time.perf_counter() + self.end_cpu = time.process_time() + + stats = BlockExecStats() + stats.start_time_s = self.start_time + stats.end_time_s = self.end_time + stats.wall_time_s = self.end_time - self.start_time + stats.cpu_time_s = self.end_cpu - self.start_cpu + if resource is None: + # NOTE(swang): resource package is not supported on Windows. This + # is only the memory usage at the end of the task, not the peak + # memory. + process = psutil.Process(os.getpid()) + stats.max_rss_bytes = int(process.memory_info().rss) + else: + stats.max_rss_bytes = int( + resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1e3 + ) + return stats + + +@DeveloperAPI +@dataclass +class BlockMetadata: + """Metadata about the block.""" + + #: The number of rows contained in this block, or None. + num_rows: Optional[int] + #: The approximate size in bytes of this block, or None. + size_bytes: Optional[int] + #: The pyarrow schema or types of the block elements, or None. + schema: Optional[Union[type, "pyarrow.lib.Schema"]] + #: The list of file paths used to generate this block, or + #: the empty list if indeterminate. + input_files: Optional[List[str]] + #: Execution stats for this block. + exec_stats: Optional[BlockExecStats] + + def __post_init__(self): + if self.input_files is None: + self.input_files = [] + if self.size_bytes is not None: + # Require size_bytes to be int, ray.util.metrics objects + # will not take other types like numpy.int64 + assert isinstance(self.size_bytes, int) + + +@DeveloperAPI +class BlockAccessor: + """Provides accessor methods for a specific block. + + Ideally, we wouldn't need a separate accessor classes for blocks. However, + this is needed if we want to support storing ``pyarrow.Table`` directly + as a top-level Ray object, without a wrapping class (issue #17186). + """ + + def num_rows(self) -> int: + """Return the number of rows contained in this block.""" + raise NotImplementedError + + def iter_rows(self, public_row_format: bool) -> Iterator[T]: + """Iterate over the rows of this block. + + Args: + public_row_format: Whether to cast rows into the public Dict row + format (this incurs extra copy conversions). + """ + raise NotImplementedError + + def slice(self, start: int, end: int, copy: bool) -> Block: + """Return a slice of this block. + + Args: + start: The starting index of the slice (inclusive). + end: The ending index of the slice (exclusive). + copy: Whether to perform a data copy for the slice. + + Returns: + The sliced block result. + """ + raise NotImplementedError + + def take(self, indices: List[int]) -> Block: + """Return a new block containing the provided row indices. + + Args: + indices: The row indices to return. + + Returns: + A new block containing the provided row indices. + """ + raise NotImplementedError + + def select(self, columns: List[Optional[str]]) -> Block: + """Return a new block containing the provided columns.""" + raise NotImplementedError + + def rename_columns(self, columns_rename: Dict[str, str]) -> Block: + """Return the block reflecting the renamed columns.""" + raise NotImplementedError + + def random_shuffle(self, random_seed: Optional[int]) -> Block: + """Randomly shuffle this block.""" + raise NotImplementedError + + def to_pandas(self) -> "pandas.DataFrame": + """Convert this block into a Pandas dataframe.""" + raise NotImplementedError + + def to_numpy( + self, columns: Optional[Union[str, List[str]]] = None + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """Convert this block (or columns of block) into a NumPy ndarray. + + Args: + columns: Name of columns to convert, or None if converting all columns. + """ + raise NotImplementedError + + def to_arrow(self) -> "pyarrow.Table": + """Convert this block into an Arrow table.""" + raise NotImplementedError + + def to_block(self) -> Block: + """Return the base block that this accessor wraps.""" + raise NotImplementedError + + def to_default(self) -> Block: + """Return the default data format for this accessor.""" + return self.to_block() + + def to_batch_format(self, batch_format: Optional[str]) -> DataBatch: + """Convert this block into the provided batch format. + + Args: + batch_format: The batch format to convert this block to. + + Returns: + This block formatted as the provided batch format. + """ + if batch_format is None: + return self.to_block() + elif batch_format == "default" or batch_format == "native": + return self.to_default() + elif batch_format == "pandas": + return self.to_pandas() + elif batch_format == "pyarrow": + return self.to_arrow() + elif batch_format == "numpy": + return self.to_numpy() + else: + raise ValueError( + f"The batch format must be one of {VALID_BATCH_FORMATS}, got: " + f"{batch_format}" + ) + + def size_bytes(self) -> int: + """Return the approximate size in bytes of this block.""" + raise NotImplementedError + + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + """Return the Python type or pyarrow schema of this block.""" + raise NotImplementedError + + def get_metadata( + self, + input_files: Optional[List[str]] = None, + exec_stats: Optional[BlockExecStats] = None, + ) -> BlockMetadata: + """Create a metadata object from this block.""" + return BlockMetadata( + num_rows=self.num_rows(), + size_bytes=self.size_bytes(), + schema=self.schema(), + input_files=input_files, + exec_stats=exec_stats, + ) + + def zip(self, other: "Block") -> "Block": + """Zip this block with another block of the same type and size.""" + raise NotImplementedError + + @staticmethod + def builder() -> "BlockBuilder": + """Create a builder for this block type.""" + raise NotImplementedError + + @classmethod + def batch_to_block( + cls, + batch: DataBatch, + block_type: Optional[BlockType] = None, + ) -> Block: + """Create a block from user-facing data formats.""" + + if isinstance(batch, np.ndarray): + raise ValueError( + f"Error validating {_truncated_repr(batch)}: " + "Standalone numpy arrays are not " + "allowed in Ray 2.5. Return a dict of field -> array, " + "e.g., `{'data': array}` instead of `array`." + ) + + elif isinstance(batch, collections.abc.Mapping): + if block_type is None or block_type == BlockType.ARROW: + try: + return cls.batch_to_arrow_block(batch) + except ArrowConversionError as e: + if log_once("_fallback_to_pandas_block_warning"): + logger.warning( + f"Failed to convert batch to Arrow due to: {e}; " + f"falling back to Pandas block" + ) + + if block_type is None: + return cls.batch_to_pandas_block(batch) + else: + raise e + else: + assert block_type == BlockType.PANDAS + return cls.batch_to_pandas_block(batch) + return batch + + @classmethod + def batch_to_arrow_block(cls, batch: Dict[str, Any]) -> Block: + """Create an Arrow block from user-facing data formats.""" + from ray.data._internal.arrow_block import ArrowBlockBuilder + + return ArrowBlockBuilder._table_from_pydict(batch) + + @classmethod + def batch_to_pandas_block(cls, batch: Dict[str, Any]) -> Block: + """Create a Pandas block from user-facing data formats.""" + from ray.data._internal.pandas_block import PandasBlockAccessor + + return PandasBlockAccessor.numpy_to_block(batch) + + @staticmethod + def for_block(block: Block) -> "BlockAccessor[T]": + """Create a block accessor for the given block.""" + _check_pyarrow_version() + import pandas + import pyarrow + + if isinstance(block, pyarrow.Table): + from ray.data._internal.arrow_block import ArrowBlockAccessor + + return ArrowBlockAccessor(block) + elif isinstance(block, pandas.DataFrame): + from ray.data._internal.pandas_block import PandasBlockAccessor + + return PandasBlockAccessor(block) + elif isinstance(block, bytes): + from ray.data._internal.arrow_block import ArrowBlockAccessor + + return ArrowBlockAccessor.from_bytes(block) + elif isinstance(block, list): + raise ValueError( + f"Error validating {_truncated_repr(block)}: " + "Standalone Python objects are not " + "allowed in Ray 2.5. To use Python objects in a dataset, " + "wrap them in a dict of numpy arrays, e.g., " + "return `{'item': batch}` instead of just `batch`." + ) + else: + raise TypeError("Not a block type: {} ({})".format(block, type(block))) + + def sample(self, n_samples: int, sort_key: "SortKey") -> "Block": + """Return a random sample of items from this block.""" + raise NotImplementedError + + def sort_and_partition( + self, boundaries: List[T], sort_key: "SortKey" + ) -> List["Block"]: + """Return a list of sorted partitions of this block.""" + raise NotImplementedError + + def combine(self, key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block: + """Combine rows with the same key into an accumulator.""" + raise NotImplementedError + + @staticmethod + def merge_sorted_blocks( + blocks: List["Block"], sort_key: "SortKey" + ) -> Tuple[Block, BlockMetadata]: + """Return a sorted block by merging a list of sorted blocks.""" + raise NotImplementedError + + @staticmethod + def aggregate_combined_blocks( + blocks: List[Block], sort_key: "SortKey", aggs: Tuple["AggregateFn"] + ) -> Tuple[Block, BlockMetadata]: + """Aggregate partially combined and sorted blocks.""" + raise NotImplementedError + + def block_type(self) -> BlockType: + """Return the block type of this block.""" + raise NotImplementedError + + +def _get_block_boundaries(columns: list[np.ndarray]) -> np.ndarray: + """Compute boundaries of the groups within a block, which is represented + by a list of 1D numpy arrays for each column. In each column, + NaNs/None are considered to be the same group. + + Args: + columns: a list of 1D numpy arrays. This is generally given by the + dictionary values of ``BlockAccessor.to_numpy()``. + + Returns: + A list of starting indices of each group and an end index of the last + group, i.e., there are ``num_groups + 1`` entries and the first and last + entries are 0 and ``len(array)`` respectively. + """ + + # There are 3 categories: general, numerics with NaN, and categorical with None. + # We only needed to check the last element for NaNs/None, as they are assumed to + # be sorted. + general_arrays = [] + num_arrays_with_nan = [] + cat_arrays_with_none = [] + for arr in columns: + if np.issubdtype(arr.dtype, np.number) and np.isnan(arr[-1]): + num_arrays_with_nan.append(arr) + elif not np.issubdtype(arr.dtype, np.number) and arr[-1] is None: + cat_arrays_with_none.append(arr) + else: + general_arrays.append(arr) + + # Compute the difference between each pair of elements. Handle the cases + # where neighboring elements are both NaN or None. Output as a list of + # boolean arrays. + diffs = [] + if len(general_arrays) > 0: + diffs.append( + np.vstack([arr[1:] != arr[:-1] for arr in general_arrays]).any(axis=0) + ) + if len(num_arrays_with_nan) > 0: + # Two neighboring numeric elements belong to the same group when they are + # 1) both finite and equal + # or 2) both np.nan + diffs.append( + np.vstack( + [ + (arr[1:] != arr[:-1]) + & (np.isfinite(arr[1:]) | np.isfinite(arr[:-1])) + for arr in num_arrays_with_nan + ] + ).any(axis=0) + ) + if len(cat_arrays_with_none) > 0: + # Two neighboring str/object elements belong to the same group when they are + # 1) both finite and equal + # or 2) both None + diffs.append( + np.vstack( + [ + (arr[1:] != arr[:-1]) + & ~(np.equal(arr[1:], None) & np.equal(arr[:-1], None)) + for arr in cat_arrays_with_none + ] + ).any(axis=0) + ) + + # A series of vectorized operations to compute the boundaries: + # - column_stack: stack the bool arrays into a single 2D bool array + # - any() and nonzero(): find the indices where any of the column diffs are True + # - add 1 to get the index of the first element of the next group + # - hstack(): include the 0 and last indices to the boundaries + boundaries = np.hstack( + [ + [0], + (np.column_stack(diffs).any(axis=1).nonzero()[0] + 1), + [len(columns[0])], + ] + ).astype(int) + + return boundaries diff --git a/.venv/lib/python3.11/site-packages/ray/data/context.py b/.venv/lib/python3.11/site-packages/ray/data/context.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9ad30584a38f53dc2ac78224af79f22a8dc3d1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/context.py @@ -0,0 +1,468 @@ +import logging +import os +import threading +import warnings +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import ray +from ray._private.ray_constants import env_bool, env_integer +from ray._private.worker import WORKER_MODE +from ray.util.annotations import DeveloperAPI +from ray.util.debug import log_once +from ray.util.scheduling_strategies import SchedulingStrategyT + +if TYPE_CHECKING: + from ray.data._internal.execution.interfaces import ExecutionOptions + +logger = logging.getLogger(__name__) + +# The context singleton on this process. +_default_context: "Optional[DataContext]" = None +_context_lock = threading.Lock() + + +# We chose 128MiB for default: With streaming execution and num_cpus many concurrent +# tasks, the memory footprint will be about 2 * num_cpus * target_max_block_size ~= RAM +# * DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION * 0.3 (default object store memory +# fraction set by Ray core), assuming typical memory:core ratio of 4:1. +DEFAULT_TARGET_MAX_BLOCK_SIZE = 128 * 1024 * 1024 + +# We set a higher target block size because we have to materialize +# all input blocks anyway, so there is no performance advantage to having +# smaller blocks. Setting a larger block size allows avoiding overhead from an +# excessive number of partitions. +# We choose 1GiB as 4x less than the typical memory:core ratio (4:1). +DEFAULT_SHUFFLE_TARGET_MAX_BLOCK_SIZE = 1024 * 1024 * 1024 + +# We will attempt to slice blocks whose size exceeds this factor * +# target_max_block_size. We will warn the user if slicing fails and we produce +# blocks larger than this threshold. +MAX_SAFE_BLOCK_SIZE_FACTOR = 1.5 + +DEFAULT_TARGET_MIN_BLOCK_SIZE = 1 * 1024 * 1024 + +# This default appears to work well with most file sizes on remote storage systems, +# which is very sensitive to the buffer size. +DEFAULT_STREAMING_READ_BUFFER_SIZE = 32 * 1024 * 1024 + +DEFAULT_ENABLE_PANDAS_BLOCK = True + +DEFAULT_READ_OP_MIN_NUM_BLOCKS = 200 + +DEFAULT_ACTOR_PREFETCHER_ENABLED = False + +DEFAULT_USE_PUSH_BASED_SHUFFLE = bool( + os.environ.get("RAY_DATA_PUSH_BASED_SHUFFLE", None) +) + +DEFAULT_SCHEDULING_STRATEGY = "SPREAD" + +# This default enables locality-based scheduling in Ray for tasks where arg data +# transfer is a bottleneck. +DEFAULT_SCHEDULING_STRATEGY_LARGE_ARGS = "DEFAULT" + +DEFAULT_LARGE_ARGS_THRESHOLD = 50 * 1024 * 1024 + +DEFAULT_USE_POLARS = False + +DEFAULT_EAGER_FREE = bool(int(os.environ.get("RAY_DATA_EAGER_FREE", "1"))) + +DEFAULT_DECODING_SIZE_ESTIMATION_ENABLED = True + +DEFAULT_MIN_PARALLELISM = 200 + +DEFAULT_ENABLE_TENSOR_EXTENSION_CASTING = True + +# NOTE: V1 tensor type format only supports tensors of no more than 2Gb in +# total cumulative size (due to it internally utilizing int32 offsets) +# +# V2 in turn relies on int64 offsets, therefore having a limit of ~9Eb (exabytes) +DEFAULT_USE_ARROW_TENSOR_V2 = env_bool("RAY_DATA_USE_ARROW_TENSOR_V2", True) + +DEFAULT_AUTO_LOG_STATS = False + +DEFAULT_VERBOSE_STATS_LOG = False + +DEFAULT_TRACE_ALLOCATIONS = bool(int(os.environ.get("RAY_DATA_TRACE_ALLOCATIONS", "0"))) + +DEFAULT_LOG_INTERNAL_STACK_TRACE_TO_STDOUT = env_bool( + "RAY_DATA_LOG_INTERNAL_STACK_TRACE_TO_STDOUT", False +) + +DEFAULT_RAY_DATA_RAISE_ORIGINAL_MAP_EXCEPTION = env_bool( + "RAY_DATA_RAISE_ORIGINAL_MAP_EXCEPTION", False +) + +DEFAULT_USE_RAY_TQDM = bool(int(os.environ.get("RAY_TQDM", "1"))) + +# Globally enable or disable all progress bars. +# If this is False, both the global and operator-level progress bars are disabled. +DEFAULT_ENABLE_PROGRESS_BARS = not bool( + env_integer("RAY_DATA_DISABLE_PROGRESS_BARS", 0) +) +DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = env_bool( + "RAY_DATA_ENABLE_PROGRESS_BAR_NAME_TRUNCATION", True +) + +DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS = False + + +# `write_file_retry_on_errors` is deprecated in favor of `retried_io_errors`. You +# shouldn't need to modify `DEFAULT_WRITE_FILE_RETRY_ON_ERRORS`. +DEFAULT_WRITE_FILE_RETRY_ON_ERRORS = ( + "AWS Error INTERNAL_FAILURE", + "AWS Error NETWORK_CONNECTION", + "AWS Error SLOW_DOWN", + "AWS Error UNKNOWN (HTTP status 503)", +) + +DEFAULT_RETRIED_IO_ERRORS = ( + "AWS Error INTERNAL_FAILURE", + "AWS Error NETWORK_CONNECTION", + "AWS Error SLOW_DOWN", + "AWS Error UNKNOWN (HTTP status 503)", + "AWS Error SERVICE_UNAVAILABLE", +) + +DEFAULT_WARN_ON_DRIVER_MEMORY_USAGE_BYTES = 2 * 1024 * 1024 * 1024 + +DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS = False + +DEFAULT_ENABLE_OP_RESOURCE_RESERVATION = env_bool( + "RAY_DATA_ENABLE_OP_RESOURCE_RESERVATION", True +) + +DEFAULT_OP_RESOURCE_RESERVATION_RATIO = float( + os.environ.get("RAY_DATA_OP_RESERVATION_RATIO", "0.5") +) + +DEFAULT_MAX_ERRORED_BLOCKS = 0 + +# Use this to prefix important warning messages for the user. +WARN_PREFIX = "⚠️ " + +# Use this to prefix important success messages for the user. +OK_PREFIX = "✔️ " + +# Default batch size for batch transformations. +DEFAULT_BATCH_SIZE = 1024 + +# Default value of the max number of blocks that can be buffered at the +# streaming generator of each `DataOpTask`. +# Note, if this value is too large, we'll need to allocate more memory +# buffer for the pending task outputs, which may lead to bad performance +# as we may not have enough memory buffer for the operator outputs. +# If the value is too small, the task may be frequently blocked due to +# streaming generator backpressure. +DEFAULT_MAX_NUM_BLOCKS_IN_STREAMING_GEN_BUFFER = 2 + +# Default value for whether or not to try to create directories for write +# calls if the URI is an S3 URI. +DEFAULT_S3_TRY_CREATE_DIR = False + +DEFAULT_WAIT_FOR_MIN_ACTORS_S = env_integer( + "RAY_DATA_DEFAULT_WAIT_FOR_MIN_ACTORS_S", 60 * 10 +) + + +def _execution_options_factory() -> "ExecutionOptions": + # Lazily import to avoid circular dependencies. + from ray.data._internal.execution.interfaces import ExecutionOptions + + return ExecutionOptions() + + +@DeveloperAPI +@dataclass +class DataContext: + """Global settings for Ray Data. + + Configure this class to enable advanced features and tune performance. + + .. warning:: + Apply changes before creating a :class:`~ray.data.Dataset`. Changes made after + won't take effect. + + .. note:: + This object is automatically propagated to workers. Access it from the driver + and remote workers with :meth:`DataContext.get_current()`. + + Examples: + >>> from ray.data import DataContext + >>> DataContext.get_current().enable_progress_bars = False + + Args: + target_max_block_size: The max target block size in bytes for reads and + transformations. + target_shuffle_max_block_size: The max target block size in bytes for shuffle + ops like ``random_shuffle``, ``sort``, and ``repartition``. + target_min_block_size: Ray Data avoids creating blocks smaller than this + size in bytes on read. This takes precedence over + ``read_op_min_num_blocks``. + streaming_read_buffer_size: Buffer size when doing streaming reads from local or + remote storage. + enable_pandas_block: Whether pandas block format is enabled. + actor_prefetcher_enabled: Whether to use actor based block prefetcher. + use_push_based_shuffle: Whether to use push-based shuffle. + pipeline_push_based_shuffle_reduce_tasks: + scheduling_strategy: The global scheduling strategy. For tasks with large args, + ``scheduling_strategy_large_args`` takes precedence. + scheduling_strategy_large_args: Scheduling strategy for tasks with large args. + large_args_threshold: Size in bytes after which point task arguments are + considered large. Choose a value so that the data transfer overhead is + significant in comparison to task scheduling (i.e., low tens of ms). + use_polars: Whether to use Polars for tabular dataset sorts, groupbys, and + aggregations. + eager_free: Whether to eagerly free memory. + decoding_size_estimation: Whether to estimate in-memory decoding data size for + data source. + min_parallelism: This setting is deprecated. Use ``read_op_min_num_blocks`` + instead. + read_op_min_num_blocks: Minimum number of read output blocks for a dataset. + enable_tensor_extension_casting: Whether to automatically cast NumPy ndarray + columns in Pandas DataFrames to tensor extension columns. + use_arrow_tensor_v2: Config enabling V2 version of ArrowTensorArray supporting + tensors > 2Gb in size (off by default) + enable_fallback_to_arrow_object_ext_type: Enables fallback to serialize column + values not suppported by Arrow natively (like user-defined custom Python + classes for ex, etc) using `ArrowPythonObjectType` (simply serializing + these as bytes) + enable_auto_log_stats: Whether to automatically log stats after execution. If + disabled, you can still manually print stats with ``Dataset.stats()``. + verbose_stats_logs: Whether stats logs should be verbose. This includes fields + such as `extra_metrics` in the stats output, which are excluded by default. + trace_allocations: Whether to trace allocations / eager free. This adds + significant performance overheads and should only be used for debugging. + execution_options: The + :class:`~ray.data._internal.execution.interfaces.execution_options.ExecutionOptions` + to use. + use_ray_tqdm: Whether to enable distributed tqdm. + enable_progress_bars: Whether to enable progress bars. + enable_progress_bar_name_truncation: If True, the name of the progress bar + (often the operator name) will be truncated if it exceeds + `ProgressBar.MAX_NAME_LENGTH`. Otherwise, the full operator name is shown. + enable_get_object_locations_for_metrics: Whether to enable + ``get_object_locations`` for metrics. + write_file_retry_on_errors: A list of substrings of error messages that should + trigger a retry when writing files. This is useful for handling transient + errors when writing to remote storage systems. + warn_on_driver_memory_usage_bytes: If driver memory exceeds this threshold, + Ray Data warns you. For now, this only applies to shuffle ops because most + other ops are unlikely to use as much driver memory. + actor_task_retry_on_errors: The application-level errors that actor task should + retry. This follows same format as :ref:`retry_exceptions ` in + Ray Core. Default to `False` to not retry on any errors. Set to `True` to + retry all errors, or set to a list of errors to retry. + enable_op_resource_reservation: Whether to reserve resources for each operator. + op_resource_reservation_ratio: The ratio of the total resources to reserve for + each operator. + max_errored_blocks: Max number of blocks that are allowed to have errors, + unlimited if negative. This option allows application-level exceptions in + block processing tasks. These exceptions may be caused by UDFs (e.g., due to + corrupted data samples) or IO errors. Data in the failed blocks are dropped. + This option can be useful to prevent a long-running job from failing due to + a small number of bad blocks. + log_internal_stack_trace_to_stdout: Whether to include internal Ray Data/Ray + Core code stack frames when logging to stdout. The full stack trace is + always written to the Ray Data log file. + raise_original_map_exception: Whether to raise the original exception + encountered in map UDF instead of wrapping it in a `UserCodeException`. + print_on_execution_start: If ``True``, print execution information when + execution starts. + s3_try_create_dir: If ``True``, try to create directories on S3 when a write + call is made with a S3 URI. + wait_for_min_actors_s: The default time to wait for minimum requested + actors to start before raising a timeout, in seconds. + retried_io_errors: A list of substrings of error messages that should + trigger a retry when reading or writing files. This is useful for handling + transient errors when reading from remote storage systems. + """ + + target_max_block_size: int = DEFAULT_TARGET_MAX_BLOCK_SIZE + target_shuffle_max_block_size: int = DEFAULT_SHUFFLE_TARGET_MAX_BLOCK_SIZE + target_min_block_size: int = DEFAULT_TARGET_MIN_BLOCK_SIZE + streaming_read_buffer_size: int = DEFAULT_STREAMING_READ_BUFFER_SIZE + enable_pandas_block: bool = DEFAULT_ENABLE_PANDAS_BLOCK + actor_prefetcher_enabled: bool = DEFAULT_ACTOR_PREFETCHER_ENABLED + use_push_based_shuffle: bool = DEFAULT_USE_PUSH_BASED_SHUFFLE + pipeline_push_based_shuffle_reduce_tasks: bool = True + scheduling_strategy: SchedulingStrategyT = DEFAULT_SCHEDULING_STRATEGY + scheduling_strategy_large_args: SchedulingStrategyT = ( + DEFAULT_SCHEDULING_STRATEGY_LARGE_ARGS + ) + large_args_threshold: int = DEFAULT_LARGE_ARGS_THRESHOLD + use_polars: bool = DEFAULT_USE_POLARS + eager_free: bool = DEFAULT_EAGER_FREE + decoding_size_estimation: bool = DEFAULT_DECODING_SIZE_ESTIMATION_ENABLED + min_parallelism: int = DEFAULT_MIN_PARALLELISM + read_op_min_num_blocks: int = DEFAULT_READ_OP_MIN_NUM_BLOCKS + enable_tensor_extension_casting: bool = DEFAULT_ENABLE_TENSOR_EXTENSION_CASTING + use_arrow_tensor_v2: bool = DEFAULT_USE_ARROW_TENSOR_V2 + enable_fallback_to_arrow_object_ext_type: Optional[bool] = None + enable_auto_log_stats: bool = DEFAULT_AUTO_LOG_STATS + verbose_stats_logs: bool = DEFAULT_VERBOSE_STATS_LOG + trace_allocations: bool = DEFAULT_TRACE_ALLOCATIONS + execution_options: "ExecutionOptions" = field( + default_factory=_execution_options_factory + ) + use_ray_tqdm: bool = DEFAULT_USE_RAY_TQDM + enable_progress_bars: bool = DEFAULT_ENABLE_PROGRESS_BARS + # By default, enable the progress bar for operator-level progress. + # In __post_init__(), we disable operator-level progress + # bars when running in a Ray job. + enable_operator_progress_bars: bool = True + enable_progress_bar_name_truncation: bool = ( + DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION + ) + enable_get_object_locations_for_metrics: bool = ( + DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS + ) + write_file_retry_on_errors: List[str] = DEFAULT_WRITE_FILE_RETRY_ON_ERRORS + warn_on_driver_memory_usage_bytes: int = DEFAULT_WARN_ON_DRIVER_MEMORY_USAGE_BYTES + actor_task_retry_on_errors: Union[ + bool, List[BaseException] + ] = DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS + op_resource_reservation_enabled: bool = DEFAULT_ENABLE_OP_RESOURCE_RESERVATION + op_resource_reservation_ratio: float = DEFAULT_OP_RESOURCE_RESERVATION_RATIO + max_errored_blocks: int = DEFAULT_MAX_ERRORED_BLOCKS + log_internal_stack_trace_to_stdout: bool = ( + DEFAULT_LOG_INTERNAL_STACK_TRACE_TO_STDOUT + ) + raise_original_map_exception: bool = DEFAULT_RAY_DATA_RAISE_ORIGINAL_MAP_EXCEPTION + print_on_execution_start: bool = True + s3_try_create_dir: bool = DEFAULT_S3_TRY_CREATE_DIR + wait_for_min_actors_s: int = DEFAULT_WAIT_FOR_MIN_ACTORS_S + retried_io_errors: List[str] = field( + default_factory=lambda: list(DEFAULT_RETRIED_IO_ERRORS) + ) + + override_object_store_memory_limit_fraction: float = None + + def __post_init__(self): + # The additonal ray remote args that should be added to + # the task-pool-based data tasks. + self._task_pool_data_task_remote_args: Dict[str, Any] = {} + # The extra key-value style configs. + # These configs are managed by individual components or plugins via + # `set_config`, `get_config` and `remove_config`. + # The reason why we use a dict instead of individual fields is to decouple + # the DataContext from the plugin implementations, as well as to avoid + # circular dependencies. + self._kv_configs: Dict[str, Any] = {} + self._max_num_blocks_in_streaming_gen_buffer = ( + DEFAULT_MAX_NUM_BLOCKS_IN_STREAMING_GEN_BUFFER + ) + + is_ray_job = os.environ.get("RAY_JOB_ID") is not None + if is_ray_job: + is_driver = ray.get_runtime_context().worker.mode != WORKER_MODE + if is_driver and log_once( + "ray_data_disable_operator_progress_bars_in_ray_jobs" + ): + logger.info( + "Disabling operator-level progress bars by default in Ray Jobs. " + "To enable progress bars for all operators, set " + "`ray.data.DataContext.get_current()" + ".enable_operator_progress_bars = True`." + ) + # Disable operator-level progress bars by default in Ray jobs. + # The global progress bar for the overall Dataset execution will + # still be enabled, unless the user also sets + # `ray.data.DataContext.get_current().enable_progress_bars = False`. + self.enable_operator_progress_bars = False + else: + # When not running in Ray job, operator-level progress + # bars are enabled by default. + self.enable_operator_progress_bars = True + + def __setattr__(self, name: str, value: Any) -> None: + if ( + name == "write_file_retry_on_errors" + and value != DEFAULT_WRITE_FILE_RETRY_ON_ERRORS + ): + warnings.warn( + "`write_file_retry_on_errors` is deprecated. Configure " + "`retried_io_errors` instead.", + DeprecationWarning, + ) + + super().__setattr__(name, value) + + @staticmethod + def get_current() -> "DataContext": + """Get or create the current DataContext. + + When a Dataset is created, the current DataContext will be sealed. + Changes to `DataContext.get_current()` will not impact existing Datasets. + + Examples: + + .. testcode:: + import ray + + context = ray.data.DataContext.get_current() + + context.target_max_block_size = 100 * 1024 ** 2 + ds1 = ray.data.range(1) + context.target_max_block_size = 1 * 1024 ** 2 + ds2 = ray.data.range(1) + + # ds1's target_max_block_size will be 100MB + ds1.take_all() + # ds2's target_max_block_size will be 1MB + ds2.take_all() + + Developer notes: Avoid using `DataContext.get_current()` in data + internal components, use the DataContext object captured in the + Dataset and pass it around as arguments. + """ + + global _default_context + + with _context_lock: + if _default_context is None: + _default_context = DataContext() + + return _default_context + + @staticmethod + def _set_current(context: "DataContext") -> None: + """Set the current context in a remote worker. + + This is used internally by Dataset to propagate the driver context to + remote workers used for parallelization. + """ + global _default_context + _default_context = context + + def get_config(self, key: str, default: Any = None) -> Any: + """Get the value for a key-value style config. + + Args: + key: The key of the config. + default: The default value to return if the key is not found. + Returns: The value for the key, or the default value if the key is not found. + """ + return self._kv_configs.get(key, default) + + def set_config(self, key: str, value: Any) -> None: + """Set the value for a key-value style config. + + Args: + key: The key of the config. + value: The value of the config. + """ + self._kv_configs[key] = value + + def remove_config(self, key: str) -> None: + """Remove a key-value style config. + + Args: + key: The key of the config. + """ + self._kv_configs.pop(key, None) + + +# Backwards compatibility alias. +DatasetContext = DataContext diff --git a/.venv/lib/python3.11/site-packages/ray/data/dataset.py b/.venv/lib/python3.11/site-packages/ray/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..82d403ec012d7aa7a8773fc21eee36c6493562cb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/dataset.py @@ -0,0 +1,5621 @@ +import collections +import copy +import html +import itertools +import logging +import time +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) + +import numpy as np + +import ray +import ray.cloudpickle as pickle +from ray._private.thirdparty.tabulate.tabulate import tabulate +from ray._private.usage import usage_lib +from ray.air.util.tensor_extensions.arrow import ( + ArrowTensorTypeV2, + get_arrow_extension_fixed_shape_tensor_types, +) +from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray +from ray.data._internal.aggregate import Max, Mean, Min, Std, Sum, Unique +from ray.data._internal.compute import ComputeStrategy +from ray.data._internal.datasource.bigquery_datasink import BigQueryDatasink +from ray.data._internal.datasource.csv_datasink import CSVDatasink +from ray.data._internal.datasource.image_datasink import ImageDatasink +from ray.data._internal.datasource.json_datasink import JSONDatasink +from ray.data._internal.datasource.mongo_datasink import MongoDatasink +from ray.data._internal.datasource.numpy_datasink import NumpyDatasink +from ray.data._internal.datasource.parquet_datasink import ParquetDatasink +from ray.data._internal.datasource.sql_datasink import SQLDatasink +from ray.data._internal.datasource.tfrecords_datasink import TFRecordDatasink +from ray.data._internal.datasource.webdataset_datasink import WebDatasetDatasink +from ray.data._internal.equalize import _equalize +from ray.data._internal.execution.interfaces import RefBundle +from ray.data._internal.execution.interfaces.ref_bundle import ( + _ref_bundles_iterator_to_block_refs_list, +) +from ray.data._internal.execution.util import memory_string +from ray.data._internal.iterator.iterator_impl import DataIteratorImpl +from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator +from ray.data._internal.logical.operators.all_to_all_operator import ( + RandomizeBlocks, + RandomShuffle, + Repartition, + Sort, +) +from ray.data._internal.logical.operators.count_operator import Count +from ray.data._internal.logical.operators.input_data_operator import InputData +from ray.data._internal.logical.operators.map_operator import ( + Filter, + FlatMap, + MapBatches, + MapRows, + Project, +) +from ray.data._internal.logical.operators.n_ary_operator import ( + Union as UnionLogicalOperator, +) +from ray.data._internal.logical.operators.n_ary_operator import Zip +from ray.data._internal.logical.operators.one_to_one_operator import Limit +from ray.data._internal.logical.operators.write_operator import Write +from ray.data._internal.logical.optimizers import LogicalPlan +from ray.data._internal.pandas_block import PandasBlockBuilder, PandasBlockSchema +from ray.data._internal.plan import ExecutionPlan +from ray.data._internal.planner.exchange.sort_task_spec import SortKey +from ray.data._internal.planner.plan_write_op import gen_datasink_write_result +from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.split import _get_num_rows, _split_at_indices +from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, StatsManager +from ray.data._internal.util import ( + AllToAllAPI, + ConsumptionAPI, + _validate_rows_per_file_args, + get_compute_strategy, +) +from ray.data.aggregate import AggregateFn +from ray.data.block import ( + VALID_BATCH_FORMATS, + Block, + BlockAccessor, + DataBatch, + DataBatchColumn, + T, + U, + UserDefinedFunction, + _apply_batch_format, + _apply_batch_size, +) +from ray.data.context import DataContext +from ray.data.datasource import Connection, Datasink, FilenameProvider +from ray.data.iterator import DataIterator +from ray.data.random_access_dataset import RandomAccessDataset +from ray.types import ObjectRef +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +from ray.widgets import Template +from ray.widgets.util import repr_with_fallback + +if TYPE_CHECKING: + import dask + import mars + import modin + import pandas + import pyarrow + import pyspark + import tensorflow as tf + import torch + import torch.utils.data + from tensorflow_metadata.proto.v0 import schema_pb2 + + from ray.data._internal.execution.interfaces import Executor, NodeIdStr + from ray.data.grouped_data import GroupedData + + +logger = logging.getLogger(__name__) + +TensorflowFeatureTypeSpec = Union[ + "tf.TypeSpec", List["tf.TypeSpec"], Dict[str, "tf.TypeSpec"] +] + +TensorFlowTensorBatchType = Union["tf.Tensor", Dict[str, "tf.Tensor"]] + +CollatedData = TypeVar("CollatedData") +TorchBatchType = Union[Dict[str, "torch.Tensor"], CollatedData] + +BT_API_GROUP = "Basic Transformations" +SSR_API_GROUP = "Sorting, Shuffling and Repartitioning" +SMD_API_GROUP = "Splitting and Merging datasets" +GGA_API_GROUP = "Grouped and Global aggregations" +CD_API_GROUP = "Consuming Data" +IOC_API_GROUP = "I/O and Conversion" +IM_API_GROUP = "Inspecting Metadata" +E_API_GROUP = "Execution" + + +@PublicAPI +class Dataset: + """A Dataset is a distributed data collection for data loading and processing. + + Datasets are distributed pipelines that produce ``ObjectRef[Block]`` outputs, + where each block holds data in Arrow format, representing a shard of the overall + data collection. The block also determines the unit of parallelism. For more + details, see :ref:`Ray Data Internals `. + + Datasets can be created in multiple ways: from synthetic data via ``range_*()`` + APIs, from existing memory data via ``from_*()`` APIs (this creates a subclass + of Dataset called ``MaterializedDataset``), or from external storage + systems such as local disk, S3, HDFS etc. via the ``read_*()`` APIs. The + (potentially processed) Dataset can be saved back to external storage systems + via the ``write_*()`` APIs. + + Examples: + .. testcode:: + :skipif: True + + import ray + # Create dataset from synthetic data. + ds = ray.data.range(1000) + # Create dataset from in-memory data. + ds = ray.data.from_items( + [{"col1": i, "col2": i * 2} for i in range(1000)] + ) + # Create dataset from external storage system. + ds = ray.data.read_parquet("s3://bucket/path") + # Save dataset back to external storage system. + ds.write_csv("s3://bucket/output") + + Dataset has two kinds of operations: transformation, which takes in Dataset + and outputs a new Dataset (e.g. :py:meth:`.map_batches()`); and consumption, + which produces values (not a data stream) as output + (e.g. :meth:`.iter_batches()`). + + Dataset transformations are lazy, with execution of the transformations being + triggered by downstream consumption. + + Dataset supports parallel processing at scale: transformations such as + :py:meth:`.map_batches()`, aggregations such as + :py:meth:`.min()`/:py:meth:`.max()`/:py:meth:`.mean()`, grouping via + :py:meth:`.groupby()`, shuffling operations such as :py:meth:`.sort()`, + :py:meth:`.random_shuffle()`, and :py:meth:`.repartition()`. + + Examples: + >>> import ray + >>> ds = ray.data.range(1000) + >>> # Transform batches (Dict[str, np.ndarray]) with map_batches(). + >>> ds.map_batches(lambda batch: {"id": batch["id"] * 2}) # doctest: +ELLIPSIS + MapBatches() + +- Dataset(num_rows=1000, schema={id: int64}) + >>> # Compute the maximum. + >>> ds.max("id") + 999 + >>> # Shuffle this dataset randomly. + >>> ds.random_shuffle() # doctest: +ELLIPSIS + RandomShuffle + +- Dataset(num_rows=1000, schema={id: int64}) + >>> # Sort it back in order. + >>> ds.sort("id") # doctest: +ELLIPSIS + Sort + +- Dataset(num_rows=1000, schema={id: int64}) + + Both unexecuted and materialized Datasets can be passed between Ray tasks and + actors without incurring a copy. Dataset supports conversion to/from several + more featureful dataframe libraries (e.g., Spark, Dask, Modin, MARS), and are also + compatible with distributed TensorFlow / PyTorch. + """ + + def __init__( + self, + plan: ExecutionPlan, + logical_plan: LogicalPlan, + ): + """Construct a Dataset (internal API). + + The constructor is not part of the Dataset API. Use the ``ray.data.*`` + read methods to construct a dataset. + """ + assert isinstance(plan, ExecutionPlan), type(plan) + usage_lib.record_library_usage("dataset") # Legacy telemetry name. + + self._plan = plan + self._logical_plan = logical_plan + self._plan.link_logical_plan(logical_plan) + + # Handle to currently running executor for this dataset. + self._current_executor: Optional["Executor"] = None + self._write_ds = None + + self._set_uuid(StatsManager.get_dataset_id_from_stats_actor()) + + @staticmethod + def copy( + ds: "Dataset", _deep_copy: bool = False, _as: Optional[type] = None + ) -> "Dataset": + if not _as: + _as = type(ds) + if _deep_copy: + return _as(ds._plan.deep_copy(), ds._logical_plan) + else: + return _as(ds._plan.copy(), ds._logical_plan) + + @PublicAPI(api_group=BT_API_GROUP) + def map( + self, + fn: UserDefinedFunction[Dict[str, Any], Dict[str, Any]], + *, + compute: Optional[ComputeStrategy] = None, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Apply the given function to each row of this dataset. + + Use this method to transform your data. To learn more, see + :ref:`Transforming rows `. + + You can use either a function or a callable class to perform the transformation. + For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses + stateful Ray actors. For more information, see + :ref:`Stateful Transforms `. + + .. tip:: + + If your transformation is vectorized like most NumPy or pandas operations, + :meth:`~Dataset.map_batches` might be faster. + + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + + Examples: + + .. testcode:: + + import os + from typing import Any, Dict + import ray + + def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]: + row["filename"] = os.path.basename(row["path"]) + return row + + ds = ( + ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple", include_paths=True) + .map(parse_filename) + ) + print(ds.schema()) + + .. testoutput:: + + Column Type + ------ ---- + image numpy.ndarray(shape=(32, 32, 3), dtype=uint8) + path string + filename string + + Time complexity: O(dataset size / parallelism) + + Args: + fn: The function to apply to each row, or a class type + that can be instantiated to create such a callable. + compute: This argument is deprecated. Use ``concurrency`` argument. + fn_args: Positional arguments to pass to ``fn`` after the first argument. + These arguments are top-level arguments to the underlying Ray task. + fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are + top-level arguments to the underlying Ray task. + fn_constructor_args: Positional arguments to pass to ``fn``'s constructor. + You can only provide this if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor. + This can only be provided if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + num_cpus: The number of CPUs to reserve for each parallel map worker. + num_gpus: The number of GPUs to reserve for each parallel map worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel map + worker. + concurrency: The number of Ray workers to use concurrently. For a fixed-sized + worker pool of size ``n``, specify ``concurrency=n``. For an autoscaling + worker pool from ``m`` to ``n`` workers, specify ``concurrency=(m, n)``. + ray_remote_args_fn: A function that returns a dictionary of remote args + passed to each map worker. The purpose of this argument is to generate + dynamic arguments for each actor/task, and will be called each time prior + to initializing the worker. Args returned from this dict will always + override the args in ``ray_remote_args``. Note: this is an advanced, + experimental feature. + ray_remote_args: Additional resource requirements to request from + Ray for each map worker. See :func:`ray.remote` for details. + + .. seealso:: + + :meth:`~Dataset.flat_map` + Call this method to create new rows from existing ones. Unlike + :meth:`~Dataset.map`, a function passed to + :meth:`~Dataset.flat_map` can return multiple rows. + + :meth:`~Dataset.map_batches` + Call this method to transform batches of data. + """ # noqa: E501 + compute = get_compute_strategy( + fn, + fn_constructor_args=fn_constructor_args, + compute=compute, + concurrency=concurrency, + ) + + if num_cpus is not None: + ray_remote_args["num_cpus"] = num_cpus + + if num_gpus is not None: + ray_remote_args["num_gpus"] = num_gpus + + plan = self._plan.copy() + map_op = MapRows( + self._logical_plan.dag, + fn, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + logical_plan = LogicalPlan(map_op, self.context) + return Dataset(plan, logical_plan) + + def _set_name(self, name: Optional[str]): + """Set the name of the dataset. + + Used as a prefix for metrics tags. + """ + self._plan._dataset_name = name + + @property + def _name(self) -> Optional[str]: + """Returns the dataset name""" + return self._plan._dataset_name + + @PublicAPI(api_group=BT_API_GROUP) + def map_batches( + self, + fn: UserDefinedFunction[DataBatch, DataBatch], + *, + batch_size: Union[int, None, Literal["default"]] = "default", + compute: Optional[ComputeStrategy] = None, + batch_format: Optional[str] = "default", + zero_copy_batch: bool = False, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Apply the given function to batches of data. + + This method is useful for preprocessing data and performing inference. To learn + more, see :ref:`Transforming batches `. + + You can use either a function or a callable class to perform the transformation. + For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses + stateful Ray actors. For more information, see + :ref:`Stateful Transforms `. + + .. tip:: + To understand the format of the input to ``fn``, call :meth:`~Dataset.take_batch` + on the dataset to get a batch in the same format as will be passed to ``fn``. + + .. tip:: + If ``fn`` doesn't mutate its input, set ``zero_copy_batch=True`` to improve + performance and decrease memory utilization. + + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + + Examples: + + Call :meth:`~Dataset.map_batches` to transform your data. + + .. testcode:: + + from typing import Dict + import numpy as np + import ray + + def add_dog_years(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + batch["age_in_dog_years"] = 7 * batch["age"] + return batch + + ds = ( + ray.data.from_items([ + {"name": "Luna", "age": 4}, + {"name": "Rory", "age": 14}, + {"name": "Scout", "age": 9}, + ]) + .map_batches(add_dog_years) + ) + ds.show() + + .. testoutput:: + + {'name': 'Luna', 'age': 4, 'age_in_dog_years': 28} + {'name': 'Rory', 'age': 14, 'age_in_dog_years': 98} + {'name': 'Scout', 'age': 9, 'age_in_dog_years': 63} + + If your function returns large objects, yield outputs in chunks. + + .. testcode:: + + from typing import Dict + import ray + import numpy as np + + def map_fn_with_large_output(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + for i in range(3): + yield {"large_output": np.ones((100, 1000))} + + ds = ( + ray.data.from_items([1]) + .map_batches(map_fn_with_large_output) + ) + + If you require stateful transfomation, + use Python callable class. Here is an example showing how to use stateful transforms to create model inference workers, without having to reload the model on each call. + + .. testcode:: + + from typing import Dict + import numpy as np + import torch + import ray + + class TorchPredictor: + + def __init__(self): + self.model = torch.nn.Identity().cuda() + self.model.eval() + + def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + inputs = torch.as_tensor(batch["data"], dtype=torch.float32).cuda() + with torch.inference_mode(): + batch["output"] = self.model(inputs).detach().cpu().numpy() + return batch + + ds = ( + ray.data.from_numpy(np.ones((32, 100))) + .map_batches( + TorchPredictor, + # Two workers with one GPU each + concurrency=2, + # Batch size is required if you're using GPUs. + batch_size=4, + num_gpus=1 + ) + ) + + To learn more, see + :ref:`End-to-end: Offline Batch Inference `. + + Args: + fn: The function or generator to apply to a record batch, or a class type + that can be instantiated to create such a callable. Note ``fn`` must be + pickle-able. + batch_size: The desired number of rows in each batch, or ``None`` to use + entire blocks as batches (blocks may contain different numbers of rows). + The actual size of the batch provided to ``fn`` may be smaller than + ``batch_size`` if ``batch_size`` doesn't evenly divide the block(s) sent + to a given map task. Default batch_size is 1024 with "default". + compute: This argument is deprecated. Use ``concurrency`` argument. + batch_format: If ``"default"`` or ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are + ``pandas.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. + zero_copy_batch: Whether ``fn`` should be provided zero-copy, read-only + batches. If this is ``True`` and no copy is required for the + ``batch_format`` conversion, the batch is a zero-copy, read-only + view on data in Ray's object store, which can decrease memory + utilization and improve performance. If this is ``False``, the batch + is writable, which requires an extra copy to guarantee. + If ``fn`` mutates its input, this needs to be ``False`` in order to + avoid "assignment destination is read-only" or "buffer source array is + read-only" errors. Default is ``False``. + fn_args: Positional arguments to pass to ``fn`` after the first argument. + These arguments are top-level arguments to the underlying Ray task. + fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are + top-level arguments to the underlying Ray task. + fn_constructor_args: Positional arguments to pass to ``fn``'s constructor. + You can only provide this if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor. + This can only be provided if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + num_cpus: The number of CPUs to reserve for each parallel map worker. + num_gpus: The number of GPUs to reserve for each parallel map worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel map worker. + concurrency: The number of Ray workers to use concurrently. For a fixed-sized + worker pool of size ``n``, specify ``concurrency=n``. For an autoscaling + worker pool from ``m`` to ``n`` workers, specify ``concurrency=(m, n)``. + ray_remote_args_fn: A function that returns a dictionary of remote args + passed to each map worker. The purpose of this argument is to generate + dynamic arguments for each actor/task, and will be called each time prior + to initializing the worker. Args returned from this dict will always + override the args in ``ray_remote_args``. Note: this is an advanced, + experimental feature. + ray_remote_args: Additional resource requirements to request from + Ray for each map worker. See :func:`ray.remote` for details. + + .. note:: + + The size of the batches provided to ``fn`` might be smaller than the + specified ``batch_size`` if ``batch_size`` doesn't evenly divide the + block(s) sent to a given map task. + + If ``batch_size`` is set and each input block is smaller than the + ``batch_size``, Ray Data will bundle up many blocks as the input for one + task, until their total size is equal to or greater than the given + ``batch_size``. + If ``batch_size`` is not set, the bundling will not be performed. Each task + will receive only one input block. + + .. seealso:: + + :meth:`~Dataset.iter_batches` + Call this function to iterate over batches of data. + + :meth:`~Dataset.take_batch` + Call this function to get a batch of data from the dataset + in the same format as will be passed to the `fn` function of + :meth:`~Dataset.map_batches`. + + :meth:`~Dataset.flat_map` + Call this method to create new records from existing ones. Unlike + :meth:`~Dataset.map`, a function passed to :meth:`~Dataset.flat_map` + can return multiple records. + + :meth:`~Dataset.map` + Call this method to transform one record at time. + + """ # noqa: E501 + use_gpus = num_gpus is not None and num_gpus > 0 + if use_gpus and (batch_size is None or batch_size == "default"): + raise ValueError( + "You must provide `batch_size` to `map_batches` when requesting GPUs. " + "The optimal batch size depends on the model, data, and GPU used. " + "We recommend using the largest batch size that doesn't result " + "in your GPU device running out of memory. You can view the GPU memory " + "usage via the Ray dashboard." + ) + + if isinstance(batch_size, int) and batch_size < 1: + raise ValueError("Batch size can't be negative or 0") + + return self._map_batches_without_batch_size_validation( + fn, + batch_size=batch_size, + compute=compute, + batch_format=batch_format, + zero_copy_batch=zero_copy_batch, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + num_cpus=num_cpus, + num_gpus=num_gpus, + concurrency=concurrency, + ray_remote_args_fn=ray_remote_args_fn, + **ray_remote_args, + ) + + def _map_batches_without_batch_size_validation( + self, + fn: UserDefinedFunction[DataBatch, DataBatch], + *, + batch_size: Union[int, None, Literal["default"]], + compute: Optional[ComputeStrategy], + batch_format: Optional[str], + zero_copy_batch: bool, + fn_args: Optional[Iterable[Any]], + fn_kwargs: Optional[Dict[str, Any]], + fn_constructor_args: Optional[Iterable[Any]], + fn_constructor_kwargs: Optional[Dict[str, Any]], + num_cpus: Optional[float], + num_gpus: Optional[float], + concurrency: Optional[Union[int, Tuple[int, int]]], + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]], + **ray_remote_args, + ): + # NOTE: The `map_groups` implementation calls `map_batches` with + # `batch_size=None`. The issue is that if you request GPUs with + # `batch_size=None`, then `map_batches` raises a value error. So, to allow users + # to call `map_groups` with GPUs, we need a separate method that doesn't + # perform batch size validation. + + compute = get_compute_strategy( + fn, + fn_constructor_args=fn_constructor_args, + compute=compute, + concurrency=concurrency, + ) + + if num_cpus is not None: + ray_remote_args["num_cpus"] = num_cpus + + if num_gpus is not None: + ray_remote_args["num_gpus"] = num_gpus + + batch_format = _apply_batch_format(batch_format) + + min_rows_per_bundled_input = None + if batch_size is not None and batch_size != "default": + # Enable blocks bundling when batch_size is specified by caller. + min_rows_per_bundled_input = batch_size + batch_size = _apply_batch_size(batch_size) + + if batch_format not in VALID_BATCH_FORMATS: + raise ValueError( + f"The batch format must be one of {VALID_BATCH_FORMATS}, got: " + f"{batch_format}" + ) + + plan = self._plan.copy() + map_batches_op = MapBatches( + self._logical_plan.dag, + fn, + batch_size=batch_size, + batch_format=batch_format, + zero_copy_batch=zero_copy_batch, + min_rows_per_bundled_input=min_rows_per_bundled_input, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + logical_plan = LogicalPlan(map_batches_op, self.context) + return Dataset(plan, logical_plan) + + @PublicAPI(api_group=BT_API_GROUP) + def add_column( + self, + col: str, + fn: Callable[ + [DataBatch], + DataBatchColumn, + ], + *, + batch_format: Optional[str] = "pandas", + compute: Optional[str] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Add the given column to the dataset. + + A function generating the new column values given the batch in pyarrow or pandas + format must be specified. This function must operate on batches of + `batch_format`. + + Examples: + + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.schema() + Column Type + ------ ---- + id int64 + + Add a new column equal to ``id * 2``. + + >>> ds.add_column("new_id", lambda df: df["id"] * 2).schema() + Column Type + ------ ---- + id int64 + new_id int64 + + Time complexity: O(dataset size / parallelism) + + Args: + col: Name of the column to add. If the name already exists, the + column is overwritten. + fn: Map function generating the column values given a batch of + records in pandas format. + batch_format: If ``"default"`` or ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are + ``pandas.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. If ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. + compute: This argument is deprecated. Use ``concurrency`` argument. + concurrency: The number of Ray workers to use concurrently. For a + fixed-sized worker pool of size ``n``, specify ``concurrency=n``. For + an autoscaling worker pool from ``m`` to ``n`` workers, specify + ``concurrency=(m, n)``. + ray_remote_args: Additional resource requirements to request from + Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See + :func:`ray.remote` for details. + """ + # Check that batch_format + accepted_batch_formats = ["pandas", "pyarrow", "numpy"] + if batch_format not in accepted_batch_formats: + raise ValueError( + f"batch_format argument must be on of {accepted_batch_formats}, " + f"got: {batch_format}" + ) + + def add_column(batch: DataBatch) -> DataBatch: + column = fn(batch) + if batch_format == "pandas": + batch.loc[:, col] = column + return batch + elif batch_format == "pyarrow": + import pyarrow as pa + + assert isinstance(column, (pa.Array, pa.ChunkedArray)), ( + f"For pyarrow batch format, the function must return a pyarrow " + f"Array, got: {type(column)}" + ) + # Historically, this method was written for pandas batch format. + # To resolve https://github.com/ray-project/ray/issues/48090, + # we also allow pyarrow batch format which is preferred but would be + # a breaking change to enforce. + + # For pyarrow, the index of the column will be -1 if it is missing in + # which case we'll want to append it + column_idx = batch.schema.get_field_index(col) + if column_idx == -1: + return batch.append_column(col, column) + else: + return batch.set_column(column_idx, col, column) + + else: + # batch format is assumed to be numpy since we checked at the + # beginning of the add_column function + assert isinstance(column, np.ndarray), ( + f"For numpy batch format, the function must return a " + f"numpy.ndarray, got: {type(column)}" + ) + batch[col] = column + return batch + + if not callable(fn): + raise ValueError("`fn` must be callable, got {}".format(fn)) + + return self.map_batches( + add_column, + batch_format=batch_format, + compute=compute, + concurrency=concurrency, + zero_copy_batch=False, + **ray_remote_args, + ) + + @PublicAPI(api_group=BT_API_GROUP) + def drop_columns( + self, + cols: List[str], + *, + compute: Optional[str] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Drop one or more columns from the dataset. + + Examples: + + >>> import ray + >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") + >>> ds.schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + variety string + >>> ds.drop_columns(["variety"]).schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + + Time complexity: O(dataset size / parallelism) + + Args: + cols: Names of the columns to drop. If any name does not exist, + an exception is raised. Column names must be unique. + compute: This argument is deprecated. Use ``concurrency`` argument. + concurrency: The number of Ray workers to use concurrently. For a fixed-sized + worker pool of size ``n``, specify ``concurrency=n``. For an autoscaling + worker pool from ``m`` to ``n`` workers, specify ``concurrency=(m, n)``. + ray_remote_args: Additional resource requirements to request from + Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See + :func:`ray.remote` for details. + """ # noqa: E501 + + if len(cols) != len(set(cols)): + raise ValueError(f"drop_columns expects unique column names, got: {cols}") + + def drop_columns(batch): + return batch.drop(cols) + + return self.map_batches( + drop_columns, + batch_format="pyarrow", + zero_copy_batch=True, + compute=compute, + concurrency=concurrency, + **ray_remote_args, + ) + + @PublicAPI(api_group=BT_API_GROUP) + def select_columns( + self, + cols: Union[str, List[str]], + *, + compute: Union[str, ComputeStrategy] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Select one or more columns from the dataset. + + Specified columns must be in the dataset schema. + + .. tip:: + If you're reading parquet files with :meth:`ray.data.read_parquet`, + you might be able to speed it up by using projection pushdown; see + :ref:`Parquet column pruning ` for details. + + Examples: + + >>> import ray + >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") + >>> ds.schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + variety string + >>> ds.select_columns(["sepal.length", "sepal.width"]).schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + + Time complexity: O(dataset size / parallelism) + + Args: + cols: Names of the columns to select. If a name isn't in the + dataset schema, an exception is raised. Columns also should be unique. + compute: This argument is deprecated. Use ``concurrency`` argument. + concurrency: The number of Ray workers to use concurrently. For a fixed-sized + worker pool of size ``n``, specify ``concurrency=n``. For an autoscaling + worker pool from ``m`` to ``n`` workers, specify ``concurrency=(m, n)``. + ray_remote_args: Additional resource requirements to request from + Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See + :func:`ray.remote` for details. + """ # noqa: E501 + if isinstance(cols, str): + cols = [cols] + elif isinstance(cols, list): + if not all(isinstance(col, str) for col in cols): + raise ValueError( + "select_columns requires all elements of 'cols' to be strings." + ) + else: + raise TypeError( + "select_columns requires 'cols' to be a string or a list of strings." + ) + + if not cols: + raise ValueError("select_columns requires at least one column to select.") + + if len(cols) != len(set(cols)): + raise ValueError( + "select_columns expected unique column names, " + f"got duplicate column names: {cols}" + ) + + # Don't feel like we really need this + from ray.data._internal.compute import TaskPoolStrategy + + compute = TaskPoolStrategy(size=concurrency) + + plan = self._plan.copy() + select_op = Project( + self._logical_plan.dag, + cols=cols, + cols_rename=None, + compute=compute, + ray_remote_args=ray_remote_args, + ) + logical_plan = LogicalPlan(select_op, self.context) + return Dataset(plan, logical_plan) + + @PublicAPI(api_group=BT_API_GROUP) + def rename_columns( + self, + names: Union[List[str], Dict[str, str]], + *, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + **ray_remote_args, + ): + """Rename columns in the dataset. + + Examples: + + >>> import ray + >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") + >>> ds.schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + variety string + + You can pass a dictionary mapping old column names to new column names. + + >>> ds.rename_columns({"variety": "category"}).schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + category string + + Or you can pass a list of new column names. + + >>> ds.rename_columns( + ... ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"] + ... ).schema() + Column Type + ------ ---- + sepal_length double + sepal_width double + petal_length double + petal_width double + variety string + + Args: + names: A dictionary that maps old column names to new column names, or a + list of new column names. + concurrency: The maximum number of Ray workers to use concurrently. + ray_remote_args: Additional resource requirements to request from + Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See + :func:`ray.remote` for details. + """ # noqa: E501 + + if isinstance(names, dict): + if not names: + raise ValueError("rename_columns received 'names' with no entries.") + + if len(names.values()) != len(set(names.values())): + raise ValueError( + f"rename_columns received duplicate values in the 'names': " + f"{names}" + ) + + if not all( + isinstance(k, str) and isinstance(v, str) for k, v in names.items() + ): + raise ValueError( + "rename_columns requires both keys and values in the 'names' " + "to be strings." + ) + + cols_rename = names + elif isinstance(names, list): + if not names: + raise ValueError( + "rename_columns requires 'names' with at least one column name." + ) + + if len(names) != len(set(names)): + raise ValueError( + f"rename_columns received duplicate values in the 'names': {names}" + ) + + if not all(isinstance(col, str) for col in names): + raise ValueError( + "rename_columns requires all elements in the 'names' to be strings." + ) + + current_names = self.schema().names + if len(current_names) != len(names): + raise ValueError( + f"rename_columns requires 'names': {names} length match current " + f"schema names: {current_names}." + ) + + cols_rename = dict(zip(current_names, names)) + else: + raise TypeError( + f"rename_columns expected names to be either List[str] or " + f"Dict[str, str], got {type(names)}." + ) + + if concurrency is not None and not isinstance(concurrency, int): + raise ValueError( + f"Expected `concurrency` to be an integer or `None`, but " + f"got {concurrency}." + ) + + # Construct the plan and project operation + from ray.data._internal.compute import TaskPoolStrategy + + compute = TaskPoolStrategy(size=concurrency) + + plan = self._plan.copy() + select_op = Project( + self._logical_plan.dag, + cols=None, + cols_rename=cols_rename, + compute=compute, + ray_remote_args=ray_remote_args, + ) + logical_plan = LogicalPlan(select_op, self.context) + return Dataset(plan, logical_plan) + + @PublicAPI(api_group=BT_API_GROUP) + def flat_map( + self, + fn: UserDefinedFunction[Dict[str, Any], List[Dict[str, Any]]], + *, + compute: Optional[ComputeStrategy] = None, + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Apply the given function to each row and then flatten results. + + Use this method if your transformation returns multiple rows for each input + row. + + You can use either a function or a callable class to perform the transformation. + For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses + stateful Ray actors. For more information, see + :ref:`Stateful Transforms `. + + .. tip:: + :meth:`~Dataset.map_batches` can also modify the number of rows. If your + transformation is vectorized like most NumPy and pandas operations, + it might be faster. + + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + + Examples: + + .. testcode:: + + from typing import Any, Dict, List + import ray + + def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: + return [row] * 2 + + print( + ray.data.range(3) + .flat_map(duplicate_row) + .take_all() + ) + + .. testoutput:: + + [{'id': 0}, {'id': 0}, {'id': 1}, {'id': 1}, {'id': 2}, {'id': 2}] + + Time complexity: O(dataset size / parallelism) + + Args: + fn: The function or generator to apply to each record, or a class type + that can be instantiated to create such a callable. + compute: This argument is deprecated. Use ``concurrency`` argument. + fn_args: Positional arguments to pass to ``fn`` after the first argument. + These arguments are top-level arguments to the underlying Ray task. + fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are + top-level arguments to the underlying Ray task. + fn_constructor_args: Positional arguments to pass to ``fn``'s constructor. + You can only provide this if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor. + This can only be provided if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + num_cpus: The number of CPUs to reserve for each parallel map worker. + num_gpus: The number of GPUs to reserve for each parallel map worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel map + worker. + concurrency: The number of Ray workers to use concurrently. For a + fixed-sized worker pool of size ``n``, specify ``concurrency=n``. + For an autoscaling worker pool from ``m`` to ``n`` workers, specify + ``concurrency=(m, n)``. + ray_remote_args_fn: A function that returns a dictionary of remote args + passed to each map worker. The purpose of this argument is to generate + dynamic arguments for each actor/task, and will be called each time + prior to initializing the worker. Args returned from this dict will + always override the args in ``ray_remote_args``. Note: this is an + advanced, experimental feature. + ray_remote_args: Additional resource requirements to request from + Ray for each map worker. See :func:`ray.remote` for details. + + .. seealso:: + + :meth:`~Dataset.map_batches` + Call this method to transform batches of data. + + :meth:`~Dataset.map` + Call this method to transform one row at time. + """ + compute = get_compute_strategy( + fn, + fn_constructor_args=fn_constructor_args, + compute=compute, + concurrency=concurrency, + ) + + if num_cpus is not None: + ray_remote_args["num_cpus"] = num_cpus + + if num_gpus is not None: + ray_remote_args["num_gpus"] = num_gpus + + plan = self._plan.copy() + op = FlatMap( + input_op=self._logical_plan.dag, + fn=fn, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @PublicAPI(api_group=BT_API_GROUP) + def filter( + self, + fn: Optional[UserDefinedFunction[Dict[str, Any], bool]] = None, + expr: Optional[str] = None, + *, + compute: Union[str, ComputeStrategy] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Filter out rows that don't satisfy the given predicate. + + You can use either a function or a callable class or an expression string to + perform the transformation. + For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses + stateful Ray actors. For more information, see + :ref:`Stateful Transforms `. + + .. tip:: + If you use the `expr` parameter with a Python expression string, Ray Data + optimizes your filter with native Arrow interfaces. + + Examples: + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.filter(expr="id <= 4").take_all() + [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}] + + Time complexity: O(dataset size / parallelism) + + Args: + fn: The predicate to apply to each row, or a class type + that can be instantiated to create such a callable. + expr: An expression string needs to be a valid Python expression that + will be converted to ``pyarrow.dataset.Expression`` type. + compute: This argument is deprecated. Use ``concurrency`` argument. + concurrency: The number of Ray workers to use concurrently. For a + fixed-sized worker pool of size ``n``, specify ``concurrency=n``. + For an autoscaling worker pool from ``m`` to ``n`` workers, specify + ``concurrency=(m, n)``. + ray_remote_args_fn: A function that returns a dictionary of remote args + passed to each map worker. The purpose of this argument is to generate + dynamic arguments for each actor/task, and will be called each time + prior to initializing the worker. Args returned from this dict will + always override the args in ``ray_remote_args``. Note: this is an + advanced, experimental feature. + ray_remote_args: Additional resource requirements to request from + Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See + :func:`ray.remote` for details. + """ + # Ensure exactly one of fn or expr is provided + resolved_expr = None + if not ((fn is None) ^ (expr is None)): + raise ValueError("Exactly one of 'fn' or 'expr' must be provided.") + elif expr is not None: + from ray.data._internal.compute import TaskPoolStrategy + from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501 + ExpressionEvaluator, + ) + + # TODO: (srinathk) bind the expression to the actual schema. + # If fn is a string, convert it to a pyarrow.dataset.Expression + # Initialize ExpressionEvaluator with valid columns, if available + evaluator = ExpressionEvaluator() + resolved_expr = evaluator.get_filters(expression=expr) + + compute = TaskPoolStrategy(size=concurrency) + else: + warnings.warn( + "Use 'expr' instead of 'fn' when possible for performant filters." + ) + + if callable(fn): + compute = get_compute_strategy( + fn=fn, + compute=compute, + concurrency=concurrency, + ) + else: + raise ValueError( + f"fn must be a UserDefinedFunction, but got " + f"{type(fn).__name__} instead." + ) + + plan = self._plan.copy() + op = Filter( + input_op=self._logical_plan.dag, + fn=fn, + filter_expr=resolved_expr, + compute=compute, + ray_remote_args_fn=ray_remote_args_fn, + ray_remote_args=ray_remote_args, + ) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @AllToAllAPI + @PublicAPI(api_group=SSR_API_GROUP) + def repartition( + self, + num_blocks: int, + *, + shuffle: bool = False, + ) -> "Dataset": + """Repartition the :class:`Dataset` into exactly this number of :ref:`blocks `. + + This method can be useful to tune the performance of your pipeline. To learn + more, see :ref:`Advanced: Performance Tips and Tuning `. + + If you're writing data to files, you can also use this method to change the + number of output files. To learn more, see + :ref:`Changing the number of output files `. + + .. note:: + + Repartition has two modes. If ``shuffle=False``, Ray Data performs the + minimal data movement needed to equalize block sizes. Otherwise, Ray Data + performs a full distributed shuffle. + + .. image:: /data/images/dataset-shuffle.svg + :align: center + + .. + https://docs.google.com/drawings/d/132jhE3KXZsf29ho1yUdPrCHB9uheHBWHJhDQMXqIVPA/edit + + Examples: + >>> import ray + >>> ds = ray.data.range(100).repartition(10).materialize() + >>> ds.num_blocks() + 10 + + Time complexity: O(dataset size / parallelism) + + Args: + num_blocks: The number of blocks. + shuffle: Whether to perform a distributed shuffle during the + repartition. When shuffle is enabled, each output block + contains a subset of data rows from each input block, which + requires all-to-all data movement. When shuffle is disabled, + output blocks are created from adjacent input blocks, + minimizing data movement. + + Returns: + The repartitioned :class:`Dataset`. + """ # noqa: E501 + plan = self._plan.copy() + op = Repartition( + self._logical_plan.dag, + num_outputs=num_blocks, + shuffle=shuffle, + ) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @AllToAllAPI + @PublicAPI(api_group=SSR_API_GROUP) + def random_shuffle( + self, + *, + seed: Optional[int] = None, + num_blocks: Optional[int] = None, + **ray_remote_args, + ) -> "Dataset": + """Randomly shuffle the rows of this :class:`Dataset`. + + .. tip:: + + This method can be slow. For better performance, try + :ref:`Iterating over batches with shuffling `. + Also, see :ref:`Optimizing shuffles `. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.random_shuffle().take(3) # doctest: +SKIP + {'id': 41}, {'id': 21}, {'id': 92}] + >>> ds.random_shuffle(seed=42).take(3) # doctest: +SKIP + {'id': 77}, {'id': 21}, {'id': 63}] + + Time complexity: O(dataset size / parallelism) + + Args: + seed: Fix the random seed to use, otherwise one is chosen + based on system randomness. + + Returns: + The shuffled :class:`Dataset`. + """ # noqa: E501 + + if num_blocks is not None: + raise DeprecationWarning( + "`num_blocks` parameter is deprecated in Ray 2.9. random_shuffle() " + "does not support to change the number of output blocks. Use " + "repartition() instead.", # noqa: E501 + ) + plan = self._plan.copy() + op = RandomShuffle( + self._logical_plan.dag, + seed=seed, + ray_remote_args=ray_remote_args, + ) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @AllToAllAPI + @PublicAPI(api_group=SSR_API_GROUP) + def randomize_block_order( + self, + *, + seed: Optional[int] = None, + ) -> "Dataset": + """Randomly shuffle the :ref:`blocks ` of this :class:`Dataset`. + + This method is useful if you :meth:`~Dataset.split` your dataset into shards and + want to randomize the data in each shard without performing a full + :meth:`~Dataset.random_shuffle`. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.take(5) + [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}] + >>> ds.randomize_block_order().take(5) # doctest: +SKIP + {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}] + + Args: + seed: Fix the random seed to use, otherwise one is chosen + based on system randomness. + + Returns: + The block-shuffled :class:`Dataset`. + """ # noqa: E501 + + plan = self._plan.copy() + op = RandomizeBlocks( + self._logical_plan.dag, + seed=seed, + ) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @PublicAPI(api_group=BT_API_GROUP) + def random_sample( + self, fraction: float, *, seed: Optional[int] = None + ) -> "Dataset": + """Returns a new :class:`Dataset` containing a random fraction of the rows. + + .. note:: + + This method returns roughly ``fraction * total_rows`` rows. An exact number + of rows isn't guaranteed. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.random_sample(0.1).count() # doctest: +SKIP + 10 + + Args: + fraction: The fraction of elements to sample. + seed: Seeds the python random pRNG generator. + + Returns: + Returns a :class:`Dataset` containing the sampled rows. + """ + import random + + import pandas as pd + import pyarrow as pa + + if self._plan.initial_num_blocks() == 0: + raise ValueError("Cannot sample from an empty Dataset.") + + if fraction < 0 or fraction > 1: + raise ValueError("Fraction must be between 0 and 1.") + + if seed is not None: + random.seed(seed) + + def random_sample(batch): + if isinstance(batch, list): + return [row for row in batch if random.random() <= fraction] + if isinstance(batch, pa.Table): + # Lets the item pass if weight generated for that item <= fraction + return batch.filter( + pa.array(random.random() <= fraction for _ in range(len(batch))) + ) + if isinstance(batch, pd.DataFrame): + return batch.sample(frac=fraction) + if isinstance(batch, np.ndarray): + return _create_possibly_ragged_ndarray( + [row for row in batch if random.random() <= fraction] + ) + raise ValueError(f"Unsupported batch type: {type(batch)}") + + return self.map_batches(random_sample, batch_format=None) + + @ConsumptionAPI + @PublicAPI(api_group=SMD_API_GROUP) + def streaming_split( + self, + n: int, + *, + equal: bool = False, + locality_hints: Optional[List["NodeIdStr"]] = None, + ) -> List[DataIterator]: + """Returns ``n`` :class:`DataIterators ` that can + be used to read disjoint subsets of the dataset in parallel. + + This method is the recommended way to consume :class:`Datasets ` for + distributed training. + + Streaming split works by delegating the execution of this :class:`Dataset` to a + coordinator actor. The coordinator pulls block references from the executed + stream, and divides those blocks among ``n`` output iterators. Iterators pull + blocks from the coordinator actor to return to their caller on ``next``. + + The returned iterators are also repeatable; each iteration will trigger a + new execution of the Dataset. There is an implicit barrier at the start of + each iteration, which means that `next` must be called on all iterators before + the iteration starts. + + .. warning:: + + Because iterators are pulling blocks from the same :class:`Dataset` + execution, if one iterator falls behind, other iterators may be stalled. + + Examples: + + .. testcode:: + + import ray + + ds = ray.data.range(100) + it1, it2 = ds.streaming_split(2, equal=True) + + Consume data from iterators in parallel. + + .. testcode:: + + @ray.remote + def consume(it): + for batch in it.iter_batches(): + pass + + ray.get([consume.remote(it1), consume.remote(it2)]) + + You can loop over the iterators multiple times (multiple epochs). + + .. testcode:: + + @ray.remote + def train(it): + NUM_EPOCHS = 2 + for _ in range(NUM_EPOCHS): + for batch in it.iter_batches(): + pass + + ray.get([train.remote(it1), train.remote(it2)]) + + The following remote function call blocks waiting for a read on ``it2`` to + start. + + .. testcode:: + :skipif: True + + ray.get(train.remote(it1)) + + Args: + n: Number of output iterators to return. + equal: If ``True``, each output iterator sees an exactly equal number + of rows, dropping data if necessary. If ``False``, some iterators may + see slightly more or less rows than others, but no data is dropped. + locality_hints: Specify the node ids corresponding to each iterator + location. Dataset will try to minimize data movement based on the + iterator output locations. This list must have length ``n``. You can + get the current node id of a task or actor by calling + ``ray.get_runtime_context().get_node_id()``. + + Returns: + The output iterator splits. These iterators are Ray-serializable and can + be freely passed to any Ray task or actor. + + .. seealso:: + + :meth:`Dataset.split` + Unlike :meth:`~Dataset.streaming_split`, :meth:`~Dataset.split` + materializes the dataset in memory. + """ + return StreamSplitDataIterator.create(self, n, equal, locality_hints) + + @ConsumptionAPI + @PublicAPI(api_group=SMD_API_GROUP) + def split( + self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None + ) -> List["MaterializedDataset"]: + """Materialize and split the dataset into ``n`` disjoint pieces. + + This method returns a list of ``MaterializedDataset`` that can be passed to Ray + Tasks and Actors and used to read the dataset rows in parallel. + + Examples: + + .. testcode:: + + @ray.remote + class Worker: + + def train(self, data_iterator): + for batch in data_iterator.iter_batches(batch_size=8): + pass + + workers = [Worker.remote() for _ in range(4)] + shards = ray.data.range(100).split(n=4, equal=True) + ray.get([w.train.remote(s) for w, s in zip(workers, shards)]) + + Time complexity: O(1) + + Args: + n: Number of child datasets to return. + equal: Whether to guarantee each split has an equal + number of records. This might drop records if the rows can't be + divided equally among the splits. + locality_hints: [Experimental] A list of Ray actor handles of size ``n``. + The system tries to co-locate the blocks of the i-th dataset + with the i-th actor to maximize data locality. + + Returns: + A list of ``n`` disjoint dataset splits. + + .. seealso:: + + :meth:`Dataset.split_at_indices` + Unlike :meth:`~Dataset.split`, which splits a dataset into approximately + equal splits, :meth:`Dataset.split_proportionately` lets you split a + dataset into different sizes. + + :meth:`Dataset.split_proportionately` + This method is equivalent to :meth:`Dataset.split_at_indices` if + you compute indices manually. + + :meth:`Dataset.streaming_split`. + Unlike :meth:`~Dataset.split`, :meth:`~Dataset.streaming_split` + doesn't materialize the dataset in memory. + """ + if n <= 0: + raise ValueError(f"The number of splits {n} is not positive.") + + # fallback to split_at_indices for equal split without locality hints. + # simple benchmarks shows spilit_at_indices yields more stable performance. + # https://github.com/ray-project/ray/pull/26641 for more context. + if equal and locality_hints is None: + count = self.count() + split_index = count // n + # we are creating n split_indices which will generate + # n + 1 splits; the last split will at most contains (n - 1) + # rows, which could be safely dropped. + split_indices = [split_index * i for i in range(1, n + 1)] + shards = self.split_at_indices(split_indices) + return shards[:n] + + if locality_hints and len(locality_hints) != n: + raise ValueError( + f"The length of locality_hints {len(locality_hints)} " + f"doesn't equal the number of splits {n}." + ) + + bundle = self._plan.execute() + # We should not free blocks since we will materialize the Datasets. + owned_by_consumer = False + stats = self._plan.stats() + block_refs, metadata = zip(*bundle.blocks) + + if locality_hints is None: + block_refs_splits = np.array_split(block_refs, n) + metadata_splits = np.array_split(metadata, n) + + split_datasets = [] + for block_refs_split, metadata_split in zip( + block_refs_splits, metadata_splits + ): + ref_bundles = [ + RefBundle([(b, m)], owns_blocks=owned_by_consumer) + for b, m in zip(block_refs_split, metadata_split) + ] + logical_plan = LogicalPlan( + InputData(input_data=ref_bundles), self.context + ) + split_datasets.append( + MaterializedDataset( + ExecutionPlan(stats), + logical_plan, + ) + ) + return split_datasets + + metadata_mapping = dict(zip(block_refs, metadata)) + + # If the locality_hints is set, we use a two-round greedy algorithm + # to co-locate the blocks with the actors based on block + # and actor's location (node_id). + # + # The split algorithm tries to allocate equally-sized blocks regardless + # of locality. Thus we first calculate the expected number of blocks + # for each split. + # + # In the first round, for each actor, we look for all blocks that + # match the actor's node_id, then allocate those matched blocks to + # this actor until we reach the limit(expected number). + # + # In the second round: fill each actor's allocation with + # remaining unallocated blocks until we reach the limit. + + def build_allocation_size_map( + num_blocks: int, actors: List[Any] + ) -> Dict[Any, int]: + """Given the total number of blocks and a list of actors, calcuate + the expected number of blocks to allocate for each actor. + """ + num_actors = len(actors) + num_blocks_per_actor = num_blocks // num_actors + num_blocks_left = num_blocks - num_blocks_per_actor * n + num_blocks_by_actor = {} + for i, actor in enumerate(actors): + num_blocks_by_actor[actor] = num_blocks_per_actor + if i < num_blocks_left: + num_blocks_by_actor[actor] += 1 + return num_blocks_by_actor + + def build_block_refs_by_node_id( + blocks: List[ObjectRef[Block]], + ) -> Dict[str, List[ObjectRef[Block]]]: + """Build the reverse index from node_id to block_refs. For + simplicity, if the block is stored on multiple nodes we + only pick the first one. + """ + block_ref_locations = ray.experimental.get_object_locations(blocks) + block_refs_by_node_id = collections.defaultdict(list) + for block_ref in blocks: + node_ids = block_ref_locations.get(block_ref, {}).get("node_ids", []) + node_id = node_ids[0] if node_ids else None + block_refs_by_node_id[node_id].append(block_ref) + return block_refs_by_node_id + + def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: + """Build a map from a actor to its node_id.""" + actors_state = ray._private.state.actors() + return { + actor: actors_state.get(actor._actor_id.hex(), {}) + .get("Address", {}) + .get("NodeID") + for actor in actors + } + + # expected number of blocks to be allocated for each actor + expected_block_count_by_actor = build_allocation_size_map( + len(block_refs), locality_hints + ) + # the reverse index from node_id to block_refs + block_refs_by_node_id = build_block_refs_by_node_id(block_refs) + # the map from actor to its node_id + node_id_by_actor = build_node_id_by_actor(locality_hints) + + allocation_per_actor = collections.defaultdict(list) + + # In the first round, for each actor, we look for all blocks that + # match the actor's node_id, then allocate those matched blocks to + # this actor until we reach the limit(expected number) + for actor in locality_hints: + node_id = node_id_by_actor[actor] + matching_blocks = block_refs_by_node_id[node_id] + expected_block_count = expected_block_count_by_actor[actor] + allocation = [] + while matching_blocks and len(allocation) < expected_block_count: + allocation.append(matching_blocks.pop()) + allocation_per_actor[actor] = allocation + + # In the second round: fill each actor's allocation with + # remaining unallocated blocks until we reach the limit + remaining_block_refs = list( + itertools.chain.from_iterable(block_refs_by_node_id.values()) + ) + for actor in locality_hints: + while ( + len(allocation_per_actor[actor]) < expected_block_count_by_actor[actor] + ): + allocation_per_actor[actor].append(remaining_block_refs.pop()) + + assert len(remaining_block_refs) == 0, len(remaining_block_refs) + + per_split_bundles = [] + for actor in locality_hints: + blocks = allocation_per_actor[actor] + metadata = [metadata_mapping[b] for b in blocks] + bundle = RefBundle( + tuple(zip(blocks, metadata)), owns_blocks=owned_by_consumer + ) + per_split_bundles.append(bundle) + + if equal: + # equalize the splits + per_split_bundles = _equalize(per_split_bundles, owned_by_consumer) + + split_datasets = [] + for bundle in per_split_bundles: + logical_plan = LogicalPlan(InputData(input_data=[bundle]), self.context) + split_datasets.append( + MaterializedDataset( + ExecutionPlan(stats), + logical_plan, + ) + ) + return split_datasets + + @ConsumptionAPI + @PublicAPI(api_group=SMD_API_GROUP) + def split_at_indices(self, indices: List[int]) -> List["MaterializedDataset"]: + """Materialize and split the dataset at the given indices (like ``np.split``). + + Examples: + >>> import ray + >>> ds = ray.data.range(10) + >>> d1, d2, d3 = ds.split_at_indices([2, 5]) + >>> d1.take_batch() + {'id': array([0, 1])} + >>> d2.take_batch() + {'id': array([2, 3, 4])} + >>> d3.take_batch() + {'id': array([5, 6, 7, 8, 9])} + + Time complexity: O(num splits) + + Args: + indices: List of sorted integers which indicate where the dataset + are split. If an index exceeds the length of the dataset, + an empty dataset is returned. + + Returns: + The dataset splits. + + .. seealso:: + + :meth:`Dataset.split` + Unlike :meth:`~Dataset.split_at_indices`, which lets you split a + dataset into different sizes, :meth:`Dataset.split` splits a dataset + into approximately equal splits. + + :meth:`Dataset.split_proportionately` + This method is equivalent to :meth:`Dataset.split_at_indices` if + you compute indices manually. + + :meth:`Dataset.streaming_split`. + Unlike :meth:`~Dataset.split`, :meth:`~Dataset.streaming_split` + doesn't materialize the dataset in memory. + """ + + if len(indices) < 1: + raise ValueError("indices must be at least of length 1") + if sorted(indices) != indices: + raise ValueError("indices must be sorted") + if indices[0] < 0: + raise ValueError("indices must be positive") + start_time = time.perf_counter() + bundle = self._plan.execute() + blocks, metadata = _split_at_indices( + bundle.blocks, + indices, + False, + ) + split_duration = time.perf_counter() - start_time + parent_stats = self._plan.stats() + splits = [] + + for bs, ms in zip(blocks, metadata): + stats = DatasetStats(metadata={"Split": ms}, parent=parent_stats) + stats.time_total_s = split_duration + ref_bundles = [ + RefBundle([(b, m)], owns_blocks=False) for b, m in zip(bs, ms) + ] + logical_plan = LogicalPlan(InputData(input_data=ref_bundles), self.context) + + splits.append( + MaterializedDataset( + ExecutionPlan(stats), + logical_plan, + ) + ) + return splits + + @ConsumptionAPI + @PublicAPI(api_group=SMD_API_GROUP) + def split_proportionately( + self, proportions: List[float] + ) -> List["MaterializedDataset"]: + """Materialize and split the dataset using proportions. + + A common use case for this is splitting the dataset into train + and test sets (equivalent to eg. scikit-learn's ``train_test_split``). + For a higher level abstraction, see :meth:`Dataset.train_test_split`. + + This method splits datasets so that all splits + always contains at least one row. If that isn't possible, + an exception is raised. + + This is equivalent to caulculating the indices manually and calling + :meth:`Dataset.split_at_indices`. + + Examples: + >>> import ray + >>> ds = ray.data.range(10) + >>> d1, d2, d3 = ds.split_proportionately([0.2, 0.5]) + >>> d1.take_batch() + {'id': array([0, 1])} + >>> d2.take_batch() + {'id': array([2, 3, 4, 5, 6])} + >>> d3.take_batch() + {'id': array([7, 8, 9])} + + Time complexity: O(num splits) + + Args: + proportions: List of proportions to split the dataset according to. + Must sum up to less than 1, and each proportion must be bigger + than 0. + + Returns: + The dataset splits. + + .. seealso:: + + :meth:`Dataset.split` + Unlike :meth:`~Dataset.split_proportionately`, which lets you split a + dataset into different sizes, :meth:`Dataset.split` splits a dataset + into approximately equal splits. + + :meth:`Dataset.split_at_indices` + :meth:`Dataset.split_proportionately` uses this method under the hood. + + :meth:`Dataset.streaming_split`. + Unlike :meth:`~Dataset.split`, :meth:`~Dataset.streaming_split` + doesn't materialize the dataset in memory. + """ + + if len(proportions) < 1: + raise ValueError("proportions must be at least of length 1") + if sum(proportions) >= 1: + raise ValueError("proportions must sum to less than 1") + if any(p <= 0 for p in proportions): + raise ValueError("proportions must be bigger than 0") + + dataset_length = self.count() + cumulative_proportions = np.cumsum(proportions) + split_indices = [ + int(dataset_length * proportion) for proportion in cumulative_proportions + ] + + # Ensure each split has at least one element + subtract = 0 + for i in range(len(split_indices) - 2, -1, -1): + split_indices[i] -= subtract + if split_indices[i] == split_indices[i + 1]: + subtract += 1 + split_indices[i] -= 1 + if any(i <= 0 for i in split_indices): + raise ValueError( + "Couldn't create non-empty splits with the given proportions." + ) + + return self.split_at_indices(split_indices) + + @ConsumptionAPI + @PublicAPI(api_group=SMD_API_GROUP) + def train_test_split( + self, + test_size: Union[int, float], + *, + shuffle: bool = False, + seed: Optional[int] = None, + ) -> Tuple["MaterializedDataset", "MaterializedDataset"]: + """Materialize and split the dataset into train and test subsets. + + Examples: + + >>> import ray + >>> ds = ray.data.range(8) + >>> train, test = ds.train_test_split(test_size=0.25) + >>> train.take_batch() + {'id': array([0, 1, 2, 3, 4, 5])} + >>> test.take_batch() + {'id': array([6, 7])} + + Args: + test_size: If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the test split. If int, + represents the absolute number of test samples. The train split + always complements the test split. + shuffle: Whether or not to globally shuffle the dataset before splitting. + Defaults to ``False``. This may be a very expensive operation with a + large dataset. + seed: Fix the random seed to use for shuffle, otherwise one is chosen + based on system randomness. Ignored if ``shuffle=False``. + + Returns: + Train and test subsets as two ``MaterializedDatasets``. + + .. seealso:: + + :meth:`Dataset.split_proportionately` + """ + ds = self + + if shuffle: + ds = ds.random_shuffle(seed=seed) + + if not isinstance(test_size, (int, float)): + raise TypeError(f"`test_size` must be int or float got {type(test_size)}.") + if isinstance(test_size, float): + if test_size <= 0 or test_size >= 1: + raise ValueError( + "If `test_size` is a float, it must be bigger than 0 and smaller " + f"than 1. Got {test_size}." + ) + return ds.split_proportionately([1 - test_size]) + else: + ds_length = ds.count() + if test_size <= 0 or test_size >= ds_length: + raise ValueError( + "If `test_size` is an int, it must be bigger than 0 and smaller " + f"than the size of the dataset ({ds_length}). " + f"Got {test_size}." + ) + return ds.split_at_indices([ds_length - test_size]) + + @PublicAPI(api_group=SMD_API_GROUP) + def union(self, *other: List["Dataset"]) -> "Dataset": + """Concatenate :class:`Datasets ` across rows. + + The order of the blocks in the datasets is preserved, as is the + relative ordering between the datasets passed in the argument list. + + .. caution:: + Unioned datasets aren't lineage-serializable. As a result, they can't be + used as a tunable hyperparameter in Ray Tune. + + Examples: + + >>> import ray + >>> ds1 = ray.data.range(2) + >>> ds2 = ray.data.range(3) + >>> ds1.union(ds2).take_all() + [{'id': 0}, {'id': 1}, {'id': 0}, {'id': 1}, {'id': 2}] + + Args: + other: List of datasets to combine with this one. The datasets + must have the same schema as this dataset, otherwise the + behavior is undefined. + + Returns: + A new dataset holding the rows of the input datasets. + """ + start_time = time.perf_counter() + + datasets = [self] + list(other) + logical_plans = [union_ds._plan._logical_plan for union_ds in datasets] + op = UnionLogicalOperator( + *[plan.dag for plan in logical_plans], + ) + logical_plan = LogicalPlan(op, self.context) + + stats = DatasetStats( + metadata={"Union": []}, + parent=[d._plan.stats() for d in datasets], + ) + stats.time_total_s = time.perf_counter() - start_time + return Dataset( + ExecutionPlan(stats), + logical_plan, + ) + + @AllToAllAPI + @PublicAPI(api_group=GGA_API_GROUP) + def groupby( + self, + key: Union[str, List[str], None], + ) -> "GroupedData": + """Group rows of a :class:`Dataset` according to a column. + + Use this method to transform data based on a + categorical variable. + + Examples: + + .. testcode:: + + import pandas as pd + import ray + + def normalize_variety(group: pd.DataFrame) -> pd.DataFrame: + for feature in group.drop("variety").columns: + group[feature] = group[feature] / group[feature].abs().max() + return group + + ds = ( + ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") + .groupby("variety") + .map_groups(normalize_variety, batch_format="pandas") + ) + + Time complexity: O(dataset size * log(dataset size / parallelism)) + + Args: + key: A column name or list of column names. + If this is ``None``, place all rows in a single group. + + Returns: + A lazy :class:`~ray.data.grouped_data.GroupedData`. + + .. seealso:: + + :meth:`~ray.data.grouped_data.GroupedData.map_groups` + Call this method to transform groups of data. + """ + from ray.data.grouped_data import GroupedData + + # Always allow None since groupby interprets that as grouping all + # records into a single global group. + if key is not None: + # Fetching the schema can trigger execution, so don't fetch it for + # input validation. + SortKey(key).validate_schema(self.schema(fetch_if_missing=False)) + + return GroupedData(self, key) + + @AllToAllAPI + @ConsumptionAPI + @PublicAPI(api_group=GGA_API_GROUP) + def unique(self, column: str) -> List[Any]: + """List the unique elements in a given column. + + Examples: + + >>> import ray + >>> ds = ray.data.from_items([1, 2, 3, 2, 3]) + >>> ds.unique("item") + [1, 2, 3] + + This function is very useful for computing labels + in a machine learning dataset: + + >>> import ray + >>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") + >>> ds.unique("target") + [0, 1, 2] + + One common use case is to convert the class labels + into integers for training and inference: + + >>> classes = {0: 'Setosa', 1: 'Versicolor', 2: 'Virginica'} + >>> def preprocessor(df, classes): + ... df["variety"] = df["target"].map(classes) + ... return df + >>> train_ds = ds.map_batches( + ... preprocessor, fn_kwargs={"classes": classes}, batch_format="pandas") + >>> train_ds.sort("sepal length (cm)").take(1) # Sort to make it deterministic + [{'sepal length (cm)': 4.3, ..., 'variety': 'Setosa'}] + + Time complexity: O(dataset size / parallelism) + + Args: + column: The column to collect unique elements over. + + Returns: + A list with unique elements in the given column. + """ # noqa: E501 + ret = self._aggregate_on(Unique, column) + return self._aggregate_result(ret) + + @AllToAllAPI + @ConsumptionAPI + @PublicAPI(api_group=GGA_API_GROUP) + def aggregate(self, *aggs: AggregateFn) -> Union[Any, Dict[str, Any]]: + """Aggregate values using one or more functions. + + Use this method to compute metrics like the product of a column. + + Examples: + + .. testcode:: + + import ray + from ray.data.aggregate import AggregateFn + + ds = ray.data.from_items([{"number": i} for i in range(1, 10)]) + aggregation = AggregateFn( + init=lambda column: 1, + # Apply this to each row to produce a partial aggregate result + accumulate_row=lambda a, row: a * row["number"], + # Apply this to merge partial aggregate results into a final result + merge=lambda a1, a2: a1 * a2, + name="prod" + ) + print(ds.aggregate(aggregation)) + + .. testoutput:: + + {'prod': 362880} + + Time complexity: O(dataset size / parallelism) + + Args: + *aggs: :class:`Aggregations ` to perform. + + Returns: + A ``dict`` where each each value is an aggregation for a given column. + """ + ret = self.groupby(None).aggregate(*aggs).take(1) + return ret[0] if len(ret) > 0 else None + + @AllToAllAPI + @ConsumptionAPI + @PublicAPI(api_group=GGA_API_GROUP) + def sum( + self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True + ) -> Union[Any, Dict[str, Any]]: + """Compute the sum of one or more columns. + + Examples: + >>> import ray + >>> ray.data.range(100).sum("id") + 4950 + >>> ray.data.from_items([ + ... {"A": i, "B": i**2} + ... for i in range(100) + ... ]).sum(["A", "B"]) + {'sum(A)': 4950, 'sum(B)': 328350} + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values are ignored when computing the sum. If ``False``, + when a null value is encountered, the output is ``None``. + Ray Data considers ``np.nan``, ``None``, and ``pd.NaT`` to be null + values. Default is ``True``. + + Returns: + The sum result. + + For different values of ``on``, the return varies: + + - ``on=None``: a dict containing the column-wise sum of all + columns, + - ``on="col"``: a scalar representing the sum of all items in + column ``"col"``, + - ``on=["col_1", ..., "col_n"]``: an n-column ``dict`` + containing the column-wise sum of the provided columns. + + If the dataset is empty, all values are null. If ``ignore_nulls`` is + ``False`` and any value is null, then the output is ``None``. + """ + ret = self._aggregate_on(Sum, on, ignore_nulls=ignore_nulls) + return self._aggregate_result(ret) + + @AllToAllAPI + @ConsumptionAPI + @PublicAPI(api_group=GGA_API_GROUP) + def min( + self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True + ) -> Union[Any, Dict[str, Any]]: + """Return the minimum of one or more columns. + + Examples: + >>> import ray + >>> ray.data.range(100).min("id") + 0 + >>> ray.data.from_items([ + ... {"A": i, "B": i**2} + ... for i in range(100) + ... ]).min(["A", "B"]) + {'min(A)': 0, 'min(B)': 0} + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values are ignored when computing the min; if ``False``, + when a null value is encountered, the output is ``None``. + This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null + values. Default is ``True``. + + Returns: + The min result. + + For different values of ``on``, the return varies: + + - ``on=None``: an dict containing the column-wise min of + all columns, + - ``on="col"``: a scalar representing the min of all items in + column ``"col"``, + - ``on=["col_1", ..., "col_n"]``: an n-column dict + containing the column-wise min of the provided columns. + + If the dataset is empty, all values are null. If ``ignore_nulls`` is + ``False`` and any value is null, then the output is ``None``. + """ + ret = self._aggregate_on(Min, on, ignore_nulls=ignore_nulls) + return self._aggregate_result(ret) + + @AllToAllAPI + @ConsumptionAPI + @PublicAPI(api_group=GGA_API_GROUP) + def max( + self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True + ) -> Union[Any, Dict[str, Any]]: + """Return the maximum of one or more columns. + + Examples: + >>> import ray + >>> ray.data.range(100).max("id") + 99 + >>> ray.data.from_items([ + ... {"A": i, "B": i**2} + ... for i in range(100) + ... ]).max(["A", "B"]) + {'max(A)': 99, 'max(B)': 9801} + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values are ignored when computing the max; if ``False``, + when a null value is encountered, the output is ``None``. + This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null + values. Default is ``True``. + + Returns: + The max result. + + For different values of ``on``, the return varies: + + - ``on=None``: an dict containing the column-wise max of + all columns, + - ``on="col"``: a scalar representing the max of all items in + column ``"col"``, + - ``on=["col_1", ..., "col_n"]``: an n-column dict + containing the column-wise max of the provided columns. + + If the dataset is empty, all values are null. If ``ignore_nulls`` is + ``False`` and any value is null, then the output is ``None``. + """ + ret = self._aggregate_on(Max, on, ignore_nulls=ignore_nulls) + return self._aggregate_result(ret) + + @AllToAllAPI + @ConsumptionAPI + @PublicAPI(api_group=GGA_API_GROUP) + def mean( + self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True + ) -> Union[Any, Dict[str, Any]]: + """Compute the mean of one or more columns. + + Examples: + >>> import ray + >>> ray.data.range(100).mean("id") + 49.5 + >>> ray.data.from_items([ + ... {"A": i, "B": i**2} + ... for i in range(100) + ... ]).mean(["A", "B"]) + {'mean(A)': 49.5, 'mean(B)': 3283.5} + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values are ignored when computing the mean; if ``False``, + when a null value is encountered, the output is ``None``. + This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null + values. Default is ``True``. + + Returns: + The mean result. + + For different values of ``on``, the return varies: + + - ``on=None``: an dict containing the column-wise mean of + all columns, + - ``on="col"``: a scalar representing the mean of all items in + column ``"col"``, + - ``on=["col_1", ..., "col_n"]``: an n-column dict + containing the column-wise mean of the provided columns. + + If the dataset is empty, all values are null. If ``ignore_nulls`` is + ``False`` and any value is null, then the output is ``None``. + """ + ret = self._aggregate_on(Mean, on, ignore_nulls=ignore_nulls) + return self._aggregate_result(ret) + + @AllToAllAPI + @ConsumptionAPI + @PublicAPI(api_group=GGA_API_GROUP) + def std( + self, + on: Optional[Union[str, List[str]]] = None, + ddof: int = 1, + ignore_nulls: bool = True, + ) -> Union[Any, Dict[str, Any]]: + """Compute the standard deviation of one or more columns. + + .. note:: + This method uses Welford's online method for an accumulator-style + computation of the standard deviation. This method has + numerical stability, and is computable in a single pass. This may give + different (but more accurate) results than NumPy, Pandas, and sklearn, which + use a less numerically stable two-pass algorithm. + To learn more, see + `the Wikapedia article `_. + + Examples: + >>> import ray + >>> round(ray.data.range(100).std("id", ddof=0), 5) + 28.86607 + >>> ray.data.from_items([ + ... {"A": i, "B": i**2} + ... for i in range(100) + ... ]).std(["A", "B"]) + {'std(A)': 29.011491975882016, 'std(B)': 2968.1748039269296} + + Args: + on: a column name or a list of column names to aggregate. + ddof: Delta Degrees of Freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + ignore_nulls: Whether to ignore null values. If ``True``, null + values are ignored when computing the std; if ``False``, + when a null value is encountered, the output is ``None``. + This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null + values. Default is ``True``. + + Returns: + The standard deviation result. + + For different values of ``on``, the return varies: + + - ``on=None``: an dict containing the column-wise std of + all columns, + - ``on="col"``: a scalar representing the std of all items in + column ``"col"``, + - ``on=["col_1", ..., "col_n"]``: an n-column dict + containing the column-wise std of the provided columns. + + If the dataset is empty, all values are null. If ``ignore_nulls`` is + ``False`` and any value is null, then the output is ``None``. + """ # noqa: E501 + ret = self._aggregate_on(Std, on, ignore_nulls=ignore_nulls, ddof=ddof) + return self._aggregate_result(ret) + + @AllToAllAPI + @PublicAPI(api_group=SSR_API_GROUP) + def sort( + self, + key: Union[str, List[str]], + descending: Union[bool, List[bool]] = False, + boundaries: List[Union[int, float]] = None, + ) -> "Dataset": + """Sort the dataset by the specified key column or key function. + The `key` parameter must be specified (i.e., it cannot be `None`). + + .. note:: + If provided, the `boundaries` parameter can only be used to partition + the first sort key. + + Examples: + >>> import ray + >>> ds = ray.data.range(15) + >>> ds = ds.sort("id", descending=False, boundaries=[5, 10]) + >>> for df in ray.get(ds.to_pandas_refs()): + ... print(df) + id + 0 0 + 1 1 + 2 2 + 3 3 + 4 4 + id + 0 5 + 1 6 + 2 7 + 3 8 + 4 9 + id + 0 10 + 1 11 + 2 12 + 3 13 + 4 14 + + Time complexity: O(dataset size * log(dataset size / parallelism)) + + Args: + key: The column or a list of columns to sort by. + descending: Whether to sort in descending order. Must be a boolean or a list + of booleans matching the number of the columns. + boundaries: The list of values based on which to repartition the dataset. + For example, if the input boundary is [10,20], rows with values less + than 10 will be divided into the first block, rows with values greater + than or equal to 10 and less than 20 will be divided into the + second block, and rows with values greater than or equal to 20 + will be divided into the third block. If not provided, the + boundaries will be sampled from the input blocks. This feature + only supports numeric columns right now. + + Returns: + A new, sorted :class:`Dataset`. + + Raises: + ``ValueError``: if the sort key is None. + """ + if key is None: + raise ValueError("The 'key' parameter cannot be None for sorting.") + sort_key = SortKey(key, descending, boundaries) + plan = self._plan.copy() + op = Sort( + self._logical_plan.dag, + sort_key=sort_key, + ) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @PublicAPI(api_group=SMD_API_GROUP) + def zip(self, other: "Dataset") -> "Dataset": + """Zip the columns of this dataset with the columns of another. + + The datasets must have the same number of rows. Their column sets are + merged, and any duplicate column names are disambiguated with suffixes like + ``"_1"``. + + .. note:: + The smaller of the two datasets is repartitioned to align the number + of rows per block with the larger dataset. + + .. note:: + Zipped datasets aren't lineage-serializable. As a result, they can't be used + as a tunable hyperparameter in Ray Tune. + + Examples: + >>> import ray + >>> ds1 = ray.data.range(5) + >>> ds2 = ray.data.range(5) + >>> ds1.zip(ds2).take_batch() + {'id': array([0, 1, 2, 3, 4]), 'id_1': array([0, 1, 2, 3, 4])} + + Args: + other: The dataset to zip with on the right hand side. + + Returns: + A :class:`Dataset` containing the columns of the second dataset + concatenated horizontally with the columns of the first dataset, + with duplicate column names disambiguated with suffixes like ``"_1"``. + """ + plan = self._plan.copy() + op = Zip(self._logical_plan.dag, other._logical_plan.dag) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @PublicAPI(api_group=BT_API_GROUP) + def limit(self, limit: int) -> "Dataset": + """Truncate the dataset to the first ``limit`` rows. + + Unlike :meth:`~Dataset.take`, this method doesn't move data to the caller's + machine. Instead, it returns a new :class:`Dataset` pointing to the truncated + distributed data. + + Examples: + >>> import ray + >>> ds = ray.data.range(1000) + >>> ds.limit(5).count() + 5 + + Time complexity: O(limit specified) + + Args: + limit: The size of the dataset to truncate to. + + Returns: + The truncated dataset. + """ + plan = self._plan.copy() + op = Limit(self._logical_plan.dag, limit=limit) + logical_plan = LogicalPlan(op, self.context) + return Dataset(plan, logical_plan) + + @ConsumptionAPI + @PublicAPI(api_group=CD_API_GROUP) + def take_batch( + self, batch_size: int = 20, *, batch_format: Optional[str] = "default" + ) -> DataBatch: + """Return up to ``batch_size`` rows from the :class:`Dataset` in a batch. + + Ray Data represents batches as NumPy arrays or pandas DataFrames. You can + configure the batch type by specifying ``batch_format``. + + This method is useful for inspecting inputs to :meth:`~Dataset.map_batches`. + + .. warning:: + + :meth:`~Dataset.take_batch` moves up to ``batch_size`` rows to the caller's + machine. If ``batch_size`` is large, this method can cause an ` + ``OutOfMemory`` error on the caller. + + Examples: + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.take_batch(5) + {'id': array([0, 1, 2, 3, 4])} + + Time complexity: O(batch_size specified) + + Args: + batch_size: The maximum number of rows to return. + batch_format: If ``"default"`` or ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are + ``pandas.DataFrame``. + + Returns: + A batch of up to ``batch_size`` rows from the dataset. + + Raises: + ``ValueError``: if the dataset is empty. + """ + batch_format = _apply_batch_format(batch_format) + limited_ds = self.limit(batch_size) + + try: + res = next( + iter( + limited_ds.iter_batches( + batch_size=batch_size, + prefetch_batches=0, + batch_format=batch_format, + ) + ) + ) + except StopIteration: + raise ValueError("The dataset is empty.") + self._synchronize_progress_bar() + + # Save the computed stats to the original dataset. + self._plan._snapshot_stats = limited_ds._plan.stats() + return res + + @ConsumptionAPI + @PublicAPI(api_group=CD_API_GROUP) + def take(self, limit: int = 20) -> List[Dict[str, Any]]: + """Return up to ``limit`` rows from the :class:`Dataset`. + + This method is useful for inspecting data. + + .. warning:: + + :meth:`~Dataset.take` moves up to ``limit`` rows to the caller's machine. If + ``limit`` is large, this method can cause an ``OutOfMemory`` error on the + caller. + + Examples: + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.take(3) + [{'id': 0}, {'id': 1}, {'id': 2}] + + Time complexity: O(limit specified) + + Args: + limit: The maximum number of rows to return. + + Returns: + A list of up to ``limit`` rows from the dataset. + + .. seealso:: + + :meth:`~Dataset.take_all` + Call this method to return all rows. + """ + if ray.util.log_once("dataset_take"): + logger.info( + "Tip: Use `take_batch()` instead of `take() / show()` to return " + "records in pandas or numpy batch format." + ) + output = [] + + limited_ds = self.limit(limit) + for row in limited_ds.iter_rows(): + output.append(row) + if len(output) >= limit: + break + self._synchronize_progress_bar() + + # Save the computed stats to the original dataset. + self._plan._snapshot_stats = limited_ds._plan.stats() + return output + + @ConsumptionAPI + @PublicAPI(api_group=CD_API_GROUP) + def take_all(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: + """Return all of the rows in this :class:`Dataset`. + + This method is useful for inspecting small datasets. + + .. warning:: + + :meth:`~Dataset.take_all` moves the entire dataset to the caller's + machine. If the dataset is large, this method can cause an + ``OutOfMemory`` error on the caller. + + Examples: + >>> import ray + >>> ds = ray.data.range(5) + >>> ds.take_all() + [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}] + + Time complexity: O(dataset size) + + Args: + limit: Raise an error if the size exceeds the specified limit. + + Returns: + A list of all the rows in the dataset. + + .. seealso:: + + :meth:`~Dataset.take` + Call this method to return a specific number of rows. + """ + output = [] + for row in self.iter_rows(): + output.append(row) + if limit is not None and len(output) > limit: + raise ValueError( + f"The dataset has more than the given limit of {limit} records." + ) + self._synchronize_progress_bar() + return output + + @ConsumptionAPI + @PublicAPI(api_group=CD_API_GROUP) + def show(self, limit: int = 20) -> None: + """Print up to the given number of rows from the :class:`Dataset`. + + This method is useful for inspecting data. + + Examples: + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.show(3) + {'id': 0} + {'id': 1} + {'id': 2} + + Time complexity: O(limit specified) + + Args: + limit: The maximum number of row to print. + + .. seealso:: + + :meth:`~Dataset.take` + Call this method to get (not print) a given number of rows. + """ + for row in self.take(limit): + print(row) + + @ConsumptionAPI( + if_more_than_read=True, + datasource_metadata="row count", + pattern="Examples:", + ) + @PublicAPI(api_group=IM_API_GROUP) + def count(self) -> int: + """Count the number of rows in the dataset. + + For Datasets which only read Parquet files (created with + :meth:`~ray.data.read_parquet`), this method reads the file metadata to + efficiently count the number of rows without reading in the entire data. + + Examples: + >>> import ray + >>> ds = ray.data.range(10) + >>> ds.count() + 10 + + Returns: + The number of records in the dataset. + """ + # Handle empty dataset. + if self._plan.initial_num_blocks() == 0: + return 0 + + # For parquet, we can return the count directly from metadata. + meta_count = self._meta_count() + if meta_count is not None: + return meta_count + + plan = self._plan.copy() + count_op = Count([self._logical_plan.dag]) + logical_plan = LogicalPlan(count_op, self.context) + count_ds = Dataset(plan, logical_plan) + + count = 0 + for batch in count_ds.iter_batches(batch_size=None): + assert Count.COLUMN_NAME in batch, ( + "Outputs from the 'Count' logical operator should contain a column " + f"named '{Count.COLUMN_NAME}'" + ) + count += batch[Count.COLUMN_NAME].sum() + # Explicitly cast to int to avoid returning `np.int64`, which is the result + # from calculating `sum()` from numpy batches. + return int(count) + + @ConsumptionAPI( + if_more_than_read=True, + datasource_metadata="schema", + extra_condition="or if ``fetch_if_missing=True`` (the default)", + pattern="Time complexity:", + ) + @PublicAPI(api_group=IM_API_GROUP) + def schema(self, fetch_if_missing: bool = True) -> Optional["Schema"]: + """Return the schema of the dataset. + + Examples: + >>> import ray + >>> ds = ray.data.range(10) + >>> ds.schema() + Column Type + ------ ---- + id int64 + + Time complexity: O(1) + + Args: + fetch_if_missing: If True, synchronously fetch the schema if it's + not known. If False, None is returned if the schema is not known. + Default is True. + + Returns: + The :class:`ray.data.Schema` class of the records, or None if the + schema is not known and fetch_if_missing is False. + """ + + context = self._plan._context + + # First check if the schema is already known from materialized blocks. + base_schema = self._plan.schema(fetch_if_missing=False) + if base_schema is not None: + return Schema(base_schema, data_context=context) + + # Lazily execute only the first block to minimize computation. We achieve this + # by appending a Limit[1] operation to a copy of this Dataset, which we then + # execute to get its schema. + base_schema = self.limit(1)._plan.schema(fetch_if_missing=fetch_if_missing) + if base_schema is not None: + self._plan.cache_schema(base_schema) + return Schema(base_schema, data_context=context) + else: + return None + + @ConsumptionAPI( + if_more_than_read=True, + datasource_metadata="schema", + extra_condition="or if ``fetch_if_missing=True`` (the default)", + pattern="Time complexity:", + ) + @PublicAPI(api_group=IM_API_GROUP) + def columns(self, fetch_if_missing: bool = True) -> Optional[List[str]]: + """Returns the columns of this Dataset. + + Time complexity: O(1) + + Example: + >>> import ray + >>> # Create dataset from synthetic data. + >>> ds = ray.data.range(1000) + >>> ds.columns() + ['id'] + + Args: + fetch_if_missing: If True, synchronously fetch the column names from the + schema if it's not known. If False, None is returned if the schema is + not known. Default is True. + + Returns: + A list of the column names for this Dataset or None if schema is not known + and `fetch_if_missing` is False. + + """ + schema = self.schema(fetch_if_missing=fetch_if_missing) + if schema is not None: + return schema.names + return None + + @PublicAPI(api_group=IM_API_GROUP) + def num_blocks(self) -> int: + """Return the number of blocks of this :class:`Dataset`. + + This method is only implemented for :class:`~ray.data.MaterializedDataset`, + since the number of blocks may dynamically change during execution. + For instance, during read and transform operations, Ray Data may dynamically + adjust the number of blocks to respect memory limits, increasing the + number of blocks at runtime. + + Returns: + The number of blocks of this :class:`Dataset`. + """ + raise NotImplementedError( + "Number of blocks is only available for `MaterializedDataset`," + "because the number of blocks may dynamically change during execution." + "Call `ds.materialize()` to get a `MaterializedDataset`." + ) + + @ConsumptionAPI + @PublicAPI(api_group=IM_API_GROUP) + def size_bytes(self) -> int: + """Return the in-memory size of the dataset. + + Examples: + >>> import ray + >>> ds = ray.data.range(10) + >>> ds.size_bytes() + 80 + + Returns: + The in-memory size of the dataset in bytes, or None if the + in-memory size is not known. + """ + # If the size is known from metadata, return it. + if self._logical_plan.dag.aggregate_output_metadata().size_bytes is not None: + return self._logical_plan.dag.aggregate_output_metadata().size_bytes + + metadata = self._plan.execute().metadata + if not metadata or metadata[0].size_bytes is None: + return None + return sum(m.size_bytes for m in metadata) + + @ConsumptionAPI + @PublicAPI(api_group=IM_API_GROUP) + def input_files(self) -> List[str]: + """Return the list of input files for the dataset. + + Examples: + >>> import ray + >>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") + >>> ds.input_files() + ['ray-example-data/iris.csv'] + + Returns: + The list of input files used to create the dataset, or an empty + list if the input files is not known. + """ + return list(set(self._plan.input_files())) + + @ConsumptionAPI + @PublicAPI(api_group=IOC_API_GROUP) + def write_parquet( + self, + path: str, + *, + partition_cols: Optional[List[str]] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + min_rows_per_file: Optional[int] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, + **arrow_parquet_args, + ) -> None: + """Writes the :class:`~ray.data.Dataset` to parquet files under the provided ``path``. + + The number of files is determined by the number of blocks in the dataset. + To control the number of number of blocks, call + :meth:`~ray.data.Dataset.repartition`. + + If pyarrow can't represent your data, this method errors. + + By default, the format of the output files is ``{uuid}_{block_idx}.parquet``, + where ``uuid`` is a unique id for the dataset. To modify this behavior, + implement a custom :class:`~ray.data.datasource.FilenameProvider` and pass it in + as the ``filename_provider`` argument. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.write_parquet("local:///tmp/data/") + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where + parquet files are written to. + partition_cols: Column names by which to partition the dataset. + Files are writted in Hive partition style. + filesystem: The pyarrow filesystem implementation to write to. + These filesystems are specified in the + `pyarrow docs `_. + Specify this if you need to provide specific configurations to the + filesystem. By default, the filesystem is automatically selected based + on the scheme of the paths. For example, if the path begins with + ``s3://``, the ``S3FileSystem`` is used. + try_create_dir: If ``True``, attempts to create all directories in the + destination path. Does nothing if all directories already + exist. Defaults to ``True``. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_output_stream `_, which is used when + opening the file to write to. + filename_provider: A :class:`~ray.data.datasource.FilenameProvider` + implementation. Use this parameter to customize what your filenames + look like. + arrow_parquet_args_fn: Callable that returns a dictionary of write + arguments that are provided to `pyarrow.parquet.write_table() `_ + when writing each block to a file. Overrides + any duplicate keys from ``arrow_parquet_args``. Use this argument + instead of ``arrow_parquet_args`` if any of your write arguments + can't pickled, or if you'd like to lazily resolve the write + arguments for each dataset block. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. + ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. + arrow_parquet_args: Options to pass to + `pyarrow.parquet.write_table() `_, which is used to write out each + block to a file. + """ # noqa: E501 + if arrow_parquet_args_fn is None: + arrow_parquet_args_fn = lambda: {} # noqa: E731 + + if partition_cols and (num_rows_per_file or min_rows_per_file): + raise ValueError( + "Cannot pass num_rows_per_file or min_rows_per_file when partition_cols " + "argument is specified" + ) + + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + + datasink = ParquetDatasink( + path, + partition_cols=partition_cols, + arrow_parquet_args_fn=arrow_parquet_args_fn, + arrow_parquet_args=arrow_parquet_args, + min_rows_per_file=effective_min_rows, # Pass through to datasink + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + filename_provider=filename_provider, + dataset_uuid=self._uuid, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI + @PublicAPI(api_group=IOC_API_GROUP) + def write_json( + self, + path: str, + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + pandas_json_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + min_rows_per_file: Optional[int] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, + **pandas_json_args, + ) -> None: + """Writes the :class:`~ray.data.Dataset` to JSON and JSONL files. + + The number of files is determined by the number of blocks in the dataset. + To control the number of number of blocks, call + :meth:`~ray.data.Dataset.repartition`. + + This method is only supported for datasets with records that are convertible to + pandas dataframes. + + By default, the format of the output files is ``{uuid}_{block_idx}.json``, + where ``uuid`` is a unique id for the dataset. To modify this behavior, + implement a custom :class:`~ray.data.datasource.FilenameProvider` and pass it in + as the ``filename_provider`` argument. + + Examples: + Write the dataset as JSON file to a local directory. + + >>> import ray + >>> import pandas as pd + >>> ds = ray.data.from_pandas([pd.DataFrame({"one": [1], "two": ["a"]})]) + >>> ds.write_json("local:///tmp/data") + + Write the dataset as JSONL files to a local directory. + + >>> ds = ray.data.read_json("s3://anonymous@ray-example-data/train.jsonl") + >>> ds.write_json("local:///tmp/data") + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where + the JSON files are written to. + filesystem: The pyarrow filesystem implementation to write to. + These filesystems are specified in the + `pyarrow docs `_. + Specify this if you need to provide specific configurations to the + filesystem. By default, the filesystem is automatically selected based + on the scheme of the paths. For example, if the path begins with + ``s3://``, the ``S3FileSystem`` is used. + try_create_dir: If ``True``, attempts to create all directories in the + destination path. Does nothing if all directories already + exist. Defaults to ``True``. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_output_stream `_, which is used when + opening the file to write to. + filename_provider: A :class:`~ray.data.datasource.FilenameProvider` + implementation. Use this parameter to customize what your filenames + look like. + pandas_json_args_fn: Callable that returns a dictionary of write + arguments that are provided to + `pandas.DataFrame.to_json() `_ + when writing each block to a file. Overrides + any duplicate keys from ``pandas_json_args``. Use this parameter + instead of ``pandas_json_args`` if any of your write arguments + can't be pickled, or if you'd like to lazily resolve the write + arguments for each dataset block. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + num_rows_per_file: Deprecated. Use ``min_rows_per_file`` instead. + pandas_json_args: These args are passed to + `pandas.DataFrame.to_json() `_, + which is used under the hood to write out each + :class:`~ray.data.Dataset` block. These + are dict(orient="records", lines=True) by default. + """ + if pandas_json_args_fn is None: + pandas_json_args_fn = lambda: {} # noqa: E731 + + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + + datasink = JSONDatasink( + path, + pandas_json_args_fn=pandas_json_args_fn, + pandas_json_args=pandas_json_args, + min_rows_per_file=effective_min_rows, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + filename_provider=filename_provider, + dataset_uuid=self._uuid, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @PublicAPI(stability="alpha", api_group=IOC_API_GROUP) + @ConsumptionAPI + def write_images( + self, + path: str, + column: str, + file_format: str = "png", + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + ) -> None: + """Writes the :class:`~ray.data.Dataset` to images. + + Examples: + >>> import ray + >>> ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple") + >>> ds.write_images("local:///tmp/images", column="image") + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where + the images are written to. + column: The column containing the data you want to write to images. + file_format: The image file format to write with. For available options, + see `Image file formats `_. + filesystem: The pyarrow filesystem implementation to write to. + These filesystems are specified in the + `pyarrow docs `_. + Specify this if you need to provide specific configurations to the + filesystem. By default, the filesystem is automatically selected based + on the scheme of the paths. For example, if the path begins with + ``s3://``, the ``S3FileSystem`` is used. + try_create_dir: If ``True``, attempts to create all directories in the + destination path. Does nothing if all directories already + exist. Defaults to ``True``. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_output_stream `_, which is used when + opening the file to write to. + filename_provider: A :class:`~ray.data.datasource.FilenameProvider` + implementation. Use this parameter to customize what your filenames + look like. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + """ # noqa: E501 + datasink = ImageDatasink( + path, + column, + file_format, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + filename_provider=filename_provider, + dataset_uuid=self._uuid, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI + @PublicAPI(api_group=IOC_API_GROUP) + def write_csv( + self, + path: str, + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + arrow_csv_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + min_rows_per_file: Optional[int] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, + **arrow_csv_args, + ) -> None: + """Writes the :class:`~ray.data.Dataset` to CSV files. + + The number of files is determined by the number of blocks in the dataset. + To control the number of number of blocks, call + :meth:`~ray.data.Dataset.repartition`. + + This method is only supported for datasets with records that are convertible to + pyarrow tables. + + By default, the format of the output files is ``{uuid}_{block_idx}.csv``, + where ``uuid`` is a unique id for the dataset. To modify this behavior, + implement a custom :class:`~ray.data.datasource.FilenameProvider` + and pass it in as the ``filename_provider`` argument. + + + Examples: + Write the dataset as CSV files to a local directory. + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.write_csv("local:///tmp/data") + + Write the dataset as CSV files to S3. + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.write_csv("s3://bucket/folder/) # doctest: +SKIP + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where + the CSV files are written to. + filesystem: The pyarrow filesystem implementation to write to. + These filesystems are specified in the + `pyarrow docs `_. + Specify this if you need to provide specific configurations to the + filesystem. By default, the filesystem is automatically selected based + on the scheme of the paths. For example, if the path begins with + ``s3://``, the ``S3FileSystem`` is used. + try_create_dir: If ``True``, attempts to create all directories in the + destination path if ``True``. Does nothing if all directories already + exist. Defaults to ``True``. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_output_stream `_, which is used when + opening the file to write to. + filename_provider: A :class:`~ray.data.datasource.FilenameProvider` + implementation. Use this parameter to customize what your filenames + look like. + arrow_csv_args_fn: Callable that returns a dictionary of write + arguments that are provided to `pyarrow.write.write_csv `_ when writing each + block to a file. Overrides any duplicate keys from ``arrow_csv_args``. + Use this argument instead of ``arrow_csv_args`` if any of your write + arguments cannot be pickled, or if you'd like to lazily resolve the + write arguments for each dataset block. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. + arrow_csv_args: Options to pass to `pyarrow.write.write_csv `_ + when writing each block to a file. + """ + if arrow_csv_args_fn is None: + arrow_csv_args_fn = lambda: {} # noqa: E731 + + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + + datasink = CSVDatasink( + path, + arrow_csv_args_fn=arrow_csv_args_fn, + arrow_csv_args=arrow_csv_args, + min_rows_per_file=effective_min_rows, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + filename_provider=filename_provider, + dataset_uuid=self._uuid, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI + @PublicAPI(api_group=IOC_API_GROUP) + def write_tfrecords( + self, + path: str, + *, + tf_schema: Optional["schema_pb2.Schema"] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + min_rows_per_file: Optional[int] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, + ) -> None: + """Write the :class:`~ray.data.Dataset` to TFRecord files. + + The `TFRecord `_ + files contain + `tf.train.Example `_ + records, with one Example record for each row in the dataset. + + .. warning:: + tf.train.Feature only natively stores ints, floats, and bytes, + so this function only supports datasets with these data types, + and will error if the dataset contains unsupported types. + + The number of files is determined by the number of blocks in the dataset. + To control the number of number of blocks, call + :meth:`~ray.data.Dataset.repartition`. + + This method is only supported for datasets with records that are convertible to + pyarrow tables. + + By default, the format of the output files is ``{uuid}_{block_idx}.tfrecords``, + where ``uuid`` is a unique id for the dataset. To modify this behavior, + implement a custom :class:`~ray.data.datasource.FilenameProvider` + and pass it in as the ``filename_provider`` argument. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.write_tfrecords("local:///tmp/data/") + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where tfrecords + files are written to. + filesystem: The pyarrow filesystem implementation to write to. + These filesystems are specified in the + `pyarrow docs `_. + Specify this if you need to provide specific configurations to the + filesystem. By default, the filesystem is automatically selected based + on the scheme of the paths. For example, if the path begins with + ``s3://``, the ``S3FileSystem`` is used. + try_create_dir: If ``True``, attempts to create all directories in the + destination path. Does nothing if all directories already + exist. Defaults to ``True``. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_output_stream `_, which is used when + opening the file to write to. + filename_provider: A :class:`~ray.data.datasource.FilenameProvider` + implementation. Use this parameter to customize what your filenames + look like. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. + """ + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + + datasink = TFRecordDatasink( + path=path, + tf_schema=tf_schema, + min_rows_per_file=effective_min_rows, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + filename_provider=filename_provider, + dataset_uuid=self._uuid, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI + @PublicAPI(stability="alpha", api_group=IOC_API_GROUP) + def write_webdataset( + self, + path: str, + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + min_rows_per_file: Optional[int] = None, + ray_remote_args: Dict[str, Any] = None, + encoder: Optional[Union[bool, str, callable, list]] = True, + concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, + ) -> None: + """Writes the dataset to `WebDataset `_ files. + + The `TFRecord `_ + files will contain + `tf.train.Example `_ # noqa: E501 + records, with one Example record for each row in the dataset. + + .. warning:: + tf.train.Feature only natively stores ints, floats, and bytes, + so this function only supports datasets with these data types, + and will error if the dataset contains unsupported types. + + This is only supported for datasets convertible to Arrow records. + To control the number of files, use :meth:`Dataset.repartition`. + + Unless a custom filename provider is given, the format of the output + files is ``{uuid}_{block_idx}.tfrecords``, where ``uuid`` is a unique id + for the dataset. + + Examples: + + .. testcode:: + :skipif: True + + import ray + + ds = ray.data.range(100) + ds.write_webdataset("s3://bucket/folder/") + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where tfrecords + files are written to. + filesystem: The filesystem implementation to write to. + try_create_dir: If ``True``, attempts to create all + directories in the destination path. Does nothing if all directories + already exist. Defaults to ``True``. + arrow_open_stream_args: kwargs passed to + ``pyarrow.fs.FileSystem.open_output_stream`` + filename_provider: A :class:`~ray.data.datasource.FilenameProvider` + implementation. Use this parameter to customize what your filenames + look like. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. + ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. + """ + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + + datasink = WebDatasetDatasink( + path, + encoder=encoder, + min_rows_per_file=effective_min_rows, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + filename_provider=filename_provider, + dataset_uuid=self._uuid, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI + @PublicAPI(api_group=IOC_API_GROUP) + def write_numpy( + self, + path: str, + *, + column: str, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + filename_provider: Optional[FilenameProvider] = None, + min_rows_per_file: Optional[int] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, + ) -> None: + """Writes a column of the :class:`~ray.data.Dataset` to .npy files. + + This is only supported for columns in the datasets that can be converted to + NumPy arrays. + + The number of files is determined by the number of blocks in the dataset. + To control the number of number of blocks, call + :meth:`~ray.data.Dataset.repartition`. + + + By default, the format of the output files is ``{uuid}_{block_idx}.npy``, + where ``uuid`` is a unique id for the dataset. To modify this behavior, + implement a custom :class:`~ray.data.datasource.FilenameProvider` + and pass it in as the ``filename_provider`` argument. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.write_numpy("local:///tmp/data/", column="id") + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where + the npy files are written to. + column: The name of the column that contains the data to + be written. + filesystem: The pyarrow filesystem implementation to write to. + These filesystems are specified in the + `pyarrow docs `_. + Specify this if you need to provide specific configurations to the + filesystem. By default, the filesystem is automatically selected based + on the scheme of the paths. For example, if the path begins with + ``s3://``, the ``S3FileSystem`` is used. + try_create_dir: If ``True``, attempts to create all directories in + destination path. Does nothing if all directories already + exist. Defaults to ``True``. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_output_stream `_, which is used when + opening the file to write to. + filename_provider: A :class:`~ray.data.datasource.FilenameProvider` + implementation. Use this parameter to customize what your filenames + look like. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. + """ + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + + datasink = NumpyDatasink( + path, + column, + min_rows_per_file=effective_min_rows, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + filename_provider=filename_provider, + dataset_uuid=self._uuid, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI + def write_sql( + self, + sql: str, + connection_factory: Callable[[], Connection], + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + ) -> None: + """Write to a database that provides a + `Python DB API2-compliant `_ connector. + + .. note:: + + This method writes data in parallel using the DB API2 ``executemany`` + method. To learn more about this method, see + `PEP 249 `_. + + Examples: + + .. testcode:: + + import sqlite3 + import ray + + connection = sqlite3.connect("example.db") + connection.cursor().execute("CREATE TABLE movie(title, year, score)") + dataset = ray.data.from_items([ + {"title": "Monty Python and the Holy Grail", "year": 1975, "score": 8.2}, + {"title": "And Now for Something Completely Different", "year": 1971, "score": 7.5} + ]) + + dataset.write_sql( + "INSERT INTO movie VALUES(?, ?, ?)", lambda: sqlite3.connect("example.db") + ) + + result = connection.cursor().execute("SELECT * FROM movie ORDER BY year") + print(result.fetchall()) + + .. testoutput:: + + [('And Now for Something Completely Different', 1971, 7.5), ('Monty Python and the Holy Grail', 1975, 8.2)] + + .. testcode:: + :hide: + + import os + os.remove("example.db") + + Arguments: + sql: An ``INSERT INTO`` statement that specifies the table to write to. The + number of parameters must match the number of columns in the table. + connection_factory: A function that takes no arguments and returns a + Python DB API2 + `Connection object `_. + ray_remote_args: Keyword arguments passed to :func:`ray.remote` in the + write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + """ # noqa: E501 + datasink = SQLDatasink(sql=sql, connection_factory=connection_factory) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @PublicAPI(stability="alpha", api_group=IOC_API_GROUP) + @ConsumptionAPI + def write_mongo( + self, + uri: str, + database: str, + collection: str, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + ) -> None: + """Writes the :class:`~ray.data.Dataset` to a MongoDB database. + + This method is only supported for datasets convertible to pyarrow tables. + + The number of parallel writes is determined by the number of blocks in the + dataset. To control the number of number of blocks, call + :meth:`~ray.data.Dataset.repartition`. + + .. warning:: + This method supports only a subset of the PyArrow's types, due to the + limitation of pymongoarrow which is used underneath. Writing unsupported + types fails on type checking. See all the supported types at: + https://mongo-arrow.readthedocs.io/en/latest/data_types.html. + + .. note:: + The records are inserted into MongoDB as new documents. If a record has + the _id field, this _id must be non-existent in MongoDB, otherwise the write + is rejected and fail (hence preexisting documents are protected from + being mutated). It's fine to not have _id field in record and MongoDB will + auto generate one at insertion. + + Examples: + + .. testcode:: + :skipif: True + + import ray + + ds = ray.data.range(100) + ds.write_mongo( + uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin", + database="my_db", + collection="my_collection" + ) + + Args: + uri: The URI to the destination MongoDB where the dataset is + written to. For the URI format, see details in the + `MongoDB docs `_. + database: The name of the database. This database must exist otherwise + a ValueError is raised. + collection: The name of the collection in the database. This collection + must exist otherwise a ValueError is raised. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + + Raises: + ValueError: if ``database`` doesn't exist. + ValueError: if ``collection`` doesn't exist. + """ + datasink = MongoDatasink( + uri=uri, + database=database, + collection=collection, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI + def write_bigquery( + self, + project_id: str, + dataset: str, + max_retry_cnt: int = 10, + overwrite_table: Optional[bool] = True, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + ) -> None: + """Write the dataset to a BigQuery dataset table. + + To control the number of parallel write tasks, use ``.repartition()`` + before calling this method. + + Examples: + .. testcode:: + :skipif: True + + import ray + import pandas as pd + + docs = [{"title": "BigQuery Datasource test"} for key in range(4)] + ds = ray.data.from_pandas(pd.DataFrame(docs)) + ds.write_bigquery( + project_id="my_project_id", + dataset="my_dataset_table", + overwrite_table=True + ) + + Args: + project_id: The name of the associated Google Cloud Project that hosts + the dataset to read. For more information, see details in + `Creating and managing projects `_. + dataset: The name of the dataset in the format of ``dataset_id.table_id``. + The dataset is created if it doesn't already exist. + max_retry_cnt: The maximum number of retries that an individual block write + is retried due to BigQuery rate limiting errors. This isn't + related to Ray fault tolerance retries. The default number of retries + is 10. + overwrite_table: Whether the write will overwrite the table if it already + exists. The default behavior is to overwrite the table. + ``overwrite_table=False`` will append to the table if it exists. + ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + """ # noqa: E501 + if ray_remote_args is None: + ray_remote_args = {} + + # Each write task will launch individual remote tasks to write each block + # To avoid duplicate block writes, the write task should not be retried + if ray_remote_args.get("max_retries", 0) != 0: + warnings.warn( + "The max_retries of a BigQuery Write Task should be set to 0" + " to avoid duplicate writes." + ) + else: + ray_remote_args["max_retries"] = 0 + + datasink = BigQueryDatasink( + project_id=project_id, + dataset=dataset, + max_retry_cnt=max_retry_cnt, + overwrite_table=overwrite_table, + ) + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + + @ConsumptionAPI(pattern="Time complexity:") + def write_datasink( + self, + datasink: Datasink, + *, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + ) -> None: + """Writes the dataset to a custom :class:`~ray.data.Datasink`. + + Time complexity: O(dataset size / parallelism) + + Args: + datasink: The :class:`~ray.data.Datasink` to write to. + ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + """ # noqa: E501 + if ray_remote_args is None: + ray_remote_args = {} + + if not datasink.supports_distributed_writes: + if ray.util.client.ray.is_connected(): + raise ValueError( + "If you're using Ray Client, Ray Data won't schedule write tasks " + "on the driver's node." + ) + ray_remote_args["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), + soft=False, + ) + + plan = self._plan.copy() + write_op = Write( + self._logical_plan.dag, + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + logical_plan = LogicalPlan(write_op, self.context) + + try: + + datasink.on_write_start() + + self._write_ds = Dataset(plan, logical_plan).materialize() + # TODO: Get and handle the blocks with an iterator instead of getting + # everything in a blocking way, so some blocks can be freed earlier. + raw_write_results = ray.get(self._write_ds._plan.execute().block_refs) + write_result = gen_datasink_write_result(raw_write_results) + logger.info( + "Data sink %s finished. %d rows and %s data written.", + datasink.get_name(), + write_result.num_rows, + memory_string(write_result.size_bytes), + ) + datasink.on_write_complete(write_result) + + except Exception as e: + datasink.on_write_failed(e) + raise + + @ConsumptionAPI( + delegate=( + "Calling any of the consumption methods on the returned ``DataIterator``" + ), + pattern="Returns:", + ) + @PublicAPI(api_group=CD_API_GROUP) + def iterator(self) -> DataIterator: + """Return a :class:`~ray.data.DataIterator` over this dataset. + + Don't call this method directly. Use it internally. + + Returns: + A :class:`~ray.data.DataIterator` over this dataset. + """ + return DataIteratorImpl(self) + + @ConsumptionAPI + @PublicAPI(api_group=CD_API_GROUP) + def iter_rows(self) -> Iterable[Dict[str, Any]]: + """Return an iterable over the rows in this dataset. + + Examples: + >>> import ray + >>> for row in ray.data.range(3).iter_rows(): + ... print(row) + {'id': 0} + {'id': 1} + {'id': 2} + + Time complexity: O(1) + + Returns: + An iterable over the rows in this dataset. + """ + return self.iterator().iter_rows() + + @ConsumptionAPI + @PublicAPI(api_group=CD_API_GROUP) + def iter_batches( + self, + *, + prefetch_batches: int = 1, + batch_size: Optional[int] = 256, + batch_format: Optional[str] = "default", + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + _collate_fn: Optional[Callable[[DataBatch], CollatedData]] = None, + ) -> Iterable[DataBatch]: + """Return an iterable over batches of data. + + This method is useful for model training. + + Examples: + + .. testcode:: + + import ray + + ds = ray.data.read_images("example://image-datasets/simple") + + for batch in ds.iter_batches(batch_size=2, batch_format="numpy"): + print(batch) + + .. testoutput:: + :options: +MOCK + + {'image': array([[[[...]]]], dtype=uint8)} + ... + {'image': array([[[[...]]]], dtype=uint8)} + + Time complexity: O(1) + + Args: + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool is used + to fetch the objects to the local node and format the batches. Defaults + to 1. + batch_size: The number of rows in each batch, or ``None`` to use entire + blocks as batches (blocks may contain different numbers of rows). + The final batch may include fewer than ``batch_size`` rows if + ``drop_last`` is ``False``. Defaults to 256. + batch_format: If ``"default"`` or ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are + ``pandas.DataFrame``. + drop_last: Whether to drop the last batch if it's incomplete. + local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled + using a local in-memory shuffle buffer, and this value serves as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer are drained. + local_shuffle_seed: The seed to use for the local random shuffle. + + Returns: + An iterable over batches of data. + """ + batch_format = _apply_batch_format(batch_format) + return self.iterator().iter_batches( + prefetch_batches=prefetch_batches, + batch_size=batch_size, + batch_format=batch_format, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + _collate_fn=_collate_fn, + ) + + @ConsumptionAPI + @PublicAPI(api_group=CD_API_GROUP) + def iter_torch_batches( + self, + *, + prefetch_batches: int = 1, + batch_size: Optional[int] = 256, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: str = "auto", + collate_fn: Optional[Callable[[Dict[str, np.ndarray]], CollatedData]] = None, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + ) -> Iterable[TorchBatchType]: + """Return an iterable over batches of data represented as Torch tensors. + + This iterable yields batches of type ``Dict[str, torch.Tensor]``. + For more flexibility, call :meth:`~Dataset.iter_batches` and manually convert + your data to Torch tensors. + + Examples: + >>> import ray + >>> for batch in ray.data.range( + ... 12, + ... ).iter_torch_batches(batch_size=4): + ... print(batch) + {'id': tensor([0, 1, 2, 3])} + {'id': tensor([4, 5, 6, 7])} + {'id': tensor([ 8, 9, 10, 11])} + + Use the ``collate_fn`` to customize how the tensor batch is created. + + >>> from typing import Any, Dict + >>> import torch + >>> import numpy as np + >>> import ray + >>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any: + ... return torch.stack( + ... [torch.as_tensor(array) for array in batch.values()], + ... axis=1 + ... ) + >>> dataset = ray.data.from_items([ + ... {"col_1": 1, "col_2": 2}, + ... {"col_1": 3, "col_2": 4}]) + >>> for batch in dataset.iter_torch_batches(collate_fn=collate_fn): + ... print(batch) + tensor([[1, 2], + [3, 4]]) + + + Time complexity: O(1) + + Args: + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool is used + to fetch the objects to the local node, format the batches, and apply + the ``collate_fn``. Defaults to 1. + batch_size: The number of rows in each batch, or ``None`` to use entire + blocks as batches (blocks may contain different number of rows). + The final batch may include fewer than ``batch_size`` rows if + ``drop_last`` is ``False``. Defaults to 256. + dtypes: The Torch dtype(s) for the created tensor(s); if ``None``, the dtype + is inferred from the tensor data. You can't use this parameter with + ``collate_fn``. + device: The device on which the tensor should be placed. Defaults to + "auto" which moves the tensors to the appropriate device when the + Dataset is passed to Ray Train and ``collate_fn`` is not provided. + Otherwise, defaults to CPU. You can't use this parameter with + ``collate_fn``. + collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. + When this parameter is specified, the user should manually handle the + host to device data transfer outside of collate_fn. + This is useful for further processing the data after it has been + batched. Potential use cases include collating along a dimension other + than the first, padding sequences of various lengths, or generally + handling batches of different length tensors. If not provided, the + default collate function is used which simply converts the batch of + numpy arrays to a batch of PyTorch tensors. This API is still + experimental and is subject to change. You can't use this parameter in + conjunction with ``dtypes`` or ``device``. + drop_last: Whether to drop the last batch if it's incomplete. + local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled + using a local in-memory shuffle buffer, and this value serves as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer are drained. + ``batch_size`` must also be specified when using local shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + + Returns: + An iterable over Torch Tensor batches. + + .. seealso:: + :meth:`Dataset.iter_batches` + Call this method to manually convert your data to Torch tensors. + """ # noqa: E501 + return self.iterator().iter_torch_batches( + prefetch_batches=prefetch_batches, + batch_size=batch_size, + dtypes=dtypes, + device=device, + collate_fn=collate_fn, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + ) + + @ConsumptionAPI + @Deprecated + def iter_tf_batches( + self, + *, + prefetch_batches: int = 1, + batch_size: Optional[int] = 256, + dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + ) -> Iterable[TensorFlowTensorBatchType]: + """Return an iterable over batches of data represented as TensorFlow tensors. + + This iterable yields batches of type ``Dict[str, tf.Tensor]``. + For more flexibility, call :meth:`~Dataset.iter_batches` and manually convert + your data to TensorFlow tensors. + + .. tip:: + If you don't need the additional flexibility provided by this method, + consider using :meth:`~ray.data.Dataset.to_tf` instead. It's easier + to use. + + Examples: + + .. testcode:: + + import ray + + ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") + + tf_dataset = ds.to_tf( + feature_columns="sepal length (cm)", + label_columns="target", + batch_size=2 + ) + for features, labels in tf_dataset: + print(features, labels) + + .. testoutput:: + + tf.Tensor([5.1 4.9], shape=(2,), dtype=float64) tf.Tensor([0 0], shape=(2,), dtype=int64) + ... + tf.Tensor([6.2 5.9], shape=(2,), dtype=float64) tf.Tensor([2 2], shape=(2,), dtype=int64) + + Time complexity: O(1) + + Args: + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool is used + to fetch the objects to the local node, format the batches, and apply + the ``collate_fn``. Defaults to 1. + batch_size: The number of rows in each batch, or ``None`` to use entire + blocks as batches (blocks may contain different numbers of rows). + The final batch may include fewer than ``batch_size`` rows if + ``drop_last`` is ``False``. Defaults to 256. + dtypes: The TensorFlow dtype(s) for the created tensor(s); if ``None``, the + dtype is inferred from the tensor data. + drop_last: Whether to drop the last batch if it's incomplete. + local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled + using a local in-memory shuffle buffer, and this value serves as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer are drained. + ``batch_size`` must also be specified when using local shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + + Returns: + An iterable over TensorFlow Tensor batches. + + .. seealso:: + :meth:`Dataset.iter_batches` + Call this method to manually convert your data to TensorFlow tensors. + """ # noqa: E501 + warnings.warn( + "`iter_tf_batches` is deprecated and will be removed after May 2025. Use " + "`to_tf` instead.", + DeprecationWarning, + ) + return self.iterator().iter_tf_batches( + prefetch_batches=prefetch_batches, + batch_size=batch_size, + dtypes=dtypes, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + ) + + @ConsumptionAPI(pattern="Time complexity:") + @Deprecated + def to_torch( + self, + *, + label_column: Optional[str] = None, + feature_columns: Optional[ + Union[List[str], List[List[str]], Dict[str, List[str]]] + ] = None, + label_column_dtype: Optional["torch.dtype"] = None, + feature_column_dtypes: Optional[ + Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]] + ] = None, + batch_size: int = 1, + prefetch_batches: int = 1, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + unsqueeze_label_tensor: bool = True, + unsqueeze_feature_tensors: bool = True, + ) -> "torch.utils.data.IterableDataset": + """Return a + `Torch IterableDataset `_ + over this :class:`~ray.data.Dataset`. + + This is only supported for datasets convertible to Arrow records. + + It is recommended to use the returned ``IterableDataset`` directly + instead of passing it into a torch ``DataLoader``. + + Each element in ``IterableDataset`` is a tuple consisting of 2 + elements. The first item contains the feature tensor(s), and the + second item is the label tensor. Those can take on different + forms, depending on the specified arguments. + + For the features tensor (N is the ``batch_size`` and n, m, k + are the number of features per tensor): + + * If ``feature_columns`` is a ``List[str]``, the features is + a tensor of shape (N, n), with columns corresponding to + ``feature_columns`` + + * If ``feature_columns`` is a ``List[List[str]]``, the features is + a list of tensors of shape [(N, m),...,(N, k)], with columns of each + tensor corresponding to the elements of ``feature_columns`` + + * If ``feature_columns`` is a ``Dict[str, List[str]]``, the features + is a dict of key-tensor pairs of shape + {key1: (N, m),..., keyN: (N, k)}, with columns of each + tensor corresponding to the value of ``feature_columns`` under the + key. + + If ``unsqueeze_label_tensor=True`` (default), the label tensor is + of shape (N, 1). Otherwise, it is of shape (N,). + If ``label_column`` is specified as ``None``, then no column from the + ``Dataset`` is treated as the label, and the output label tensor + is ``None``. + + Note that you probably want to call :meth:`Dataset.split` on this dataset if + there are to be multiple Torch workers consuming the data. + + Time complexity: O(1) + + Args: + label_column: The name of the column used as the + label (second element of the output list). Can be None for + prediction, in which case the second element of returned + tuple will also be None. + feature_columns: The names of the columns + to use as the features. Can be a list of lists or + a dict of string-list pairs for multi-tensor output. + If ``None``, then use all columns except the label column as + the features. + label_column_dtype: The torch dtype to + use for the label column. If ``None``, then automatically infer + the dtype. + feature_column_dtypes: The dtypes to use for the feature + tensors. This should match the format of ``feature_columns``, + or be a single dtype, in which case it is applied to + all tensors. If ``None``, then automatically infer the dtype. + batch_size: How many samples per batch to yield at a time. + Defaults to 1. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool is used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 1. + drop_last: Set to True to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If + False and the size of the stream is not divisible by the batch + size, then the last batch is smaller. Defaults to False. + local_shuffle_buffer_size: If non-None, the data is randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer is drained. This + buffer size must be greater than or equal to ``batch_size``, and + therefore ``batch_size`` must also be specified when using local + shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + unsqueeze_label_tensor: If set to True, the label tensor + is unsqueezed (reshaped to (N, 1)). Otherwise, it will + be left as is, that is (N, ). In general, regression loss + functions expect an unsqueezed tensor, while classification + loss functions expect a squeezed one. Defaults to True. + unsqueeze_feature_tensors: If set to True, the features tensors + are unsqueezed (reshaped to (N, 1)) before being concatenated into + the final features tensor. Otherwise, they are left as is, that is + (N, ). Defaults to True. + + Returns: + A `Torch IterableDataset`_. + """ # noqa: E501 + warnings.warn( + "`to_torch` is deprecated and will be removed after May 2025. Use " + "`iter_torch_batches` instead.", + DeprecationWarning, + ) + return self.iterator().to_torch( + label_column=label_column, + feature_columns=feature_columns, + label_column_dtype=label_column_dtype, + feature_column_dtypes=feature_column_dtypes, + batch_size=batch_size, + prefetch_batches=prefetch_batches, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + unsqueeze_label_tensor=unsqueeze_label_tensor, + unsqueeze_feature_tensors=unsqueeze_feature_tensors, + ) + + @ConsumptionAPI + @PublicAPI(api_group=IOC_API_GROUP) + def to_tf( + self, + feature_columns: Union[str, List[str]], + label_columns: Union[str, List[str]], + *, + additional_columns: Union[str, List[str]] = None, + prefetch_batches: int = 1, + batch_size: int = 1, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None, + label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None, + additional_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None, + ) -> "tf.data.Dataset": + """Return a `TensorFlow Dataset `_ + over this :class:`~ray.data.Dataset`. + + .. warning:: + If your :class:`~ray.data.Dataset` contains ragged tensors, this method errors. + To prevent errors, :ref:`resize your tensors `. + + Examples: + >>> import ray + >>> ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") + >>> ds + Dataset( + num_rows=?, + schema={ + sepal length (cm): double, + sepal width (cm): double, + petal length (cm): double, + petal width (cm): double, + target: int64 + } + ) + + If your model accepts a single tensor as input, specify a single feature column. + + >>> ds.to_tf(feature_columns="sepal length (cm)", label_columns="target") + <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))> + + If your model accepts a dictionary as input, specify a list of feature columns. + + >>> ds.to_tf(["sepal length (cm)", "sepal width (cm)"], "target") + <_OptionsDataset element_spec=({'sepal length (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), 'sepal width (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal width (cm)')}, TensorSpec(shape=(None,), dtype=tf.int64, name='target'))> + + If your dataset contains multiple features but your model accepts a single + tensor as input, combine features with + :class:`~ray.data.preprocessors.Concatenator`. + + >>> from ray.data.preprocessors import Concatenator + >>> columns_to_concat = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"] + >>> preprocessor = Concatenator(columns=columns_to_concat, output_column_name="features") + >>> ds = preprocessor.transform(ds) + >>> ds + Concatenator + +- Dataset( + num_rows=?, + schema={ + sepal length (cm): double, + sepal width (cm): double, + petal length (cm): double, + petal width (cm): double, + target: int64 + } + ) + >>> ds.to_tf("features", "target") + <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))> + + If your model accepts different types, shapes, or names of tensors as input, specify the type spec. + If type specs are not specified, they are automatically inferred from the schema of the dataset. + + >>> import tensorflow as tf + >>> ds.to_tf( + ... feature_columns="features", + ... label_columns="target", + ... feature_type_spec=tf.TensorSpec(shape=(None, 4), dtype=tf.float32, name="features"), + ... label_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="label") + ... ) + <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name='features'), TensorSpec(shape=(None,), dtype=tf.float32, name='label'))> + + If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns. + A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``. + + >>> import pandas as pd + >>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df))) + >>> ds.to_tf(feature_columns="features", label_columns="target", additional_columns="sample weights") + <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))> + + If your model accepts different types, shapes, or names for the additional metadata, specify the type spec of the additional column. + + >>> ds.to_tf( + ... feature_columns="features", + ... label_columns="target", + ... additional_columns="sample weights", + ... additional_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="weight") + ... ) + <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.float32, name='weight'))> + + Args: + feature_columns: Columns that correspond to model inputs. If this is a + string, the input data is a tensor. If this is a list, the input data + is a ``dict`` that maps column names to their tensor representation. + label_columns: Columns that correspond to model targets. If this is a + string, the target data is a tensor. If this is a list, the target data + is a ``dict`` that maps column names to their tensor representation. + additional_columns: Columns that correspond to sample weights or other metadata. + If this is a string, the weight data is a tensor. If this is a list, the + weight data is a ``dict`` that maps column names to their tensor representation. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool is used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 1. + batch_size: Record batch size. Defaults to 1. + drop_last: Set to True to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If + False and the size of the stream is not divisible by the batch + size, then the last batch is smaller. Defaults to False. + local_shuffle_buffer_size: If non-None, the data is randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer is drained. This + buffer size must be greater than or equal to ``batch_size``, and + therefore ``batch_size`` must also be specified when using local + shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is + only one column, specify a `tf.TypeSpec`. If there are multiple columns, + specify a ``dict`` that maps column names to their `tf.TypeSpec`. + Default is `None` to automatically infer the type of each column. + label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is + only one column, specify a `tf.TypeSpec`. If there are multiple columns, + specify a ``dict`` that maps column names to their `tf.TypeSpec`. + Default is `None` to automatically infer the type of each column. + additional_type_spec: The `tf.TypeSpec` of `additional_columns`. If there + is only one column, specify a `tf.TypeSpec`. If there are multiple + columns, specify a ``dict`` that maps column names to their `tf.TypeSpec`. + Default is `None` to automatically infer the type of each column. + + Returns: + A `TensorFlow Dataset`_ that yields inputs and targets. + + .. seealso:: + + :meth:`~ray.data.Dataset.iter_tf_batches` + Call this method if you need more flexibility. + """ # noqa: E501 + + return self.iterator().to_tf( + feature_columns=feature_columns, + label_columns=label_columns, + additional_columns=additional_columns, + prefetch_batches=prefetch_batches, + drop_last=drop_last, + batch_size=batch_size, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + feature_type_spec=feature_type_spec, + label_type_spec=label_type_spec, + additional_type_spec=additional_type_spec, + ) + + @ConsumptionAPI(pattern="Time complexity:") + @PublicAPI(api_group=IOC_API_GROUP) + def to_dask( + self, + meta: Union[ + "pandas.DataFrame", + "pandas.Series", + Dict[str, Any], + Iterable[Any], + Tuple[Any], + None, + ] = None, + verify_meta: bool = True, + ) -> "dask.dataframe.DataFrame": + """Convert this :class:`~ray.data.Dataset` into a + `Dask DataFrame `_. + + This is only supported for datasets convertible to Arrow records. + + Note that this function will set the Dask scheduler to Dask-on-Ray + globally, via the config. + + Time complexity: O(dataset size / parallelism) + + Args: + meta: An empty `pandas DataFrame`_ or `Series`_ that matches the dtypes and column + names of the stream. This metadata is necessary for many algorithms in + dask dataframe to work. For ease of use, some alternative inputs are + also available. Instead of a DataFrame, a dict of ``{name: dtype}`` or + iterable of ``(name, dtype)`` can be provided (note that the order of + the names should match the order of the columns). Instead of a series, a + tuple of ``(name, dtype)`` can be used. + By default, this is inferred from the underlying Dataset schema, + with this argument supplying an optional override. + verify_meta: If True, Dask will check that the partitions have consistent + metadata. Defaults to True. + + Returns: + A `Dask DataFrame`_ created from this dataset. + + .. _pandas DataFrame: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html + .. _Series: https://pandas.pydata.org/docs/reference/api/pandas.Series.html + """ # noqa: E501 + import dask + import dask.dataframe as dd + import pandas as pd + + try: + import pyarrow as pa + except Exception: + pa = None + + from ray.data._internal.pandas_block import PandasBlockSchema + from ray.util.client.common import ClientObjectRef + from ray.util.dask import ray_dask_get + + dask.config.set(scheduler=ray_dask_get) + + @dask.delayed + def block_to_df(block_ref: ObjectRef[Block]) -> pd.DataFrame: + if isinstance(block_ref, (ray.ObjectRef, ClientObjectRef)): + raise ValueError( + "Dataset.to_dask() must be used with Dask-on-Ray, please " + "set the Dask scheduler to ray_dask_get (located in " + "ray.util.dask)." + ) + return _block_to_df(block_ref) + + if meta is None: + from ray.data.extensions import TensorDtype + + # Infer Dask metadata from Dataset schema. + schema = self.schema(fetch_if_missing=True) + if isinstance(schema, PandasBlockSchema): + meta = pd.DataFrame( + { + col: pd.Series( + dtype=( + dtype + if not isinstance(dtype, TensorDtype) + else np.object_ + ) + ) + for col, dtype in zip(schema.names, schema.types) + } + ) + elif pa is not None and isinstance(schema, pa.Schema): + arrow_tensor_ext_types = get_arrow_extension_fixed_shape_tensor_types() + + if any( + isinstance(type_, arrow_tensor_ext_types) for type_ in schema.types + ): + meta = pd.DataFrame( + { + col: pd.Series( + dtype=( + dtype.to_pandas_dtype() + if not isinstance(dtype, arrow_tensor_ext_types) + else np.object_ + ) + ) + for col, dtype in zip(schema.names, schema.types) + } + ) + else: + meta = schema.empty_table().to_pandas() + + dfs = [] + for ref_bundle in self.iter_internal_ref_bundles(): + for block_ref in ref_bundle.block_refs: + dfs.append(block_to_df(block_ref)) + + ddf = dd.from_delayed( + dfs, + meta=meta, + verify_meta=verify_meta, + ) + return ddf + + @ConsumptionAPI(pattern="Time complexity:") + @PublicAPI(api_group=IOC_API_GROUP) + def to_mars(self) -> "mars.dataframe.DataFrame": + """Convert this :class:`~ray.data.Dataset` into a + `Mars DataFrame `_. + + Time complexity: O(dataset size / parallelism) + + Returns: + A `Mars DataFrame`_ created from this dataset. + """ # noqa: E501 + import pandas as pd + import pyarrow as pa + from mars.dataframe.datasource.read_raydataset import DataFrameReadRayDataset + from mars.dataframe.utils import parse_index + + from ray.data._internal.pandas_block import PandasBlockSchema + + refs = self.to_pandas_refs() + # remove this when https://github.com/mars-project/mars/issues/2945 got fixed + schema = self.schema() + if isinstance(schema, Schema): + schema = schema.base_schema + if isinstance(schema, PandasBlockSchema): + dtypes = pd.Series(schema.types, index=schema.names) + elif isinstance(schema, pa.Schema): + dtypes = schema.empty_table().to_pandas().dtypes + else: + raise NotImplementedError(f"Unsupported format of schema {schema}") + index_value = parse_index(pd.RangeIndex(-1)) + columns_value = parse_index(dtypes.index, store_data=True) + op = DataFrameReadRayDataset(refs=refs) + return op(index_value=index_value, columns_value=columns_value, dtypes=dtypes) + + @ConsumptionAPI(pattern="Time complexity:") + @PublicAPI(api_group=IOC_API_GROUP) + def to_modin(self) -> "modin.pandas.dataframe.DataFrame": + """Convert this :class:`~ray.data.Dataset` into a + `Modin DataFrame `_. + + This works by first converting this dataset into a distributed set of + Pandas DataFrames (using :meth:`Dataset.to_pandas_refs`). + See caveats there. Then the individual DataFrames are used to + create the Modin DataFrame using + ``modin.distributed.dataframe.pandas.partitions.from_partitions()``. + + This is only supported for datasets convertible to Arrow records. + This function induces a copy of the data. For zero-copy access to the + underlying data, consider using :meth:`.to_arrow_refs` or + :meth:`.iter_internal_ref_bundles`. + + Time complexity: O(dataset size / parallelism) + + Returns: + A `Modin DataFrame`_ created from this dataset. + """ # noqa: E501 + + from modin.distributed.dataframe.pandas.partitions import from_partitions + + pd_objs = self.to_pandas_refs() + return from_partitions(pd_objs, axis=0) + + @ConsumptionAPI(pattern="Time complexity:") + @PublicAPI(api_group=IOC_API_GROUP) + def to_spark(self, spark: "pyspark.sql.SparkSession") -> "pyspark.sql.DataFrame": + """Convert this :class:`~ray.data.Dataset` into a + `Spark DataFrame `_. + + Time complexity: O(dataset size / parallelism) + + Args: + spark: A `SparkSession`_, which must be created by RayDP (Spark-on-Ray). + + Returns: + A `Spark DataFrame`_ created from this dataset. + + .. _SparkSession: https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.SparkSession.html + """ # noqa: E501 + import raydp + + schema = self.schema() + if isinstance(schema, Schema): + schema = schema.base_schema + + ref_bundles = self.iter_internal_ref_bundles() + block_refs = _ref_bundles_iterator_to_block_refs_list(ref_bundles) + return raydp.spark.ray_dataset_to_spark_dataframe(spark, schema, block_refs) + + @ConsumptionAPI(pattern="Time complexity:") + @PublicAPI(api_group=IOC_API_GROUP) + def to_pandas(self, limit: int = None) -> "pandas.DataFrame": + """Convert this :class:`~ray.data.Dataset` to a single pandas DataFrame. + + This method errors if the number of rows exceeds the provided ``limit``. + To truncate the dataset beforehand, call :meth:`.limit`. + + Examples: + >>> import ray + >>> ds = ray.data.from_items([{"a": i} for i in range(3)]) + >>> ds.to_pandas() + a + 0 0 + 1 1 + 2 2 + + Time complexity: O(dataset size) + + Args: + limit: The maximum number of rows to return. An error is + raised if the dataset has more rows than this limit. Defaults to + ``None``, which means no limit. + + Returns: + A pandas DataFrame created from this dataset, containing a limited + number of rows. + + Raises: + ValueError: if the number of rows in the :class:`~ray.data.Dataset` exceeds + ``limit``. + """ + if limit is not None: + count = self.count() + if count > limit: + raise ValueError( + f"the dataset has more than the given limit of {limit} " + f"rows: {count}. If you are sure that a DataFrame with " + f"{count} rows will fit in local memory, set " + "ds.to_pandas(limit=None) to disable limits." + ) + + builder = PandasBlockBuilder() + for batch in self.iter_batches(batch_format="pandas", batch_size=None): + builder.add_block(batch) + block = builder.build() + + # `PandasBlockBuilder` creates a dataframe with internal extension types like + # 'TensorDtype'. We use the `to_pandas` method to convert these extension + # types to regular types. + return BlockAccessor.for_block(block).to_pandas() + + @ConsumptionAPI(pattern="Time complexity:") + @DeveloperAPI + def to_pandas_refs(self) -> List[ObjectRef["pandas.DataFrame"]]: + """Converts this :class:`~ray.data.Dataset` into a distributed set of Pandas + dataframes. + + One DataFrame is created for each block in this Dataset. + + This function induces a copy of the data. For zero-copy access to the + underlying data, consider using :meth:`Dataset.to_arrow_refs` or + :meth:`Dataset.iter_internal_ref_bundles`. + + Examples: + >>> import ray + >>> ds = ray.data.range(10, override_num_blocks=2) + >>> refs = ds.to_pandas_refs() + >>> len(refs) + 2 + + Time complexity: O(dataset size / parallelism) + + Returns: + A list of remote pandas DataFrames created from this dataset. + """ + + block_to_df = cached_remote_fn(_block_to_df) + pandas_refs = [] + for bundle in self.iter_internal_ref_bundles(): + for block_ref in bundle.block_refs: + pandas_refs.append(block_to_df.remote(block_ref)) + return pandas_refs + + @DeveloperAPI + def to_numpy_refs( + self, *, column: Optional[str] = None + ) -> List[ObjectRef[np.ndarray]]: + """Converts this :class:`~ray.data.Dataset` into a distributed set of NumPy + ndarrays or dictionary of NumPy ndarrays. + + This is only supported for datasets convertible to NumPy ndarrays. + This function induces a copy of the data. For zero-copy access to the + underlying data, consider using :meth:`Dataset.to_arrow_refs` or + :meth:`Dataset.iter_internal_ref_bundles`. + + Examples: + >>> import ray + >>> ds = ray.data.range(10, override_num_blocks=2) + >>> refs = ds.to_numpy_refs() + >>> len(refs) + 2 + + Time complexity: O(dataset size / parallelism) + + Args: + column: The name of the column to convert to numpy. If ``None``, all columns + are used. If multiple columns are specified, each returned + future represents a dict of ndarrays. Defaults to None. + + Returns: + A list of remote NumPy ndarrays created from this dataset. + """ + block_to_ndarray = cached_remote_fn(_block_to_ndarray) + numpy_refs = [] + for bundle in self.iter_internal_ref_bundles(): + for block_ref in bundle.block_refs: + numpy_refs.append(block_to_ndarray.remote(block_ref, column=column)) + return numpy_refs + + @ConsumptionAPI(pattern="Time complexity:") + @DeveloperAPI + def to_arrow_refs(self) -> List[ObjectRef["pyarrow.Table"]]: + """Convert this :class:`~ray.data.Dataset` into a distributed set of PyArrow + tables. + + One PyArrow table is created for each block in this Dataset. + + This method is only supported for datasets convertible to PyArrow tables. + This function is zero-copy if the existing data is already in PyArrow + format. Otherwise, the data is converted to PyArrow format. + + Examples: + >>> import ray + >>> ds = ray.data.range(10, override_num_blocks=2) + >>> refs = ds.to_arrow_refs() + >>> len(refs) + 2 + + Time complexity: O(1) unless conversion is required. + + Returns: + A list of remote PyArrow tables created from this dataset. + """ + import pyarrow as pa + + ref_bundles: Iterator[RefBundle] = self.iter_internal_ref_bundles() + block_refs: List[ + ObjectRef["pyarrow.Table"] + ] = _ref_bundles_iterator_to_block_refs_list(ref_bundles) + # Schema is safe to call since we have already triggered execution with + # iter_internal_ref_bundles. + schema = self.schema(fetch_if_missing=True) + if isinstance(schema, Schema): + schema = schema.base_schema + if isinstance(schema, pa.Schema): + # Zero-copy path. + return block_refs + + block_to_arrow = cached_remote_fn(_block_to_arrow) + return [block_to_arrow.remote(block) for block in block_refs] + + @ConsumptionAPI(pattern="Args:") + def to_random_access_dataset( + self, + key: str, + num_workers: Optional[int] = None, + ) -> RandomAccessDataset: + """Convert this dataset into a distributed RandomAccessDataset (EXPERIMENTAL). + + RandomAccessDataset partitions the dataset across the cluster by the given + sort key, providing efficient random access to records via binary search. A + number of worker actors are created, each of which has zero-copy access to the + underlying sorted data blocks of the dataset. + + Note that the key must be unique in the dataset. If there are duplicate keys, + an arbitrary value is returned. + + This is only supported for Arrow-format datasets. + + Args: + key: The key column over which records can be queried. + num_workers: The number of actors to use to serve random access queries. + By default, this is determined by multiplying the number of Ray nodes + in the cluster by four. As a rule of thumb, you can expect each worker + to provide ~3000 records / second via ``get_async()``, and + ~10000 records / second via ``multiget()``. + """ + if num_workers is None: + num_workers = 4 * len(ray.nodes()) + return RandomAccessDataset(self, key, num_workers=num_workers) + + @ConsumptionAPI(pattern="store memory.", insert_after=True) + @PublicAPI(api_group=E_API_GROUP) + def materialize(self) -> "MaterializedDataset": + """Execute and materialize this dataset into object store memory. + + This can be used to read all blocks into memory. By default, Dataset + doesn't read blocks from the datasource until the first transform. + + Note that this does not mutate the original Dataset. Only the blocks of the + returned MaterializedDataset class are pinned in memory. + + Examples: + >>> import ray + >>> ds = ray.data.range(10) + >>> materialized_ds = ds.materialize() + >>> materialized_ds + MaterializedDataset(num_blocks=..., num_rows=10, schema={id: int64}) + + Returns: + A MaterializedDataset holding the materialized data blocks. + """ + copy = Dataset.copy(self, _deep_copy=True, _as=MaterializedDataset) + copy._plan.execute() + + bundle = copy._plan._snapshot_bundle + blocks_with_metadata = bundle.blocks + # TODO(hchen): Here we generate the same number of blocks as + # the original Dataset. Because the old code path does this, and + # some unit tests implicily depend on this behavior. + # After we remove the old code path, we should consider merging + # some blocks for better perf. + ref_bundles = [ + RefBundle( + blocks=[block_with_metadata], + owns_blocks=False, + ) + for block_with_metadata in blocks_with_metadata + ] + logical_plan = LogicalPlan(InputData(input_data=ref_bundles), self.context) + output = MaterializedDataset( + ExecutionPlan(copy._plan.stats()), + logical_plan, + ) + # Metrics are tagged with `copy`s uuid, update the output uuid with + # this so the user can access the metrics label. + output._set_name(copy._name) + output._set_uuid(copy._get_uuid()) + output._plan.execute() # No-op that marks the plan as fully executed. + return output + + @PublicAPI(api_group=IM_API_GROUP) + def stats(self) -> str: + """Returns a string containing execution timing information. + + Note that this does not trigger execution, so if the dataset has not yet + executed, an empty string is returned. + + Examples: + + .. testcode:: + + import ray + + ds = ray.data.range(10) + assert ds.stats() == "" + + ds = ds.materialize() + print(ds.stats()) + + .. testoutput:: + :options: +MOCK + + Operator 0 Read: 1 tasks executed, 5 blocks produced in 0s + * Remote wall time: 16.29us min, 7.29ms max, 1.21ms mean, 24.17ms total + * Remote cpu time: 16.0us min, 2.54ms max, 810.45us mean, 16.21ms total + * Peak heap memory usage (MiB): 137968.75 min, 142734.38 max, 139846 mean + * Output num rows: 0 min, 1 max, 0 mean, 10 total + * Output size bytes: 0 min, 8 max, 4 mean, 80 total + * Tasks per node: 20 min, 20 max, 20 mean; 1 nodes used + + """ + if self._current_executor: + return self._current_executor.get_stats().to_summary().to_string() + elif self._write_ds is not None and self._write_ds._plan.has_computed_output(): + return self._write_ds.stats() + return self._get_stats_summary().to_string() + + def _get_stats_summary(self) -> DatasetStatsSummary: + return self._plan.stats().to_summary() + + @ConsumptionAPI(pattern="Examples:") + @DeveloperAPI + def iter_internal_ref_bundles(self) -> Iterator[RefBundle]: + """Get an iterator over ``RefBundles`` + belonging to this Dataset. Calling this function doesn't keep + the data materialized in-memory. + + Examples: + >>> import ray + >>> ds = ray.data.range(1) + >>> for ref_bundle in ds.iter_internal_ref_bundles(): + ... for block_ref, block_md in ref_bundle.blocks: + ... block = ray.get(block_ref) + + Returns: + An iterator over this Dataset's ``RefBundles``. + """ + + iter_ref_bundles, _, _ = self._plan.execute_to_iterator() + self._synchronize_progress_bar() + return iter_ref_bundles + + @Deprecated + @ConsumptionAPI(pattern="Examples:") + def get_internal_block_refs(self) -> List[ObjectRef[Block]]: + """Get a list of references to the underlying blocks of this dataset. + + This function can be used for zero-copy access to the data. It blocks + until the underlying blocks are computed. + + Examples: + >>> import ray + >>> ds = ray.data.range(1) + >>> ds.get_internal_block_refs() + [ObjectRef(...)] + + Returns: + A list of references to this dataset's blocks. + """ + logger.warning( + "`Dataset.get_internal_block_refs()` is deprecated. Use " + "`Dataset.iter_internal_ref_bundles()` instead.", + ) + block_refs = self._plan.execute().block_refs + self._synchronize_progress_bar() + return block_refs + + @DeveloperAPI + def has_serializable_lineage(self) -> bool: + """Whether this dataset's lineage is able to be serialized for storage and + later deserialized, possibly on a different cluster. + + Only datasets that are created from data that we know will still exist at + deserialization time, e.g. data external to this Ray cluster such as persistent + cloud object stores, support lineage-based serialization. All of the + ray.data.read_*() APIs support lineage-based serialization. + + Examples: + + >>> import ray + >>> ray.data.from_items(list(range(10))).has_serializable_lineage() + False + >>> ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv").has_serializable_lineage() + True + """ # noqa: E501 + return all( + op.is_lineage_serializable() + for op in self._logical_plan.dag.post_order_iter() + ) + + @DeveloperAPI + def serialize_lineage(self) -> bytes: + """ + Serialize this dataset's lineage, not the actual data or the existing data + futures, to bytes that can be stored and later deserialized, possibly on a + different cluster. + + Note that this uses pickle and will drop all computed data, and that everything + is recomputed from scratch after deserialization. + + Use :py:meth:`Dataset.deserialize_lineage` to deserialize the serialized + bytes returned from this method into a Dataset. + + .. note:: + Unioned and zipped datasets, produced by :py:meth`Dataset.union` and + :py:meth:`Dataset.zip`, are not lineage-serializable. + + Examples: + + .. testcode:: + + import ray + + ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") + serialized_ds = ds.serialize_lineage() + ds = ray.data.Dataset.deserialize_lineage(serialized_ds) + print(ds) + + .. testoutput:: + + Dataset( + num_rows=?, + schema={ + sepal length (cm): double, + sepal width (cm): double, + petal length (cm): double, + petal width (cm): double, + target: int64 + } + ) + + + Returns: + Serialized bytes containing the lineage of this dataset. + """ + if not self.has_serializable_lineage(): + raise ValueError( + "Lineage-based serialization is not supported for this stream, which " + "means that it cannot be used as a tunable hyperparameter. " + "Lineage-based serialization is explicitly NOT supported for unioned " + "or zipped datasets (see docstrings for those methods), and is only " + "supported for datasets created from data that we know will still " + "exist at deserialization time, e.g. external data in persistent cloud " + "object stores or in-memory data from long-lived clusters. Concretely, " + "all ray.data.read_*() APIs should support lineage-based " + "serialization, while all of the ray.data.from_*() APIs do not. To " + "allow this stream to be serialized to storage, write the data to an " + "external store (such as AWS S3, GCS, or Azure Blob Storage) using the " + "Dataset.write_*() APIs, and serialize a new dataset reading " + "from the external store using the ray.data.read_*() APIs." + ) + # Copy Dataset and clear the blocks from the execution plan so only the + # Dataset's lineage is serialized. + plan_copy = self._plan.deep_copy() + logical_plan_copy = copy.copy(self._plan._logical_plan) + ds = Dataset(plan_copy, logical_plan_copy) + ds._plan.clear_snapshot() + ds._set_uuid(self._get_uuid()) + + def _reduce_remote_fn(rf: ray.remote_function.RemoteFunction): + # Custom reducer for Ray remote function handles that allows for + # cross-cluster serialization. + # This manually unsets the last export session and job to force re-exporting + # of the function when the handle is deserialized on a new cluster. + # TODO(Clark): Fix this in core Ray, see issue: + # https://github.com/ray-project/ray/issues/24152. + reconstructor, args, state = rf.__reduce__() + state["_last_export_session_and_job"] = None + return reconstructor, args, state + + context = ray._private.worker.global_worker.get_serialization_context() + try: + context._register_cloudpickle_reducer( + ray.remote_function.RemoteFunction, _reduce_remote_fn + ) + serialized = pickle.dumps(ds) + finally: + context._unregister_cloudpickle_reducer(ray.remote_function.RemoteFunction) + return serialized + + @staticmethod + @DeveloperAPI + def deserialize_lineage(serialized_ds: bytes) -> "Dataset": + """ + Deserialize the provided lineage-serialized Dataset. + + This uses pickle, and assumes that the provided serialized bytes were + serialized using :py:meth:`Dataset.serialize_lineage`. + + Examples: + + .. testcode:: + + import ray + + ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") + serialized_ds = ds.serialize_lineage() + ds = ray.data.Dataset.deserialize_lineage(serialized_ds) + print(ds) + + .. testoutput:: + + Dataset( + num_rows=?, + schema={ + sepal length (cm): double, + sepal width (cm): double, + petal length (cm): double, + petal width (cm): double, + target: int64 + } + ) + + Args: + serialized_ds: The serialized Dataset that we wish to deserialize. + + Returns: + A deserialized ``Dataset`` instance. + """ + return pickle.loads(serialized_ds) + + @property + @DeveloperAPI + def context(self) -> DataContext: + """Return the DataContext used to create this Dataset.""" + return self._plan._context + + def _aggregate_on( + self, agg_cls: type, on: Optional[Union[str, List[str]]], *args, **kwargs + ): + """Helper for aggregating on a particular subset of the dataset. + + This validates the `on` argument, and converts a list of column names + or lambdas to a multi-aggregation. A null `on` results in a + multi-aggregation on all columns for an Arrow Dataset, and a single + aggregation on the entire row for a simple Dataset. + """ + aggs = self._build_multicolumn_aggs(agg_cls, on, *args, **kwargs) + return self.aggregate(*aggs) + + def _build_multicolumn_aggs( + self, + agg_cls: type, + on: Optional[Union[str, List[str]]], + *args, + skip_cols: Optional[List[str]] = None, + **kwargs, + ): + """Build set of aggregations for applying a single aggregation to + multiple columns. + """ + # Expand None into an aggregation for each column. + if on is None: + schema = self.schema(fetch_if_missing=True) + if schema is not None and not isinstance(schema, type): + if not skip_cols: + skip_cols = [] + if len(schema.names) > 0: + on = [col for col in schema.names if col not in skip_cols] + + if not isinstance(on, list): + on = [on] + return [agg_cls(on_, *args, **kwargs) for on_ in on] + + def _aggregate_result(self, result: Union[Tuple, Mapping]) -> U: + if result is not None and len(result) == 1: + if isinstance(result, tuple): + return result[0] + else: + # NOTE (kfstorm): We cannot call `result[0]` directly on + # `PandasRow` because indexing a column with position is not + # supported by pandas. + return list(result.values())[0] + else: + return result + + @repr_with_fallback(["ipywidgets", "8"]) + def _repr_mimebundle_(self, **kwargs): + """Return a mimebundle with an ipywidget repr and a simple text repr. + + Depending on the frontend where the data is being displayed, + different mimetypes are used from this bundle. + See https://ipython.readthedocs.io/en/stable/config/integrating.html + for information about this method, and + https://ipywidgets.readthedocs.io/en/latest/embedding.html + for more information about the jupyter widget mimetype. + + Returns: + A mimebundle containing an ipywidget repr and a simple text repr. + """ + import ipywidgets + + title = ipywidgets.HTML(f"

{self.__class__.__name__}

") + tab = self._tab_repr_() + widget = ipywidgets.VBox([title, tab], layout=ipywidgets.Layout(width="100%")) + + # Get the widget mime bundle, but replace the plaintext + # with the Datastream repr + bundle = widget._repr_mimebundle_(**kwargs) + bundle.update( + { + "text/plain": repr(self), + } + ) + return bundle + + def _tab_repr_(self): + from ipywidgets import HTML, Tab + + metadata = { + "num_blocks": self._plan.initial_num_blocks(), + "num_rows": self._meta_count(), + } + # Show metadata if available, but don't trigger execution. + schema = self.schema(fetch_if_missing=False) + if schema is None: + schema_repr = Template("rendered_html_common.html.j2").render( + content="
Unknown schema
" + ) + elif isinstance(schema, type): + schema_repr = Template("rendered_html_common.html.j2").render( + content=f"
Data type: {html.escape(str(schema))}
" + ) + else: + schema_data = {} + for sname, stype in zip(schema.names, schema.types): + schema_data[sname] = getattr(stype, "__name__", str(stype)) + + schema_repr = Template("scrollableTable.html.j2").render( + table=tabulate( + tabular_data=schema_data.items(), + tablefmt="html", + showindex=False, + headers=["Name", "Type"], + ), + max_height="300px", + ) + + children = [] + children.append( + HTML( + Template("scrollableTable.html.j2").render( + table=tabulate( + tabular_data=metadata.items(), + tablefmt="html", + showindex=False, + headers=["Field", "Value"], + ), + max_height="300px", + ) + ) + ) + children.append(HTML(schema_repr)) + return Tab(children, titles=["Metadata", "Schema"]) + + def __repr__(self) -> str: + return self._plan.get_plan_as_string(self.__class__) + + def __str__(self) -> str: + return repr(self) + + def __bool__(self) -> bool: + # Prevents `__len__` from being called to check if it is None + # see: issue #25152 + return True + + def __len__(self) -> int: + raise AttributeError( + "Use `ds.count()` to compute the length of a distributed Dataset. " + "This may be an expensive operation." + ) + + def __iter__(self): + raise TypeError( + "`Dataset` objects aren't iterable. To iterate records, call " + "`ds.iter_rows()` or `ds.iter_batches()`. For more information, read " + "https://docs.ray.io/en/latest/data/iterating-over-data.html." + ) + + def _block_num_rows(self) -> List[int]: + get_num_rows = cached_remote_fn(_get_num_rows) + num_rows = [] + for ref_bundle in self.iter_internal_ref_bundles(): + for block_ref in ref_bundle.block_refs: + num_rows.append(get_num_rows.remote(block_ref)) + return ray.get(num_rows) + + def _meta_count(self) -> Optional[int]: + return self._plan.meta_count() + + def _get_uuid(self) -> str: + return self._uuid + + def _set_uuid(self, uuid: str) -> None: + self._uuid = uuid + self._plan._dataset_uuid = uuid + self._plan._in_stats.dataset_uuid = uuid + + def _synchronize_progress_bar(self): + """Flush progress bar output by shutting down the current executor. + + This should be called at the end of all blocking APIs (e.g., `take`), but not + async APIs (e.g., `iter_batches`). + + The streaming executor runs in a separate generator / thread, so it is + possible the shutdown logic runs even after a call to retrieve rows from the + stream has finished. Explicit shutdown avoids this, which can clobber console + output (https://github.com/ray-project/ray/issues/32414). + """ + if self._current_executor: + self._current_executor.shutdown() + self._current_executor = None + + def __getstate__(self): + # Note: excludes _current_executor which is not serializable. + return { + "plan": self._plan, + "uuid": self._uuid, + "logical_plan": self._logical_plan, + } + + def __setstate__(self, state): + self._plan = state["plan"] + self._uuid = state["uuid"] + self._logical_plan = state["logical_plan"] + self._current_executor = None + + def __del__(self): + if not self._current_executor: + return + + # When Python shuts down, `ray` might evaluate to ``. + # This value is truthy and not `None`, so we use a try-catch in addition to + # `if ray is not None`. For more information, see #42382. + try: + if ray is not None and ray.is_initialized(): + self._current_executor.shutdown() + except TypeError: + pass + + +@PublicAPI +class MaterializedDataset(Dataset, Generic[T]): + """A Dataset materialized in Ray memory, e.g., via `.materialize()`. + + The blocks of a MaterializedDataset object are materialized into Ray object store + memory, which means that this class can be shared or iterated over by multiple Ray + tasks without re-executing the underlying computations for producing the stream. + """ + + def num_blocks(self) -> int: + """Return the number of blocks of this :class:`MaterializedDataset`. + + Examples: + >>> import ray + >>> ds = ray.data.range(100).repartition(10).materialize() + >>> ds.num_blocks() + 10 + + Time complexity: O(1) + + Returns: + The number of blocks of this :class:`Dataset`. + """ + return self._plan.initial_num_blocks() + + +@PublicAPI(stability="beta") +class Schema: + """Dataset schema. + + Attributes: + base_schema: The underlying Arrow or Pandas schema. + """ + + def __init__( + self, + base_schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"], + *, + data_context: Optional[DataContext] = None, + ): + self.base_schema = base_schema + + # Snapshot the current context, so that the config of Datasets is always + # determined by the config at the time it was created. + self._context = data_context or copy.deepcopy(DataContext.get_current()) + + @property + def names(self) -> List[str]: + """Lists the columns of this Dataset.""" + return self.base_schema.names + + @property + def types(self) -> List[Union[type[object], "pyarrow.lib.DataType"]]: + """Lists the types of this Dataset in Arrow format + + For non-Arrow compatible types, we return "object". + """ + import pyarrow as pa + + from ray.data.extensions import ArrowTensorType, TensorDtype + + if isinstance(self.base_schema, pa.lib.Schema): + return list(self.base_schema.types) + + arrow_types = [] + for dtype in self.base_schema.types: + if isinstance(dtype, TensorDtype): + + if self._context.use_arrow_tensor_v2: + pa_tensor_type_class = ArrowTensorTypeV2 + else: + pa_tensor_type_class = ArrowTensorType + + # Manually convert our Pandas tensor extension type to Arrow. + arrow_types.append( + pa_tensor_type_class( + shape=dtype._shape, dtype=pa.from_numpy_dtype(dtype._dtype) + ) + ) + + else: + try: + arrow_types.append(pa.from_numpy_dtype(dtype)) + except pa.ArrowNotImplementedError: + arrow_types.append(object) + except Exception: + logger.exception(f"Error converting dtype {dtype} to Arrow.") + arrow_types.append(None) + return arrow_types + + def __eq__(self, other): + return ( + isinstance(other, Schema) + and other.types == self.types + and other.names == self.names + ) + + def __repr__(self): + column_width = max([len(name) for name in self.names] + [len("Column")]) + padding = 2 + + output = "Column" + output += " " * ((column_width + padding) - len("Column")) + output += "Type\n" + + output += "-" * len("Column") + output += " " * ((column_width + padding) - len("Column")) + output += "-" * len("Type") + "\n" + + for name, type in zip(self.names, self.types): + output += name + output += " " * ((column_width + padding) - len(name)) + output += f"{type}\n" + + output = output.rstrip() + return output + + +def _block_to_df(block: Block) -> "pandas.DataFrame": + block = BlockAccessor.for_block(block) + return block.to_pandas() + + +def _block_to_ndarray(block: Block, column: Optional[str]): + block = BlockAccessor.for_block(block) + return block.to_numpy(column) + + +def _block_to_arrow(block: Block): + block = BlockAccessor.for_block(block) + return block.to_arrow() diff --git a/.venv/lib/python3.11/site-packages/ray/data/datasource/datasink.py b/.venv/lib/python3.11/site-packages/ray/data/datasource/datasink.py new file mode 100644 index 0000000000000000000000000000000000000000..fee67910ba98a6e27e70cc050784ced5917af4f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/datasource/datasink.py @@ -0,0 +1,164 @@ +import logging +from dataclasses import dataclass +from typing import Generic, Iterable, List, Optional, TypeVar + +import ray +from ray.data._internal.execution.interfaces import TaskContext +from ray.data.block import Block, BlockAccessor +from ray.util.annotations import DeveloperAPI + +logger = logging.getLogger(__name__) + + +WriteReturnType = TypeVar("WriteReturnType") +"""Generic type for the return value of `Datasink.write`.""" + + +@dataclass +@DeveloperAPI +class WriteResult(Generic[WriteReturnType]): + """Aggregated result of the Datasink write operations.""" + + # Total number of written rows. + num_rows: int + # Total size in bytes of written data. + size_bytes: int + # All returned values of `Datasink.write`. + write_returns: List[WriteReturnType] + + +@DeveloperAPI +class Datasink(Generic[WriteReturnType]): + """Interface for defining write-related logic. + + If you want to write data to something that isn't built-in, subclass this class + and call :meth:`~ray.data.Dataset.write_datasink`. + """ + + def on_write_start(self) -> None: + """Callback for when a write job starts. + + Use this method to perform setup for write tasks. For example, creating a + staging bucket in S3. + """ + pass + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> WriteReturnType: + """Write blocks. This is used by a single write task. + + Args: + blocks: Generator of data blocks. + ctx: ``TaskContext`` for the write task. + + Returns: + Result of this write task. When the entire write operator finishes, + All returned values will be passed as `WriteResult.write_returns` + to `Datasink.on_write_complete`. + """ + raise NotImplementedError + + def on_write_complete(self, write_result: WriteResult[WriteReturnType]): + """Callback for when a write job completes. + + This can be used to "commit" a write output. This method must + succeed prior to ``write_datasink()`` returning to the user. If this + method fails, then ``on_write_failed()`` is called. + + Args: + write_result: Aggregated result of the + the Write operator, containing write results and stats. + """ + pass + + def on_write_failed(self, error: Exception) -> None: + """Callback for when a write job fails. + + This is called on a best-effort basis on write failures. + + Args: + error: The first error encountered. + """ + pass + + def get_name(self) -> str: + """Return a human-readable name for this datasink. + + This is used as the names of the write tasks. + """ + name = type(self).__name__ + datasink_suffix = "Datasink" + if name.startswith("_"): + name = name[1:] + if name.endswith(datasink_suffix): + name = name[: -len(datasink_suffix)] + return name + + @property + def supports_distributed_writes(self) -> bool: + """If ``False``, only launch write tasks on the driver's node.""" + return True + + @property + def min_rows_per_write(self) -> Optional[int]: + """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call. + + If ``None``, Ray Data passes a system-chosen number of rows. + """ + return None + + +@DeveloperAPI +class DummyOutputDatasink(Datasink[None]): + """An example implementation of a writable datasource for testing. + Examples: + >>> import ray + >>> from ray.data.datasource import DummyOutputDatasink + >>> output = DummyOutputDatasink() + >>> ray.data.range(10).write_datasink(output) + >>> assert output.num_ok == 1 + """ + + def __init__(self): + ctx = ray.data.DataContext.get_current() + + # Setup a dummy actor to send the data. In a real datasource, write + # tasks would send data to an external system instead of a Ray actor. + @ray.remote(scheduling_strategy=ctx.scheduling_strategy) + class DataSink: + def __init__(self): + self.rows_written = 0 + self.enabled = True + + def write(self, block: Block) -> None: + block = BlockAccessor.for_block(block) + self.rows_written += block.num_rows() + + def get_rows_written(self): + return self.rows_written + + self.data_sink = DataSink.remote() + self.num_ok = 0 + self.num_failed = 0 + self.enabled = True + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> None: + tasks = [] + if not self.enabled: + raise ValueError("disabled") + for b in blocks: + tasks.append(self.data_sink.write.remote(b)) + ray.get(tasks) + + def on_write_complete(self, write_result: WriteResult[None]): + self.num_ok += 1 + + def on_write_failed(self, error: Exception) -> None: + self.num_failed += 1 diff --git a/.venv/lib/python3.11/site-packages/ray/data/datasource/datasource.py b/.venv/lib/python3.11/site-packages/ray/data/datasource/datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..c09460d9bdf99ee08b793cffb0850dfbfe952367 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/datasource/datasource.py @@ -0,0 +1,243 @@ +from typing import Callable, Iterable, List, Optional + +import numpy as np + +from ray.data._internal.util import _check_pyarrow_version +from ray.data.block import Block, BlockMetadata +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI + + +@PublicAPI +class Datasource: + """Interface for defining a custom :class:`~ray.data.Dataset` datasource. + + To read a datasource into a dataset, use :meth:`~ray.data.read_datasource`. + """ # noqa: E501 + + @Deprecated + def create_reader(self, **read_args) -> "Reader": + """ + Deprecated: Implement :meth:`~ray.data.Datasource.get_read_tasks` and + :meth:`~ray.data.Datasource.estimate_inmemory_data_size` instead. + """ + return _LegacyDatasourceReader(self, **read_args) + + @Deprecated + def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask"]: + """ + Deprecated: Implement :meth:`~ray.data.Datasource.get_read_tasks` and + :meth:`~ray.data.Datasource.estimate_inmemory_data_size` instead. + """ + raise NotImplementedError + + def get_name(self) -> str: + """Return a human-readable name for this datasource. + This will be used as the names of the read tasks. + """ + name = type(self).__name__ + datasource_suffix = "Datasource" + if name.endswith(datasource_suffix): + name = name[: -len(datasource_suffix)] + return name + + def estimate_inmemory_data_size(self) -> Optional[int]: + """Return an estimate of the in-memory data size, or None if unknown. + + Note that the in-memory data size may be larger than the on-disk data size. + """ + raise NotImplementedError + + def get_read_tasks(self, parallelism: int) -> List["ReadTask"]: + """Execute the read and return read tasks. + + Args: + parallelism: The requested read parallelism. The number of read + tasks should equal to this value if possible. + + Returns: + A list of read tasks that can be executed to read blocks from the + datasource in parallel. + """ + raise NotImplementedError + + @property + def should_create_reader(self) -> bool: + has_implemented_get_read_tasks = ( + type(self).get_read_tasks is not Datasource.get_read_tasks + ) + has_implemented_estimate_inmemory_data_size = ( + type(self).estimate_inmemory_data_size + is not Datasource.estimate_inmemory_data_size + ) + return ( + not has_implemented_get_read_tasks + or not has_implemented_estimate_inmemory_data_size + ) + + @property + def supports_distributed_reads(self) -> bool: + """If ``False``, only launch read tasks on the driver's node.""" + return True + + +@Deprecated +class Reader: + """A bound read operation for a :class:`~ray.data.Datasource`. + + This is a stateful class so that reads can be prepared in multiple stages. + For example, it is useful for :class:`Datasets ` to know the + in-memory size of the read prior to executing it. + """ + + def estimate_inmemory_data_size(self) -> Optional[int]: + """Return an estimate of the in-memory data size, or None if unknown. + + Note that the in-memory data size may be larger than the on-disk data size. + """ + raise NotImplementedError + + def get_read_tasks(self, parallelism: int) -> List["ReadTask"]: + """Execute the read and return read tasks. + + Args: + parallelism: The requested read parallelism. The number of read + tasks should equal to this value if possible. + read_args: Additional kwargs to pass to the datasource impl. + + Returns: + A list of read tasks that can be executed to read blocks from the + datasource in parallel. + """ + raise NotImplementedError + + +class _LegacyDatasourceReader(Reader): + def __init__(self, datasource: Datasource, **read_args): + self._datasource = datasource + self._read_args = read_args + + def estimate_inmemory_data_size(self) -> Optional[int]: + return None + + def get_read_tasks(self, parallelism: int) -> List["ReadTask"]: + return self._datasource.prepare_read(parallelism, **self._read_args) + + +@DeveloperAPI +class ReadTask(Callable[[], Iterable[Block]]): + """A function used to read blocks from the :class:`~ray.data.Dataset`. + + Read tasks are generated by :meth:`~ray.data.Datasource.get_read_tasks`, + and return a list of ``ray.data.Block`` when called. Initial metadata about the read + operation can be retrieved via the ``metadata`` attribute prior to executing the + read. Final metadata is returned after the read along with the blocks. + + Ray will execute read tasks in remote functions to parallelize execution. + Note that the number of blocks returned can vary at runtime. For example, + if a task is reading a single large file it can return multiple blocks to + avoid running out of memory during the read. + + The initial metadata should reflect all the blocks returned by the read, + e.g., if the metadata says ``num_rows=1000``, the read can return a single + block of 1000 rows, or multiple blocks with 1000 rows altogether. + + The final metadata (returned with the actual block) reflects the exact + contents of the block itself. + """ + + def __init__(self, read_fn: Callable[[], Iterable[Block]], metadata: BlockMetadata): + self._metadata = metadata + self._read_fn = read_fn + + @property + def metadata(self) -> BlockMetadata: + return self._metadata + + @property + def read_fn(self) -> Callable[[], Iterable[Block]]: + return self._read_fn + + def __call__(self) -> Iterable[Block]: + result = self._read_fn() + if not hasattr(result, "__iter__"): + DeprecationWarning( + "Read function must return Iterable[Block], got {}. " + "Probably you need to return `[block]` instead of " + "`block`.".format(result) + ) + yield from result + + +@DeveloperAPI +class RandomIntRowDatasource(Datasource): + """An example datasource that generates rows with random int64 columns. + + Examples: + >>> import ray + >>> from ray.data.datasource import RandomIntRowDatasource + >>> source = RandomIntRowDatasource() # doctest: +SKIP + >>> ray.data.read_datasource( # doctest: +SKIP + ... source, n=10, num_columns=2).take() + {'c_0': 1717767200176864416, 'c_1': 999657309586757214} + {'c_0': 4983608804013926748, 'c_1': 1160140066899844087} + """ + + def __init__(self, n: int, num_columns: int): + self._n = n + self._num_columns = num_columns + + def estimate_inmemory_data_size(self) -> Optional[int]: + return self._n * self._num_columns * 8 + + def get_read_tasks( + self, + parallelism: int, + ) -> List[ReadTask]: + _check_pyarrow_version() + import pyarrow + + read_tasks: List[ReadTask] = [] + n = self._n + num_columns = self._num_columns + block_size = max(1, n // parallelism) + + def make_block(count: int, num_columns: int) -> Block: + return pyarrow.Table.from_arrays( + np.random.randint( + np.iinfo(np.int64).max, size=(num_columns, count), dtype=np.int64 + ), + names=[f"c_{i}" for i in range(num_columns)], + ) + + schema = pyarrow.Table.from_pydict( + {f"c_{i}": [0] for i in range(num_columns)} + ).schema + + i = 0 + while i < n: + count = min(block_size, n - i) + meta = BlockMetadata( + num_rows=count, + size_bytes=8 * count * num_columns, + schema=schema, + input_files=None, + exec_stats=None, + ) + read_tasks.append( + ReadTask( + lambda count=count, num_columns=num_columns: [ + make_block(count, num_columns) + ], + meta, + ) + ) + i += block_size + + return read_tasks + + def get_name(self) -> str: + """Return a human-readable name for this datasource. + This will be used as the names of the read tasks. + Note: overrides the base `Datasource` method. + """ + return "RandomInt" diff --git a/.venv/lib/python3.11/site-packages/ray/data/datasource/file_based_datasource.py b/.venv/lib/python3.11/site-packages/ray/data/datasource/file_based_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..4320cd64ff0c81cbff99f51d47622c972bf0105f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/datasource/file_based_datasource.py @@ -0,0 +1,572 @@ +import io +import logging +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Union, +) + +import numpy as np + +import ray +from ray.data._internal.util import ( + _check_pyarrow_version, + _is_local_scheme, + call_with_retry, + make_async_gen, +) +from ray.data.block import Block, BlockAccessor +from ray.data.context import DataContext +from ray.data.datasource.datasource import Datasource, ReadTask +from ray.data.datasource.file_meta_provider import ( + BaseFileMetadataProvider, + DefaultFileMetadataProvider, +) +from ray.data.datasource.partitioning import ( + Partitioning, + PathPartitionFilter, + PathPartitionParser, +) +from ray.data.datasource.path_util import ( + _has_file_extension, + _resolve_paths_and_filesystem, +) +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import pandas as pd + import pyarrow + + +logger = logging.getLogger(__name__) + + +# We should parallelize file size fetch operations beyond this threshold. +FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD = 16 + +# 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem. +PATHS_PER_FILE_SIZE_FETCH_TASK = 16 + +# The max retry backoff in seconds for opening file. +OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS = 32 + +# The max number of attempts for opening file. +OPEN_FILE_MAX_ATTEMPTS = 10 + + +@DeveloperAPI +@dataclass +class FileShuffleConfig: + """Configuration for file shuffling. + + This configuration object controls how files are shuffled while reading file-based + datasets. + + .. note:: + Even if you provided a seed, you might still observe a non-deterministic row + order. This is because tasks are executed in parallel and their completion + order might vary. If you need to preserve the order of rows, set + `DataContext.get_current().execution_options.preserve_order`. + + Args: + seed: An optional integer seed for the file shuffler. If provided, Ray Data + shuffles files deterministically based on this seed. + + Example: + >>> import ray + >>> from ray.data import FileShuffleConfig + >>> shuffle = FileShuffleConfig(seed=42) + >>> ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea", shuffle=shuffle) + """ # noqa: E501 + + seed: Optional[int] = None + + def __post_init__(self): + """Ensure that the seed is either None or an integer.""" + if self.seed is not None and not isinstance(self.seed, int): + raise ValueError("Seed must be an integer or None.") + + +@DeveloperAPI +class FileBasedDatasource(Datasource): + """File-based datasource for reading files. + + Don't use this class directly. Instead, subclass it and implement `_read_stream()`. + """ + + # If `_WRITE_FILE_PER_ROW` is `True`, this datasource calls `_write_row` and writes + # each row to a file. Otherwise, this datasource calls `_write_block` and writes + # each block to a file. + _WRITE_FILE_PER_ROW = False + _FILE_EXTENSIONS: Optional[Union[str, List[str]]] = None + # Number of threads for concurrent reading within each read task. + # If zero or negative, reading will be performed in the main thread. + _NUM_THREADS_PER_TASK = 0 + + def __init__( + self, + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None, + open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(), + partition_filter: PathPartitionFilter = None, + partitioning: Partitioning = None, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + include_paths: bool = False, + file_extensions: Optional[List[str]] = None, + ): + _check_pyarrow_version() + + self._supports_distributed_reads = not _is_local_scheme(paths) + if not self._supports_distributed_reads and ray.util.client.ray.is_connected(): + raise ValueError( + "Because you're using Ray Client, read tasks scheduled on the Ray " + "cluster can't access your local files. To fix this issue, store " + "files in cloud storage or a distributed filesystem like NFS." + ) + + self._schema = schema + self._open_stream_args = open_stream_args + self._meta_provider = meta_provider + self._partition_filter = partition_filter + self._partitioning = partitioning + self._ignore_missing_paths = ignore_missing_paths + self._include_paths = include_paths + paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem) + paths, file_sizes = map( + list, + zip( + *meta_provider.expand_paths( + paths, + self._filesystem, + partitioning, + ignore_missing_paths=ignore_missing_paths, + ) + ), + ) + + if ignore_missing_paths and len(paths) == 0: + raise ValueError( + "None of the provided paths exist. " + "The 'ignore_missing_paths' field is set to True." + ) + + if self._partition_filter is not None: + # Use partition filter to skip files which are not needed. + path_to_size = dict(zip(paths, file_sizes)) + paths = self._partition_filter(paths) + file_sizes = [path_to_size[p] for p in paths] + if len(paths) == 0: + raise ValueError( + "No input files found to read. Please double check that " + "'partition_filter' field is set properly." + ) + + if file_extensions is not None: + path_to_size = dict(zip(paths, file_sizes)) + paths = [p for p in paths if _has_file_extension(p, file_extensions)] + file_sizes = [path_to_size[p] for p in paths] + if len(paths) == 0: + raise ValueError( + "No input files found to read with the following file extensions: " + f"{file_extensions}. Please double check that " + "'file_extensions' field is set properly." + ) + + _validate_shuffle_arg(shuffle) + self._file_metadata_shuffler = None + if shuffle == "files": + self._file_metadata_shuffler = np.random.default_rng() + elif isinstance(shuffle, FileShuffleConfig): + # Create a NumPy random generator with a fixed seed if provided + self._file_metadata_shuffler = np.random.default_rng(shuffle.seed) + + # Read tasks serialize `FileBasedDatasource` instances, and the list of paths + # can be large. To avoid slow serialization speeds, we store a reference to + # the paths rather than the paths themselves. + self._paths_ref = ray.put(paths) + self._file_sizes_ref = ray.put(file_sizes) + + def _paths(self) -> List[str]: + return ray.get(self._paths_ref) + + def _file_sizes(self) -> List[float]: + return ray.get(self._file_sizes_ref) + + def estimate_inmemory_data_size(self) -> Optional[int]: + total_size = 0 + for sz in self._file_sizes(): + if sz is not None: + total_size += sz + return total_size + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + import numpy as np + + ctx = DataContext.get_current() + open_stream_args = self._open_stream_args + partitioning = self._partitioning + + paths = self._paths() + file_sizes = self._file_sizes() + + if self._file_metadata_shuffler is not None: + files_metadata = list(zip(paths, file_sizes)) + shuffled_files_metadata = [ + files_metadata[i] + for i in self._file_metadata_shuffler.permutation(len(files_metadata)) + ] + paths, file_sizes = list(map(list, zip(*shuffled_files_metadata))) + + read_stream = self._read_stream + filesystem = _wrap_s3_serialization_workaround(self._filesystem) + + if open_stream_args is None: + open_stream_args = {} + + open_input_source = self._open_input_source + + def read_files( + read_paths: Iterable[str], + ) -> Iterable[Block]: + nonlocal filesystem, open_stream_args, partitioning + + DataContext._set_current(ctx) + fs = _unwrap_s3_serialization_workaround(filesystem) + for read_path in read_paths: + partitions: Dict[str, str] = {} + if partitioning is not None: + parse = PathPartitionParser(partitioning) + partitions = parse(read_path) + + with _open_file_with_retry( + read_path, + lambda read_path=read_path: open_input_source( + fs, read_path, **open_stream_args + ), + ) as f: + for block in read_stream(f, read_path): + if partitions: + block = _add_partitions(block, partitions) + if self._include_paths: + block_accessor = BlockAccessor.for_block(block) + block = block_accessor.append_column( + "path", [read_path] * block_accessor.num_rows() + ) + yield block + + def create_read_task_fn(read_paths, num_threads): + def read_task_fn(): + nonlocal num_threads, read_paths + + # TODO: We should refactor the code so that we can get the results in + # order even when using multiple threads. + if ctx.execution_options.preserve_order: + num_threads = 0 + + if num_threads > 0: + if len(read_paths) < num_threads: + num_threads = len(read_paths) + + logger.debug( + f"Reading {len(read_paths)} files with {num_threads} threads." + ) + + yield from make_async_gen( + iter(read_paths), + read_files, + num_workers=num_threads, + ) + else: + logger.debug(f"Reading {len(read_paths)} files.") + yield from read_files(read_paths) + + return read_task_fn + + # fix https://github.com/ray-project/ray/issues/24296 + parallelism = min(parallelism, len(paths)) + + read_tasks = [] + split_paths = np.array_split(paths, parallelism) + split_file_sizes = np.array_split(file_sizes, parallelism) + + for read_paths, file_sizes in zip(split_paths, split_file_sizes): + if len(read_paths) <= 0: + continue + + meta = self._meta_provider( + read_paths, + self._schema, + rows_per_file=self._rows_per_file(), + file_sizes=file_sizes, + ) + + read_task_fn = create_read_task_fn(read_paths, self._NUM_THREADS_PER_TASK) + + read_task = ReadTask(read_task_fn, meta) + + read_tasks.append(read_task) + + return read_tasks + + def _open_input_source( + self, + filesystem: "pyarrow.fs.FileSystem", + path: str, + **open_args, + ) -> "pyarrow.NativeFile": + """Opens a source path for reading and returns the associated Arrow NativeFile. + + The default implementation opens the source path as a sequential input stream, + using ctx.streaming_read_buffer_size as the buffer size if none is given by the + caller. + + Implementations that do not support streaming reads (e.g. that require random + access) should override this method. + """ + import pyarrow as pa + from pyarrow.fs import HadoopFileSystem + + ctx = DataContext.get_current() + + compression = open_args.get("compression", None) + if compression is None: + try: + # If no compression manually given, try to detect + # compression codec from path. + compression = pa.Codec.detect(path).name + except (ValueError, TypeError): + # Arrow's compression inference on the file path + # doesn't work for Snappy, so we double-check ourselves. + import pathlib + + suffix = pathlib.Path(path).suffix + if suffix and suffix[1:] == "snappy": + compression = "snappy" + else: + compression = None + + buffer_size = open_args.pop("buffer_size", None) + if buffer_size is None: + buffer_size = ctx.streaming_read_buffer_size + + if compression == "snappy": + # Arrow doesn't support streaming Snappy decompression since the canonical + # C++ Snappy library doesn't natively support streaming decompression. We + # works around this by manually decompressing the file with python-snappy. + open_args["compression"] = None + else: + open_args["compression"] = compression + + file = call_with_retry( + lambda: filesystem.open_input_stream( + path, buffer_size=buffer_size, **open_args + ), + description=f"open file {path}", + match=ctx.retried_io_errors, + ) + + if compression == "snappy": + import snappy + + stream = io.BytesIO() + if isinstance(filesystem, HadoopFileSystem): + snappy.hadoop_snappy.stream_decompress(src=file, dst=stream) + else: + snappy.stream_decompress(src=file, dst=stream) + stream.seek(0) + + file = pa.PythonFile(stream, mode="r") + + return file + + def _rows_per_file(self): + """Returns the number of rows per file, or None if unknown.""" + return None + + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + """Streaming read a single file. + + This method should be implemented by subclasses. + """ + raise NotImplementedError( + "Subclasses of FileBasedDatasource must implement _read_stream()." + ) + + @property + def supports_distributed_reads(self) -> bool: + return self._supports_distributed_reads + + +def _add_partitions( + data: Union["pyarrow.Table", "pd.DataFrame"], partitions: Dict[str, Any] +) -> Union["pyarrow.Table", "pd.DataFrame"]: + import pandas as pd + import pyarrow as pa + + assert isinstance(data, (pa.Table, pd.DataFrame)) + if isinstance(data, pa.Table): + return _add_partitions_to_table(data, partitions) + if isinstance(data, pd.DataFrame): + return _add_partitions_to_dataframe(data, partitions) + + +def _add_partitions_to_table( + table: "pyarrow.Table", partitions: Dict[str, Any] +) -> "pyarrow.Table": + import pyarrow as pa + import pyarrow.compute as pc + + column_names = set(table.column_names) + for field, value in partitions.items(): + column = pa.array([value] * len(table)) + if field in column_names: + # TODO: Handle cast error. + column_type = table.schema.field(field).type + column = column.cast(column_type) + + values_are_equal = pc.all(pc.equal(column, table[field])) + values_are_equal = values_are_equal.as_py() + + if not values_are_equal: + raise ValueError( + f"Partition column {field} exists in table data, but partition " + f"value '{value}' is different from in-data values: " + f"{table[field].unique().to_pylist()}." + ) + + i = table.schema.get_field_index(field) + table = table.set_column(i, field, column) + else: + table = table.append_column(field, column) + + return table + + +def _add_partitions_to_dataframe( + df: "pd.DataFrame", partitions: Dict[str, Any] +) -> "pd.DataFrame": + import pandas as pd + + for field, value in partitions.items(): + column = pd.Series(data=[value] * len(df), name=field) + + if field in df: + column = column.astype(df[field].dtype) + mask = df[field].notna() + if not df[field][mask].equals(column[mask]): + raise ValueError( + f"Partition column {field} exists in table data, but partition " + f"value '{value}' is different from in-data values: " + f"{list(df[field].unique())}." + ) + + df[field] = column + + return df + + +def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"): + # This is needed because pa.fs.S3FileSystem assumes pa.fs is already + # imported before deserialization. See #17085. + import pyarrow as pa + import pyarrow.fs + + if isinstance(filesystem, pa.fs.S3FileSystem): + return _S3FileSystemWrapper(filesystem) + return filesystem + + +def _unwrap_s3_serialization_workaround( + filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"] +): + if isinstance(filesystem, _S3FileSystemWrapper): + return filesystem.unwrap() + else: + return filesystem + + +class _S3FileSystemWrapper: + def __init__(self, fs: "pyarrow.fs.S3FileSystem"): + self._fs = fs + + def unwrap(self): + return self._fs + + @classmethod + def _reconstruct(cls, fs_reconstruct, fs_args): + # Implicitly trigger S3 subsystem initialization by importing + # pyarrow.fs. + import pyarrow.fs # noqa: F401 + + return cls(fs_reconstruct(*fs_args)) + + def __reduce__(self): + return _S3FileSystemWrapper._reconstruct, self._fs.__reduce__() + + +def _wrap_arrow_serialization_workaround(kwargs: dict) -> dict: + if "filesystem" in kwargs: + kwargs["filesystem"] = _wrap_s3_serialization_workaround(kwargs["filesystem"]) + + return kwargs + + +def _unwrap_arrow_serialization_workaround(kwargs: dict) -> dict: + if isinstance(kwargs.get("filesystem"), _S3FileSystemWrapper): + kwargs["filesystem"] = kwargs["filesystem"].unwrap() + return kwargs + + +def _resolve_kwargs( + kwargs_fn: Callable[[], Dict[str, Any]], **kwargs +) -> Dict[str, Any]: + if kwargs_fn: + kwarg_overrides = kwargs_fn() + kwargs.update(kwarg_overrides) + return kwargs + + +def _open_file_with_retry( + file_path: str, + open_file: Callable[[], "pyarrow.NativeFile"], +) -> "pyarrow.NativeFile": + """Open file with an exponential backoff retry strategy. + + This is to avoid transient task failure with remote storage (such as S3), + when the remote storage throttles the requests. + """ + if OPEN_FILE_MAX_ATTEMPTS < 1: + raise ValueError( + "OPEN_FILE_MAX_ATTEMPTS cannot be negative or 0. Get: " + f"{OPEN_FILE_MAX_ATTEMPTS}" + ) + + return call_with_retry( + open_file, + description=f"open file {file_path}", + match=DataContext.get_current().retried_io_errors, + max_attempts=OPEN_FILE_MAX_ATTEMPTS, + max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS, + ) + + +def _validate_shuffle_arg(shuffle: Optional[str]) -> None: + if not ( + shuffle is None or shuffle == "files" or isinstance(shuffle, FileShuffleConfig) + ): + raise ValueError( + f"Invalid value for 'shuffle': {shuffle}. " + "Valid values are None, 'files', `FileShuffleConfig`." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/data/datasource/file_meta_provider.py b/.venv/lib/python3.11/site-packages/ray/data/datasource/file_meta_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..c6654e9e2708f32ee50577ac186c84961a9e7396 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/datasource/file_meta_provider.py @@ -0,0 +1,484 @@ +import itertools +import logging +import os +import pathlib +import re +from typing import ( + TYPE_CHECKING, + Callable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) + +import numpy as np + +import ray +from ray.data._internal.progress_bar import ProgressBar +from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.util import call_with_retry +from ray.data.block import BlockMetadata +from ray.data.datasource.partitioning import Partitioning +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import pyarrow + + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class FileMetadataProvider: + """Abstract callable that provides metadata for the files of a single dataset block. + + Current subclasses: + - :class:`BaseFileMetadataProvider` + - :class:`ParquetMetadataProvider` + """ + + def _get_block_metadata( + self, + paths: List[str], + schema: Optional[Union[type, "pyarrow.lib.Schema"]], + **kwargs, + ) -> BlockMetadata: + """Resolves and returns block metadata for files in the given paths. + + All file paths provided should belong to a single dataset block. + + Args: + paths: The file paths for a single dataset block. + schema: The user-provided or inferred schema for the given paths, + if any. + + Returns: + BlockMetadata aggregated across the given paths. + """ + raise NotImplementedError + + def __call__( + self, + paths: List[str], + schema: Optional[Union[type, "pyarrow.lib.Schema"]], + **kwargs, + ) -> BlockMetadata: + return self._get_block_metadata(paths, schema, **kwargs) + + +@DeveloperAPI +class BaseFileMetadataProvider(FileMetadataProvider): + """Abstract callable that provides metadata for + :class:`~ray.data.datasource.file_based_datasource.FileBasedDatasource` + implementations that reuse the base :meth:`~ray.data.Datasource.prepare_read` + method. + + Also supports file and file size discovery in input directory paths. + + Current subclasses: + - :class:`DefaultFileMetadataProvider` + """ + + def _get_block_metadata( + self, + paths: List[str], + schema: Optional[Union[type, "pyarrow.lib.Schema"]], + *, + rows_per_file: Optional[int], + file_sizes: List[Optional[int]], + ) -> BlockMetadata: + """Resolves and returns block metadata for files of a single dataset block. + + Args: + paths: The file paths for a single dataset block. These + paths will always be a subset of those previously returned from + :meth:`.expand_paths`. + schema: The user-provided or inferred schema for the given file + paths, if any. + rows_per_file: The fixed number of rows per input file, or None. + file_sizes: Optional file size per input file previously returned + from :meth:`.expand_paths`, where `file_sizes[i]` holds the size of + the file at `paths[i]`. + + Returns: + BlockMetadata aggregated across the given file paths. + """ + raise NotImplementedError + + def expand_paths( + self, + paths: List[str], + filesystem: Optional["pyarrow.fs.FileSystem"], + partitioning: Optional[Partitioning] = None, + ignore_missing_paths: bool = False, + ) -> Iterator[Tuple[str, int]]: + """Expands all paths into concrete file paths by walking directories. + + Also returns a sidecar of file sizes. + + The input paths must be normalized for compatibility with the input + filesystem prior to invocation. + + Args: + paths: A list of file and/or directory paths compatible with the + given filesystem. + filesystem: The filesystem implementation that should be used for + expanding all paths and reading their files. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that + are not found. Defaults to False. + + Returns: + An iterator of `(file_path, file_size)` pairs. None may be returned for the + file size if it is either unknown or will be fetched later by + `_get_block_metadata()`, but the length of + both lists must be equal. + """ + raise NotImplementedError + + +@DeveloperAPI +class DefaultFileMetadataProvider(BaseFileMetadataProvider): + """Default metadata provider for + :class:`~ray.data.datasource.file_based_datasource.FileBasedDatasource` + implementations that reuse the base `prepare_read` method. + + Calculates block size in bytes as the sum of its constituent file sizes, + and assumes a fixed number of rows per file. + """ + + def _get_block_metadata( + self, + paths: List[str], + schema: Optional[Union[type, "pyarrow.lib.Schema"]], + *, + rows_per_file: Optional[int], + file_sizes: List[Optional[int]], + ) -> BlockMetadata: + if rows_per_file is None: + num_rows = None + else: + num_rows = len(paths) * rows_per_file + return BlockMetadata( + num_rows=num_rows, + size_bytes=None if None in file_sizes else int(sum(file_sizes)), + schema=schema, + input_files=paths, + exec_stats=None, + ) # Exec stats filled in later. + + def expand_paths( + self, + paths: List[str], + filesystem: "pyarrow.fs.FileSystem", + partitioning: Optional[Partitioning] = None, + ignore_missing_paths: bool = False, + ) -> Iterator[Tuple[str, int]]: + yield from _expand_paths(paths, filesystem, partitioning, ignore_missing_paths) + + +@DeveloperAPI +class FastFileMetadataProvider(DefaultFileMetadataProvider): + """Fast Metadata provider for + :class:`~ray.data.datasource.file_based_datasource.FileBasedDatasource` + implementations. + + Offers improved performance vs. + :class:`DefaultFileMetadataProvider` + by skipping directory path expansion and file size collection. + While this performance improvement may be negligible for local filesystems, + it can be substantial for cloud storage service providers. + + This should only be used when all input paths exist and are known to be files. + """ + + def expand_paths( + self, + paths: List[str], + filesystem: "pyarrow.fs.FileSystem", + partitioning: Optional[Partitioning] = None, + ignore_missing_paths: bool = False, + ) -> Iterator[Tuple[str, int]]: + if ignore_missing_paths: + raise ValueError( + "`ignore_missing_paths` cannot be set when used with " + "`FastFileMetadataProvider`. All paths must exist when " + "using `FastFileMetadataProvider`." + ) + + logger.warning( + f"Skipping expansion of {len(paths)} path(s). If your paths contain " + f"directories or if file size collection is required, try rerunning this " + f"read with `meta_provider=DefaultFileMetadataProvider()`." + ) + + yield from zip(paths, itertools.repeat(None, len(paths))) + + +def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str: + # NOTE: this is not comprehensive yet, and should be extended as more errors arise. + # NOTE: The latter patterns are raised in Arrow 10+, while the former is raised in + # Arrow < 10. + aws_error_pattern = ( + r"^(?:(.*)AWS Error \[code \d+\]: No response body\.(.*))|" + r"(?:(.*)AWS Error UNKNOWN \(HTTP status 400\) during HeadObject operation: " + r"No response body\.(.*))|" + r"(?:(.*)AWS Error ACCESS_DENIED during HeadObject operation: No response " + r"body\.(.*))$" + ) + if re.match(aws_error_pattern, str(error)): + # Specially handle AWS error when reading files, to give a clearer error + # message to avoid confusing users. The real issue is most likely that the AWS + # S3 file credentials have not been properly configured yet. + if isinstance(paths, str): + # Quote to highlight single file path in error message for better + # readability. List of file paths will be shown up as ['foo', 'boo'], + # so only quote single file path here. + paths = f'"{paths}"' + raise OSError( + ( + f"Failing to read AWS S3 file(s): {paths}. " + "Please check that file exists and has properly configured access. " + "You can also run AWS CLI command to get more detailed error message " + "(e.g., aws s3 ls ). " + "See https://awscli.amazonaws.com/v2/documentation/api/latest/reference/s3/index.html " # noqa + "and https://docs.ray.io/en/latest/data/creating-datasets.html#reading-from-remote-storage " # noqa + "for more information." + ) + ) + else: + raise error + + +def _expand_paths( + paths: List[str], + filesystem: "pyarrow.fs.FileSystem", + partitioning: Optional[Partitioning], + ignore_missing_paths: bool = False, +) -> Iterator[Tuple[str, int]]: + """Get the file sizes for all provided file paths.""" + from pyarrow.fs import LocalFileSystem + + from ray.data.datasource.file_based_datasource import ( + FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD, + ) + from ray.data.datasource.path_util import _unwrap_protocol + + # We break down our processing paths into a few key cases: + # 1. If len(paths) < threshold, fetch the file info for the individual files/paths + # serially. + # 2. If all paths are contained under the same parent directory (or base directory, + # if using partitioning), fetch all file infos at this prefix and filter to the + # provided paths on the client; this should be a single file info request. + # 3. If more than threshold requests required, parallelize them via Ray tasks. + # 1. Small # of paths case. + if ( + len(paths) < FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD + # Local file systems are very fast to hit. + or isinstance(filesystem, LocalFileSystem) + ): + yield from _get_file_infos_serial(paths, filesystem, ignore_missing_paths) + else: + # 2. Common path prefix case. + # Get longest common path of all paths. + common_path = os.path.commonpath(paths) + # If parent directory (or base directory, if using partitioning) is common to + # all paths, fetch all file infos at that prefix and filter the response to the + # provided paths. + if ( + partitioning is not None + and common_path == _unwrap_protocol(partitioning.base_dir) + ) or all(str(pathlib.Path(path).parent) == common_path for path in paths): + yield from _get_file_infos_common_path_prefix( + paths, common_path, filesystem, ignore_missing_paths + ) + # 3. Parallelization case. + else: + # Parallelize requests via Ray tasks. + yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths) + + +def _get_file_infos_serial( + paths: List[str], + filesystem: "pyarrow.fs.FileSystem", + ignore_missing_paths: bool = False, +) -> Iterator[Tuple[str, int]]: + for path in paths: + yield from _get_file_infos(path, filesystem, ignore_missing_paths) + + +def _get_file_infos_common_path_prefix( + paths: List[str], + common_path: str, + filesystem: "pyarrow.fs.FileSystem", + ignore_missing_paths: bool = False, +) -> Iterator[Tuple[str, int]]: + path_to_size = {path: None for path in paths} + for path, file_size in _get_file_infos( + common_path, filesystem, ignore_missing_paths + ): + if path in path_to_size: + path_to_size[path] = file_size + + # Check if all `paths` have file size metadata. + # If any of paths has no file size, fall back to get files metadata in parallel. + # This can happen when path is a directory, but not a file. + have_missing_path = False + for path in paths: + if path_to_size[path] is None: + logger.debug( + f"Finding path {path} not have file size metadata. " + "Fall back to get files metadata in parallel for all paths." + ) + have_missing_path = True + break + + if have_missing_path: + # Parallelize requests via Ray tasks. + yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths) + else: + # Iterate over `paths` to yield each path in original order. + # NOTE: do not iterate over `path_to_size` because the dictionary skips + # duplicated path, while `paths` might contain duplicated path if one wants + # to read same file multiple times. + for path in paths: + yield path, path_to_size[path] + + +def _get_file_infos_parallel( + paths: List[str], + filesystem: "pyarrow.fs.FileSystem", + ignore_missing_paths: bool = False, +) -> Iterator[Tuple[str, int]]: + from ray.data.datasource.file_based_datasource import ( + PATHS_PER_FILE_SIZE_FETCH_TASK, + _unwrap_s3_serialization_workaround, + _wrap_s3_serialization_workaround, + ) + + logger.warning( + f"Expanding {len(paths)} path(s). This may be a HIGH LATENCY " + f"operation on some cloud storage services. Moving all the " + "paths to a common parent directory will lead to faster " + "metadata fetching." + ) + + # Capture the filesystem in the fetcher func closure, but wrap it in our + # serialization workaround to make sure that the pickle roundtrip works as expected. + filesystem = _wrap_s3_serialization_workaround(filesystem) + + def _file_infos_fetcher(paths: List[str]) -> List[Tuple[str, int]]: + fs = _unwrap_s3_serialization_workaround(filesystem) + return list( + itertools.chain.from_iterable( + _get_file_infos(path, fs, ignore_missing_paths) for path in paths + ) + ) + + yield from _fetch_metadata_parallel( + paths, _file_infos_fetcher, PATHS_PER_FILE_SIZE_FETCH_TASK + ) + + +Uri = TypeVar("Uri") +Meta = TypeVar("Meta") + + +def _fetch_metadata_parallel( + uris: List[Uri], + fetch_func: Callable[[List[Uri]], List[Meta]], + desired_uris_per_task: int, + **ray_remote_args, +) -> Iterator[Meta]: + """Fetch file metadata in parallel using Ray tasks.""" + remote_fetch_func = cached_remote_fn(fetch_func) + if ray_remote_args: + remote_fetch_func = remote_fetch_func.options(**ray_remote_args) + # Choose a parallelism that results in a # of metadata fetches per task that + # dominates the Ray task overhead while ensuring good parallelism. + # Always launch at least 2 parallel fetch tasks. + parallelism = max(len(uris) // desired_uris_per_task, 2) + metadata_fetch_bar = ProgressBar( + "Metadata Fetch Progress", total=parallelism, unit="task" + ) + fetch_tasks = [] + for uri_chunk in np.array_split(uris, parallelism): + if len(uri_chunk) == 0: + continue + fetch_tasks.append(remote_fetch_func.remote(uri_chunk)) + results = metadata_fetch_bar.fetch_until_complete(fetch_tasks) + yield from itertools.chain.from_iterable(results) + + +def _get_file_infos( + path: str, filesystem: "pyarrow.fs.FileSystem", ignore_missing_path: bool = False +) -> List[Tuple[str, int]]: + """Get the file info for all files at or under the provided path.""" + from pyarrow.fs import FileType + + file_infos = [] + try: + ctx = ray.data.DataContext.get_current() + file_info = call_with_retry( + lambda: filesystem.get_file_info(path), + description="get file info", + match=ctx.retried_io_errors, + ) + except OSError as e: + _handle_read_os_error(e, path) + if file_info.type == FileType.Directory: + for file_path, file_size in _expand_directory(path, filesystem): + file_infos.append((file_path, file_size)) + elif file_info.type == FileType.File: + file_infos.append((path, file_info.size)) + elif file_info.type == FileType.NotFound and ignore_missing_path: + pass + else: + raise FileNotFoundError(path) + + return file_infos + + +def _expand_directory( + path: str, + filesystem: "pyarrow.fs.FileSystem", + exclude_prefixes: Optional[List[str]] = None, + ignore_missing_path: bool = False, +) -> List[Tuple[str, int]]: + """ + Expand the provided directory path to a list of file paths. + + Args: + path: The directory path to expand. + filesystem: The filesystem implementation that should be used for + reading these files. + exclude_prefixes: The file relative path prefixes that should be + excluded from the returned file set. Default excluded prefixes are + "." and "_". + + Returns: + An iterator of (file_path, file_size) tuples. + """ + if exclude_prefixes is None: + exclude_prefixes = [".", "_"] + + from pyarrow.fs import FileSelector + + selector = FileSelector(path, recursive=True, allow_not_found=ignore_missing_path) + files = filesystem.get_file_info(selector) + base_path = selector.base_dir + out = [] + for file_ in files: + if not file_.is_file: + continue + file_path = file_.path + if not file_path.startswith(base_path): + continue + relative = file_path[len(base_path) :] + if any(relative.startswith(prefix) for prefix in exclude_prefixes): + continue + out.append((file_path, file_.size)) + # We sort the paths to guarantee a stable order. + return sorted(out) diff --git a/.venv/lib/python3.11/site-packages/ray/data/datasource/filename_provider.py b/.venv/lib/python3.11/site-packages/ray/data/datasource/filename_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..592db59a2fe52c8dcdac60a7de33fb531304d5a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/datasource/filename_provider.py @@ -0,0 +1,122 @@ +from typing import Any, Dict, Optional + +from ray.data.block import Block +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class FilenameProvider: + """Generates filenames when you write a :class:`~ray.data.Dataset`. + + Use this class to customize the filenames used when writing a Dataset. + + Some methods write each row to a separate file, while others write each block to a + separate file. For example, :meth:`ray.data.Dataset.write_images` writes individual + rows, and :func:`ray.data.Dataset.write_parquet` writes blocks of data. For more + information about blocks, see :ref:`Data internals `. + + If you're writing each row to a separate file, implement + :meth:`~FilenameProvider.get_filename_for_row`. Otherwise, implement + :meth:`~FilenameProvider.get_filename_for_block`. + + Example: + + This snippet shows you how to encode labels in written files. For example, if + `"cat"` is a label, you might write a file named `cat_000000_000000_000000.png`. + + .. testcode:: + + import ray + from ray.data.datasource import FilenameProvider + + class ImageFilenameProvider(FilenameProvider): + + def __init__(self, file_format: str): + self.file_format = file_format + + def get_filename_for_row(self, row, task_index, block_index, row_index): + return ( + f"{row['label']}_{task_index:06}_{block_index:06}" + f"_{row_index:06}.{self.file_format}" + ) + + ds = ray.data.read_parquet("s3://anonymous@ray-example-data/images.parquet") + ds.write_images( + "/tmp/results", + column="image", + filename_provider=ImageFilenameProvider("png") + ) + """ # noqa: E501 + + def get_filename_for_block( + self, block: Block, task_index: int, block_index: int + ) -> str: + """Generate a filename for a block of data. + + .. note:: + Filenames must be unique and deterministic for a given task and block index. + + A block consists of multiple rows and corresponds to a single output file. + Each task might produce a different number of blocks. + + Args: + block: The block that will be written to a file. + task_index: The index of the the write task. + block_index: The index of the block *within* the write task. + """ + raise NotImplementedError + + def get_filename_for_row( + self, row: Dict[str, Any], task_index: int, block_index: int, row_index: int + ) -> str: + """Generate a filename for a row. + + .. note:: + Filenames must be unique and deterministic for a given task, block, and row + index. + + A block consists of multiple rows, and each row corresponds to a single + output file. Each task might produce a different number of blocks, and each + block might contain a different number of rows. + + .. tip:: + If you require a contiguous row index into the global dataset, use + :meth:`~ray.data.Dataset.iter_rows`. This method is single-threaded and + isn't recommended for large datasets. + + Args: + row: The row that will be written to a file. + task_index: The index of the the write task. + block_index: The index of the block *within* the write task. + row_index: The index of the row *within* the block. + """ + raise NotImplementedError + + +class _DefaultFilenameProvider(FilenameProvider): + def __init__( + self, dataset_uuid: Optional[str] = None, file_format: Optional[str] = None + ): + self._dataset_uuid = dataset_uuid + self._file_format = file_format + + def get_filename_for_block( + self, block: Block, task_index: int, block_index: int + ) -> str: + file_id = f"{task_index:06}_{block_index:06}" + return self._generate_filename(file_id) + + def get_filename_for_row( + self, row: Dict[str, Any], task_index: int, block_index: int, row_index: int + ) -> str: + file_id = f"{task_index:06}_{block_index:06}_{row_index:06}" + return self._generate_filename(file_id) + + def _generate_filename(self, file_id: str) -> str: + filename = "" + if self._dataset_uuid is not None: + filename += f"{self._dataset_uuid}_" + filename += file_id + if self._file_format is not None: + filename += f".{self._file_format}" + return filename diff --git a/.venv/lib/python3.11/site-packages/ray/data/datasource/parquet_meta_provider.py b/.venv/lib/python3.11/site-packages/ray/data/datasource/parquet_meta_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..f43272dec77900a05fcd17559e87f130dbc7950b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/datasource/parquet_meta_provider.py @@ -0,0 +1,252 @@ +from typing import TYPE_CHECKING, List, Optional, Union + +import ray.cloudpickle as cloudpickle +from ray.data._internal.util import call_with_retry +from ray.data.block import BlockMetadata +from ray.data.datasource.file_meta_provider import ( + FileMetadataProvider, + _fetch_metadata_parallel, +) +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import pyarrow + + from ray.data._internal.datasource.parquet_datasource import SerializedFragment + + +FRAGMENTS_PER_META_FETCH = 6 +PARALLELIZE_META_FETCH_THRESHOLD = 24 + +# The application-level exceptions to retry for metadata prefetching task. +# Default to retry on access denied and read timeout errors because AWS S3 would throw +# these transient errors when load is too high. +RETRY_EXCEPTIONS_FOR_META_FETCH_TASK = ["AWS Error ACCESS_DENIED", "Timeout"] +# Maximum number of retries for metadata prefetching task due to transient errors. +RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK = 32 +# Maximum retry back-off interval in seconds for failed metadata prefetching task. +RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK = 64 + + +class _ParquetFileFragmentMetaData: + """Class to store metadata of a Parquet file fragment. This includes + all attributes from `pyarrow.parquet.FileMetaData` except for `schema`, + which is stored in `self.schema_pickled` as a pickled object from + `cloudpickle.loads()`, used in deduplicating schemas across multiple fragments.""" + + def __init__(self, fragment_metadata: "pyarrow.parquet.FileMetaData"): + self.created_by = fragment_metadata.created_by + self.format_version = fragment_metadata.format_version + self.num_columns = fragment_metadata.num_columns + self.num_row_groups = fragment_metadata.num_row_groups + self.num_rows = fragment_metadata.num_rows + self.serialized_size = fragment_metadata.serialized_size + # This is a pickled schema object, to be set later with + # `self.set_schema_pickled()`. To get the underlying schema, use + # `cloudpickle.loads(self.schema_pickled)`. + self.schema_pickled = None + + # Calculate the total byte size of the file fragment using the original + # object, as it is not possible to access row groups from this class. + self.total_byte_size = 0 + for row_group_idx in range(fragment_metadata.num_row_groups): + row_group_metadata = fragment_metadata.row_group(row_group_idx) + self.total_byte_size += row_group_metadata.total_byte_size + + def set_schema_pickled(self, schema_pickled: bytes): + """Note: to get the underlying schema, use + `cloudpickle.loads(self.schema_pickled)`.""" + self.schema_pickled = schema_pickled + + +@DeveloperAPI +class ParquetMetadataProvider(FileMetadataProvider): + """Provides block metadata for Arrow Parquet file fragments.""" + + def _get_block_metadata( + self, + paths: List[str], + schema: Optional[Union[type, "pyarrow.lib.Schema"]], + *, + num_fragments: int, + prefetched_metadata: Optional[List["_ParquetFileFragmentMetaData"]], + ) -> BlockMetadata: + """Resolves and returns block metadata for files of a single dataset block. + + Args: + paths: The file paths for a single dataset block. + schema: The user-provided or inferred schema for the given file + paths, if any. + num_fragments: The number of Parquet file fragments derived from the input + file paths. + prefetched_metadata: Metadata previously returned from + `prefetch_file_metadata()` for each file fragment, where + `prefetched_metadata[i]` contains the metadata for `fragments[i]`. + + Returns: + BlockMetadata aggregated across the given file paths. + """ + if ( + prefetched_metadata is not None + and len(prefetched_metadata) == num_fragments + and all(m is not None for m in prefetched_metadata) + ): + # Fragment metadata was available, construct a normal + # BlockMetadata. + block_metadata = BlockMetadata( + num_rows=sum(m.num_rows for m in prefetched_metadata), + size_bytes=sum(m.total_byte_size for m in prefetched_metadata), + schema=schema, + input_files=paths, + exec_stats=None, + ) # Exec stats filled in later. + else: + # Fragment metadata was not available, construct an empty + # BlockMetadata. + block_metadata = BlockMetadata( + num_rows=None, + size_bytes=None, + schema=schema, + input_files=paths, + exec_stats=None, + ) + return block_metadata + + def prefetch_file_metadata( + self, + fragments: List["pyarrow.dataset.ParquetFileFragment"], + **ray_remote_args, + ) -> Optional[List[_ParquetFileFragmentMetaData]]: + """Pre-fetches file metadata for all Parquet file fragments in a single batch. + + Subsets of the metadata returned will be provided as input to subsequent calls + to ``_get_block_metadata`` together with their corresponding Parquet file + fragments. + + Args: + fragments: The Parquet file fragments to fetch metadata for. + + Returns: + Metadata resolved for each input file fragment, or `None`. Metadata + must be returned in the same order as all input file fragments, such + that `metadata[i]` always contains the metadata for `fragments[i]`. + """ + from ray.data._internal.datasource.parquet_datasource import SerializedFragment + + if len(fragments) > PARALLELIZE_META_FETCH_THRESHOLD: + # Wrap Parquet fragments in serialization workaround. + fragments = [SerializedFragment(fragment) for fragment in fragments] + # Fetch Parquet metadata in parallel using Ray tasks. + + def fetch_func(fragments): + return _fetch_metadata_serialization_wrapper( + fragments, + # Ensure that retry settings are propagated to remote tasks. + retry_match=RETRY_EXCEPTIONS_FOR_META_FETCH_TASK, + retry_max_attempts=RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK, + retry_max_interval=RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK, + ) + + raw_metadata = list( + _fetch_metadata_parallel( + fragments, + fetch_func, + FRAGMENTS_PER_META_FETCH, + **ray_remote_args, + ) + ) + else: + raw_metadata = _fetch_metadata(fragments) + + return _dedupe_metadata(raw_metadata) + + +def _fetch_metadata_serialization_wrapper( + fragments: List["SerializedFragment"], + retry_match: Optional[List[str]], + retry_max_attempts: int, + retry_max_interval: int, +) -> List["pyarrow.parquet.FileMetaData"]: + from ray.data._internal.datasource.parquet_datasource import ( + _deserialize_fragments_with_retry, + ) + + deserialized_fragments = _deserialize_fragments_with_retry(fragments) + try: + metadata = call_with_retry( + lambda: _fetch_metadata(deserialized_fragments), + description="fetch metdata", + match=retry_match, + max_attempts=retry_max_attempts, + max_backoff_s=retry_max_interval, + ) + except OSError as e: + raise RuntimeError( + f"Exceeded maximum number of attempts ({retry_max_attempts}) to retry " + "metadata fetching task. Metadata fetching tasks can fail due to transient " + "errors like rate limiting.\n" + "\n" + "To increase the maximum number of attempts, configure " + "`RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK`. For example:\n" + "```\n" + "ray.data._internal.datasource.parquet_datasource.RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK = 64\n" # noqa: E501 + "```\n" + "To increase the maximum retry backoff interval, configure " + "`RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK`. For example:\n" + "```\n" + "ray.data._internal.datasource.parquet_datasource.RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK = 128\n" # noqa: E501 + "```\n" + "If the error continues to occur, you can also try decresasing the " + "concurency of metadata fetching tasks by setting " + "`NUM_CPUS_FOR_META_FETCH_TASK` to a larger value. For example:\n" + "```\n" + "ray.data._internal.datasource.parquet_datasource.NUM_CPUS_FOR_META_FETCH_TASK = 4.\n" # noqa: E501 + "```\n" + "To change which exceptions to retry on, set " + "`RETRY_EXCEPTIONS_FOR_META_FETCH_TASK` to a list of error messages. For " + "example:\n" + "```\n" + 'ray.data._internal.datasource.parquet_datasource.RETRY_EXCEPTIONS_FOR_META_FETCH_TASK = ["AWS Error ACCESS_DENIED", "Timeout"]\n' # noqa: E501 + "```" + ) from e + return metadata + + +def _fetch_metadata( + fragments: List["pyarrow.dataset.ParquetFileFragment"], +) -> List["pyarrow.parquet.FileMetaData"]: + fragment_metadata = [] + for f in fragments: + try: + fragment_metadata.append(f.metadata) + except AttributeError: + break + return fragment_metadata + + +def _dedupe_metadata( + raw_metadatas: List["pyarrow.parquet.FileMetaData"], +) -> List[_ParquetFileFragmentMetaData]: + """For datasets with a large number of columns, the FileMetaData + (in particular the schema) can be very large. We can reduce the + memory usage by only keeping unique schema objects across all + file fragments. This method deduplicates the schemas and returns + a list of `_ParquetFileFragmentMetaData` objects.""" + schema_to_id = {} # schema_id -> serialized_schema + id_to_schema = {} # serialized_schema -> schema_id + stripped_metadatas = [] + for fragment_metadata in raw_metadatas: + stripped_md = _ParquetFileFragmentMetaData(fragment_metadata) + + schema_ser = cloudpickle.dumps(fragment_metadata.schema.to_arrow_schema()) + if schema_ser not in schema_to_id: + schema_id = len(schema_to_id) + schema_to_id[schema_ser] = schema_id + id_to_schema[schema_id] = schema_ser + stripped_md.set_schema_pickled(schema_ser) + else: + schema_id = schema_to_id.get(schema_ser) + existing_schema_ser = id_to_schema[schema_id] + stripped_md.set_schema_pickled(existing_schema_ser) + stripped_metadatas.append(stripped_md) + return stripped_metadatas diff --git a/.venv/lib/python3.11/site-packages/ray/data/exceptions.py b/.venv/lib/python3.11/site-packages/ray/data/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..894d0e1504fcfcbcaca3d8e3abc6b9d84a7aabd0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/exceptions.py @@ -0,0 +1,91 @@ +import logging +from typing import Callable + +from ray.data._internal.logging import get_log_directory +from ray.data.context import DataContext +from ray.exceptions import UserCodeException +from ray.util import log_once +from ray.util.annotations import DeveloperAPI +from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class RayDataUserCodeException(UserCodeException): + """Represents an Exception originating from user code, e.g. + user-specified UDF used in a Ray Data transformation. + + By default, the frames corresponding to Ray Data internal files are + omitted from the stack trace logged to stdout, but will still be + emitted to the Ray Data specific log file. To emit all stack frames to stdout, + set `DataContext.log_internal_stack_trace_to_stdout` to True.""" + + pass + + +@DeveloperAPI +class SystemException(Exception): + """Represents an Exception originating from Ray Data internal code + or Ray Core private code paths, as opposed to user code. When + Exceptions of this form are raised, it likely indicates a bug + in Ray Data or Ray Core.""" + + pass + + +@DeveloperAPI +def omit_traceback_stdout(fn: Callable) -> Callable: + """Decorator which runs the function, and if there is an exception raised, + drops the stack trace before re-raising the exception. The original exception, + including the full unmodified stack trace, is always written to the Ray Data + log file at `data_exception_logger._log_path`. + + This is useful for stripping long stack traces of internal Ray Data code, + which can otherwise obfuscate user code errors.""" + + def handle_trace(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + # Only log the full internal stack trace to stdout when configured + # via DataContext, or when the Ray Debugger is enabled. + # The full stack trace will always be emitted to the Ray Data log file. + log_to_stdout = DataContext.get_current().log_internal_stack_trace_to_stdout + if _is_ray_debugger_post_mortem_enabled(): + logger.exception("Full stack trace:") + raise e + + is_user_code_exception = isinstance(e, UserCodeException) + if is_user_code_exception: + # Exception has occurred in user code. + if not log_to_stdout and log_once("ray_data_exception_internal_hidden"): + logger.error( + "Exception occurred in user code, with the abbreviated stack " + "trace below. By default, the Ray Data internal stack trace " + "is omitted from stdout, and only written to the Ray Data log " + f"files at {get_log_directory()}. To " + "output the full stack trace to stdout, set " + "`DataContext.log_internal_stack_trace_to_stdout` to True." + ) + else: + # Exception has occurred in internal Ray Data / Ray Core code. + logger.error( + "Exception occurred in Ray Data or Ray Core internal code. " + "If you continue to see this error, please open an issue on " + "the Ray project GitHub page with the full stack trace below: " + "https://github.com/ray-project/ray/issues/new/choose" + ) + + should_hide_traceback = is_user_code_exception and not log_to_stdout + logger.exception( + "Full stack trace:", + exc_info=True, + extra={"hide": should_hide_traceback}, + ) + if is_user_code_exception: + raise e.with_traceback(None) + else: + raise e.with_traceback(None) from SystemException() + + return handle_trace diff --git a/.venv/lib/python3.11/site-packages/ray/data/grouped_data.py b/.venv/lib/python3.11/site-packages/ray/data/grouped_data.py new file mode 100644 index 0000000000000000000000000000000000000000..277c201e14c5472ca15137472e28bb963a5ee983 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/grouped_data.py @@ -0,0 +1,494 @@ +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +from ray.data._internal.aggregate import Count, Max, Mean, Min, Std, Sum +from ray.data._internal.compute import ComputeStrategy +from ray.data._internal.logical.interfaces import LogicalPlan +from ray.data._internal.logical.operators.all_to_all_operator import Aggregate +from ray.data.aggregate import AggregateFn +from ray.data.block import ( + BlockAccessor, + CallableClass, + DataBatch, + UserDefinedFunction, + _get_block_boundaries, +) +from ray.data.dataset import Dataset +from ray.util.annotations import PublicAPI + +CDS_API_GROUP = "Computations or Descriptive Stats" +FA_API_GROUP = "Function Application" + + +class GroupedData: + """Represents a grouped dataset created by calling ``Dataset.groupby()``. + + The actual groupby is deferred until an aggregation is applied. + """ + + def __init__( + self, + dataset: Dataset, + key: Optional[Union[str, List[str]]], + ): + """Construct a dataset grouped by key (internal API). + + The constructor is not part of the GroupedData API. + Use the ``Dataset.groupby()`` method to construct one. + """ + self._dataset = dataset + self._key = key + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(dataset={self._dataset}, " f"key={self._key!r})" + ) + + @PublicAPI(api_group=FA_API_GROUP) + def aggregate(self, *aggs: AggregateFn) -> Dataset: + """Implements an accumulator-based aggregation. + + Args: + aggs: Aggregations to do. + + Returns: + The output is an dataset of ``n + 1`` columns where the first column + is the groupby key and the second through ``n + 1`` columns are the + results of the aggregations. + If groupby key is ``None`` then the key part of return is omitted. + """ + + plan = self._dataset._plan.copy() + op = Aggregate( + self._dataset._logical_plan.dag, + key=self._key, + aggs=aggs, + ) + logical_plan = LogicalPlan(op, self._dataset.context) + return Dataset( + plan, + logical_plan, + ) + + def _aggregate_on( + self, + agg_cls: type, + on: Union[str, List[str]], + *args, + **kwargs, + ): + """Helper for aggregating on a particular subset of the dataset. + + This validates the `on` argument, and converts a list of column names + to a multi-aggregation. A null `on` results in a + multi-aggregation on all columns for an Arrow Dataset, and a single + aggregation on the entire row for a simple Dataset. + """ + aggs = self._dataset._build_multicolumn_aggs( + agg_cls, on, *args, skip_cols=self._key, **kwargs + ) + return self.aggregate(*aggs) + + @PublicAPI(api_group=FA_API_GROUP) + def map_groups( + self, + fn: UserDefinedFunction[DataBatch, DataBatch], + *, + compute: Union[str, ComputeStrategy] = None, + batch_format: Optional[str] = "default", + fn_args: Optional[Iterable[Any]] = None, + fn_kwargs: Optional[Dict[str, Any]] = None, + fn_constructor_args: Optional[Iterable[Any]] = None, + fn_constructor_kwargs: Optional[Dict[str, Any]] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + concurrency: Optional[Union[int, Tuple[int, int]]] = None, + **ray_remote_args, + ) -> "Dataset": + """Apply the given function to each group of records of this dataset. + + While map_groups() is very flexible, note that it comes with downsides: + * It may be slower than using more specific methods such as min(), max(). + * It requires that each group fits in memory on a single node. + + In general, prefer to use aggregate() instead of map_groups(). + + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + + Examples: + >>> # Return a single record per group (list of multiple records in, + >>> # list of a single record out). + >>> import ray + >>> import pandas as pd + >>> import numpy as np + >>> # Get first value per group. + >>> ds = ray.data.from_items([ # doctest: +SKIP + ... {"group": 1, "value": 1}, + ... {"group": 1, "value": 2}, + ... {"group": 2, "value": 3}, + ... {"group": 2, "value": 4}]) + >>> ds.groupby("group").map_groups( # doctest: +SKIP + ... lambda g: {"result": np.array([g["value"][0]])}) + + >>> # Return multiple records per group (dataframe in, dataframe out). + >>> df = pd.DataFrame( + ... {"A": ["a", "a", "b"], "B": [1, 1, 3], "C": [4, 6, 5]} + ... ) + >>> ds = ray.data.from_pandas(df) # doctest: +SKIP + >>> grouped = ds.groupby("A") # doctest: +SKIP + >>> grouped.map_groups( # doctest: +SKIP + ... lambda g: g.apply( + ... lambda c: c / g[c.name].sum() if c.name in ["B", "C"] else c + ... ) + ... ) # doctest: +SKIP + + Args: + fn: The function to apply to each group of records, or a class type + that can be instantiated to create such a callable. It takes as + input a batch of all records from a single group, and returns a + batch of zero or more records, similar to map_batches(). + compute: The compute strategy, either "tasks" (default) to use Ray + tasks, ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed-size actor + pool, or ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` for an + autoscaling actor pool. + batch_format: Specify ``"default"`` to use the default block format + (NumPy), ``"pandas"`` to select ``pandas.DataFrame``, "pyarrow" to + select ``pyarrow.Table``, or ``"numpy"`` to select + ``Dict[str, numpy.ndarray]``, or None to return the underlying block + exactly as is with no additional formatting. + fn_args: Arguments to `fn`. + fn_kwargs: Keyword arguments to `fn`. + fn_constructor_args: Positional arguments to pass to ``fn``'s constructor. + You can only provide this if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor. + This can only be provided if ``fn`` is a callable class. These arguments + are top-level arguments in the underlying Ray actor construction task. + num_cpus: The number of CPUs to reserve for each parallel map worker. + num_gpus: The number of GPUs to reserve for each parallel map worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel map + worker. + ray_remote_args: Additional resource requirements to request from + Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See + :func:`ray.remote` for details. + + Returns: + The return type is determined by the return type of ``fn``, and the return + value is combined from results of all groups. + """ + # Globally sort records by key. + # Note that sort() will ensure that records of the same key partitioned + # into the same block. + if self._key is not None: + sorted_ds = self._dataset.sort(self._key) + else: + sorted_ds = self._dataset.repartition(1) + + # The batch is the entire block, because we have batch_size=None for + # map_batches() below. + def apply_udf_to_groups(udf, batch, *args, **kwargs): + block = BlockAccessor.batch_to_block(batch) + block_accessor = BlockAccessor.for_block(block) + + # Get the list of boundaries including first start and last end indices + if self._key: + projected_block = block_accessor.to_numpy(self._key) + + # get_block_boundaries() expects a list of arrays + if isinstance(self._key, str): + projected_block = [projected_block] + else: + # projected_block is a dict of arrays + projected_block = list(projected_block.values()) + + boundaries = _get_block_boundaries(projected_block) + else: + boundaries = [0, block_accessor.num_rows()] + + for start, end in zip(boundaries[:-1], boundaries[1:]): + group_block = block_accessor.slice(start, end, copy=False) + group_block_accessor = BlockAccessor.for_block(group_block) + # Convert block of each group to batch format here, because the + # block format here can be different from batch format + # (e.g. block is Arrow format, and batch is NumPy format). + group_batch = group_block_accessor.to_batch_format(batch_format) + applied = udf(group_batch, *args, **kwargs) + yield applied + + if isinstance(fn, CallableClass): + + class wrapped_fn: + def __init__(self, *args, **kwargs): + self.fn = fn(*args, **kwargs) + + def __call__(self, batch, *args, **kwargs): + yield from apply_udf_to_groups(self.fn, batch, *args, **kwargs) + + else: + + def wrapped_fn(batch, *args, **kwargs): + yield from apply_udf_to_groups(fn, batch, *args, **kwargs) + + # Change the name of the wrapped function so that users see the name of their + # function rather than `wrapped_fn` in the progress bar. + if isinstance(fn, partial): + wrapped_fn.__name__ = fn.func.__name__ + else: + wrapped_fn.__name__ = fn.__name__ + + # Note we set batch_size=None here, so it will use the entire block as a batch, + # which ensures that each group will be contained within a batch in entirety. + return sorted_ds._map_batches_without_batch_size_validation( + wrapped_fn, + batch_size=None, + compute=compute, + batch_format=batch_format, + zero_copy_batch=False, + fn_args=fn_args, + fn_kwargs=fn_kwargs, + fn_constructor_args=fn_constructor_args, + fn_constructor_kwargs=fn_constructor_kwargs, + num_cpus=num_cpus, + num_gpus=num_gpus, + concurrency=concurrency, + ray_remote_args_fn=None, + **ray_remote_args, + ) + + @PublicAPI(api_group=CDS_API_GROUP) + def count(self) -> Dataset: + """Compute count aggregation. + + Examples: + >>> import ray + >>> ray.data.from_items([ # doctest: +SKIP + ... {"A": x % 3, "B": x} for x in range(100)]).groupby( # doctest: +SKIP + ... "A").count() # doctest: +SKIP + + Returns: + A dataset of ``[k, v]`` columns where ``k`` is the groupby key and + ``v`` is the number of rows with that key. + If groupby key is ``None`` then the key part of return is omitted. + """ + return self.aggregate(Count()) + + @PublicAPI(api_group=CDS_API_GROUP) + def sum( + self, on: Union[str, List[str]] = None, ignore_nulls: bool = True + ) -> Dataset: + r"""Compute grouped sum aggregation. + + Examples: + >>> import ray + >>> ray.data.from_items([ # doctest: +SKIP + ... (i % 3, i, i**2) # doctest: +SKIP + ... for i in range(100)]) \ # doctest: +SKIP + ... .groupby(lambda x: x[0] % 3) \ # doctest: +SKIP + ... .sum(lambda x: x[2]) # doctest: +SKIP + >>> ray.data.range(100).groupby("id").sum() # doctest: +SKIP + >>> ray.data.from_items([ # doctest: +SKIP + ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP + ... for i in range(100)]) \ # doctest: +SKIP + ... .groupby("A") \ # doctest: +SKIP + ... .sum(["B", "C"]) # doctest: +SKIP + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values will be ignored when computing the sum; if ``False``, + if a null value is encountered, the output will be null. + We consider np.nan, None, and pd.NaT to be null values. + Default is ``True``. + + Returns: + The sum result. + + For different values of ``on``, the return varies: + + - ``on=None``: a dataset containing a groupby key column, + ``"k"``, and a column-wise sum column for each original column + in the dataset. + - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1`` + columns where the first column is the groupby key and the second + through ``n + 1`` columns are the results of the aggregations. + + If groupby key is ``None`` then the key part of return is omitted. + """ + return self._aggregate_on(Sum, on, ignore_nulls=ignore_nulls) + + @PublicAPI(api_group=CDS_API_GROUP) + def min( + self, on: Union[str, List[str]] = None, ignore_nulls: bool = True + ) -> Dataset: + """Compute grouped min aggregation. + + Examples: + >>> import ray + >>> ray.data.le(100).groupby("value").min() # doctest: +SKIP + >>> ray.data.from_items([ # doctest: +SKIP + ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP + ... for i in range(100)]) \ # doctest: +SKIP + ... .groupby("A") \ # doctest: +SKIP + ... .min(["B", "C"]) # doctest: +SKIP + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values will be ignored when computing the min; if ``False``, + if a null value is encountered, the output will be null. + We consider np.nan, None, and pd.NaT to be null values. + Default is ``True``. + + Returns: + The min result. + + For different values of ``on``, the return varies: + + - ``on=None``: a dataset containing a groupby key column, + ``"k"``, and a column-wise min column for each original column in + the dataset. + - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1`` + columns where the first column is the groupby key and the second + through ``n + 1`` columns are the results of the aggregations. + + If groupby key is ``None`` then the key part of return is omitted. + """ + return self._aggregate_on(Min, on, ignore_nulls=ignore_nulls) + + @PublicAPI(api_group=CDS_API_GROUP) + def max( + self, on: Union[str, List[str]] = None, ignore_nulls: bool = True + ) -> Dataset: + """Compute grouped max aggregation. + + Examples: + >>> import ray + >>> ray.data.le(100).groupby("value").max() # doctest: +SKIP + >>> ray.data.from_items([ # doctest: +SKIP + ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP + ... for i in range(100)]) \ # doctest: +SKIP + ... .groupby("A") \ # doctest: +SKIP + ... .max(["B", "C"]) # doctest: +SKIP + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values will be ignored when computing the max; if ``False``, + if a null value is encountered, the output will be null. + We consider np.nan, None, and pd.NaT to be null values. + Default is ``True``. + + Returns: + The max result. + + For different values of ``on``, the return varies: + + - ``on=None``: a dataset containing a groupby key column, + ``"k"``, and a column-wise max column for each original column in + the dataset. + - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1`` + columns where the first column is the groupby key and the second + through ``n + 1`` columns are the results of the aggregations. + + If groupby key is ``None`` then the key part of return is omitted. + """ + return self._aggregate_on(Max, on, ignore_nulls=ignore_nulls) + + @PublicAPI(api_group=CDS_API_GROUP) + def mean( + self, on: Union[str, List[str]] = None, ignore_nulls: bool = True + ) -> Dataset: + """Compute grouped mean aggregation. + + Examples: + >>> import ray + >>> ray.data.le(100).groupby("value").mean() # doctest: +SKIP + >>> ray.data.from_items([ # doctest: +SKIP + ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP + ... for i in range(100)]) \ # doctest: +SKIP + ... .groupby("A") \ # doctest: +SKIP + ... .mean(["B", "C"]) # doctest: +SKIP + + Args: + on: a column name or a list of column names to aggregate. + ignore_nulls: Whether to ignore null values. If ``True``, null + values will be ignored when computing the mean; if ``False``, + if a null value is encountered, the output will be null. + We consider np.nan, None, and pd.NaT to be null values. + Default is ``True``. + + Returns: + The mean result. + + For different values of ``on``, the return varies: + + - ``on=None``: a dataset containing a groupby key column, + ``"k"``, and a column-wise mean column for each original column + in the dataset. + - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1`` + columns where the first column is the groupby key and the second + through ``n + 1`` columns are the results of the aggregations. + + If groupby key is ``None`` then the key part of return is omitted. + """ + return self._aggregate_on(Mean, on, ignore_nulls=ignore_nulls) + + @PublicAPI(api_group=CDS_API_GROUP) + def std( + self, + on: Union[str, List[str]] = None, + ddof: int = 1, + ignore_nulls: bool = True, + ) -> Dataset: + """Compute grouped standard deviation aggregation. + + Examples: + >>> import ray + >>> ray.data.range(100).groupby("id").std(ddof=0) # doctest: +SKIP + >>> ray.data.from_items([ # doctest: +SKIP + ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP + ... for i in range(100)]) \ # doctest: +SKIP + ... .groupby("A") \ # doctest: +SKIP + ... .std(["B", "C"]) # doctest: +SKIP + + NOTE: This uses Welford's online method for an accumulator-style + computation of the standard deviation. This method was chosen due to + it's numerical stability, and it being computable in a single pass. + This may give different (but more accurate) results than NumPy, Pandas, + and sklearn, which use a less numerically stable two-pass algorithm. + See + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + + Args: + on: a column name or a list of column names to aggregate. + ddof: Delta Degrees of Freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + ignore_nulls: Whether to ignore null values. If ``True``, null + values will be ignored when computing the std; if ``False``, + if a null value is encountered, the output will be null. + We consider np.nan, None, and pd.NaT to be null values. + Default is ``True``. + + Returns: + The standard deviation result. + + For different values of ``on``, the return varies: + + - ``on=None``: a dataset containing a groupby key column, + ``"k"``, and a column-wise std column for each original column in + the dataset. + - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1`` + columns where the first column is the groupby key and the second + through ``n + 1`` columns are the results of the aggregations. + + If groupby key is ``None`` then the key part of return is omitted. + """ + return self._aggregate_on(Std, on, ignore_nulls=ignore_nulls, ddof=ddof) + + +# Backwards compatibility alias. +GroupedDataset = GroupedData diff --git a/.venv/lib/python3.11/site-packages/ray/data/iterator.py b/.venv/lib/python3.11/site-packages/ray/data/iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..d22271328ce97211f184c6d08709ff88e5371184 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/iterator.py @@ -0,0 +1,931 @@ +import abc +import time +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) + +import numpy as np + +from ray.data._internal.block_batching.iter_batches import iter_batches +from ray.data._internal.execution.interfaces import RefBundle +from ray.data._internal.logical.operators.input_data_operator import InputData +from ray.data._internal.logical.optimizers import LogicalPlan +from ray.data._internal.plan import ExecutionPlan +from ray.data._internal.stats import DatasetStats, StatsManager +from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + import tensorflow as tf + import torch + + from ray.data.dataset import ( + CollatedData, + MaterializedDataset, + Schema, + TensorFlowTensorBatchType, + TorchBatchType, + ) + + +T = TypeVar("T") + + +class _IterableFromIterator(Iterable[T]): + def __init__(self, iterator_gen: Callable[[], Iterator[T]]): + """Constructs an Iterable from an iterator generator. + + Args: + iterator_gen: A function that returns an iterator each time it + is called. For example, this can be a generator function. + """ + self.iterator_gen = iterator_gen + + def __iter__(self): + return self.iterator_gen() + + +@PublicAPI +class DataIterator(abc.ABC): + """An iterator for reading records from a :class:`~Dataset`. + + For Datasets, each iteration call represents a complete read of all items in the + Dataset. + + If using Ray Train, each trainer actor should get its own iterator by calling + :meth:`ray.train.get_dataset_shard("train") + `. + + Examples: + >>> import ray + >>> ds = ray.data.range(5) + >>> ds + Dataset(num_rows=5, schema={id: int64}) + >>> ds.iterator() + DataIterator(Dataset(num_rows=5, schema={id: int64})) + """ + + @abc.abstractmethod + def _to_ref_bundle_iterator( + self, + ) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]: + """Returns the iterator to use for `iter_batches`. + + Returns: + A tuple. The first item of the tuple is an iterator over RefBundles. + The second item of the tuple is a DatasetStats object used for recording + stats during iteration. + The third item is a boolean indicating if the blocks can be safely cleared + after use. + """ + raise NotImplementedError + + @PublicAPI + def iter_batches( + self, + *, + prefetch_batches: int = 1, + batch_size: int = 256, + batch_format: Optional[str] = "default", + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + _collate_fn: Optional[Callable[[DataBatch], "CollatedData"]] = None, + _finalize_fn: Optional[Callable[[Any], Any]] = None, + ) -> Iterable[DataBatch]: + """Return a batched iterable over the dataset. + + Examples: + >>> import ray + >>> for batch in ray.data.range( + ... 1000000 + ... ).iterator().iter_batches(): # doctest: +SKIP + ... print(batch) # doctest: +SKIP + + Time complexity: O(1) + + Args: + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 1. + batch_size: The number of rows in each batch, or None to use entire blocks + as batches (blocks may contain different number of rows). + The final batch may include fewer than ``batch_size`` rows if + ``drop_last`` is ``False``. Defaults to 256. + batch_format: Specify ``"default"`` to use the default block format + (NumPy), ``"pandas"`` to select ``pandas.DataFrame``, "pyarrow" to + select ``pyarrow.Table``, or ``"numpy"`` to select + ``Dict[str, numpy.ndarray]``, or None to return the underlying block + exactly as is with no additional formatting. + drop_last: Whether to drop the last batch if it's incomplete. + local_shuffle_buffer_size: If non-None, the data will be randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer will be drained. + local_shuffle_seed: The seed to use for the local random shuffle. + + Returns: + An iterable over record batches. + """ + batch_format = _apply_batch_format(batch_format) + + def _create_iterator() -> Iterator[DataBatch]: + time_start = time.perf_counter() + # Iterate through the dataset from the start each time + # _iterator_gen is called. + # This allows multiple iterations of the dataset without + # needing to explicitly call `iter_batches()` multiple times. + ( + ref_bundles_iterator, + stats, + blocks_owned_by_consumer, + ) = self._to_ref_bundle_iterator() + + iterator = iter( + iter_batches( + ref_bundles_iterator, + stats=stats, + clear_block_after_read=blocks_owned_by_consumer, + batch_size=batch_size, + batch_format=batch_format, + drop_last=drop_last, + collate_fn=_collate_fn, + finalize_fn=_finalize_fn, + shuffle_buffer_min_size=local_shuffle_buffer_size, + shuffle_seed=local_shuffle_seed, + prefetch_batches=prefetch_batches, + ) + ) + + dataset_tag = self._get_dataset_tag() + + if stats: + stats.iter_initialize_s.add(time.perf_counter() - time_start) + + for batch in iterator: + yield batch + StatsManager.update_iteration_metrics(stats, dataset_tag) + StatsManager.clear_iteration_metrics(dataset_tag) + + if stats: + stats.iter_total_s.add(time.perf_counter() - time_start) + + return _IterableFromIterator(_create_iterator) + + def _get_dataset_tag(self) -> str: + return "unknown_dataset" + + @PublicAPI + def iter_rows(self) -> Iterable[Dict[str, Any]]: + """Return a local row iterable over the dataset. + + If the dataset is a tabular dataset (Arrow/Pandas blocks), dicts + are yielded for each row by the iterator. If the dataset is not tabular, + the raw row is yielded. + + Examples: + >>> import ray + >>> dataset = ray.data.range(10) + >>> next(iter(dataset.iterator().iter_rows())) + {'id': 0} + + Time complexity: O(1) + + Returns: + An iterable over rows of the dataset. + """ + batch_iterable = self.iter_batches( + batch_size=None, batch_format=None, prefetch_batches=1 + ) + + def _wrapped_iterator(): + for batch in batch_iterable: + batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch)) + for row in batch.iter_rows(public_row_format=True): + yield row + + return _IterableFromIterator(_wrapped_iterator) + + @abc.abstractmethod + @PublicAPI + def stats(self) -> str: + """Returns a string containing execution timing information.""" + raise NotImplementedError + + @abc.abstractmethod + def schema(self) -> "Schema": + """Return the schema of the dataset iterated over.""" + raise NotImplementedError + + @PublicAPI + def iter_torch_batches( + self, + *, + prefetch_batches: int = 1, + batch_size: Optional[int] = 256, + dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, + device: str = "auto", + collate_fn: Optional[Callable[[Dict[str, np.ndarray]], "CollatedData"]] = None, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + ) -> Iterable["TorchBatchType"]: + """Return a batched iterable of Torch Tensors over the dataset. + + This iterable yields a dictionary of column-tensors. If you are looking for + more flexibility in the tensor conversion (e.g. casting dtypes) or the batch + format, try using :meth:`~ray.data.DataIterator.iter_batches` directly. + + Examples: + >>> import ray + >>> for batch in ray.data.range( + ... 12, + ... ).iterator().iter_torch_batches(batch_size=4): + ... print(batch) + {'id': tensor([0, 1, 2, 3])} + {'id': tensor([4, 5, 6, 7])} + {'id': tensor([ 8, 9, 10, 11])} + + Use the ``collate_fn`` to customize how the tensor batch is created. + + >>> from typing import Any, Dict + >>> import torch + >>> import numpy as np + >>> import ray + >>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any: + ... return torch.stack( + ... [torch.as_tensor(array) for array in batch.values()], + ... axis=1 + ... ) + >>> iterator = ray.data.from_items([ + ... {"col_1": 1, "col_2": 2}, + ... {"col_1": 3, "col_2": 4}]).iterator() + >>> for batch in iterator.iter_torch_batches(collate_fn=collate_fn): + ... print(batch) + tensor([[1, 2], + [3, 4]]) + + Time complexity: O(1) + + Args: + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 1. + batch_size: The number of rows in each batch, or None to use entire blocks + as batches (blocks may contain different number of rows). + The final batch may include fewer than ``batch_size`` rows if + ``drop_last`` is ``False``. Defaults to 256. + dtypes: The Torch dtype(s) for the created tensor(s); if None, the dtype + will be inferred from the tensor data. You can't use this parameter + with ``collate_fn``. + device: The device on which the tensor should be placed. Defaults to + "auto" which moves the tensors to the appropriate device when the + Dataset is passed to Ray Train and ``collate_fn`` is not provided. + Otherwise, defaults to CPU. You can't use this parameter with + ``collate_fn``. + collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch. + When this parameter is specified, the user should manually handle the + host to device data transfer outside of ``collate_fn``. + This is useful for further processing the data after it has been + batched. Potential use cases include collating along a dimension other + than the first, padding sequences of various lengths, or generally + handling batches of different length tensors. If not provided, the + default collate function is used which simply converts the batch of + numpy arrays to a batch of PyTorch tensors. This API is still + experimental and is subject to change. You can't use this parameter in + conjunction with ``dtypes`` or ``device``. + drop_last: Whether to drop the last batch if it's incomplete. + local_shuffle_buffer_size: If non-None, the data will be randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer will be drained. This + buffer size must be greater than or equal to ``batch_size``, and + therefore ``batch_size`` must also be specified when using local + shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + + Returns: + An iterable over Torch Tensor batches. + """ + + from ray.air._internal.torch_utils import ( + convert_ndarray_batch_to_torch_tensor_batch, + ) + from ray.train.torch import get_device + + if collate_fn is not None and (dtypes is not None or device != "auto"): + raise ValueError( + "collate_fn cannot be used with dtypes and device." + "You should manually move the output Torch tensors to the" + "desired dtype and device outside of collate_fn." + ) + + if device == "auto": + # Use the appropriate device for Ray Train, or falls back to CPU if + # Ray Train is not being used. + device = get_device() + + if collate_fn is None: + # The default collate_fn handles formatting and Tensor creation. + # Here, we set device=None to defer host to device data transfer + # to the subsequent finalize_fn. + def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): + return convert_ndarray_batch_to_torch_tensor_batch( + batch, + dtypes=dtypes, + device=None, + ) + + # The default finalize_fn handles the host to device data transfer. + # This is executed in a 1-thread pool separately from collate_fn + # to allow independent parallelism of these steps. + def finalize_fn(batch: Union["torch.Tensor", Dict[str, "torch.Tensor"]]): + if device is not None: + if isinstance(batch, dict): + for k, t in batch.items(): + batch[k] = t.to(device=device) + else: + batch = batch.to(device=device) + return batch + + else: + finalize_fn = None + + return self.iter_batches( + prefetch_batches=prefetch_batches, + batch_size=batch_size, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + _collate_fn=collate_fn, + _finalize_fn=finalize_fn, + ) + + def iter_tf_batches( + self, + *, + prefetch_batches: int = 1, + batch_size: Optional[int] = 256, + dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + ) -> Iterable["TensorFlowTensorBatchType"]: + """Return a batched iterable of TensorFlow Tensors over the dataset. + + This iterable will yield single-tensor batches of the underlying dataset + consists of a single column; otherwise, it will yield a dictionary of + column-tensors. + + .. tip:: + If you don't need the additional flexibility provided by this method, + consider using :meth:`~ray.data.Dataset.to_tf` instead. It's easier + to use. + + Examples: + >>> import ray + >>> for batch in ray.data.range( # doctest: +SKIP + ... 12, + ... ).iter_tf_batches(batch_size=4): + ... print(batch.shape) # doctest: +SKIP + (4, 1) + (4, 1) + (4, 1) + + Time complexity: O(1) + + Args: + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 1. + batch_size: The number of rows in each batch, or None to use entire blocks + as batches (blocks may contain different number of rows). + The final batch may include fewer than ``batch_size`` rows if + ``drop_last`` is ``False``. Defaults to 256. + dtypes: The TensorFlow dtype(s) for the created tensor(s); if None, the + dtype will be inferred from the tensor data. + drop_last: Whether to drop the last batch if it's incomplete. + local_shuffle_buffer_size: If non-None, the data will be randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer will be drained. This + buffer size must be greater than or equal to ``batch_size``, and + therefore ``batch_size`` must also be specified when using local + shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + + Returns: + An iterator over TensorFlow Tensor batches. + """ + from ray.air._internal.tensorflow_utils import ( + convert_ndarray_batch_to_tf_tensor_batch, + ) + + batch_iterable = self.iter_batches( + prefetch_batches=prefetch_batches, + batch_size=batch_size, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + ) + mapped_iterable = map( + lambda batch: convert_ndarray_batch_to_tf_tensor_batch( + batch, dtypes=dtypes + ), + batch_iterable, + ) + + return mapped_iterable + + def to_torch( + self, + *, + label_column: Optional[str] = None, + feature_columns: Optional[ + Union[List[str], List[List[str]], Dict[str, List[str]]] + ] = None, + label_column_dtype: Optional["torch.dtype"] = None, + feature_column_dtypes: Optional[ + Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]] + ] = None, + batch_size: int = 1, + prefetch_batches: int = 1, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + unsqueeze_label_tensor: bool = True, + unsqueeze_feature_tensors: bool = True, + ) -> "torch.utils.data.IterableDataset": + """Return a Torch IterableDataset over this dataset. + + This is only supported for datasets convertible to Arrow records. + + It is recommended to use the returned ``IterableDataset`` directly + instead of passing it into a torch ``DataLoader``. + + Each element in IterableDataset will be a tuple consisting of 2 + elements. The first item contains the feature tensor(s), and the + second item is the label tensor. Those can take on different + forms, depending on the specified arguments. + + For the features tensor (N is the ``batch_size`` and n, m, k + are the number of features per tensor): + + * If ``feature_columns`` is a ``List[str]``, the features will be + a tensor of shape (N, n), with columns corresponding to + ``feature_columns`` + + * If ``feature_columns`` is a ``List[List[str]]``, the features will be + a list of tensors of shape [(N, m),...,(N, k)], with columns of each + tensor corresponding to the elements of ``feature_columns`` + + * If ``feature_columns`` is a ``Dict[str, List[str]]``, the features + will be a dict of key-tensor pairs of shape + {key1: (N, m),..., keyN: (N, k)}, with columns of each + tensor corresponding to the value of ``feature_columns`` under the + key. + + If ``unsqueeze_label_tensor=True`` (default), the label tensor will be + of shape (N, 1). Otherwise, it will be of shape (N,). + If ``label_column`` is specified as ``None``, then no column from the + ``Dataset`` will be treated as the label, and the output label tensor + will be ``None``. + + Note that you probably want to call ``.split()`` on this dataset if + there are to be multiple Torch workers consuming the data. + + Time complexity: O(1) + + Args: + label_column: The name of the column used as the + label (second element of the output list). Can be None for + prediction, in which case the second element of returned + tuple will also be None. + feature_columns: The names of the columns + to use as the features. Can be a list of lists or + a dict of string-list pairs for multi-tensor output. + If None, then use all columns except the label column as + the features. + label_column_dtype: The torch dtype to + use for the label column. If None, then automatically infer + the dtype. + feature_column_dtypes: The dtypes to use for the feature + tensors. This should match the format of ``feature_columns``, + or be a single dtype, in which case it will be applied to + all tensors. If None, then automatically infer the dtype. + batch_size: How many samples per batch to yield at a time. + Defaults to 1. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 1. + drop_last: Set to True to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If + False and the size of dataset is not divisible by the batch + size, then the last batch will be smaller. Defaults to False. + local_shuffle_buffer_size: If non-None, the data will be randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer will be drained. This + buffer size must be greater than or equal to ``batch_size``, and + therefore ``batch_size`` must also be specified when using local + shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + unsqueeze_label_tensor: If set to True, the label tensor + will be unsqueezed (reshaped to (N, 1)). Otherwise, it will + be left as is, that is (N, ). In general, regression loss + functions expect an unsqueezed tensor, while classification + loss functions expect a squeezed one. Defaults to True. + unsqueeze_feature_tensors: If set to True, the features tensors + will be unsqueezed (reshaped to (N, 1)) before being concatenated into + the final features tensor. Otherwise, they will be left as is, that is + (N, ). Defaults to True. + + Returns: + A torch IterableDataset. + """ + import torch + + from ray.air._internal.torch_utils import convert_pandas_to_torch_tensor + from ray.data._internal.torch_iterable_dataset import TorchIterableDataset + + # If an empty collection is passed in, treat it the same as None + if not feature_columns: + feature_columns = None + + if feature_column_dtypes and not isinstance(feature_column_dtypes, torch.dtype): + if isinstance(feature_columns, dict): + if not isinstance(feature_column_dtypes, dict): + raise TypeError( + "If `feature_columns` is a dict, " + "`feature_column_dtypes` must be None, `torch.dtype`," + f" or dict, got {type(feature_column_dtypes)}." + ) + if set(feature_columns) != set(feature_column_dtypes): + raise ValueError( + "`feature_columns` and `feature_column_dtypes` " + "must have the same keys." + ) + if any(not subcolumns for subcolumns in feature_columns.values()): + raise ValueError("column list may not be empty") + elif isinstance(feature_columns[0], (list, tuple)): + if not isinstance(feature_column_dtypes, (list, tuple)): + raise TypeError( + "If `feature_columns` is a list of lists, " + "`feature_column_dtypes` must be None, `torch.dtype`," + f" or a sequence, got {type(feature_column_dtypes)}." + ) + if len(feature_columns) != len(feature_column_dtypes): + raise ValueError( + "`feature_columns` and `feature_column_dtypes` " + "must have the same length." + ) + if any(not subcolumns for subcolumns in feature_columns): + raise ValueError("column list may not be empty") + + def make_generator(): + for batch in self.iter_batches( + batch_size=batch_size, + batch_format="pandas", + prefetch_batches=prefetch_batches, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + ): + if label_column: + label_tensor = convert_pandas_to_torch_tensor( + batch, + [label_column], + label_column_dtype, + unsqueeze=unsqueeze_label_tensor, + ) + batch.pop(label_column) + else: + label_tensor = None + + if isinstance(feature_columns, dict): + features_tensor = { + key: convert_pandas_to_torch_tensor( + batch, + feature_columns[key], + ( + feature_column_dtypes[key] + if isinstance(feature_column_dtypes, dict) + else feature_column_dtypes + ), + unsqueeze=unsqueeze_feature_tensors, + ) + for key in feature_columns + } + else: + features_tensor = convert_pandas_to_torch_tensor( + batch, + columns=feature_columns, + column_dtypes=feature_column_dtypes, + unsqueeze=unsqueeze_feature_tensors, + ) + + yield (features_tensor, label_tensor) + + return TorchIterableDataset(make_generator) + + @PublicAPI + def to_tf( + self, + feature_columns: Union[str, List[str]], + label_columns: Union[str, List[str]], + *, + additional_columns: Union[Optional[str], Optional[List[str]]] = None, + prefetch_batches: int = 1, + batch_size: int = 1, + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None, + label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None, + additional_type_spec: Union[ + Optional["tf.TypeSpec"], Optional[Dict[str, "tf.TypeSpec"]] + ] = None, + ) -> "tf.data.Dataset": + """Return a TF Dataset over this dataset. + + .. warning:: + If your dataset contains ragged tensors, this method errors. To prevent + errors, :ref:`resize your tensors `. + + Examples: + >>> import ray + >>> ds = ray.data.read_csv( + ... "s3://anonymous@air-example-data/iris.csv" + ... ) + >>> it = ds.iterator(); it + DataIterator(Dataset( + num_rows=?, + schema={ + sepal length (cm): double, + sepal width (cm): double, + petal length (cm): double, + petal width (cm): double, + target: int64 + } + )) + + If your model accepts a single tensor as input, specify a single feature column. + + >>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target") + <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))> + + If your model accepts a dictionary as input, specify a list of feature columns. + + >>> it.to_tf(["sepal length (cm)", "sepal width (cm)"], "target") + <_OptionsDataset element_spec=({'sepal length (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), 'sepal width (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal width (cm)')}, TensorSpec(shape=(None,), dtype=tf.int64, name='target'))> + + If your dataset contains multiple features but your model accepts a single + tensor as input, combine features with + :class:`~ray.data.preprocessors.Concatenator`. + + >>> from ray.data.preprocessors import Concatenator + >>> columns_to_concat = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"] + >>> preprocessor = Concatenator(columns=columns_to_concat, output_column_name="features") + >>> it = preprocessor.transform(ds).iterator() + >>> it + DataIterator(Concatenator + +- Dataset( + num_rows=?, + schema={ + sepal length (cm): double, + sepal width (cm): double, + petal length (cm): double, + petal width (cm): double, + target: int64 + } + )) + >>> it.to_tf("features", "target") + <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))> + + If your model accepts different types, shapes, or names of tensors as input, specify the type spec. + If type specs are not specified, they are automatically inferred from the schema of the iterator. + + >>> import tensorflow as tf + >>> it.to_tf( + ... feature_columns="features", + ... label_columns="target", + ... feature_type_spec=tf.TensorSpec(shape=(None, 4), dtype=tf.float32, name="features"), + ... label_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="label") + ... ) + <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name='features'), TensorSpec(shape=(None,), dtype=tf.float32, name='label'))> + + If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns. + A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``. + + >>> import pandas as pd + >>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df))) + >>> it = ds.iterator() + >>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target", additional_columns="sample weights") + <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))> + + If your model accepts different types, shapes, or names for the additional metadata, specify the type spec of the additional column. + + >>> it.to_tf( + ... feature_columns="sepal length (cm)", + ... label_columns="target", + ... additional_columns="sample weights", + ... additional_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="weight") + ... ) + <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.float32, name='weight'))> + + Args: + feature_columns: Columns that correspond to model inputs. If this is a + string, the input data is a tensor. If this is a list, the input data + is a ``dict`` that maps column names to their tensor representation. + label_columns: Columns that correspond to model targets. If this is a + string, the target data is a tensor. If this is a list, the target data + is a ``dict`` that maps column names to their tensor representation. + additional_columns: Columns that correspond to sample weights or other metadata. + If this is a string, the weight data is a tensor. If this is a list, the + weight data is a ``dict`` that maps column names to their tensor representation. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 1. + batch_size: Record batch size. Defaults to 1. + drop_last: Set to True to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If + False and the size of dataset is not divisible by the batch + size, then the last batch will be smaller. Defaults to False. + local_shuffle_buffer_size: If non-None, the data will be randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle + buffer in order to yield a batch. When there are no more rows to add to + the buffer, the remaining rows in the buffer will be drained. This + buffer size must be greater than or equal to ``batch_size``, and + therefore ``batch_size`` must also be specified when using local + shuffling. + local_shuffle_seed: The seed to use for the local random shuffle. + feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is + only one column, specify a `tf.TypeSpec`. If there are multiple columns, + specify a ``dict`` that maps column names to their `tf.TypeSpec`. + Default is `None` to automatically infer the type of each column. + label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is + only one column, specify a `tf.TypeSpec`. If there are multiple columns, + specify a ``dict`` that maps column names to their `tf.TypeSpec`. + Default is `None` to automatically infer the type of each column. + additional_type_spec: The `tf.TypeSpec` of `additional_columns`. If there + is only one column, specify a `tf.TypeSpec`. If there are multiple + columns, specify a ``dict`` that maps column names to their `tf.TypeSpec`. + Default is `None` to automatically infer the type of each column. + + Returns: + A ``tf.data.Dataset`` that yields inputs and targets. + """ # noqa: E501 + + from ray.air._internal.tensorflow_utils import ( + convert_ndarray_to_tf_tensor, + get_type_spec, + ) + + try: + import tensorflow as tf + except ImportError: + raise ValueError("tensorflow must be installed!") + + def validate_column(column: str) -> None: + if column not in valid_columns: + raise ValueError( + f"You specified '{column}' in `feature_columns`, " + f"`label_columns`, or `additional_columns`, but there's no " + f"column named '{column}' in the dataset. " + f"Valid column names are: {valid_columns}." + ) + + def validate_columns(columns: Union[str, List]) -> None: + if isinstance(columns, list): + for column in columns: + validate_column(column) + else: + validate_column(columns) + + def convert_batch_to_tensors( + batch: Dict[str, np.ndarray], + *, + columns: Union[str, List[str]], + type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]], + ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]: + if isinstance(columns, str): + return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec) + return { + column: convert_ndarray_to_tf_tensor( + batch[column], type_spec=type_spec[column] + ) + for column in columns + } + + def generator(): + for batch in self.iter_batches( + prefetch_batches=prefetch_batches, + batch_size=batch_size, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + ): + assert isinstance(batch, dict) + features = convert_batch_to_tensors( + batch, columns=feature_columns, type_spec=feature_type_spec + ) + labels = convert_batch_to_tensors( + batch, columns=label_columns, type_spec=label_type_spec + ) + + if additional_columns is None: + yield features, labels + else: + additional_metadata = convert_batch_to_tensors( + batch, + columns=additional_columns, + type_spec=additional_type_spec, + ) + yield features, labels, additional_metadata + + if feature_type_spec is None or label_type_spec is None: + schema = self.schema() + valid_columns = set(schema.names) + validate_columns(feature_columns) + validate_columns(label_columns) + feature_type_spec = get_type_spec(schema, columns=feature_columns) + label_type_spec = get_type_spec(schema, columns=label_columns) + + if additional_columns is not None and additional_type_spec is None: + schema = self.schema() + valid_columns = set(schema.names) + validate_columns(additional_columns) + additional_type_spec = get_type_spec(schema, columns=additional_columns) + + if additional_columns is not None: + dataset = tf.data.Dataset.from_generator( + generator, + output_signature=( + feature_type_spec, + label_type_spec, + additional_type_spec, + ), + ) + else: + dataset = tf.data.Dataset.from_generator( + generator, output_signature=(feature_type_spec, label_type_spec) + ) + + options = tf.data.Options() + options.experimental_distribute.auto_shard_policy = ( + tf.data.experimental.AutoShardPolicy.OFF + ) + return dataset.with_options(options) + + @PublicAPI + def materialize(self) -> "MaterializedDataset": + """Execute and materialize this data iterator into object store memory. + + .. note:: + This method triggers the execution and materializes all blocks + of the iterator, returning its contents as a + :class:`~ray.data.dataset.MaterializedDataset` for further processing. + """ + + from ray.data.dataset import MaterializedDataset + + ref_bundles_iter, stats, _ = self._to_ref_bundle_iterator() + + ref_bundles = list(ref_bundles_iter) + execution_plan = ExecutionPlan(stats) + logical_plan = LogicalPlan( + InputData(input_data=ref_bundles), + execution_plan._context, + ) + return MaterializedDataset( + execution_plan, + logical_plan, + ) + + def __del__(self): + # Clear metrics on deletion in case the iterator was not fully consumed. + StatsManager.clear_iteration_metrics(self._get_dataset_tag()) + + +# Backwards compatibility alias. +DatasetIterator = DataIterator diff --git a/.venv/lib/python3.11/site-packages/ray/data/preprocessor.py b/.venv/lib/python3.11/site-packages/ray/data/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..9db73405a702cd3f7786c61b5a34d272e8b47db7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/preprocessor.py @@ -0,0 +1,318 @@ +import abc +import base64 +import collections +import pickle +import warnings +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, Union + +from ray.air.util.data_batch_conversion import BatchFormat +from ray.util.annotations import DeveloperAPI, PublicAPI + +if TYPE_CHECKING: + import numpy as np + import pandas as pd + + from ray.air.data_batch_type import DataBatchType + from ray.data import Dataset + + +@PublicAPI(stability="beta") +class PreprocessorNotFittedException(RuntimeError): + """Error raised when the preprocessor needs to be fitted first.""" + + pass + + +@PublicAPI(stability="beta") +class Preprocessor(abc.ABC): + """Implements an ML preprocessing operation. + + Preprocessors are stateful objects that can be fitted against a Dataset and used + to transform both local data batches and distributed data. For example, a + Normalization preprocessor may calculate the mean and stdev of a field during + fitting, and uses these attributes to implement its normalization transform. + + Preprocessors can also be stateless and transform data without needed to be fitted. + For example, a preprocessor may simply remove a column, which does not require + any state to be fitted. + + If you are implementing your own Preprocessor sub-class, you should override the + following: + + * ``_fit`` if your preprocessor is stateful. Otherwise, set + ``_is_fittable=False``. + * ``_transform_pandas`` and/or ``_transform_numpy`` for best performance, + implement both. Otherwise, the data will be converted to the match the + implemented method. + """ + + class FitStatus(str, Enum): + """The fit status of preprocessor.""" + + NOT_FITTABLE = "NOT_FITTABLE" + NOT_FITTED = "NOT_FITTED" + # Only meaningful for Chain preprocessors. + # At least one contained preprocessor in the chain preprocessor + # is fitted and at least one that can be fitted is not fitted yet. + # This is a state that show up if caller only interacts + # with the chain preprocessor through intended Preprocessor APIs. + PARTIALLY_FITTED = "PARTIALLY_FITTED" + FITTED = "FITTED" + + # Preprocessors that do not need to be fitted must override this. + _is_fittable = True + + def _check_has_fitted_state(self): + """Checks if the Preprocessor has fitted state. + + This is also used as an indiciation if the Preprocessor has been fit, following + convention from Ray versions prior to 2.6. + This allows preprocessors that have been fit in older versions of Ray to be + used to transform data in newer versions. + """ + + fitted_vars = [v for v in vars(self) if v.endswith("_")] + return bool(fitted_vars) + + def fit_status(self) -> "Preprocessor.FitStatus": + if not self._is_fittable: + return Preprocessor.FitStatus.NOT_FITTABLE + elif ( + hasattr(self, "_fitted") and self._fitted + ) or self._check_has_fitted_state(): + return Preprocessor.FitStatus.FITTED + else: + return Preprocessor.FitStatus.NOT_FITTED + + def fit(self, ds: "Dataset") -> "Preprocessor": + """Fit this Preprocessor to the Dataset. + + Fitted state attributes will be directly set in the Preprocessor. + + Calling it more than once will overwrite all previously fitted state: + ``preprocessor.fit(A).fit(B)`` is equivalent to ``preprocessor.fit(B)``. + + Args: + ds: Input dataset. + + Returns: + Preprocessor: The fitted Preprocessor with state attributes. + """ + fit_status = self.fit_status() + if fit_status == Preprocessor.FitStatus.NOT_FITTABLE: + # No-op as there is no state to be fitted. + return self + + if fit_status in ( + Preprocessor.FitStatus.FITTED, + Preprocessor.FitStatus.PARTIALLY_FITTED, + ): + warnings.warn( + "`fit` has already been called on the preprocessor (or at least one " + "contained preprocessors if this is a chain). " + "All previously fitted state will be overwritten!" + ) + + fitted_ds = self._fit(ds) + self._fitted = True + return fitted_ds + + def fit_transform(self, ds: "Dataset") -> "Dataset": + """Fit this Preprocessor to the Dataset and then transform the Dataset. + + Calling it more than once will overwrite all previously fitted state: + ``preprocessor.fit_transform(A).fit_transform(B)`` + is equivalent to ``preprocessor.fit_transform(B)``. + + Args: + ds: Input Dataset. + + Returns: + ray.data.Dataset: The transformed Dataset. + """ + self.fit(ds) + return self.transform(ds) + + def transform(self, ds: "Dataset") -> "Dataset": + """Transform the given dataset. + + Args: + ds: Input Dataset. + + Returns: + ray.data.Dataset: The transformed Dataset. + + Raises: + PreprocessorNotFittedException: if ``fit`` is not called yet. + """ + fit_status = self.fit_status() + if fit_status in ( + Preprocessor.FitStatus.PARTIALLY_FITTED, + Preprocessor.FitStatus.NOT_FITTED, + ): + raise PreprocessorNotFittedException( + "`fit` must be called before `transform`, " + "or simply use fit_transform() to run both steps" + ) + transformed_ds = self._transform(ds) + return transformed_ds + + def transform_batch(self, data: "DataBatchType") -> "DataBatchType": + """Transform a single batch of data. + + The data will be converted to the format supported by the Preprocessor, + based on which ``_transform_*`` methods are implemented. + + Args: + data: Input data batch. + + Returns: + DataBatchType: + The transformed data batch. This may differ + from the input type depending on which ``_transform_*`` methods + are implemented. + """ + fit_status = self.fit_status() + if fit_status in ( + Preprocessor.FitStatus.PARTIALLY_FITTED, + Preprocessor.FitStatus.NOT_FITTED, + ): + raise PreprocessorNotFittedException( + "`fit` must be called before `transform_batch`." + ) + return self._transform_batch(data) + + @DeveloperAPI + def _fit(self, ds: "Dataset") -> "Preprocessor": + """Sub-classes should override this instead of fit().""" + raise NotImplementedError() + + def _determine_transform_to_use(self) -> BatchFormat: + """Determine which batch format to use based on Preprocessor implementation. + + * If only `_transform_pandas` is implemented, then use ``pandas`` batch format. + * If only `_transform_numpy` is implemented, then use ``numpy`` batch format. + * If both are implemented, then use the Preprocessor defined preferred batch + format. + """ + + has_transform_pandas = ( + self.__class__._transform_pandas != Preprocessor._transform_pandas + ) + has_transform_numpy = ( + self.__class__._transform_numpy != Preprocessor._transform_numpy + ) + + if has_transform_numpy and has_transform_pandas: + return self.preferred_batch_format() + elif has_transform_numpy: + return BatchFormat.NUMPY + elif has_transform_pandas: + return BatchFormat.PANDAS + else: + raise NotImplementedError( + "None of `_transform_numpy` or `_transform_pandas` are implemented. " + "At least one of these transform functions must be implemented " + "for Preprocessor transforms." + ) + + def _transform(self, ds: "Dataset") -> "Dataset": + # TODO(matt): Expose `batch_size` or similar configurability. + # The default may be too small for some datasets and too large for others. + transform_type = self._determine_transform_to_use() + + # Our user-facing batch format should only be pandas or NumPy, other + # formats {arrow, simple} are internal. + kwargs = self._get_transform_config() + if transform_type == BatchFormat.PANDAS: + return ds.map_batches( + self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs + ) + elif transform_type == BatchFormat.NUMPY: + return ds.map_batches( + self._transform_numpy, batch_format=BatchFormat.NUMPY, **kwargs + ) + else: + raise ValueError( + "Invalid transform type returned from _determine_transform_to_use; " + f'"pandas" and "numpy" allowed, but got: {transform_type}' + ) + + def _get_transform_config(self) -> Dict[str, Any]: + """Returns kwargs to be passed to :meth:`ray.data.Dataset.map_batches`. + + This can be implemented by subclassing preprocessors. + """ + return {} + + def _transform_batch(self, data: "DataBatchType") -> "DataBatchType": + # For minimal install to locally import air modules + import numpy as np + import pandas as pd + + from ray.air.util.data_batch_conversion import ( + _convert_batch_type_to_numpy, + _convert_batch_type_to_pandas, + ) + + try: + import pyarrow + except ImportError: + pyarrow = None + + if not isinstance( + data, (pd.DataFrame, pyarrow.Table, collections.abc.Mapping, np.ndarray) + ): + raise ValueError( + "`transform_batch` is currently only implemented for Pandas " + "DataFrames, pyarrow Tables, NumPy ndarray and dictionary of " + f"ndarray. Got {type(data)}." + ) + + transform_type = self._determine_transform_to_use() + + if transform_type == BatchFormat.PANDAS: + return self._transform_pandas(_convert_batch_type_to_pandas(data)) + elif transform_type == BatchFormat.NUMPY: + return self._transform_numpy(_convert_batch_type_to_numpy(data)) + + @DeveloperAPI + def _transform_pandas(self, df: "pd.DataFrame") -> "pd.DataFrame": + """Run the transformation on a data batch in a Pandas DataFrame format.""" + raise NotImplementedError() + + @DeveloperAPI + def _transform_numpy( + self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]] + ) -> Union["np.ndarray", Dict[str, "np.ndarray"]]: + """Run the transformation on a data batch in a NumPy ndarray format.""" + raise NotImplementedError() + + @classmethod + @DeveloperAPI + def preferred_batch_format(cls) -> BatchFormat: + """Batch format hint for upstream producers to try yielding best block format. + + The preferred batch format to use if both `_transform_pandas` and + `_transform_numpy` are implemented. Defaults to Pandas. + + Can be overriden by Preprocessor classes depending on which transform + path is the most optimal. + """ + return BatchFormat.PANDAS + + @DeveloperAPI + def serialize(self) -> str: + """Return this preprocessor serialized as a string. + Note: this is not a stable serialization format as it uses `pickle`. + """ + # Convert it to a plain string so that it can be included as JSON metadata + # in Trainer checkpoints. + return base64.b64encode(pickle.dumps(self)).decode("ascii") + + @staticmethod + @DeveloperAPI + def deserialize(serialized: str) -> "Preprocessor": + """Load the original preprocessor serialized via `self.serialize()`.""" + return pickle.loads(base64.b64decode(serialized)) diff --git a/.venv/lib/python3.11/site-packages/ray/data/random_access_dataset.py b/.venv/lib/python3.11/site-packages/ray/data/random_access_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a24c6796f7ca6b3dcbb63b97a722882e7b0d4687 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/random_access_dataset.py @@ -0,0 +1,293 @@ +import bisect +import logging +import random +import time +from collections import defaultdict +from typing import TYPE_CHECKING, Any, List, Optional + +import numpy as np + +import ray +from ray.data._internal.execution.interfaces.ref_bundle import ( + _ref_bundles_iterator_to_block_refs_list, +) +from ray.data._internal.remote_fn import cached_remote_fn +from ray.data.block import BlockAccessor +from ray.data.context import DataContext +from ray.types import ObjectRef +from ray.util.annotations import PublicAPI + +try: + import pyarrow as pa +except ImportError: + pa = None + +if TYPE_CHECKING: + from ray.data import Dataset + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +class RandomAccessDataset: + """A class that provides distributed, random access to a Dataset. + + See: ``Dataset.to_random_access_dataset()``. + """ + + def __init__( + self, + ds: "Dataset", + key: str, + num_workers: int, + ): + """Construct a RandomAccessDataset (internal API). + + The constructor is a private API. Use ``ds.to_random_access_dataset()`` + to construct a RandomAccessDataset. + """ + schema = ds.schema(fetch_if_missing=True) + if schema is None or isinstance(schema, type): + raise ValueError("RandomAccessDataset only supports Arrow-format blocks.") + + start = time.perf_counter() + logger.info("[setup] Indexing dataset by sort key.") + sorted_ds = ds.sort(key) + get_bounds = cached_remote_fn(_get_bounds) + bundles = sorted_ds.iter_internal_ref_bundles() + blocks = _ref_bundles_iterator_to_block_refs_list(bundles) + + logger.info("[setup] Computing block range bounds.") + bounds = ray.get([get_bounds.remote(b, key) for b in blocks]) + self._non_empty_blocks = [] + self._lower_bound = None + self._upper_bounds = [] + for i, b in enumerate(bounds): + if b: + self._non_empty_blocks.append(blocks[i]) + if self._lower_bound is None: + self._lower_bound = b[0] + self._upper_bounds.append(b[1]) + + logger.info("[setup] Creating {} random access workers.".format(num_workers)) + ctx = DataContext.get_current() + scheduling_strategy = ctx.scheduling_strategy + self._workers = [ + _RandomAccessWorker.options(scheduling_strategy=scheduling_strategy).remote( + key + ) + for _ in range(num_workers) + ] + ( + self._block_to_workers_map, + self._worker_to_blocks_map, + ) = self._compute_block_to_worker_assignments() + + logger.info( + "[setup] Worker to blocks assignment: {}".format(self._worker_to_blocks_map) + ) + ray.get( + [ + w.assign_blocks.remote( + { + i: self._non_empty_blocks[i] + for i in self._worker_to_blocks_map[w] + } + ) + for w in self._workers + ] + ) + + logger.info("[setup] Finished assigning blocks to workers.") + self._build_time = time.perf_counter() - start + + def _compute_block_to_worker_assignments(self): + # Return values. + block_to_workers: dict[int, List["ray.ActorHandle"]] = defaultdict(list) + worker_to_blocks: dict["ray.ActorHandle", List[int]] = defaultdict(list) + + # Aux data structures. + loc_to_workers: dict[str, List["ray.ActorHandle"]] = defaultdict(list) + locs = ray.get([w.ping.remote() for w in self._workers]) + for i, loc in enumerate(locs): + loc_to_workers[loc].append(self._workers[i]) + block_locs = ray.experimental.get_object_locations(self._non_empty_blocks) + + # First, try to assign all blocks to all workers at its location. + for block_idx, block in enumerate(self._non_empty_blocks): + block_info = block_locs[block] + locs = block_info.get("node_ids", []) + for loc in locs: + for worker in loc_to_workers[loc]: + block_to_workers[block_idx].append(worker) + worker_to_blocks[worker].append(block_idx) + + # Randomly assign any leftover blocks to at least one worker. + # TODO: the load balancing here could be improved. + for block_idx, block in enumerate(self._non_empty_blocks): + if len(block_to_workers[block_idx]) == 0: + worker = random.choice(self._workers) + block_to_workers[block_idx].append(worker) + worker_to_blocks[worker].append(block_idx) + + return block_to_workers, worker_to_blocks + + def get_async(self, key: Any) -> ObjectRef[Any]: + """Asynchronously finds the record for a single key. + + Args: + key: The key of the record to find. + + Returns: + ObjectRef containing the record (in pydict form), or None if not found. + """ + block_index = self._find_le(key) + if block_index is None: + return ray.put(None) + return self._worker_for(block_index).get.remote(block_index, key) + + def multiget(self, keys: List[Any]) -> List[Optional[Any]]: + """Synchronously find the records for a list of keys. + + Args: + keys: List of keys to find the records for. + + Returns: + List of found records (in pydict form), or None for missing records. + """ + batches = defaultdict(list) + for k in keys: + batches[self._find_le(k)].append(k) + futures = {} + for index, keybatch in batches.items(): + if index is None: + continue + fut = self._worker_for(index).multiget.remote( + [index] * len(keybatch), keybatch + ) + futures[index] = fut + results = {} + for i, fut in futures.items(): + keybatch = batches[i] + values = ray.get(fut) + for k, v in zip(keybatch, values): + results[k] = v + return [results.get(k) for k in keys] + + def stats(self) -> str: + """Returns a string containing access timing information.""" + stats = ray.get([w.stats.remote() for w in self._workers]) + total_time = sum(s["total_time"] for s in stats) + accesses = [s["num_accesses"] for s in stats] + blocks = [s["num_blocks"] for s in stats] + msg = "RandomAccessDataset:\n" + msg += "- Build time: {}s\n".format(round(self._build_time, 2)) + msg += "- Num workers: {}\n".format(len(stats)) + msg += "- Blocks per worker: {} min, {} max, {} mean\n".format( + min(blocks), max(blocks), int(sum(blocks) / len(blocks)) + ) + msg += "- Accesses per worker: {} min, {} max, {} mean\n".format( + min(accesses), max(accesses), int(sum(accesses) / len(accesses)) + ) + msg += "- Mean access time: {}us\n".format( + int(total_time / (1 + sum(accesses)) * 1e6) + ) + return msg + + def _worker_for(self, block_index: int): + return random.choice(self._block_to_workers_map[block_index]) + + def _find_le(self, x: Any) -> int: + i = bisect.bisect_left(self._upper_bounds, x) + if i >= len(self._upper_bounds) or x < self._lower_bound: + return None + return i + + +@ray.remote(num_cpus=0) +class _RandomAccessWorker: + def __init__(self, key_field): + self.blocks = None + self.key_field = key_field + self.num_accesses = 0 + self.total_time = 0 + + def assign_blocks(self, block_ref_dict): + self.blocks = {k: ray.get(ref) for k, ref in block_ref_dict.items()} + + def get(self, block_index, key): + start = time.perf_counter() + result = self._get(block_index, key) + self.total_time += time.perf_counter() - start + self.num_accesses += 1 + return result + + def multiget(self, block_indices, keys): + start = time.perf_counter() + block = self.blocks[block_indices[0]] + if len(set(block_indices)) == 1 and isinstance( + self.blocks[block_indices[0]], pa.Table + ): + # Fast path: use np.searchsorted for vectorized search on a single block. + # This is ~3x faster than the naive case. + block = self.blocks[block_indices[0]] + col = block[self.key_field] + indices = np.searchsorted(col, keys) + acc = BlockAccessor.for_block(block) + result = [acc._get_row(i) for i in indices] + # assert result == [self._get(i, k) for i, k in zip(block_indices, keys)] + else: + result = [self._get(i, k) for i, k in zip(block_indices, keys)] + self.total_time += time.perf_counter() - start + self.num_accesses += 1 + return result + + def ping(self): + return ray.get_runtime_context().get_node_id() + + def stats(self) -> dict: + return { + "num_blocks": len(self.blocks), + "num_accesses": self.num_accesses, + "total_time": self.total_time, + } + + def _get(self, block_index, key): + if block_index is None: + return None + block = self.blocks[block_index] + column = block[self.key_field] + if isinstance(block, pa.Table): + column = _ArrowListWrapper(column) + i = _binary_search_find(column, key) + if i is None: + return None + acc = BlockAccessor.for_block(block) + return acc._get_row(i) + + +def _binary_search_find(column, x): + i = bisect.bisect_left(column, x) + if i != len(column) and column[i] == x: + return i + return None + + +class _ArrowListWrapper: + def __init__(self, arrow_col): + self.arrow_col = arrow_col + + def __getitem__(self, i): + return self.arrow_col[i].as_py() + + def __len__(self): + return len(self.arrow_col) + + +def _get_bounds(block, key): + if len(block) == 0: + return None + b = (block[key][0], block[key][len(block) - 1]) + if isinstance(block, pa.Table): + b = (b[0].as_py(), b[1].as_py()) + return b diff --git a/.venv/lib/python3.11/site-packages/ray/data/read_api.py b/.venv/lib/python3.11/site-packages/ray/data/read_api.py new file mode 100644 index 0000000000000000000000000000000000000000..aa16afa5e10dc3d2b4bc10f5f1005bd9589e0f70 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/read_api.py @@ -0,0 +1,3620 @@ +import collections +import logging +import os +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, +) + +import numpy as np + +import ray +from ray._private.auto_init_hook import wrap_auto_init +from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray +from ray.data._internal.datasource.audio_datasource import AudioDatasource +from ray.data._internal.datasource.avro_datasource import AvroDatasource +from ray.data._internal.datasource.bigquery_datasource import BigQueryDatasource +from ray.data._internal.datasource.binary_datasource import BinaryDatasource +from ray.data._internal.datasource.clickhouse_datasource import ClickHouseDatasource +from ray.data._internal.datasource.csv_datasource import CSVDatasource +from ray.data._internal.datasource.delta_sharing_datasource import ( + DeltaSharingDatasource, +) +from ray.data._internal.datasource.hudi_datasource import HudiDatasource +from ray.data._internal.datasource.iceberg_datasource import IcebergDatasource +from ray.data._internal.datasource.image_datasource import ( + ImageDatasource, + ImageFileMetadataProvider, +) +from ray.data._internal.datasource.json_datasource import JSONDatasource +from ray.data._internal.datasource.lance_datasource import LanceDatasource +from ray.data._internal.datasource.mongo_datasource import MongoDatasource +from ray.data._internal.datasource.numpy_datasource import NumpyDatasource +from ray.data._internal.datasource.parquet_bulk_datasource import ParquetBulkDatasource +from ray.data._internal.datasource.parquet_datasource import ParquetDatasource +from ray.data._internal.datasource.range_datasource import RangeDatasource +from ray.data._internal.datasource.sql_datasource import SQLDatasource +from ray.data._internal.datasource.text_datasource import TextDatasource +from ray.data._internal.datasource.tfrecords_datasource import TFRecordDatasource +from ray.data._internal.datasource.torch_datasource import TorchDatasource +from ray.data._internal.datasource.video_datasource import VideoDatasource +from ray.data._internal.datasource.webdataset_datasource import WebDatasetDatasource +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.logical.operators.from_operators import ( + FromArrow, + FromBlocks, + FromItems, + FromNumpy, + FromPandas, +) +from ray.data._internal.logical.operators.read_operator import Read +from ray.data._internal.logical.optimizers import LogicalPlan +from ray.data._internal.plan import ExecutionPlan +from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.stats import DatasetStats +from ray.data._internal.util import ( + _autodetect_parallelism, + get_table_block_metadata, + ndarray_to_block, + pandas_df_to_arrow_block, +) +from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata +from ray.data.context import DataContext +from ray.data.dataset import Dataset, MaterializedDataset +from ray.data.datasource import ( + BaseFileMetadataProvider, + Connection, + Datasource, + PathPartitionFilter, +) +from ray.data.datasource.datasource import Reader +from ray.data.datasource.file_based_datasource import ( + FileShuffleConfig, + _unwrap_arrow_serialization_workaround, + _validate_shuffle_arg, +) +from ray.data.datasource.file_meta_provider import ( + DefaultFileMetadataProvider, + FastFileMetadataProvider, +) +from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider +from ray.data.datasource.partitioning import Partitioning +from ray.types import ObjectRef +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +if TYPE_CHECKING: + import dask + import datasets + import mars + import modin + import pandas + import pyarrow + import pymongoarrow.api + import pyspark + import tensorflow as tf + import torch + from pyiceberg.expressions import BooleanExpression + from tensorflow_metadata.proto.v0 import schema_pb2 + + from ray.data._internal.datasource.tfrecords_datasource import TFXReadOptions + + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +def from_blocks(blocks: List[Block]): + """Create a :class:`~ray.data.Dataset` from a list of blocks. + + This method is primarily used for testing. Unlike other methods like + :func:`~ray.data.from_pandas` and :func:`~ray.data.from_arrow`, this method + gaurentees that it won't modify the number of blocks. + + Args: + blocks: List of blocks to create the dataset from. + + Returns: + A :class:`~ray.data.Dataset` holding the blocks. + """ + block_refs = [ray.put(block) for block in blocks] + metadata = [BlockAccessor.for_block(block).get_metadata() for block in blocks] + from_blocks_op = FromBlocks(block_refs, metadata) + execution_plan = ExecutionPlan( + DatasetStats(metadata={"FromBlocks": metadata}, parent=None) + ) + logical_plan = LogicalPlan(from_blocks_op, execution_plan._context) + return MaterializedDataset( + execution_plan, + logical_plan, + ) + + +@PublicAPI +def from_items( + items: List[Any], + *, + parallelism: int = -1, + override_num_blocks: Optional[int] = None, +) -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a list of local Python objects. + + Use this method to create small datasets from data that fits in memory. + + Examples: + + >>> import ray + >>> ds = ray.data.from_items([1, 2, 3, 4, 5]) + >>> ds + MaterializedDataset(num_blocks=..., num_rows=5, schema={item: int64}) + >>> ds.schema() + Column Type + ------ ---- + item int64 + + Args: + items: List of local Python objects. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` holding the items. + """ + import builtins + + parallelism = _get_num_output_blocks(parallelism, override_num_blocks) + if parallelism == 0: + raise ValueError(f"parallelism must be -1 or > 0, got: {parallelism}") + + detected_parallelism, _, _ = _autodetect_parallelism( + parallelism, + ray.util.get_current_placement_group(), + DataContext.get_current(), + ) + # Truncate parallelism to number of items to avoid empty blocks. + detected_parallelism = min(len(items), detected_parallelism) + + if detected_parallelism > 0: + block_size, remainder = divmod(len(items), detected_parallelism) + else: + block_size, remainder = 0, 0 + # NOTE: We need to explicitly use the builtins range since we override range below, + # with the definition of ray.data.range. + blocks: List[ObjectRef[Block]] = [] + metadata: List[BlockMetadata] = [] + for i in builtins.range(detected_parallelism): + stats = BlockExecStats.builder() + builder = DelegatingBlockBuilder() + # Evenly distribute remainder across block slices while preserving record order. + block_start = i * block_size + min(i, remainder) + block_end = (i + 1) * block_size + min(i + 1, remainder) + for j in builtins.range(block_start, block_end): + item = items[j] + if not isinstance(item, collections.abc.Mapping): + item = {"item": item} + builder.add(item) + block = builder.build() + blocks.append(ray.put(block)) + metadata.append( + BlockAccessor.for_block(block).get_metadata(exec_stats=stats.build()) + ) + + from_items_op = FromItems(blocks, metadata) + execution_plan = ExecutionPlan( + DatasetStats(metadata={"FromItems": metadata}, parent=None) + ) + logical_plan = LogicalPlan(from_items_op, execution_plan._context) + return MaterializedDataset( + execution_plan, + logical_plan, + ) + + +@PublicAPI +def range( + n: int, + *, + parallelism: int = -1, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Creates a :class:`~ray.data.Dataset` from a range of integers [0..n). + + This function allows for easy creation of synthetic datasets for testing or + benchmarking :ref:`Ray Data `. + + Examples: + + >>> import ray + >>> ds = ray.data.range(10000) + >>> ds + Dataset(num_rows=10000, schema={id: int64}) + >>> ds.map(lambda row: {"id": row["id"] * 2}).take(4) + [{'id': 0}, {'id': 2}, {'id': 4}, {'id': 6}] + + Args: + n: The upper bound of the range of integers. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` producing the integers from the range 0 to n. + + .. seealso:: + + :meth:`~ray.data.range_tensor` + Call this method for creating synthetic datasets of tensor data. + + """ + datasource = RangeDatasource(n=n, block_format="arrow", column_name="id") + return read_datasource( + datasource, + parallelism=parallelism, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def range_tensor( + n: int, + *, + shape: Tuple = (1,), + parallelism: int = -1, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Creates a :class:`~ray.data.Dataset` tensors of the provided shape from range + [0...n]. + + This function allows for easy creation of synthetic tensor datasets for testing or + benchmarking :ref:`Ray Data `. + + Examples: + + >>> import ray + >>> ds = ray.data.range_tensor(1000, shape=(2, 2)) + >>> ds + Dataset(num_rows=1000, schema={data: numpy.ndarray(shape=(2, 2), dtype=int64)}) + >>> ds.map_batches(lambda row: {"data": row["data"] * 2}).take(2) + [{'data': array([[0, 0], + [0, 0]])}, {'data': array([[2, 2], + [2, 2]])}] + + Args: + n: The upper bound of the range of tensor records. + shape: The shape of each tensor in the dataset. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` producing the tensor data from range 0 to n. + + .. seealso:: + + :meth:`~ray.data.range` + Call this method to create synthetic datasets of integer data. + + """ + datasource = RangeDatasource( + n=n, block_format="tensor", column_name="data", tensor_shape=tuple(shape) + ) + return read_datasource( + datasource, + parallelism=parallelism, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +@wrap_auto_init +def read_datasource( + datasource: Datasource, + *, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **read_args, +) -> Dataset: + """Read a stream from a custom :class:`~ray.data.Datasource`. + + Args: + datasource: The :class:`~ray.data.Datasource` to read data from. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + read_args: Additional kwargs to pass to the :class:`~ray.data.Datasource` + implementation. + + Returns: + :class:`~ray.data.Dataset` that reads data from the :class:`~ray.data.Datasource`. + """ # noqa: E501 + parallelism = _get_num_output_blocks(parallelism, override_num_blocks) + + ctx = DataContext.get_current() + + if ray_remote_args is None: + ray_remote_args = {} + + if not datasource.supports_distributed_reads: + ray_remote_args["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), + soft=False, + ) + + if "scheduling_strategy" not in ray_remote_args: + ray_remote_args["scheduling_strategy"] = ctx.scheduling_strategy + + datasource_or_legacy_reader = _get_datasource_or_legacy_reader( + datasource, + ctx, + read_args, + ) + + cur_pg = ray.util.get_current_placement_group() + requested_parallelism, _, inmemory_size = _autodetect_parallelism( + parallelism, + ctx.target_max_block_size, + DataContext.get_current(), + datasource_or_legacy_reader, + placement_group=cur_pg, + ) + + # TODO(hchen/chengsu): Remove the duplicated get_read_tasks call here after + # removing LazyBlockList code path. + read_tasks = datasource_or_legacy_reader.get_read_tasks(requested_parallelism) + import uuid + + stats = DatasetStats( + metadata={"Read": [read_task.metadata for read_task in read_tasks]}, + parent=None, + needs_stats_actor=True, + stats_uuid=uuid.uuid4(), + ) + read_op = Read( + datasource, + datasource_or_legacy_reader, + parallelism, + inmemory_size, + len(read_tasks) if read_tasks else 0, + ray_remote_args, + concurrency, + ) + execution_plan = ExecutionPlan(stats) + logical_plan = LogicalPlan(read_op, execution_plan._context) + return Dataset( + plan=execution_plan, + logical_plan=logical_plan, + ) + + +@PublicAPI(stability="alpha") +def read_audio( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Optional[Partitioning] = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + file_extensions: Optional[List[str]] = AudioDatasource._FILE_EXTENSIONS, + shuffle: Union[Literal["files"], None] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +): + """Creates a :class:`~ray.data.Dataset` from audio files. + + Examples: + >>> import ray + >>> path = "s3://anonymous@air-example-data-2/6G-audio-data-LibriSpeech-train-clean-100-flac/train-clean-100/5022/29411/5022-29411-0000.flac" + >>> ds = ray.data.read_audio(path) + >>> ds.schema() + Column Type + ------ ---- + amplitude numpy.ndarray(shape=(1, 191760), dtype=float) + sample_rate int64 + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The pyarrow filesystem + implementation to read from. These filesystems are specified in the + `pyarrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + include_paths: If ``True``, include the path to each image. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file/directory paths in ``paths`` + that are not found. Defaults to False. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks. + + Returns: + A :class:`~ray.data.Dataset` containing audio amplitudes and associated + metadata. + """ # noqa: E501 + datasource = AudioDatasource( + paths, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=DefaultFileMetadataProvider(), + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_videos( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Optional[Partitioning] = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + file_extensions: Optional[List[str]] = VideoDatasource._FILE_EXTENSIONS, + shuffle: Union[Literal["files"], None] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +): + """Creates a :class:`~ray.data.Dataset` from video files. + + Each row in the resulting dataset represents a video frame. + + Examples: + >>> import ray + >>> path = "s3://anonymous@ray-example-data/basketball.mp4" + >>> ds = ray.data.read_videos(path) + >>> ds.schema() + Column Type + ------ ---- + frame numpy.ndarray(shape=(720, 1280, 3), dtype=uint8) + frame_index int64 + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The pyarrow filesystem + implementation to read from. These filesystems are specified in the + `pyarrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + include_paths: If ``True``, include the path to each image. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file/directory paths in ``paths`` + that are not found. Defaults to False. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks. + + Returns: + A :class:`~ray.data.Dataset` containing video frames from the video files. + """ + datasource = VideoDatasource( + paths, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=DefaultFileMetadataProvider(), + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_mongo( + uri: str, + database: str, + collection: str, + *, + pipeline: Optional[List[Dict]] = None, + schema: Optional["pymongoarrow.api.Schema"] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **mongo_args, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from a MongoDB database. + + The data to read from is specified via the ``uri``, ``database`` and ``collection`` + of the MongoDB. The dataset is created from the results of executing + ``pipeline`` against the ``collection``. If ``pipeline`` is None, the entire + ``collection`` is read. + + .. tip:: + + For more details about these MongoDB concepts, see the following: + - URI: https://www.mongodb.com/docs/manual/reference/connection-string/ + - Database and Collection: https://www.mongodb.com/docs/manual/core/databases-and-collections/ + - Pipeline: https://www.mongodb.com/docs/manual/core/aggregation-pipeline/ + + To read the MongoDB in parallel, the execution of the pipeline is run on partitions + of the collection, with a Ray read task to handle a partition. Partitions are + created in an attempt to evenly distribute the documents into the specified number + of partitions. The number of partitions is determined by ``parallelism`` which can + be requested from this interface or automatically chosen if unspecified (see the + ``parallelism`` arg below). + + Examples: + >>> import ray + >>> from pymongoarrow.api import Schema # doctest: +SKIP + >>> ds = ray.data.read_mongo( # doctest: +SKIP + ... uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin", # noqa: E501 + ... database="my_db", + ... collection="my_collection", + ... pipeline=[{"$match": {"col2": {"$gte": 0, "$lt": 100}}}, {"$sort": "sort_field"}], # noqa: E501 + ... schema=Schema({"col1": pa.string(), "col2": pa.int64()}), + ... override_num_blocks=10, + ... ) + + Args: + uri: The URI of the source MongoDB where the dataset is + read from. For the URI format, see details in the `MongoDB docs `_. + database: The name of the database hosted in the MongoDB. This database + must exist otherwise ValueError is raised. + collection: The name of the collection in the database. This collection + must exist otherwise ValueError is raised. + pipeline: A `MongoDB pipeline `_, which is executed on the given collection + with results used to create Dataset. If None, the entire collection will + be read. + schema: The schema used to read the collection. If None, it'll be inferred from + the results of pipeline. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + mongo_args: kwargs passed to `aggregate_arrow_all() `_ in pymongoarrow in producing + Arrow-formatted results. + + Returns: + :class:`~ray.data.Dataset` producing rows from the results of executing the pipeline on the specified MongoDB collection. + + Raises: + ValueError: if ``database`` doesn't exist. + ValueError: if ``collection`` doesn't exist. + """ + datasource = MongoDatasource( + uri=uri, + database=database, + collection=collection, + pipeline=pipeline, + schema=schema, + **mongo_args, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_bigquery( + project_id: str, + dataset: Optional[str] = None, + query: Optional[str] = None, + *, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Create a dataset from BigQuery. + + The data to read from is specified via the ``project_id``, ``dataset`` + and/or ``query`` parameters. The dataset is created from the results of + executing ``query`` if a query is provided. Otherwise, the entire + ``dataset`` is read. + + For more information about BigQuery, see the following concepts: + + - Project id: `Creating and Managing Projects `_ + + - Dataset: `Datasets Intro `_ + + - Query: `Query Syntax `_ + + This method uses the BigQuery Storage Read API which reads in parallel, + with a Ray read task to handle each stream. The number of streams is + determined by ``parallelism`` which can be requested from this interface + or automatically chosen if unspecified (see the ``parallelism`` arg below). + + .. warning:: + The maximum query response size is 10GB. For more information, see `BigQuery response too large to return `_. + + Examples: + .. testcode:: + :skipif: True + + import ray + # Users will need to authenticate beforehand (https://cloud.google.com/sdk/gcloud/reference/auth/login) + ds = ray.data.read_bigquery( + project_id="my_project", + query="SELECT * FROM `bigquery-public-data.samples.gsod` LIMIT 1000", + ) + + Args: + project_id: The name of the associated Google Cloud Project that hosts the dataset to read. + For more information, see `Creating and Managing Projects `_. + dataset: The name of the dataset hosted in BigQuery in the format of ``dataset_id.table_id``. + Both the dataset_id and table_id must exist otherwise an exception will be raised. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + Dataset producing rows from the results of executing the query (or reading the entire dataset) + on the specified BigQuery dataset. + """ # noqa: E501 + datasource = BigQueryDatasource(project_id=project_id, dataset=dataset, query=query) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def read_parquet( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + columns: Optional[List[str]] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None, + meta_provider: Optional[ParquetMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Optional[Partitioning] = Partitioning("hive"), + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + include_paths: bool = False, + file_extensions: Optional[List[str]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **arrow_parquet_args, +) -> Dataset: + """Creates a :class:`~ray.data.Dataset` from parquet files. + + + Examples: + Read a file in remote storage. + + >>> import ray + >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") + >>> ds.schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + variety string + + Read a directory in remote storage. + + >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris-parquet/") + + Read multiple local files. + + >>> ray.data.read_parquet( + ... ["local:///path/to/file1", "local:///path/to/file2"]) # doctest: +SKIP + + Specify a schema for the parquet file. + + >>> import pyarrow as pa + >>> fields = [("sepal.length", pa.float32()), + ... ("sepal.width", pa.float32()), + ... ("petal.length", pa.float32()), + ... ("petal.width", pa.float32()), + ... ("variety", pa.string())] + >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet", + ... schema=pa.schema(fields)) + >>> ds.schema() + Column Type + ------ ---- + sepal.length float + sepal.width float + petal.length float + petal.width float + variety string + + The Parquet reader also supports projection and filter pushdown, allowing column + selection and row filtering to be pushed down to the file scan. + + .. testcode:: + + import pyarrow as pa + + # Create a Dataset by reading a Parquet file, pushing column selection and + # row filtering down to the file scan. + ds = ray.data.read_parquet( + "s3://anonymous@ray-example-data/iris.parquet", + columns=["sepal.length", "variety"], + filter=pa.dataset.field("sepal.length") > 5.0, + ) + + ds.show(2) + + .. testoutput:: + + {'sepal.length': 5.1, 'variety': 'Setosa'} + {'sepal.length': 5.4, 'variety': 'Setosa'} + + For further arguments you can pass to PyArrow as a keyword argument, see the + `PyArrow API reference `_. + + Args: + paths: A single file path or directory, or a list of file paths. Multiple + directories are not supported. + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are specified in the + `pyarrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the ``S3FileSystem`` is + used. If ``None``, this function uses a system-chosen implementation. + columns: A list of column names to read. Only the specified columns are + read during the file scan. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + tensor_column_schema: A dict of column name to PyArrow dtype and shape + mappings for converting a Parquet column containing serialized + tensors (ndarrays) as their elements to PyArrow tensors. This function + assumes that the tensors are serialized in the raw + NumPy array format in C-contiguous order (e.g., via + `arr.tobytes()`). + meta_provider: A :ref:`file metadata provider `. Custom + metadata providers may be able to resolve file metadata more quickly and/or + accurately. In most cases you do not need to set this parameter. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to HIVE partitioning. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + arrow_parquet_args: Other parquet read options to pass to PyArrow. For the full + set of arguments, see the `PyArrow API `_ + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` producing records read from the specified parquet + files. + """ + _emit_meta_provider_deprecation_warning(meta_provider) + _validate_shuffle_arg(shuffle) + + if meta_provider is None: + meta_provider = ParquetMetadataProvider() + arrow_parquet_args = _resolve_parquet_args( + tensor_column_schema, + **arrow_parquet_args, + ) + + dataset_kwargs = arrow_parquet_args.pop("dataset_kwargs", None) + _block_udf = arrow_parquet_args.pop("_block_udf", None) + schema = arrow_parquet_args.pop("schema", None) + datasource = ParquetDatasource( + paths, + columns=columns, + dataset_kwargs=dataset_kwargs, + to_batch_kwargs=arrow_parquet_args, + _block_udf=_block_udf, + filesystem=filesystem, + schema=schema, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="beta") +def read_images( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + meta_provider: Optional[BaseFileMetadataProvider] = None, + ray_remote_args: Dict[str, Any] = None, + arrow_open_file_args: Optional[Dict[str, Any]] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Partitioning = None, + size: Optional[Tuple[int, int]] = None, + mode: Optional[str] = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = ImageDatasource._FILE_EXTENSIONS, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Creates a :class:`~ray.data.Dataset` from image files. + + Examples: + >>> import ray + >>> path = "s3://anonymous@ray-example-data/batoidea/JPEGImages/" + >>> ds = ray.data.read_images(path) + >>> ds.schema() + Column Type + ------ ---- + image numpy.ndarray(shape=(32, 32, 3), dtype=uint8) + + If you need image file paths, set ``include_paths=True``. + + >>> ds = ray.data.read_images(path, include_paths=True) + >>> ds.schema() + Column Type + ------ ---- + image numpy.ndarray(shape=(32, 32, 3), dtype=uint8) + path string + >>> ds.take(1)[0]["path"] + 'ray-example-data/batoidea/JPEGImages/1.jpeg' + + If your images are arranged like: + + .. code:: + + root/dog/xxx.png + root/dog/xxy.png + + root/cat/123.png + root/cat/nsdf3.png + + Then you can include the labels by specifying a + :class:`~ray.data.datasource.partitioning.Partitioning`. + + >>> import ray + >>> from ray.data.datasource.partitioning import Partitioning + >>> root = "s3://anonymous@ray-example-data/image-datasets/dir-partitioned" + >>> partitioning = Partitioning("dir", field_names=["class"], base_dir=root) + >>> ds = ray.data.read_images(root, size=(224, 224), partitioning=partitioning) + >>> ds.schema() + Column Type + ------ ---- + image numpy.ndarray(shape=(224, 224, 3), dtype=uint8) + class string + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The pyarrow filesystem + implementation to read from. These filesystems are specified in the + `pyarrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + arrow_open_file_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. + By default, this filters out any file paths whose file extension does not + match ``*.png``, ``*.jpg``, ``*.jpeg``, ``*.tiff``, ``*.bmp``, or ``*.gif``. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + size: The desired height and width of loaded images. If unspecified, images + retain their original shape. + mode: A `Pillow mode `_ + describing the desired type and depth of pixels. If unspecified, image + modes are inferred by + `Pillow `_. + include_paths: If ``True``, include the path to each image. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file/directory paths in ``paths`` + that are not found. Defaults to False. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` producing tensors that represent the images at + the specified paths. For information on working with tensors, read the + :ref:`tensor data guide `. + + Raises: + ValueError: if ``size`` contains non-positive numbers. + ValueError: if ``mode`` is unsupported. + """ + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = ImageFileMetadataProvider() + + datasource = ImageDatasource( + paths, + size=size, + mode=mode, + include_paths=include_paths, + filesystem=filesystem, + meta_provider=meta_provider, + open_stream_args=arrow_open_file_args, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@Deprecated +def read_parquet_bulk( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + columns: Optional[List[str]] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + arrow_open_file_args: Optional[Dict[str, Any]] = None, + tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + include_paths: bool = False, + file_extensions: Optional[List[str]] = ParquetBulkDatasource._FILE_EXTENSIONS, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **arrow_parquet_args, +) -> Dataset: + """Create :class:`~ray.data.Dataset` from parquet files without reading metadata. + + Use :meth:`~ray.data.read_parquet` for most cases. + + Use :meth:`~ray.data.read_parquet_bulk` if all the provided paths point to files + and metadata fetching using :meth:`~ray.data.read_parquet` takes too long or the + parquet files do not all have a unified schema. + + Performance slowdowns are possible when using this method with parquet files that + are very large. + + .. warning:: + + Only provide file paths as input (i.e., no directory paths). An + OSError is raised if one or more paths point to directories. If your + use-case requires directory paths, use :meth:`~ray.data.read_parquet` + instead. + + Examples: + Read multiple local files. You should always provide only input file paths + (i.e. no directory paths) when known to minimize read latency. + + >>> ray.data.read_parquet_bulk( # doctest: +SKIP + ... ["/path/to/file1", "/path/to/file2"]) + + Args: + paths: A single file path or a list of file paths. + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are + specified in the + `PyArrow docs `_. + Specify this parameter if you need to provide specific configurations to + the filesystem. By default, the filesystem is automatically selected based + on the scheme of the paths. For example, if the path begins with ``s3://``, + the `S3FileSystem` is used. + columns: A list of column names to read. Only the + specified columns are read during the file scan. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + arrow_open_file_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + tensor_column_schema: A dict of column name to PyArrow dtype and shape + mappings for converting a Parquet column containing serialized + tensors (ndarrays) as their elements to PyArrow tensors. This function + assumes that the tensors are serialized in the raw + NumPy array format in C-contiguous order (e.g. via + `arr.tobytes()`). + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. + By default, this filters out any file paths whose file extension does not + match "*.parquet*". + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + arrow_parquet_args: Other parquet read options to pass to PyArrow. For the full + set of arguments, see + the `PyArrow API `_ + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` producing records read from the specified paths. + """ + _emit_meta_provider_deprecation_warning(meta_provider) + + warnings.warn( + "`read_parquet_bulk` is deprecated and will be removed after May 2025. Use " + "`read_parquet` instead.", + DeprecationWarning, + ) + + if meta_provider is None: + meta_provider = FastFileMetadataProvider() + read_table_args = _resolve_parquet_args( + tensor_column_schema, + **arrow_parquet_args, + ) + if columns is not None: + read_table_args["columns"] = columns + + datasource = ParquetBulkDatasource( + paths, + read_table_args=read_table_args, + filesystem=filesystem, + open_stream_args=arrow_open_file_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def read_json( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Partitioning = Partitioning("hive"), + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = JSONDatasource._FILE_EXTENSIONS, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **arrow_json_args, +) -> Dataset: + """Creates a :class:`~ray.data.Dataset` from JSON and JSONL files. + + For JSON file, the whole file is read as one row. + For JSONL file, each line of file is read as separate row. + + Examples: + Read a JSON file in remote storage. + + >>> import ray + >>> ds = ray.data.read_json("s3://anonymous@ray-example-data/log.json") + >>> ds.schema() + Column Type + ------ ---- + timestamp timestamp[...] + size int64 + + Read a JSONL file in remote storage. + + >>> ds = ray.data.read_json("s3://anonymous@ray-example-data/train.jsonl") + >>> ds.schema() + Column Type + ------ ---- + input string + + Read multiple local files. + + >>> ray.data.read_json( # doctest: +SKIP + ... ["local:///path/to/file1", "local:///path/to/file2"]) + + Read multiple directories. + + >>> ray.data.read_json( # doctest: +SKIP + ... ["s3://bucket/path1", "s3://bucket/path2"]) + + By default, :meth:`~ray.data.read_json` parses + `Hive-style partitions `_ + from file paths. If your data adheres to a different partitioning scheme, set + the ``partitioning`` parameter. + + >>> ds = ray.data.read_json("s3://anonymous@ray-example-data/year=2022/month=09/sales.json") + >>> ds.take(1) + [{'order_number': 10107, 'quantity': 30, 'year': '2022', 'month': '09'}] + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are specified in the + `PyArrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. + Use with a custom callback to read only selected partitions of a + dataset. + By default, this filters out any file paths whose file extension does not + match "*.json" or "*.jsonl". + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. By default, this function parses + `Hive-style partitions `_. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that are not + found. Defaults to False. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to ``FileShuffleConfig``, you can pass a random seed to shuffle + the input files, e.g. ``FileShuffleConfig(seed=42)``. + Defaults to not shuffle with ``None``. + arrow_json_args: JSON read options to pass to `pyarrow.json.read_json `_. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` producing records read from the specified paths. + """ # noqa: E501 + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + + datasource = JSONDatasource( + paths, + arrow_json_args=arrow_json_args, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def read_csv( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Partitioning = Partitioning("hive"), + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **arrow_csv_args, +) -> Dataset: + """Creates a :class:`~ray.data.Dataset` from CSV files. + + Examples: + Read a file in remote storage. + + >>> import ray + >>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") + >>> ds.schema() + Column Type + ------ ---- + sepal length (cm) double + sepal width (cm) double + petal length (cm) double + petal width (cm) double + target int64 + + Read multiple local files. + + >>> ray.data.read_csv( # doctest: +SKIP + ... ["local:///path/to/file1", "local:///path/to/file2"]) + + Read a directory from remote storage. + + >>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris-csv/") + + Read files that use a different delimiter. For more uses of ParseOptions see + https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html # noqa: #501 + + >>> from pyarrow import csv + >>> parse_options = csv.ParseOptions(delimiter="\\t") + >>> ds = ray.data.read_csv( + ... "s3://anonymous@ray-example-data/iris.tsv", + ... parse_options=parse_options) + >>> ds.schema() + Column Type + ------ ---- + sepal.length double + sepal.width double + petal.length double + petal.width double + variety string + + Convert a date column with a custom format from a CSV file. For more uses of ConvertOptions see https://arrow.apache.org/docs/python/generated/pyarrow.csv.ConvertOptions.html # noqa: #501 + + >>> from pyarrow import csv + >>> convert_options = csv.ConvertOptions( + ... timestamp_parsers=["%m/%d/%Y"]) + >>> ds = ray.data.read_csv( + ... "s3://anonymous@ray-example-data/dow_jones.csv", + ... convert_options=convert_options) + + By default, :meth:`~ray.data.read_csv` parses + `Hive-style partitions `_ + from file paths. If your data adheres to a different partitioning scheme, set + the ``partitioning`` parameter. + + >>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/year=2022/month=09/sales.csv") + >>> ds.take(1) + [{'order_number': 10107, 'quantity': 30, 'year': '2022', 'month': '09'}] + + By default, :meth:`~ray.data.read_csv` reads all files from file paths. If you want to filter + files by file extensions, set the ``file_extensions`` parameter. + + Read only ``*.csv`` files from a directory. + + >>> ray.data.read_csv("s3://anonymous@ray-example-data/different-extensions/", + ... file_extensions=["csv"]) + Dataset(num_rows=?, schema={a: int64, b: int64}) + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are specified in the + `pyarrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. + Use with a custom callback to read only selected partitions of a + dataset. By default, no files are filtered. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. By default, this function parses + `Hive-style partitions `_. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that are not + found. Defaults to False. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + arrow_csv_args: CSV read options to pass to + `pyarrow.csv.open_csv `_ + when opening CSV files. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` producing records read from the specified paths. + """ + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + + datasource = CSVDatasource( + paths, + arrow_csv_args=arrow_csv_args, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def read_text( + paths: Union[str, List[str]], + *, + encoding: str = "utf-8", + drop_empty_lines: bool = True, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + ray_remote_args: Optional[Dict[str, Any]] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Partitioning = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from lines stored in text files. + + Examples: + Read a file in remote storage. + + >>> import ray + >>> ds = ray.data.read_text("s3://anonymous@ray-example-data/this.txt") + >>> ds.schema() + Column Type + ------ ---- + text string + + Read multiple local files. + + >>> ray.data.read_text( # doctest: +SKIP + ... ["local:///path/to/file1", "local:///path/to/file2"]) + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + encoding: The encoding of the files (e.g., "utf-8" or "ascii"). + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are specified in the + `PyArrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks and + in the subsequent text decoding map task. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. + Use with a custom callback to read only selected partitions of a + dataset. By default, no files are filtered. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that are not + found. Defaults to False. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` producing lines of text read from the specified + paths. + """ + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + + datasource = TextDatasource( + paths, + drop_empty_lines=drop_empty_lines, + encoding=encoding, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def read_avro( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + ray_remote_args: Optional[Dict[str, Any]] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Partitioning = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from records stored in Avro files. + + Examples: + Read an Avro file in remote storage or local storage. + + >>> import ray + >>> ds = ray.data.read_avro("s3://anonymous@ray-example-data/mnist.avro") + >>> ds.schema() + Column Type + ------ ---- + features list + label int64 + dataType string + + >>> ray.data.read_avro( # doctest: +SKIP + ... ["local:///path/to/file1", "local:///path/to/file2"]) + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are specified in the + `PyArrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks and + in the subsequent text decoding map task. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. + Use with a custom callback to read only selected partitions of a + dataset. By default, no files are filtered. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that are not + found. Defaults to False. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` holding records from the Avro files. + """ + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + + datasource = AvroDatasource( + paths, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def read_numpy( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Partitioning = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = NumpyDatasource._FILE_EXTENSIONS, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **numpy_load_args, +) -> Dataset: + """Create an Arrow dataset from numpy files. + + Examples: + Read a directory of files in remote storage. + + >>> import ray + >>> ray.data.read_numpy("s3://bucket/path") # doctest: +SKIP + + Read multiple local files. + + >>> ray.data.read_numpy(["/path/to/file1", "/path/to/file2"]) # doctest: +SKIP + + Read multiple directories. + + >>> ray.data.read_numpy( # doctest: +SKIP + ... ["s3://bucket/path1", "s3://bucket/path2"]) + + Args: + paths: A single file/directory path or a list of file/directory paths. + A list of paths can contain both files and directories. + filesystem: The filesystem implementation to read from. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_stream `_. + numpy_load_args: Other options to pass to np.load. + meta_provider: File metadata provider. Custom metadata providers may + be able to resolve file metadata more quickly and/or accurately. If + ``None``, this function uses a system-chosen implementation. + partition_filter: Path-based partition filter, if any. Can be used + with a custom callback to read only selected partitions of a dataset. + By default, this filters out any file paths whose file extension does not + match "*.npy*". + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that are not + found. Defaults to False. + shuffle: If setting to "files", randomly shuffle input files order before read. + if setting to ``FileShuffleConfig``, the random seed can be passed toshuffle the + input files, i.e. ``FileShuffleConfig(seed = 42)``. + Defaults to not shuffle with ``None``. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + Dataset holding Tensor records read from the specified paths. + """ # noqa: E501 + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + + datasource = NumpyDatasource( + paths, + numpy_load_args=numpy_load_args, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_tfrecords( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + tf_schema: Optional["schema_pb2.Schema"] = None, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + tfx_read_options: Optional["TFXReadOptions"] = None, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from TFRecord files that contain + `tf.train.Example `_ + messages. + + .. tip:: + Using the ``tfx-bsl`` library is more performant when reading large + datasets (for example, in production use cases). To use this + implementation, you must first install ``tfx-bsl``: + + 1. `pip install tfx_bsl --no-dependencies` + 2. Pass tfx_read_options to read_tfrecords, for example: + `ds = read_tfrecords(path, ..., tfx_read_options=TFXReadOptions())` + + .. warning:: + This function exclusively supports ``tf.train.Example`` messages. If a file + contains a message that isn't of type ``tf.train.Example``, then this function + fails. + + Examples: + >>> import ray + >>> ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords") + Dataset( + num_rows=?, + schema={...} + ) + + We can also read compressed TFRecord files, which use one of the + `compression types supported by Arrow `_: + + >>> ray.data.read_tfrecords( + ... "s3://anonymous@ray-example-data/iris.tfrecords.gz", + ... arrow_open_stream_args={"compression": "gzip"}, + ... ) + Dataset( + num_rows=?, + schema={...} + ) + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are specified in the + `PyArrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. To read a compressed TFRecord file, + pass the corresponding compression type (e.g., for ``GZIP`` or ``ZLIB``), + use ``arrow_open_stream_args={'compression': 'gzip'}``). + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. + Use with a custom callback to read only selected partitions of a + dataset. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that are not + found. Defaults to False. + tf_schema: Optional TensorFlow Schema which is used to explicitly set the schema + of the underlying Dataset. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + tfx_read_options: Specifies read options when reading TFRecord files with TFX. + When no options are provided, the default version without tfx-bsl will + be used to read the tfrecords. + Returns: + A :class:`~ray.data.Dataset` that contains the example features. + + Raises: + ValueError: If a file contains a message that isn't a ``tf.train.Example``. + """ + import platform + + _emit_meta_provider_deprecation_warning(meta_provider) + + tfx_read = False + + if tfx_read_options and platform.processor() != "arm": + try: + import tfx_bsl # noqa: F401 + + tfx_read = True + except ModuleNotFoundError: + # override the tfx_read_options given that tfx-bsl is not installed + tfx_read_options = None + logger.warning( + "Please install tfx-bsl package with" + " `pip install tfx_bsl --no-dependencies`." + " This can help speed up the reading of large TFRecord files." + ) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + datasource = TFRecordDatasource( + paths, + tf_schema=tf_schema, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + tfx_read_options=tfx_read_options, + ) + ds = read_datasource( + datasource, + parallelism=parallelism, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + if ( + tfx_read_options + and tfx_read_options.auto_infer_schema + and tfx_read + and not tf_schema + ): + from ray.data._internal.datasource.tfrecords_datasource import ( + _infer_schema_and_transform, + ) + + return _infer_schema_and_transform(ds) + + return ds + + +@PublicAPI(stability="alpha") +def read_webdataset( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + decoder: Optional[Union[bool, str, callable, list]] = True, + fileselect: Optional[Union[list, callable]] = None, + filerename: Optional[Union[list, callable]] = None, + suffixes: Optional[Union[list, callable]] = None, + verbose_open: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + include_paths: bool = False, + file_extensions: Optional[List[str]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + expand_json: bool = False, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from + `WebDataset `_ files. + + Args: + paths: A single file/directory path or a list of file/directory paths. + A list of paths can contain both files and directories. + filesystem: The filesystem implementation to read from. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + arrow_open_stream_args: Key-word arguments passed to + `pyarrow.fs.FileSystem.open_input_stream `_. + To read a compressed TFRecord file, + pass the corresponding compression type (e.g. for ``GZIP`` or ``ZLIB``, use + ``arrow_open_stream_args={'compression': 'gzip'}``). + meta_provider: File metadata provider. Custom metadata providers may + be able to resolve file metadata more quickly and/or accurately. If + ``None``, this function uses a system-chosen implementation. + partition_filter: Path-based partition filter, if any. Can be used + with a custom callback to read only selected partitions of a dataset. + decoder: A function or list of functions to decode the data. + fileselect: A callable or list of glob patterns to select files. + filerename: A function or list of tuples to rename files prior to grouping. + suffixes: A function or list of suffixes to select for creating samples. + verbose_open: Whether to print the file names as they are opened. + shuffle: If setting to "files", randomly shuffle input files order before read. + if setting to ``FileShuffleConfig``, the random seed can be passed toshuffle the + input files, i.e. ``FileShuffleConfig(seed = 42)``. + Defaults to not shuffle with ``None``. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + expand_json: If ``True``, expand JSON objects into individual samples. + Defaults to ``False``. + + Returns: + A :class:`~ray.data.Dataset` that contains the example features. + + Raises: + ValueError: If a file contains a message that isn't a `tf.train.Example`_. + + .. _tf.train.Example: https://www.tensorflow.org/api_docs/python/tf/train/Example + """ # noqa: E501 + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + + datasource = WebDatasetDatasource( + paths, + decoder=decoder, + fileselect=fileselect, + filerename=filerename, + suffixes=suffixes, + verbose_open=verbose_open, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + expand_json=expand_json, + ) + return read_datasource( + datasource, + parallelism=parallelism, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def read_binary_files( + paths: Union[str, List[str]], + *, + include_paths: bool = False, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + ray_remote_args: Dict[str, Any] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: Optional[BaseFileMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Partitioning = None, + ignore_missing_paths: bool = False, + shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None, + file_extensions: Optional[List[str]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from binary files of arbitrary contents. + + Examples: + Read a file in remote storage. + + >>> import ray + >>> path = "s3://anonymous@ray-example-data/pdf-sample_0.pdf" + >>> ds = ray.data.read_binary_files(path) + >>> ds.schema() + Column Type + ------ ---- + bytes binary + + Read multiple local files. + + >>> ray.data.read_binary_files( # doctest: +SKIP + ... ["local:///path/to/file1", "local:///path/to/file2"]) + + Read a file with the filepaths included as a column in the dataset. + + >>> path = "s3://anonymous@ray-example-data/pdf-sample_0.pdf" + >>> ds = ray.data.read_binary_files(path, include_paths=True) + >>> ds.take(1)[0]["path"] + 'ray-example-data/pdf-sample_0.pdf' + + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + include_paths: If ``True``, include the path to each file. File paths are + stored in the ``'path'`` column. + filesystem: The PyArrow filesystem + implementation to read from. These filesystems are specified in the + `PyArrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + meta_provider: [Deprecated] A :ref:`file metadata provider `. + Custom metadata providers may be able to resolve file metadata more quickly + and/or accurately. In most cases, you do not need to set this. If ``None``, + this function uses a system-chosen implementation. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. + Use with a custom callback to read only selected partitions of a + dataset. By default, no files are filtered. + By default, this does not filter out any files. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + ignore_missing_paths: If True, ignores any file paths in ``paths`` that are not + found. Defaults to False. + shuffle: If setting to "files", randomly shuffle input files order before read. + If setting to :class:`~ray.data.FileShuffleConfig`, you can pass a seed to + shuffle the input files. Defaults to not shuffle with ``None``. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` producing rows read from the specified paths. + """ + _emit_meta_provider_deprecation_warning(meta_provider) + + if meta_provider is None: + meta_provider = DefaultFileMetadataProvider() + + datasource = BinaryDatasource( + paths, + include_paths=include_paths, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_sql( + sql: str, + connection_factory: Callable[[], Connection], + *, + parallelism: int = -1, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Read from a database that provides a + `Python DB API2-compliant `_ connector. + + .. note:: + + By default, ``read_sql`` launches multiple read tasks, and each task executes a + ``LIMIT`` and ``OFFSET`` to fetch a subset of the rows. However, for many + databases, ``OFFSET`` is slow. + + As a workaround, set ``override_num_blocks=1`` to directly fetch all rows in a + single task. Note that this approach requires all result rows to fit in the + memory of single task. If the rows don't fit, your program may raise an out of + memory error. + + Examples: + + For examples of reading from larger databases like MySQL and PostgreSQL, see + :ref:`Reading from SQL Databases `. + + .. testcode:: + + import sqlite3 + + import ray + + # Create a simple database + connection = sqlite3.connect("example.db") + connection.execute("CREATE TABLE movie(title, year, score)") + connection.execute( + \"\"\" + INSERT INTO movie VALUES + ('Monty Python and the Holy Grail', 1975, 8.2), + ("Monty Python Live at the Hollywood Bowl", 1982, 7.9), + ("Monty Python's Life of Brian", 1979, 8.0), + ("Rocky II", 1979, 7.3) + \"\"\" + ) + connection.commit() + connection.close() + + def create_connection(): + return sqlite3.connect("example.db") + + # Get all movies + ds = ray.data.read_sql("SELECT * FROM movie", create_connection) + # Get movies after the year 1980 + ds = ray.data.read_sql( + "SELECT title, score FROM movie WHERE year >= 1980", create_connection + ) + # Get the number of movies per year + ds = ray.data.read_sql( + "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection + ) + + .. testcode:: + :hide: + + import os + os.remove("example.db") + + Args: + sql: The SQL query to execute. + connection_factory: A function that takes no arguments and returns a + Python DB API2 + `Connection object `_. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`Dataset` containing the queried data. + """ + if parallelism != -1 and parallelism != 1: + raise ValueError( + "To ensure correctness, 'read_sql' always launches one task. The " + "'parallelism' argument you specified can't be used." + ) + + datasource = SQLDatasource(sql=sql, connection_factory=connection_factory) + return read_datasource( + datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_databricks_tables( + *, + warehouse_id: str, + table: Optional[str] = None, + query: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + parallelism: int = -1, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Read a Databricks unity catalog table or Databricks SQL execution result. + + Before calling this API, set the ``DATABRICKS_TOKEN`` environment + variable to your Databricks warehouse access token. + + .. code-block:: console + + export DATABRICKS_TOKEN=... + + If you're not running your program on the Databricks runtime, also set the + ``DATABRICKS_HOST`` environment variable. + + .. code-block:: console + + export DATABRICKS_HOST=adb-..azuredatabricks.net + + .. note:: + + This function is built on the + `Databricks statement execution API `_. + + Examples: + + .. testcode:: + :skipif: True + + import ray + + ds = ray.data.read_databricks_tables( + warehouse_id='...', + catalog='catalog_1', + schema='db_1', + query='select id from table_1 limit 750000', + ) + + Args: + warehouse_id: The ID of the Databricks warehouse. The query statement is + executed on this warehouse. + table: The name of UC table you want to read. If this argument is set, + you can't set ``query`` argument, and the reader generates query + of ``select * from {table_name}`` under the hood. + query: The query you want to execute. If this argument is set, + you can't set ``table_name`` argument. + catalog: (Optional) The default catalog name used by the query. + schema: (Optional) The default schema used by the query. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`Dataset` containing the queried data. + """ # noqa: E501 + from ray.data._internal.datasource.databricks_uc_datasource import ( + DatabricksUCDatasource, + ) + from ray.util.spark.utils import get_spark_session, is_in_databricks_runtime + + def get_dbutils(): + no_dbutils_error = RuntimeError("No dbutils module found.") + try: + import IPython + + ip_shell = IPython.get_ipython() + if ip_shell is None: + raise no_dbutils_error + return ip_shell.ns_table["user_global"]["dbutils"] + except ImportError: + raise no_dbutils_error + except KeyError: + raise no_dbutils_error + + token = os.environ.get("DATABRICKS_TOKEN") + + if not token: + raise ValueError( + "Please set environment variable 'DATABRICKS_TOKEN' to " + "databricks workspace access token." + ) + + host = os.environ.get("DATABRICKS_HOST") + if not host: + if is_in_databricks_runtime(): + ctx = ( + get_dbutils().notebook.entry_point.getDbutils().notebook().getContext() + ) + host = ctx.tags().get("browserHostName").get() + else: + raise ValueError( + "You are not in databricks runtime, please set environment variable " + "'DATABRICKS_HOST' to databricks workspace URL" + '(e.g. "adb-..azuredatabricks.net").' + ) + + if not catalog: + catalog = get_spark_session().sql("SELECT CURRENT_CATALOG()").collect()[0][0] + + if not schema: + schema = get_spark_session().sql("SELECT CURRENT_DATABASE()").collect()[0][0] + + if query is not None and table is not None: + raise ValueError("Only one of 'query' and 'table' arguments can be set.") + + if table: + query = f"select * from {table}" + + if query is None: + raise ValueError("One of 'query' and 'table' arguments should be set.") + + datasource = DatabricksUCDatasource( + host=host, + token=token, + warehouse_id=warehouse_id, + catalog=catalog, + schema=schema, + query=query, + ) + return read_datasource( + datasource=datasource, + parallelism=parallelism, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_hudi( + table_uri: str, + *, + storage_options: Optional[Dict[str, str]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """ + Create a :class:`~ray.data.Dataset` from an + `Apache Hudi table `_. + + Examples: + >>> import ray + >>> ds = ray.data.read_hudi( # doctest: +SKIP + ... table_uri="/hudi/trips", + ... ) + + Args: + table_uri: The URI of the Hudi table to read from. Local file paths, S3, and GCS + are supported. + storage_options: Extra options that make sense for a particular storage + connection. This is used to store connection parameters like credentials, + endpoint, etc. See more explanation + `here `_. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` producing records read from the Hudi table. + """ # noqa: E501 + datasource = HudiDatasource( + table_uri=table_uri, + storage_options=storage_options, + ) + + return read_datasource( + datasource=datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def from_dask(df: "dask.dataframe.DataFrame") -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a + `Dask DataFrame `_. + + Args: + df: A `Dask DataFrame`_. + + Returns: + A :class:`~ray.data.MaterializedDataset` holding rows read from the DataFrame. + """ # noqa: E501 + import dask + + from ray.util.dask import ray_dask_get + + partitions = df.to_delayed() + persisted_partitions = dask.persist(*partitions, scheduler=ray_dask_get) + + import pandas + + def to_ref(df): + if isinstance(df, pandas.DataFrame): + return ray.put(df) + elif isinstance(df, ray.ObjectRef): + return df + else: + raise ValueError( + "Expected a Ray object ref or a Pandas DataFrame, " f"got {type(df)}" + ) + + ds = from_pandas_refs( + [to_ref(next(iter(part.dask.values()))) for part in persisted_partitions], + ) + return ds + + +@PublicAPI +def from_mars(df: "mars.dataframe.DataFrame") -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a + `Mars DataFrame `_. + + Args: + df: A `Mars DataFrame`_, which must be executed by Mars-on-Ray. + + Returns: + A :class:`~ray.data.MaterializedDataset` holding rows read from the DataFrame. + """ # noqa: E501 + import mars.dataframe as md + + ds: Dataset = md.to_ray_dataset(df) + return ds + + +@PublicAPI +def from_modin(df: "modin.pandas.dataframe.DataFrame") -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a + `Modin DataFrame `_. + + Args: + df: A `Modin DataFrame`_, which must be using the Ray backend. + + Returns: + A :class:`~ray.data.MaterializedDataset` rows read from the DataFrame. + """ # noqa: E501 + from modin.distributed.dataframe.pandas.partitions import unwrap_partitions + + parts = unwrap_partitions(df, axis=0) + ds = from_pandas_refs(parts) + return ds + + +@PublicAPI +def from_pandas( + dfs: Union["pandas.DataFrame", List["pandas.DataFrame"]], + override_num_blocks: Optional[int] = None, +) -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a list of pandas dataframes. + + Examples: + >>> import pandas as pd + >>> import ray + >>> df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> ray.data.from_pandas(df) + MaterializedDataset(num_blocks=1, num_rows=3, schema={a: int64, b: int64}) + + Create a Ray Dataset from a list of Pandas DataFrames. + + >>> ray.data.from_pandas([df, df]) + MaterializedDataset(num_blocks=2, num_rows=6, schema={a: int64, b: int64}) + + Args: + dfs: A pandas dataframe or a list of pandas dataframes. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + :class:`~ray.data.Dataset` holding data read from the dataframes. + """ + import pandas as pd + + if isinstance(dfs, pd.DataFrame): + dfs = [dfs] + + if override_num_blocks is not None: + if len(dfs) > 1: + # I assume most users pass a single DataFrame as input. For simplicity, I'm + # concatenating DataFrames, even though it's not efficient. + ary = pd.concat(dfs, axis=0) + else: + ary = dfs[0] + dfs = np.array_split(ary, override_num_blocks) + + from ray.air.util.data_batch_conversion import ( + _cast_ndarray_columns_to_tensor_extension, + ) + + context = DataContext.get_current() + if context.enable_tensor_extension_casting: + dfs = [_cast_ndarray_columns_to_tensor_extension(df.copy()) for df in dfs] + + return from_pandas_refs([ray.put(df) for df in dfs]) + + +@DeveloperAPI +def from_pandas_refs( + dfs: Union[ObjectRef["pandas.DataFrame"], List[ObjectRef["pandas.DataFrame"]]], +) -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a list of Ray object references to + pandas dataframes. + + Examples: + >>> import pandas as pd + >>> import ray + >>> df_ref = ray.put(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) + >>> ray.data.from_pandas_refs(df_ref) + MaterializedDataset(num_blocks=1, num_rows=3, schema={a: int64, b: int64}) + + Create a Ray Dataset from a list of Pandas Dataframes references. + + >>> ray.data.from_pandas_refs([df_ref, df_ref]) + MaterializedDataset(num_blocks=2, num_rows=6, schema={a: int64, b: int64}) + + Args: + dfs: A Ray object reference to a pandas dataframe, or a list of + Ray object references to pandas dataframes. + + Returns: + :class:`~ray.data.Dataset` holding data read from the dataframes. + """ + if isinstance(dfs, ray.ObjectRef): + dfs = [dfs] + elif isinstance(dfs, list): + for df in dfs: + if not isinstance(df, ray.ObjectRef): + raise ValueError( + "Expected list of Ray object refs, " + f"got list containing {type(df)}" + ) + else: + raise ValueError( + "Expected Ray object ref or list of Ray object refs, " f"got {type(df)}" + ) + + context = DataContext.get_current() + if context.enable_pandas_block: + get_metadata = cached_remote_fn(get_table_block_metadata) + metadata = ray.get([get_metadata.remote(df) for df in dfs]) + execution_plan = ExecutionPlan( + DatasetStats(metadata={"FromPandas": metadata}, parent=None) + ) + logical_plan = LogicalPlan(FromPandas(dfs, metadata), execution_plan._context) + return MaterializedDataset( + execution_plan, + logical_plan, + ) + + df_to_block = cached_remote_fn(pandas_df_to_arrow_block, num_returns=2) + + res = [df_to_block.remote(df) for df in dfs] + blocks, metadata = map(list, zip(*res)) + metadata = ray.get(metadata) + execution_plan = ExecutionPlan( + DatasetStats(metadata={"FromPandas": metadata}, parent=None) + ) + logical_plan = LogicalPlan(FromPandas(blocks, metadata), execution_plan._context) + return MaterializedDataset( + execution_plan, + logical_plan, + ) + + +@PublicAPI +def from_numpy(ndarrays: Union[np.ndarray, List[np.ndarray]]) -> MaterializedDataset: + """Creates a :class:`~ray.data.Dataset` from a list of NumPy ndarrays. + + Examples: + >>> import numpy as np + >>> import ray + >>> arr = np.array([1]) + >>> ray.data.from_numpy(arr) + MaterializedDataset(num_blocks=1, num_rows=1, schema={data: int64}) + + Create a Ray Dataset from a list of NumPy arrays. + + >>> ray.data.from_numpy([arr, arr]) + MaterializedDataset(num_blocks=2, num_rows=2, schema={data: int64}) + + Args: + ndarrays: A NumPy ndarray or a list of NumPy ndarrays. + + Returns: + :class:`~ray.data.Dataset` holding data from the given ndarrays. + """ + if isinstance(ndarrays, np.ndarray): + ndarrays = [ndarrays] + + return from_numpy_refs([ray.put(ndarray) for ndarray in ndarrays]) + + +@DeveloperAPI +def from_numpy_refs( + ndarrays: Union[ObjectRef[np.ndarray], List[ObjectRef[np.ndarray]]], +) -> MaterializedDataset: + """Creates a :class:`~ray.data.Dataset` from a list of Ray object references to + NumPy ndarrays. + + Examples: + >>> import numpy as np + >>> import ray + >>> arr_ref = ray.put(np.array([1])) + >>> ray.data.from_numpy_refs(arr_ref) + MaterializedDataset(num_blocks=1, num_rows=1, schema={data: int64}) + + Create a Ray Dataset from a list of NumPy array references. + + >>> ray.data.from_numpy_refs([arr_ref, arr_ref]) + MaterializedDataset(num_blocks=2, num_rows=2, schema={data: int64}) + + Args: + ndarrays: A Ray object reference to a NumPy ndarray or a list of Ray object + references to NumPy ndarrays. + + Returns: + :class:`~ray.data.Dataset` holding data from the given ndarrays. + """ + if isinstance(ndarrays, ray.ObjectRef): + ndarrays = [ndarrays] + elif isinstance(ndarrays, list): + for ndarray in ndarrays: + if not isinstance(ndarray, ray.ObjectRef): + raise ValueError( + "Expected list of Ray object refs, " + f"got list containing {type(ndarray)}" + ) + else: + raise ValueError( + f"Expected Ray object ref or list of Ray object refs, got {type(ndarray)}" + ) + + ctx = DataContext.get_current() + ndarray_to_block_remote = cached_remote_fn(ndarray_to_block, num_returns=2) + + res = [ndarray_to_block_remote.remote(ndarray, ctx) for ndarray in ndarrays] + blocks, metadata = map(list, zip(*res)) + metadata = ray.get(metadata) + + execution_plan = ExecutionPlan( + DatasetStats(metadata={"FromNumpy": metadata}, parent=None) + ) + logical_plan = LogicalPlan(FromNumpy(blocks, metadata), execution_plan._context) + + return MaterializedDataset( + execution_plan, + logical_plan, + ) + + +@PublicAPI +def from_arrow( + tables: Union["pyarrow.Table", bytes, List[Union["pyarrow.Table", bytes]]], +) -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a list of PyArrow tables. + + Examples: + >>> import pyarrow as pa + >>> import ray + >>> table = pa.table({"x": [1]}) + >>> ray.data.from_arrow(table) + MaterializedDataset(num_blocks=1, num_rows=1, schema={x: int64}) + + Create a Ray Dataset from a list of PyArrow tables. + + >>> ray.data.from_arrow([table, table]) + MaterializedDataset(num_blocks=2, num_rows=2, schema={x: int64}) + + + Args: + tables: A PyArrow table, or a list of PyArrow tables, + or its streaming format in bytes. + + Returns: + :class:`~ray.data.Dataset` holding data from the PyArrow tables. + """ + import pyarrow as pa + + if isinstance(tables, (pa.Table, bytes)): + tables = [tables] + return from_arrow_refs([ray.put(t) for t in tables]) + + +@DeveloperAPI +def from_arrow_refs( + tables: Union[ + ObjectRef[Union["pyarrow.Table", bytes]], + List[ObjectRef[Union["pyarrow.Table", bytes]]], + ], +) -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a list of Ray object references to + PyArrow tables. + + Examples: + >>> import pyarrow as pa + >>> import ray + >>> table_ref = ray.put(pa.table({"x": [1]})) + >>> ray.data.from_arrow_refs(table_ref) + MaterializedDataset(num_blocks=1, num_rows=1, schema={x: int64}) + + Create a Ray Dataset from a list of PyArrow table references + + >>> ray.data.from_arrow_refs([table_ref, table_ref]) + MaterializedDataset(num_blocks=2, num_rows=2, schema={x: int64}) + + + Args: + tables: A Ray object reference to Arrow table, or list of Ray object + references to Arrow tables, or its streaming format in bytes. + + Returns: + :class:`~ray.data.Dataset` holding data read from the tables. + """ + if isinstance(tables, ray.ObjectRef): + tables = [tables] + + get_metadata = cached_remote_fn(get_table_block_metadata) + metadata = ray.get([get_metadata.remote(t) for t in tables]) + execution_plan = ExecutionPlan( + DatasetStats(metadata={"FromArrow": metadata}, parent=None) + ) + logical_plan = LogicalPlan(FromArrow(tables, metadata), execution_plan._context) + + return MaterializedDataset( + execution_plan, + logical_plan, + ) + + +@PublicAPI(stability="alpha") +def read_delta_sharing_tables( + url: str, + *, + limit: Optional[int] = None, + version: Optional[int] = None, + timestamp: Optional[str] = None, + json_predicate_hints: Optional[str] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """ + Read data from a Delta Sharing table. + Delta Sharing projct https://github.com/delta-io/delta-sharing/tree/main + + This function reads data from a Delta Sharing table specified by the URL. + It supports various options such as limiting the number of rows, specifying + a version or timestamp, and configuring concurrency. + + Before calling this function, ensure that the URL is correctly formatted + to point to the Delta Sharing table you want to access. Make sure you have + a valid delta_share profile in the working directory. + + Examples: + + .. testcode:: + :skipif: True + + import ray + + ds = ray.data.read_delta_sharing_tables( + url=f"your-profile.json#your-share-name.your-schema-name.your-table-name", + limit=100000, + version=1, + ) + + Args: + url: A URL under the format + "#..". + Example can be found at + https://github.com/delta-io/delta-sharing/blob/main/README.md#quick-start + limit: A non-negative integer. Load only the ``limit`` rows if the + parameter is specified. Use this optional parameter to explore the + shared table without loading the entire table into memory. + version: A non-negative integer. Load the snapshot of the table at + the specified version. + timestamp: A timestamp to specify the version of the table to read. + json_predicate_hints: Predicate hints to be applied to the table. For more + details, see: + https://github.com/delta-io/delta-sharing/blob/main/PROTOCOL.md#json-predicates-for-filtering. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control the number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`Dataset` containing the queried data. + + Raises: + ValueError: If the URL is not properly formatted or if there is an issue + with the Delta Sharing table connection. + """ + + datasource = DeltaSharingDatasource( + url=url, + json_predicate_hints=json_predicate_hints, + limit=limit, + version=version, + timestamp=timestamp, + ) + # DeltaSharing limit is at the add_files level, it will not return + # exactly the limit number of rows but it will return less files and rows. + return ray.data.read_datasource( + datasource=datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI +def from_spark( + df: "pyspark.sql.DataFrame", + *, + parallelism: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a + `Spark DataFrame `_. + + Args: + df: A `Spark DataFrame`_, which must be created by RayDP (Spark-on-Ray). + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.MaterializedDataset` holding rows read from the DataFrame. + """ # noqa: E501 + import raydp + + parallelism = _get_num_output_blocks(parallelism, override_num_blocks) + return raydp.spark.spark_dataframe_to_ray_dataset(df, parallelism) + + +@PublicAPI +def from_huggingface( + dataset: Union["datasets.Dataset", "datasets.IterableDataset"], + parallelism: int = -1, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Union[MaterializedDataset, Dataset]: + """Create a :class:`~ray.data.MaterializedDataset` from a + `Hugging Face Datasets Dataset `_ + or a :class:`~ray.data.Dataset` from a `Hugging Face Datasets IterableDataset `_. + For an `IterableDataset`, we use a streaming implementation to read data. + + If the dataset is a public Hugging Face Dataset that is hosted on the Hugging Face Hub and + no transformations have been applied, then the `hosted parquet files `_ + will be passed to :meth:`~ray.data.read_parquet` to perform a distributed read. All + other cases will be done with a single node read. + + Example: + + .. + The following `testoutput` is mocked to avoid illustrating download + logs like "Downloading and preparing dataset 162.17 MiB". + + .. testcode:: + + import ray + import datasets + + hf_dataset = datasets.load_dataset("tweet_eval", "emotion") + ray_ds = ray.data.from_huggingface(hf_dataset["train"]) + print(ray_ds) + + hf_dataset_stream = datasets.load_dataset("tweet_eval", "emotion", streaming=True) + ray_ds_stream = ray.data.from_huggingface(hf_dataset_stream["train"]) + print(ray_ds_stream) + + .. testoutput:: + :options: +MOCK + + MaterializedDataset( + num_blocks=..., + num_rows=3257, + schema={text: string, label: int64} + ) + Dataset( + num_rows=3257, + schema={text: string, label: int64} + ) + + Args: + dataset: A `Hugging Face Datasets Dataset`_ or `Hugging Face Datasets IterableDataset`_. + `DatasetDict `_ + and `IterableDatasetDict `_ + are not supported. + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` holding rows from the `Hugging Face Datasets Dataset`_. + """ # noqa: E501 + import datasets + from aiohttp.client_exceptions import ClientResponseError + + from ray.data._internal.datasource.huggingface_datasource import ( + HuggingFaceDatasource, + ) + + if isinstance(dataset, (datasets.IterableDataset, datasets.Dataset)): + try: + # Attempt to read data via Hugging Face Hub parquet files. If the + # returned list of files is empty, attempt read via other methods. + file_urls = HuggingFaceDatasource.list_parquet_urls_from_dataset(dataset) + if len(file_urls) > 0: + # If file urls are returned, the parquet files are available via API + # TODO: Add support for reading from http filesystem in + # FileBasedDatasource. GH Issue: + # https://github.com/ray-project/ray/issues/42706 + import fsspec.implementations.http + + http = fsspec.implementations.http.HTTPFileSystem() + return read_parquet( + file_urls, + parallelism=parallelism, + filesystem=http, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ray_remote_args={ + "retry_exceptions": [FileNotFoundError, ClientResponseError] + }, + ) + except (FileNotFoundError, ClientResponseError): + logger.warning( + "Distrubuted read via Hugging Face Hub parquet files failed, " + "falling back on single node read." + ) + + if isinstance(dataset, datasets.IterableDataset): + # For an IterableDataset, we can use a streaming implementation to read data. + return read_datasource( + HuggingFaceDatasource(dataset=dataset), + parallelism=parallelism, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + if isinstance(dataset, datasets.Dataset): + # For non-streaming Hugging Face Dataset, we don't support override_num_blocks + if override_num_blocks is not None: + raise ValueError( + "`override_num_blocks` parameter is not supported for " + "streaming Hugging Face Datasets. Please omit the parameter or " + "use non-streaming mode to read the dataset." + ) + + # To get the resulting Arrow table from a Hugging Face Dataset after + # applying transformations (e.g., train_test_split(), shard(), select()), + # we create a copy of the Arrow table, which applies the indices + # mapping from the transformations. + hf_ds_arrow = dataset.with_format("arrow") + ray_ds = from_arrow(hf_ds_arrow[:]) + return ray_ds + elif isinstance(dataset, (datasets.DatasetDict, datasets.IterableDatasetDict)): + available_keys = list(dataset.keys()) + raise DeprecationWarning( + "You provided a Hugging Face DatasetDict or IterableDatasetDict, " + "which contains multiple datasets, but `from_huggingface` now " + "only accepts a single Hugging Face Dataset. To convert just " + "a single Hugging Face Dataset to a Ray Dataset, specify a split. " + "For example, `ray.data.from_huggingface(my_dataset_dictionary" + f"['{available_keys[0]}'])`. " + f"Available splits are {available_keys}." + ) + else: + raise TypeError( + f"`dataset` must be a `datasets.Dataset`, but got {type(dataset)}" + ) + + +@PublicAPI +def from_tf( + dataset: "tf.data.Dataset", +) -> MaterializedDataset: + """Create a :class:`~ray.data.Dataset` from a + `TensorFlow Dataset `_. + + This function is inefficient. Use it to read small datasets or prototype. + + .. warning:: + If your dataset is large, this function may execute slowly or raise an + out-of-memory error. To avoid issues, read the underyling data with a function + like :meth:`~ray.data.read_images`. + + .. note:: + This function isn't parallelized. It loads the entire dataset into the local + node's memory before moving the data to the distributed object store. + + Examples: + >>> import ray + >>> import tensorflow_datasets as tfds + >>> dataset, _ = tfds.load('cifar10', split=["train", "test"]) # doctest: +SKIP + >>> ds = ray.data.from_tf(dataset) # doctest: +SKIP + >>> ds # doctest: +SKIP + MaterializedDataset( + num_blocks=..., + num_rows=50000, + schema={ + id: binary, + image: numpy.ndarray(shape=(32, 32, 3), dtype=uint8), + label: int64 + } + ) + >>> ds.take(1) # doctest: +SKIP + [{'id': b'train_16399', 'image': array([[[143, 96, 70], + [141, 96, 72], + [135, 93, 72], + ..., + [ 96, 37, 19], + [105, 42, 18], + [104, 38, 20]], + ..., + [[195, 161, 126], + [187, 153, 123], + [186, 151, 128], + ..., + [212, 177, 147], + [219, 185, 155], + [221, 187, 157]]], dtype=uint8), 'label': 7}] + + Args: + dataset: A `TensorFlow Dataset`_. + + Returns: + A :class:`MaterializedDataset` that contains the samples stored in the `TensorFlow Dataset`_. + """ # noqa: E501 + # FIXME: `as_numpy_iterator` errors if `dataset` contains ragged tensors. + return from_items(list(dataset.as_numpy_iterator())) + + +@PublicAPI +def from_torch( + dataset: "torch.utils.data.Dataset", + local_read: bool = False, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from a + `Torch Dataset `_. + + .. note:: + The input dataset can either be map-style or iterable-style, and can have arbitrarily large amount of data. + The data will be sequentially streamed with one single read task. + + Examples: + >>> import ray + >>> from torchvision import datasets + >>> dataset = datasets.MNIST("data", download=True) # doctest: +SKIP + >>> ds = ray.data.from_torch(dataset) # doctest: +SKIP + >>> ds # doctest: +SKIP + MaterializedDataset(num_blocks=..., num_rows=60000, schema={item: object}) + >>> ds.take(1) # doctest: +SKIP + {"item": (, 5)} + + Args: + dataset: A `Torch Dataset`_. + local_read: If ``True``, perform the read as a local read. + + Returns: + A :class:`~ray.data.Dataset` containing the Torch dataset samples. + """ # noqa: E501 + + # Files may not be accessible from all nodes, run the read task on current node. + ray_remote_args = {} + if local_read: + ray_remote_args = { + "scheduling_strategy": NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), + soft=False, + ), + # The user might have initialized Ray to have num_cpus = 0 for the head + # node. For a local read we expect the read task to be executed on the + # head node, so we should set num_cpus = 0 for the task to allow it to + # run regardless of the user's head node configuration. + "num_cpus": 0, + } + return read_datasource( + TorchDatasource(dataset=dataset), + ray_remote_args=ray_remote_args, + # Only non-parallel, streaming read is currently supported + override_num_blocks=1, + ) + + +@PublicAPI +def read_iceberg( + *, + table_identifier: str, + row_filter: Union[str, "BooleanExpression"] = None, + parallelism: int = -1, + selected_fields: Tuple[str, ...] = ("*",), + snapshot_id: Optional[int] = None, + scan_kwargs: Optional[Dict[str, str]] = None, + catalog_kwargs: Optional[Dict[str, str]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """Create a :class:`~ray.data.Dataset` from an Iceberg table. + + The table to read from is specified using a fully qualified ``table_identifier``. + Using PyIceberg, any intended row filters, selection of specific fields and + picking of a particular snapshot ID are applied, and the files that satisfy + the query are distributed across Ray read tasks. + The number of output blocks is determined by ``override_num_blocks`` + which can be requested from this interface or automatically chosen if + unspecified. + + .. tip:: + + For more details on PyIceberg, see + - URI: https://py.iceberg.apache.org/ + + Examples: + >>> import ray + >>> from pyiceberg.expressions import EqualTo #doctest: +SKIP + >>> ds = ray.data.read_iceberg( #doctest: +SKIP + ... table_identifier="db_name.table_name", + ... row_filter=EqualTo("column_name", "literal_value"), + ... catalog_kwargs={"name": "default", "type": "glue"} + ... ) + + Args: + table_identifier: Fully qualified table identifier (``db_name.table_name``) + row_filter: A PyIceberg :class:`~pyiceberg.expressions.BooleanExpression` + to use to filter the data *prior* to reading + parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. + selected_fields: Which columns from the data to read, passed directly to + PyIceberg's load functions. Should be an tuple of string column names. + snapshot_id: Optional snapshot ID for the Iceberg table, by default the latest + snapshot is used + scan_kwargs: Optional arguments to pass to PyIceberg's Table.scan() function + (e.g., case_sensitive, limit, etc.) + catalog_kwargs: Optional arguments to pass to PyIceberg's catalog.load_catalog() + function (e.g., name, type, etc.). For the function definition, see + `pyiceberg catalog + `_. + ray_remote_args: Optional arguments to pass to :func:`ray.remote` in the + read tasks. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources, and capped at the number of + physical files to be read. You shouldn't manually set this value in most + cases. + + Returns: + :class:`~ray.data.Dataset` with rows from the Iceberg table. + """ + + # Setup the Datasource + datasource = IcebergDatasource( + table_identifier=table_identifier, + row_filter=row_filter, + selected_fields=selected_fields, + snapshot_id=snapshot_id, + scan_kwargs=scan_kwargs, + catalog_kwargs=catalog_kwargs, + ) + + dataset = read_datasource( + datasource=datasource, + parallelism=parallelism, + override_num_blocks=override_num_blocks, + ray_remote_args=ray_remote_args, + ) + + return dataset + + +@PublicAPI +def read_lance( + uri: str, + *, + columns: Optional[List[str]] = None, + filter: Optional[str] = None, + storage_options: Optional[Dict[str, str]] = None, + scanner_options: Optional[Dict[str, Any]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """ + Create a :class:`~ray.data.Dataset` from a + `Lance Dataset `_. + + Examples: + >>> import ray + >>> ds = ray.data.read_lance( # doctest: +SKIP + ... uri="./db_name.lance", + ... columns=["image", "label"], + ... filter="label = 2 AND text IS NOT NULL", + ... ) + + Args: + uri: The URI of the Lance dataset to read from. Local file paths, S3, and GCS + are supported. + columns: The columns to read. By default, all columns are read. + filter: Read returns only the rows matching the filter. By default, no + filter is applied. + storage_options: Extra options that make sense for a particular storage + connection. This is used to store connection parameters like credentials, + endpoint, etc. For more information, see `Object Store Configuration `_. + scanner_options: Additional options to configure the `LanceDataset.scanner()` + method, such as `batch_size`. For more information, + see `LanceDB API doc `_ + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` producing records read from the Lance dataset. + """ # noqa: E501 + datasource = LanceDatasource( + uri=uri, + columns=columns, + filter=filter, + storage_options=storage_options, + scanner_options=scanner_options, + ) + + return read_datasource( + datasource=datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_clickhouse( + *, + table: str, + dsn: str, + columns: Optional[List[str]] = None, + filter: Optional[str] = None, + order_by: Optional[Tuple[List[str], bool]] = None, + client_settings: Optional[Dict[str, Any]] = None, + client_kwargs: Optional[Dict[str, Any]] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, +) -> Dataset: + """ + Create a :class:`~ray.data.Dataset` from a ClickHouse table or view. + + Examples: + >>> import ray + >>> ds = ray.data.read_clickhouse( # doctest: +SKIP + ... table="default.table", + ... dsn="clickhouse+http://username:password@host:8124/default", + ... columns=["timestamp", "age", "status", "text", "label"], + ... filter="age > 18 AND status = 'active'", + ... order_by=(["timestamp"], False), + ... ) + + Args: + table: Fully qualified table or view identifier (e.g., + "default.table_name"). + dsn: A string in standard DSN (Data Source Name) HTTP format (e.g., + "clickhouse+http://username:password@host:8124/default"). + For more information, see `ClickHouse Connection String doc + `_. + columns: Optional list of columns to select from the data source. + If no columns are specified, all columns will be selected by default. + filter: Optional SQL filter string that will be used in the WHERE statement + (e.g., "label = 2 AND text IS NOT NULL"). The filter string must be valid for use in + a ClickHouse SQL WHERE clause. Please Note: Parallel reads are not currently supported + when a filter is set. Specifying a filter forces the parallelism to 1 to ensure + deterministic and consistent results. For more information, see `ClickHouse SQL WHERE Clause doc + `_. + order_by: Optional tuple containing a list of columns to order by and a boolean indicating whether the order + should be descending (True for DESC, False for ASC). Please Note: order_by is required to support + parallelism. If not provided, the data will be read in a single task. This is to ensure + that the data is read in a consistent order across all tasks. + client_settings: Optional ClickHouse server settings to be used with the session/every request. + For more information, see `ClickHouse Client Settings + `_. + client_kwargs: Optional additional arguments to pass to the ClickHouse client. For more information, + see `ClickHouse Core Settings `_. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + + Returns: + A :class:`~ray.data.Dataset` producing records read from the ClickHouse table or view. + """ # noqa: E501 + datasource = ClickHouseDatasource( + table=table, + dsn=dsn, + columns=columns, + filter=filter, + order_by=order_by, + client_settings=client_settings, + client_kwargs=client_kwargs, + ) + + return read_datasource( + datasource=datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +def _get_datasource_or_legacy_reader( + ds: Datasource, + ctx: DataContext, + kwargs: dict, +) -> Union[Datasource, Reader]: + """Generates reader. + + Args: + ds: Datasource to read from. + ctx: Dataset config to use. + kwargs: Additional kwargs to pass to the legacy reader if + `Datasource.create_reader` is implemented. + + Returns: + The datasource or a generated legacy reader. + """ + kwargs = _unwrap_arrow_serialization_workaround(kwargs) + + DataContext._set_current(ctx) + + if ds.should_create_reader: + warnings.warn( + "`create_reader` has been deprecated in Ray 2.9. Instead of creating a " + "`Reader`, implement `Datasource.get_read_tasks` and " + "`Datasource.estimate_inmemory_data_size`.", + DeprecationWarning, + ) + datasource_or_legacy_reader = ds.create_reader(**kwargs) + else: + datasource_or_legacy_reader = ds + + return datasource_or_legacy_reader + + +def _resolve_parquet_args( + tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None, + **arrow_parquet_args, +) -> Dict[str, Any]: + if tensor_column_schema is not None: + existing_block_udf = arrow_parquet_args.pop("_block_udf", None) + + def _block_udf(block: "pyarrow.Table") -> "pyarrow.Table": + from ray.data.extensions import ArrowTensorArray + + for tensor_col_name, (dtype, shape) in tensor_column_schema.items(): + # NOTE(Clark): We use NumPy to consolidate these potentially + # non-contiguous buffers, and to do buffer bookkeeping in + # general. + np_col = _create_possibly_ragged_ndarray( + [ + np.ndarray(shape, buffer=buf.as_buffer(), dtype=dtype) + for buf in block.column(tensor_col_name) + ] + ) + + block = block.set_column( + block._ensure_integer_index(tensor_col_name), + tensor_col_name, + ArrowTensorArray.from_numpy(np_col, tensor_col_name), + ) + if existing_block_udf is not None: + # Apply UDF after casting the tensor columns. + block = existing_block_udf(block) + return block + + arrow_parquet_args["_block_udf"] = _block_udf + return arrow_parquet_args + + +def _get_num_output_blocks( + parallelism: int = -1, + override_num_blocks: Optional[int] = None, +) -> int: + if parallelism != -1: + logger.warning( + "The argument ``parallelism`` is deprecated in Ray 2.10. Please specify " + "argument ``override_num_blocks`` instead." + ) + elif override_num_blocks is not None: + parallelism = override_num_blocks + return parallelism + + +def _emit_meta_provider_deprecation_warning( + meta_provider: Optional[BaseFileMetadataProvider], +) -> None: + if meta_provider is not None: + warnings.warn( + "The `meta_provider` argument is deprecated and will be removed after May " + "2025.", + DeprecationWarning, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/includes/__init__.pxd b/.venv/lib/python3.11/site-packages/ray/includes/__init__.pxd new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/includes/common.pxd b/.venv/lib/python3.11/site-packages/ray/includes/common.pxd new file mode 100644 index 0000000000000000000000000000000000000000..a55e101758b33f5f646e6086fb6969f24001f49a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/common.pxd @@ -0,0 +1,749 @@ +from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr, unique_ptr +from libcpp.string cimport string as c_string + +from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t, uint32_t +from libcpp.unordered_map cimport unordered_map +from libcpp.vector cimport vector as c_vector +from libcpp.pair cimport pair as c_pair +from ray.includes.optional cimport ( + optional, +) +from ray.includes.unique_ids cimport ( + CActorID, + CJobID, + CClusterID, + CWorkerID, + CObjectID, + CTaskID, + CPlacementGroupID, + CNodeID, +) +from ray.includes.function_descriptor cimport ( + CFunctionDescriptor, +) + + +cdef extern from * namespace "polyfill" nogil: + """ + namespace polyfill { + + template + inline typename std::remove_reference::type&& move(T& t) { + return std::move(t); + } + + template + inline typename std::remove_reference::type&& move(T&& t) { + return std::move(t); + } + + } // namespace polyfill + """ + cdef T move[T](T) + + +cdef extern from "ray/common/status.h" namespace "ray" nogil: + # TODO(ryw) in Cython 3.x we can directly use `cdef enum class CStatusCode` + cdef cppclass CStatusCode "ray::StatusCode": + pass + cdef CStatusCode CStatusCode_OK "ray::StatusCode::OK" + c_bool operator==(CStatusCode lhs, CStatusCode rhs) + + cdef cppclass CRayStatus "ray::Status": + CRayStatus() + CRayStatus(CStatusCode code, const c_string &msg) + CRayStatus(CStatusCode code, const c_string &msg, int rpc_code) + CRayStatus(const CRayStatus &s) + + @staticmethod + CRayStatus OK() + + @staticmethod + CRayStatus OutOfMemory(const c_string &msg) + + @staticmethod + CRayStatus KeyError(const c_string &msg) + + @staticmethod + CRayStatus Invalid(const c_string &msg) + + @staticmethod + CRayStatus IOError(const c_string &msg) + + @staticmethod + CRayStatus TypeError(const c_string &msg) + + @staticmethod + CRayStatus UnknownError(const c_string &msg) + + @staticmethod + CRayStatus NotImplemented(const c_string &msg) + + @staticmethod + CRayStatus ObjectStoreFull(const c_string &msg) + + @staticmethod + CRayStatus RedisError(const c_string &msg) + + @staticmethod + CRayStatus TimedOut(const c_string &msg) + + @staticmethod + CRayStatus InvalidArgument(const c_string &msg) + + @staticmethod + CRayStatus Interrupted(const c_string &msg) + + @staticmethod + CRayStatus IntentionalSystemExit(const c_string &msg) + + @staticmethod + CRayStatus UnexpectedSystemExit(const c_string &msg) + + @staticmethod + CRayStatus CreationTaskError(const c_string &msg) + + @staticmethod + CRayStatus NotFound() + + @staticmethod + CRayStatus ObjectRefEndOfStream() + + c_bool ok() + c_bool IsOutOfMemory() + c_bool IsKeyError() + c_bool IsInvalid() + c_bool IsIOError() + c_bool IsTypeError() + c_bool IsUnknownError() + c_bool IsNotImplemented() + c_bool IsObjectStoreFull() + c_bool IsAlreadyExists() + c_bool IsOutOfDisk() + c_bool IsRedisError() + c_bool IsTimedOut() + c_bool IsInvalidArgument() + c_bool IsInterrupted() + c_bool ShouldExitWorker() + c_bool IsObjectNotFound() + c_bool IsNotFound() + c_bool IsObjectUnknownOwner() + c_bool IsRpcError() + c_bool IsOutOfResource() + c_bool IsObjectRefEndOfStream() + c_bool IsIntentionalSystemExit() + c_bool IsUnexpectedSystemExit() + c_bool IsChannelError() + c_bool IsChannelTimeoutError() + + c_string ToString() + c_string CodeAsString() + CStatusCode code() + c_string message() + int rpc_code() + + # We can later add more of the common status factory methods as needed + cdef CRayStatus RayStatus_OK "Status::OK"() + cdef CRayStatus RayStatus_Invalid "Status::Invalid"() + cdef CRayStatus RayStatus_NotImplemented "Status::NotImplemented"() + + +cdef extern from "ray/common/id.h" namespace "ray" nogil: + const CTaskID GenerateTaskId(const CJobID &job_id, + const CTaskID &parent_task_id, + int parent_task_counter) + + +cdef extern from "src/ray/protobuf/common.pb.h" nogil: + cdef cppclass CLanguage "Language": + pass + cdef cppclass CWorkerType "ray::core::WorkerType": + pass + cdef cppclass CWorkerExitType "ray::rpc::WorkerExitType": + pass + cdef cppclass CTaskType "ray::TaskType": + pass + cdef cppclass CPlacementStrategy "ray::core::PlacementStrategy": + pass + cdef cppclass CDefaultSchedulingStrategy "ray::rpc::DefaultSchedulingStrategy": # noqa: E501 + CDefaultSchedulingStrategy() + cdef cppclass CSpreadSchedulingStrategy "ray::rpc::SpreadSchedulingStrategy": # noqa: E501 + CSpreadSchedulingStrategy() + cdef cppclass CPlacementGroupSchedulingStrategy "ray::rpc::PlacementGroupSchedulingStrategy": # noqa: E501 + CPlacementGroupSchedulingStrategy() + void set_placement_group_id(const c_string& placement_group_id) + void set_placement_group_bundle_index(int64_t placement_group_bundle_index) # noqa: E501 + void set_placement_group_capture_child_tasks(c_bool placement_group_capture_child_tasks) # noqa: E501 + cdef cppclass CNodeAffinitySchedulingStrategy "ray::rpc::NodeAffinitySchedulingStrategy": # noqa: E501 + CNodeAffinitySchedulingStrategy() + void set_node_id(const c_string& node_id) + void set_soft(c_bool soft) + void set_spill_on_unavailable(c_bool spill_on_unavailable) + void set_fail_on_unavailable(c_bool fail_on_unavailable) + cdef cppclass CSchedulingStrategy "ray::rpc::SchedulingStrategy": + CSchedulingStrategy() + void clear_scheduling_strategy() + CSpreadSchedulingStrategy* mutable_spread_scheduling_strategy() + CDefaultSchedulingStrategy* mutable_default_scheduling_strategy() + CPlacementGroupSchedulingStrategy* mutable_placement_group_scheduling_strategy() # noqa: E501 + CNodeAffinitySchedulingStrategy* mutable_node_affinity_scheduling_strategy() + CNodeLabelSchedulingStrategy* mutable_node_label_scheduling_strategy() + cdef cppclass CAddress "ray::rpc::Address": + CAddress() + const c_string &SerializeAsString() const + void ParseFromString(const c_string &serialized) + void CopyFrom(const CAddress& address) + const c_string &worker_id() + cdef cppclass CObjectReference "ray::rpc::ObjectReference": + CObjectReference() + CAddress owner_address() const + const c_string &object_id() const + const c_string &call_site() const + cdef cppclass CNodeLabelSchedulingStrategy "ray::rpc::NodeLabelSchedulingStrategy": # noqa: E501 + CNodeLabelSchedulingStrategy() + CLabelMatchExpressions* mutable_hard() + CLabelMatchExpressions* mutable_soft() + cdef cppclass CLabelMatchExpressions "ray::rpc::LabelMatchExpressions": # noqa: E501 + CLabelMatchExpressions() + CLabelMatchExpression* add_expressions() + cdef cppclass CLabelMatchExpression "ray::rpc::LabelMatchExpression": # noqa: E501 + CLabelMatchExpression() + void set_key(const c_string &key) + CLabelOperator* mutable_operator_() + cdef cppclass CLabelIn "ray::rpc::LabelIn": # noqa: E501 + CLabelIn() + void add_values(const c_string &value) + cdef cppclass CLabelNotIn "ray::rpc::LabelNotIn": # noqa: E501 + CLabelNotIn() + void add_values(const c_string &value) + cdef cppclass CLabelExists "ray::rpc::LabelExists": # noqa: E501 + CLabelExists() + cdef cppclass CLabelDoesNotExist "ray::rpc::LabelDoesNotExist": # noqa: E501 + CLabelDoesNotExist() + cdef cppclass CLabelNotIn "ray::rpc::LabelNotIn": # noqa: E501 + CLabelNotIn() + void add_values(const c_string &value) + cdef cppclass CLabelOperator "ray::rpc::LabelOperator": # noqa: E501 + CLabelOperator() + CLabelIn* mutable_label_in() + CLabelNotIn* mutable_label_not_in() + CLabelExists* mutable_label_exists() + CLabelDoesNotExist* mutable_label_does_not_exist() + cdef cppclass CLineageReconstructionTask "ray::rpc::LineageReconstructionTask": + CLineageReconstructionTask() + const c_string &SerializeAsString() const + + +# This is a workaround for C++ enum class since Cython has no corresponding +# representation. +cdef extern from "src/ray/protobuf/common.pb.h" nogil: + cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON" + cdef CLanguage LANGUAGE_CPP "Language::CPP" + cdef CLanguage LANGUAGE_JAVA "Language::JAVA" + +cdef extern from "src/ray/protobuf/common.pb.h" nogil: + cdef CWorkerType WORKER_TYPE_WORKER "ray::core::WorkerType::WORKER" + cdef CWorkerType WORKER_TYPE_DRIVER "ray::core::WorkerType::DRIVER" + cdef CWorkerType WORKER_TYPE_SPILL_WORKER "ray::core::WorkerType::SPILL_WORKER" # noqa: E501 + cdef CWorkerType WORKER_TYPE_RESTORE_WORKER "ray::core::WorkerType::RESTORE_WORKER" # noqa: E501 + cdef CWorkerType WORKER_TYPE_UTIL_WORKER "ray::core::WorkerType::UTIL_WORKER" # noqa: E501 + cdef CWorkerExitType WORKER_EXIT_TYPE_USER_ERROR "ray::rpc::WorkerExitType::USER_ERROR" # noqa: E501 + cdef CWorkerExitType WORKER_EXIT_TYPE_SYSTEM_ERROR "ray::rpc::WorkerExitType::SYSTEM_ERROR" # noqa: E501 + cdef CWorkerExitType WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM_ERROR "ray::rpc::WorkerExitType::INTENDED_SYSTEM_EXIT" # noqa: E501 + +cdef extern from "src/ray/protobuf/common.pb.h" nogil: + cdef CTaskType TASK_TYPE_NORMAL_TASK "ray::TaskType::NORMAL_TASK" + cdef CTaskType TASK_TYPE_ACTOR_CREATION_TASK "ray::TaskType::ACTOR_CREATION_TASK" # noqa: E501 + cdef CTaskType TASK_TYPE_ACTOR_TASK "ray::TaskType::ACTOR_TASK" + +cdef extern from "src/ray/protobuf/common.pb.h" nogil: + cdef CPlacementStrategy PLACEMENT_STRATEGY_PACK \ + "ray::core::PlacementStrategy::PACK" + cdef CPlacementStrategy PLACEMENT_STRATEGY_SPREAD \ + "ray::core::PlacementStrategy::SPREAD" + cdef CPlacementStrategy PLACEMENT_STRATEGY_STRICT_PACK \ + "ray::core::PlacementStrategy::STRICT_PACK" + cdef CPlacementStrategy PLACEMENT_STRATEGY_STRICT_SPREAD \ + "ray::core::PlacementStrategy::STRICT_SPREAD" + +cdef extern from "ray/common/buffer.h" namespace "ray" nogil: + cdef cppclass CBuffer "ray::Buffer": + uint8_t *Data() const + size_t Size() const + c_bool IsPlasmaBuffer() const + + cdef cppclass LocalMemoryBuffer(CBuffer): + LocalMemoryBuffer(uint8_t *data, size_t size, c_bool copy_data) + LocalMemoryBuffer(size_t size) + + cdef cppclass SharedMemoryBuffer(CBuffer): + SharedMemoryBuffer( + const shared_ptr[CBuffer] &buffer, + int64_t offset, + int64_t size) + c_bool IsPlasmaBuffer() const + +cdef extern from "ray/common/ray_object.h" nogil: + cdef cppclass CRayObject "ray::RayObject": + CRayObject(const shared_ptr[CBuffer] &data, + const shared_ptr[CBuffer] &metadata, + const c_vector[CObjectReference] &nested_refs) + c_bool HasData() const + c_bool HasMetadata() const + const size_t DataSize() const + const shared_ptr[CBuffer] &GetData() + const shared_ptr[CBuffer] &GetMetadata() const + c_bool IsInPlasmaError() const + +cdef extern from "ray/core_worker/common.h" nogil: + cdef cppclass CRayFunction "ray::core::RayFunction": + CRayFunction() + CRayFunction(CLanguage language, + const CFunctionDescriptor &function_descriptor) + CLanguage GetLanguage() + const CFunctionDescriptor GetFunctionDescriptor() + + cdef cppclass CTaskArg "ray::TaskArg": + pass + + cdef cppclass CTaskArgByReference "ray::TaskArgByReference": + CTaskArgByReference(const CObjectID &object_id, + const CAddress &owner_address, + const c_string &call_site) + + cdef cppclass CTaskArgByValue "ray::TaskArgByValue": + CTaskArgByValue(const shared_ptr[CRayObject] &data) + + cdef cppclass CTaskOptions "ray::core::TaskOptions": + CTaskOptions() + CTaskOptions(c_string name, int num_returns, + unordered_map[c_string, double] &resources, + c_string concurrency_group_name, + int64_t generator_backpressure_num_objects) + CTaskOptions(c_string name, int num_returns, + unordered_map[c_string, double] &resources, + c_string concurrency_group_name, + int64_t generator_backpressure_num_objects, + c_string serialized_runtime_env) + CTaskOptions(c_string name, int num_returns, + unordered_map[c_string, double] &resources, + c_string concurrency_group_name, + int64_t generator_backpressure_num_objects, + c_string serialized_runtime_env, c_bool enable_task_events, + const unordered_map[c_string, c_string] &labels) + + cdef cppclass CActorCreationOptions "ray::core::ActorCreationOptions": + CActorCreationOptions() + CActorCreationOptions( + int64_t max_restarts, + int64_t max_task_retries, + int32_t max_concurrency, + const unordered_map[c_string, double] &resources, + const unordered_map[c_string, double] &placement_resources, + const c_vector[c_string] &dynamic_worker_options, + optional[c_bool] is_detached, c_string &name, c_string &ray_namespace, + c_bool is_asyncio, + const CSchedulingStrategy &scheduling_strategy, + c_string serialized_runtime_env, + const c_vector[CConcurrencyGroup] &concurrency_groups, + c_bool execute_out_of_order, + int32_t max_pending_calls, + c_bool enable_task_events, + const unordered_map[c_string, c_string] &labels) + + cdef cppclass CPlacementGroupCreationOptions \ + "ray::core::PlacementGroupCreationOptions": + CPlacementGroupCreationOptions() + CPlacementGroupCreationOptions( + const c_string &name, + CPlacementStrategy strategy, + const c_vector[unordered_map[c_string, double]] &bundles, + c_bool is_detached, + double max_cpu_fraction_per_node, + CNodeID soft_target_node_id, + ) + + cdef cppclass CObjectLocation "ray::core::ObjectLocation": + const CNodeID &GetPrimaryNodeID() const + const int64_t GetObjectSize() const + const c_vector[CNodeID] &GetNodeIDs() const + c_bool IsSpilled() const + const c_string &GetSpilledURL() const + const CNodeID &GetSpilledNodeID() const + const c_bool GetDidSpill() const + +cdef extern from "ray/gcs/gcs_client/python_callbacks.h" namespace "ray::gcs": + cdef cppclass MultiItemPyCallback[T]: + MultiItemPyCallback( + object (*)(CRayStatus, c_vector[T] &&) nogil, + void (object, object) nogil, + object) nogil + + cdef cppclass OptionalItemPyCallback[T]: + OptionalItemPyCallback( + object (*)(CRayStatus, const optional[T]&) nogil, + void (object, object) nogil, + object) nogil + + cdef cppclass StatusPyCallback: + StatusPyCallback( + object (*)(CRayStatus) nogil, + void (object, object) nogil, + object) nogil + +cdef extern from "ray/gcs/gcs_client/accessor.h" nogil: + cdef cppclass CActorInfoAccessor "ray::gcs::ActorInfoAccessor": + CRayStatus AsyncGetAllByFilter( + const optional[CActorID] &actor_id, + const optional[CJobID] &job_id, + const optional[c_string] &actor_state_name, + const MultiItemPyCallback[CActorTableData] &callback, + int64_t timeout_ms) + + CRayStatus AsyncKillActor(const CActorID &actor_id, + c_bool force_kill, + c_bool no_restart, + const StatusPyCallback &callback, + int64_t timeout_ms) + + cdef cppclass CJobInfoAccessor "ray::gcs::JobInfoAccessor": + CRayStatus GetAll( + const optional[c_string] &job_or_submission_id, + c_bool skip_submission_job_info_field, + c_bool skip_is_running_tasks_field, + c_vector[CJobTableData] &result, + int64_t timeout_ms) + + CRayStatus AsyncGetAll( + const optional[c_string] &job_or_submission_id, + c_bool skip_submission_job_info_field, + c_bool skip_is_running_tasks_field, + const MultiItemPyCallback[CJobTableData] &callback, + int64_t timeout_ms) + + cdef cppclass CNodeInfoAccessor "ray::gcs::NodeInfoAccessor": + CRayStatus CheckAlive( + const c_vector[c_string] &raylet_addresses, + int64_t timeout_ms, + c_vector[c_bool] &result) + + CRayStatus AsyncCheckAlive( + const c_vector[c_string] &raylet_addresses, + int64_t timeout_ms, + const MultiItemPyCallback[c_bool] &callback) + + CRayStatus DrainNodes( + const c_vector[CNodeID] &node_ids, + int64_t timeout_ms, + c_vector[c_string] &drained_node_ids) + + CRayStatus GetAllNoCache( + int64_t timeout_ms, + c_vector[CGcsNodeInfo] &result) + + CRayStatus AsyncGetAll( + const MultiItemPyCallback[CGcsNodeInfo] &callback, + int64_t timeout_ms, + optional[CNodeID] node_id) + + cdef cppclass CNodeResourceInfoAccessor "ray::gcs::NodeResourceInfoAccessor": + CRayStatus GetAllResourceUsage( + int64_t timeout_ms, + CGetAllResourceUsageReply &serialized_reply) + + cdef cppclass CInternalKVAccessor "ray::gcs::InternalKVAccessor": + CRayStatus Keys( + const c_string &ns, + const c_string &prefix, + int64_t timeout_ms, + c_vector[c_string] &value) + + CRayStatus Put( + const c_string &ns, + const c_string &key, + const c_string &value, + c_bool overwrite, + int64_t timeout_ms, + c_bool &added) + + CRayStatus Get( + const c_string &ns, + const c_string &key, + int64_t timeout_ms, + c_string &value) + + CRayStatus MultiGet( + const c_string &ns, + const c_vector[c_string] &keys, + int64_t timeout_ms, + unordered_map[c_string, c_string] &values) + + CRayStatus Del( + const c_string &ns, + const c_string &key, + c_bool del_by_prefix, + int64_t timeout_ms, + int& num_deleted) + + CRayStatus Exists( + const c_string &ns, + const c_string &key, + int64_t timeout_ms, + c_bool &exists) + + CRayStatus AsyncInternalKVKeys( + const c_string &ns, + const c_string &prefix, + int64_t timeout_ms, + const OptionalItemPyCallback[c_vector[c_string]] &callback) + + CRayStatus AsyncInternalKVGet( + const c_string &ns, + const c_string &key, + int64_t timeout_ms, + const OptionalItemPyCallback[c_string] &callback) + + CRayStatus AsyncInternalKVMultiGet( + const c_string &ns, + const c_vector[c_string] &keys, + int64_t timeout_ms, + const OptionalItemPyCallback[unordered_map[c_string, c_string]] &callback) + + CRayStatus AsyncInternalKVPut( + const c_string &ns, + const c_string &key, + const c_string &value, + c_bool overwrite, + int64_t timeout_ms, + const OptionalItemPyCallback[c_bool] &callback) + + CRayStatus AsyncInternalKVExists( + const c_string &ns, + const c_string &key, + int64_t timeout_ms, + const OptionalItemPyCallback[c_bool] &callback) + + CRayStatus AsyncInternalKVDel( + const c_string &ns, + const c_string &key, + c_bool del_by_prefix, + int64_t timeout_ms, + const OptionalItemPyCallback[int] &callback) + + cdef cppclass CRuntimeEnvAccessor "ray::gcs::RuntimeEnvAccessor": + CRayStatus PinRuntimeEnvUri( + const c_string &uri, + int expiration_s, + int64_t timeout_ms) + + cdef cppclass CAutoscalerStateAccessor "ray::gcs::AutoscalerStateAccessor": + + CRayStatus RequestClusterResourceConstraint( + int64_t timeout_ms, + const c_vector[unordered_map[c_string, double]] &bundles, + const c_vector[int64_t] &count_array + ) + + CRayStatus GetClusterResourceState( + int64_t timeout_ms, + c_string &serialized_reply + ) + + CRayStatus GetClusterStatus( + int64_t timeout_ms, + c_string &serialized_reply + ) + + CRayStatus ReportAutoscalingState( + int64_t timeout_ms, + const c_string &serialized_state + ) + + CRayStatus ReportClusterConfig( + int64_t timeout_ms, + const c_string &serialized_cluster_config + ) + + CRayStatus DrainNode( + const c_string &node_id, + int32_t reason, + const c_string &reason_message, + int64_t deadline_timestamp_ms, + int64_t timeout_ms, + c_bool &is_accepted, + c_string &rejection_reason_message + ) + + +cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil: + cdef enum CGrpcStatusCode "grpc::StatusCode": + UNAVAILABLE "grpc::StatusCode::UNAVAILABLE", + UNKNOWN "grpc::StatusCode::UNKNOWN", + DEADLINE_EXCEEDED "grpc::StatusCode::DEADLINE_EXCEEDED", + RESOURCE_EXHAUSTED "grpc::StatusCode::RESOURCE_EXHAUSTED", + UNIMPLEMENTED "grpc::StatusCode::UNIMPLEMENTED", + + cdef cppclass CGcsClientOptions "ray::gcs::GcsClientOptions": + CGcsClientOptions( + const c_string &gcs_address, int port, CClusterID cluster_id, + c_bool allow_cluster_id_nil, c_bool fetch_cluster_id_if_nil) + + cdef cppclass CGcsClient "ray::gcs::GcsClient": + CGcsClient(const CGcsClientOptions &options) + + c_pair[c_string, int] GetGcsServerAddress() const + CClusterID GetClusterId() const + + CActorInfoAccessor& Actors() + CJobInfoAccessor& Jobs() + CInternalKVAccessor& InternalKV() + CNodeInfoAccessor& Nodes() + CNodeResourceInfoAccessor& NodeResources() + CRuntimeEnvAccessor& RuntimeEnvs() + CAutoscalerStateAccessor& Autoscaler() + + cdef CRayStatus ConnectOnSingletonIoContext(CGcsClient &gcs_client, int timeout_ms) + +cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil: + unordered_map[c_string, double] PythonGetResourcesTotal( + const CGcsNodeInfo& node_info) + +cdef extern from "ray/gcs/pubsub/gcs_pub_sub.h" nogil: + + cdef cppclass CPythonGcsPublisher "ray::gcs::PythonGcsPublisher": + + CPythonGcsPublisher(const c_string& gcs_address) + + CRayStatus Connect() + + CRayStatus PublishError( + const c_string &key_id, const CErrorTableData &data, int64_t num_retries) + + CRayStatus PublishLogs(const c_string &key_id, const CLogBatch &data) + + cdef cppclass CPythonGcsSubscriber "ray::gcs::PythonGcsSubscriber": + + CPythonGcsSubscriber( + const c_string& gcs_address, int gcs_port, CChannelType channel_type, + const c_string& subscriber_id, const c_string& worker_id) + + CRayStatus Subscribe() + + int64_t last_batch_size() + + CRayStatus PollError( + c_string* key_id, int64_t timeout_ms, CErrorTableData* data) + + CRayStatus PollLogs( + c_string* key_id, int64_t timeout_ms, CLogBatch* data) + + CRayStatus PollActor( + c_string* key_id, int64_t timeout_ms, CActorTableData* data) + + CRayStatus Close() + +cdef extern from "ray/gcs/pubsub/gcs_pub_sub.h" namespace "ray::gcs" nogil: + c_vector[c_string] PythonGetLogBatchLines(const CLogBatch& log_batch) + +cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil: + unordered_map[c_string, c_string] PythonGetNodeLabels( + const CGcsNodeInfo& node_info) + +cdef extern from "src/ray/protobuf/gcs.pb.h" nogil: + cdef enum CChannelType "ray::rpc::ChannelType": + RAY_ERROR_INFO_CHANNEL "ray::rpc::ChannelType::RAY_ERROR_INFO_CHANNEL", + RAY_LOG_CHANNEL "ray::rpc::ChannelType::RAY_LOG_CHANNEL", + GCS_ACTOR_CHANNEL "ray::rpc::ChannelType::GCS_ACTOR_CHANNEL", + + cdef cppclass CJobConfig "ray::rpc::JobConfig": + c_string ray_namespace() const + const c_string &SerializeAsString() const + + cdef cppclass CNodeDeathInfo "ray::rpc::NodeDeathInfo": + int reason() const + c_string reason_message() const + + cdef cppclass CGcsNodeInfo "ray::rpc::GcsNodeInfo": + c_string node_id() const + c_string node_name() const + int state() const + c_string node_manager_address() const + c_string node_manager_hostname() const + int node_manager_port() const + int object_manager_port() const + c_string object_store_socket_name() const + c_string raylet_socket_name() const + int metrics_export_port() const + int runtime_env_agent_port() const + CNodeDeathInfo death_info() const + void ParseFromString(const c_string &serialized) + const c_string& SerializeAsString() const + + cdef enum CGcsNodeState "ray::rpc::GcsNodeInfo_GcsNodeState": + ALIVE "ray::rpc::GcsNodeInfo_GcsNodeState_ALIVE", + + cdef cppclass CJobTableData "ray::rpc::JobTableData": + c_string job_id() const + c_bool is_dead() const + CJobConfig config() const + const c_string &SerializeAsString() const + + cdef cppclass CGetAllResourceUsageReply "ray::rpc::GetAllResourceUsageReply": + const c_string& SerializeAsString() const + + cdef cppclass CPythonFunction "ray::rpc::PythonFunction": + void set_key(const c_string &key) + c_string key() const + + cdef cppclass CErrorTableData "ray::rpc::ErrorTableData": + c_string job_id() const + c_string type() const + c_string error_message() const + double timestamp() const + + void set_job_id(const c_string &job_id) + void set_type(const c_string &type) + void set_error_message(const c_string &error_message) + void set_timestamp(double timestamp) + + cdef cppclass CLogBatch "ray::rpc::LogBatch": + c_string ip() const + c_string pid() const + c_string job_id() const + c_bool is_error() const + c_string actor_name() const + c_string task_name() const + + void set_ip(const c_string &ip) + void set_pid(const c_string &pid) + void set_job_id(const c_string &job_id) + void set_is_error(c_bool is_error) + void add_lines(const c_string &line) + void set_actor_name(const c_string &actor_name) + void set_task_name(const c_string &task_name) + + cdef cppclass CActorTableData "ray::rpc::ActorTableData": + CAddress address() const + void ParseFromString(const c_string &serialized) + const c_string &SerializeAsString() const + +cdef extern from "ray/common/task/task_spec.h" nogil: + cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup": + CConcurrencyGroup( + const c_string &name, + uint32_t max_concurrency, + const c_vector[CFunctionDescriptor] &c_fds) + CConcurrencyGroup() + c_string GetName() const + uint32_t GetMaxConcurrency() const + c_vector[CFunctionDescriptor] GetFunctionDescriptors() const + +cdef extern from "ray/common/constants.h" nogil: + cdef const char[] kWorkerSetupHookKeyName + cdef int kResourceUnitScaling + cdef const char[] kImplicitResourcePrefix + cdef int kStreamingGeneratorReturn + cdef const char[] kGcsAutoscalerStateNamespace + cdef const char[] kGcsAutoscalerV2EnabledKey + cdef const char[] kGcsAutoscalerClusterConfigKey diff --git a/.venv/lib/python3.11/site-packages/ray/includes/function_descriptor.pxd b/.venv/lib/python3.11/site-packages/ray/includes/function_descriptor.pxd new file mode 100644 index 0000000000000000000000000000000000000000..5124405772b8ae0382625f95843943861be55bfb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/function_descriptor.pxd @@ -0,0 +1,80 @@ +from libc.stdint cimport uint8_t, uint64_t +from libcpp cimport bool as c_bool +from libcpp.memory cimport unique_ptr, shared_ptr +from libcpp.string cimport string as c_string +from libcpp.unordered_map cimport unordered_map +from libcpp.vector cimport vector as c_vector + +from ray.includes.common cimport ( + CLanguage, +) +from ray.includes.unique_ids cimport ( + CActorID, + CJobID, + CObjectID, + CTaskID, +) + +cdef extern from "src/ray/protobuf/common.pb.h" nogil: + cdef cppclass CFunctionDescriptorType \ + "ray::FunctionDescriptorType": + pass + + cdef CFunctionDescriptorType EmptyFunctionDescriptorType \ + "ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET" + cdef CFunctionDescriptorType JavaFunctionDescriptorType \ + "ray::FunctionDescriptorType::kJavaFunctionDescriptor" + cdef CFunctionDescriptorType PythonFunctionDescriptorType \ + "ray::FunctionDescriptorType::kPythonFunctionDescriptor" + cdef CFunctionDescriptorType CppFunctionDescriptorType \ + "ray::FunctionDescriptorType::kCppFunctionDescriptor" + + +cdef extern from "ray/common/function_descriptor.h" nogil: + cdef cppclass CFunctionDescriptorInterface \ + "ray::FunctionDescriptorInterface": + CFunctionDescriptorType Type() + c_string ToString() + c_string Serialize() + + ctypedef shared_ptr[CFunctionDescriptorInterface] CFunctionDescriptor \ + "ray::FunctionDescriptor" + + cdef cppclass CFunctionDescriptorBuilder "ray::FunctionDescriptorBuilder": + @staticmethod + CFunctionDescriptor Empty() + + @staticmethod + CFunctionDescriptor BuildJava(const c_string &class_name, + const c_string &function_name, + const c_string &signature) + + @staticmethod + CFunctionDescriptor BuildPython(const c_string &module_name, + const c_string &class_name, + const c_string &function_name, + const c_string &function_source_hash) + + @staticmethod + CFunctionDescriptor BuildCpp(const c_string &function_name, + const c_string &caller, + const c_string &class_name) + + @staticmethod + CFunctionDescriptor Deserialize(const c_string &serialized_binary) + + cdef cppclass CJavaFunctionDescriptor "ray::JavaFunctionDescriptor": + c_string ClassName() + c_string FunctionName() + c_string Signature() + + cdef cppclass CPythonFunctionDescriptor "ray::PythonFunctionDescriptor": + c_string ModuleName() + c_string ClassName() + c_string FunctionName() + c_string FunctionHash() + + cdef cppclass CCppFunctionDescriptor "ray::CppFunctionDescriptor": + c_string FunctionName() + c_string Caller() + c_string ClassName() diff --git a/.venv/lib/python3.11/site-packages/ray/includes/global_state_accessor.pxd b/.venv/lib/python3.11/site-packages/ray/includes/global_state_accessor.pxd new file mode 100644 index 0000000000000000000000000000000000000000..f1874ffeee0e9aaa939921419812a92e60373675 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/global_state_accessor.pxd @@ -0,0 +1,144 @@ +from libcpp.string cimport string as c_string +from libcpp cimport bool as c_bool +from libcpp.vector cimport vector as c_vector +from libcpp.unordered_map cimport unordered_map +from libcpp.memory cimport unique_ptr +from libc.stdint cimport ( + int32_t as c_int32_t, + uint32_t as c_uint32_t, + int64_t as c_int64_t, +) +from ray.includes.unique_ids cimport ( + CActorID, + CJobID, + CNodeID, + CObjectID, + CWorkerID, + CPlacementGroupID, +) +from ray.includes.common cimport ( + CRayStatus, + CGcsClientOptions, +) +from ray.includes.optional cimport ( + optional +) + +cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: + cdef cppclass CGlobalStateAccessor "ray::gcs::GlobalStateAccessor": + CGlobalStateAccessor(const CGcsClientOptions&) + c_bool Connect() + void Disconnect() + c_vector[c_string] GetAllJobInfo( + c_bool skip_submission_job_info_field, c_bool skip_is_running_tasks_field) + CJobID GetNextJobID() + c_vector[c_string] GetAllNodeInfo() + c_vector[c_string] GetAllAvailableResources() + c_vector[c_string] GetAllTotalResources() + unordered_map[CNodeID, c_int64_t] GetDrainingNodes() + unique_ptr[c_string] GetInternalKV( + const c_string &namespace, const c_string &key) + c_vector[c_string] GetAllTaskEvents() + unique_ptr[c_string] GetObjectInfo(const CObjectID &object_id) + unique_ptr[c_string] GetAllResourceUsage() + c_vector[c_string] GetAllActorInfo( + optional[CActorID], optional[CJobID], optional[c_string]) + unique_ptr[c_string] GetActorInfo(const CActorID &actor_id) + unique_ptr[c_string] GetWorkerInfo(const CWorkerID &worker_id) + c_vector[c_string] GetAllWorkerInfo() + c_bool AddWorkerInfo(const c_string &serialized_string) + c_bool UpdateWorkerDebuggerPort(const CWorkerID &worker_id, + const c_uint32_t debuger_port) + c_bool UpdateWorkerNumPausedThreads(const CWorkerID &worker_id, + const c_int32_t num_paused_threads_delta) + c_uint32_t GetWorkerDebuggerPort(const CWorkerID &worker_id) + unique_ptr[c_string] GetPlacementGroupInfo( + const CPlacementGroupID &placement_group_id) + unique_ptr[c_string] GetPlacementGroupByName( + const c_string &placement_group_name, + const c_string &ray_namespace, + ) + c_vector[c_string] GetAllPlacementGroupInfo() + c_string GetSystemConfig() + CRayStatus GetNodeToConnectForDriver( + const c_string &node_ip_address, + c_string *node_to_connect) + CRayStatus GetNode( + const c_string &node_id_hex_str, + c_string *node_info) + +cdef extern from * namespace "ray::gcs" nogil: + """ + #include + #include "ray/gcs/gcs_server/store_client_kv.h" + namespace ray { + namespace gcs { + + bool RedisGetKeySync(const std::string& host, + int32_t port, + const std::string& username, + const std::string& password, + bool use_ssl, + const std::string& config, + const std::string& key, + std::string* data) { + // Logging default value see class `RayLog`. + InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, + ray::RayLog::ShutDownRayLog, + "ray_init", + ray::RayLogLevel::WARNING, + /*log_filepath=*/"", + /*log_rotation_max_size=*/1ULL << 29, + /*log_rotation_file_num=*/10); + + RedisClientOptions options(host, port, username, password, use_ssl); + + std::string config_list; + RAY_CHECK(absl::Base64Unescape(config, &config_list)); + RayConfig::instance().initialize(config_list); + + instrumented_io_context io_service; + + auto redis_client = std::make_shared(options); + auto status = redis_client->Connect(io_service); + RAY_CHECK_OK(status) << "Failed to connect to redis."; + + auto cli = std::make_unique( + std::make_unique(std::move(redis_client))); + + bool ret_val = false; + cli->Get("session", key, {[&](std::optional result) { + if (result.has_value()) { + *data = result.value(); + ret_val = true; + } else { + RAY_LOG(INFO) << "Failed to retrieve the key " << key + << " from persistent storage."; + ret_val = false; + } + }, io_service}); + io_service.run_for(std::chrono::milliseconds(1000)); + + return ret_val; + } + + } + } + """ + c_bool RedisGetKeySync(const c_string& host, + c_int32_t port, + const c_string& username, + const c_string& password, + c_bool use_ssl, + const c_string& config, + const c_string& key, + c_string* data) + + +cdef extern from * namespace "ray::gcs" nogil: + c_bool RedisDelKeyPrefixSync(const c_string& host, + c_int32_t port, + const c_string& username, + const c_string& password, + c_bool use_ssl, + const c_string& key_prefix) diff --git a/.venv/lib/python3.11/site-packages/ray/includes/libcoreworker.pxd b/.venv/lib/python3.11/site-packages/ray/includes/libcoreworker.pxd new file mode 100644 index 0000000000000000000000000000000000000000..5f97fa67f5ae7d85d93c883ff4c9e51707f28742 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/libcoreworker.pxd @@ -0,0 +1,457 @@ +# cython: profile = False +# distutils: language = c++ +# cython: embedsignature = True + +from libc.stdint cimport int64_t, uint64_t +from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr, unique_ptr +from libcpp.pair cimport pair as c_pair +from libcpp.string cimport string as c_string +from libcpp.unordered_map cimport unordered_map +from libcpp.utility cimport pair +from libcpp.vector cimport vector as c_vector + +from ray.includes.unique_ids cimport ( + CActorID, + CClusterID, + CNodeID, + CJobID, + CTaskID, + CObjectID, + CPlacementGroupID, + CWorkerID, + ObjectIDIndexType, +) + +from ray.includes.common cimport ( + CAddress, + CObjectReference, + CActorCreationOptions, + CBuffer, + CPlacementGroupCreationOptions, + CObjectLocation, + CObjectReference, + CRayFunction, + CRayObject, + CRayStatus, + CTaskArg, + CTaskOptions, + CTaskType, + CWorkerType, + CLanguage, + CGcsClientOptions, + LocalMemoryBuffer, + CJobConfig, + CConcurrencyGroup, + CSchedulingStrategy, + CWorkerExitType, + CLineageReconstructionTask, +) +from ray.includes.function_descriptor cimport ( + CFunctionDescriptor, +) + +from ray.includes.optional cimport ( + optional, +) + +ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \ + ResourceMappingType + +ctypedef void (*ray_callback_function) \ + (shared_ptr[CRayObject] result_object, + CObjectID object_id, void* user_data) + +ctypedef void (*plasma_callback_function) \ + (CObjectID object_id, int64_t data_size, int64_t metadata_size) + +# NOTE: This ctypedef is needed, because Cython doesn't compile +# "pair[shared_ptr[const CActorHandle], CRayStatus]". +# This is a bug of cython: https://github.com/cython/cython/issues/3967. +ctypedef shared_ptr[const CActorHandle] ActorHandleSharedPtr + + +cdef extern from "ray/core_worker/profile_event.h" nogil: + cdef cppclass CProfileEvent "ray::core::worker::ProfileEvent": + void SetExtraData(const c_string &extra_data) + +cdef extern from "ray/core_worker/fiber.h" nogil: + cdef cppclass CFiberEvent "ray::core::FiberEvent": + CFiberEvent() + void Wait() + void Notify() + +cdef extern from "ray/core_worker/experimental_mutable_object_manager.h" nogil: + cdef cppclass CReaderRefInfo "ray::experimental::ReaderRefInfo": + CReaderRefInfo() + CObjectID reader_ref_id + CActorID owner_reader_actor_id + int64_t num_reader_actors + + +cdef extern from "ray/core_worker/context.h" nogil: + cdef cppclass CWorkerContext "ray::core::WorkerContext": + c_bool CurrentActorIsAsync() + const c_string &GetCurrentSerializedRuntimeEnv() + int CurrentActorMaxConcurrency() + const CActorID &GetRootDetachedActorID() + +cdef extern from "ray/core_worker/generator_waiter.h" nogil: + cdef cppclass CGeneratorBackpressureWaiter "ray::core::GeneratorBackpressureWaiter": # noqa + CGeneratorBackpressureWaiter( + int64_t generator_backpressure_num_objects, + (CRayStatus() nogil) check_signals) + CRayStatus WaitAllObjectsReported() + +cdef extern from "ray/core_worker/core_worker.h" nogil: + cdef cppclass CActorHandle "ray::core::ActorHandle": + CActorID GetActorID() const + CJobID CreationJobID() const + CLanguage ActorLanguage() const + CFunctionDescriptor ActorCreationTaskFunctionDescriptor() const + c_string ExtensionData() const + int MaxPendingCalls() const + int MaxTaskRetries() const + c_bool EnableTaskEvents() const + + cdef cppclass CCoreWorker "ray::core::CoreWorker": + CWorkerType GetWorkerType() + CLanguage GetLanguage() + + c_vector[CObjectReference] SubmitTask( + const CRayFunction &function, + const c_vector[unique_ptr[CTaskArg]] &args, + const CTaskOptions &options, + int max_retries, + c_bool retry_exceptions, + const CSchedulingStrategy &scheduling_strategy, + c_string debugger_breakpoint, + c_string serialized_retry_exception_allowlist, + c_string call_site, + const CTaskID current_task_id) + CRayStatus CreateActor( + const CRayFunction &function, + const c_vector[unique_ptr[CTaskArg]] &args, + const CActorCreationOptions &options, + const c_string &extension_data, + c_string call_site, + CActorID *actor_id) + CRayStatus CreatePlacementGroup( + const CPlacementGroupCreationOptions &options, + CPlacementGroupID *placement_group_id) + CRayStatus RemovePlacementGroup( + const CPlacementGroupID &placement_group_id) + CRayStatus WaitPlacementGroupReady( + const CPlacementGroupID &placement_group_id, int64_t timeout_seconds) + CRayStatus SubmitActorTask( + const CActorID &actor_id, const CRayFunction &function, + const c_vector[unique_ptr[CTaskArg]] &args, + const CTaskOptions &options, + int max_retries, + c_bool retry_exceptions, + c_string serialized_retry_exception_allowlist, + c_string call_site, + c_vector[CObjectReference] &task_returns, + const CTaskID current_task_id) + CRayStatus KillActor( + const CActorID &actor_id, c_bool force_kill, + c_bool no_restart) + CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill, + c_bool recursive) + + unique_ptr[CProfileEvent] CreateProfileEvent( + const c_string &event_type) + CRayStatus AllocateReturnObject( + const CObjectID &object_id, + const size_t &data_size, + const shared_ptr[CBuffer] &metadata, + const c_vector[CObjectID] &contained_object_id, + const CAddress &caller_address, + int64_t *task_output_inlined_bytes, + shared_ptr[CRayObject] *return_object) + CRayStatus SealReturnObject( + const CObjectID &return_id, + const shared_ptr[CRayObject] &return_object, + const CObjectID &generator_id, + const CAddress &caller_address + ) + c_bool PinExistingReturnObject( + const CObjectID &return_id, + shared_ptr[CRayObject] *return_object, + const CObjectID &generator_id, + const CAddress &caller_address) + void AsyncDelObjectRefStream(const CObjectID &generator_id) + CRayStatus TryReadObjectRefStream( + const CObjectID &generator_id, + CObjectReference *object_ref_out) + c_bool StreamingGeneratorIsFinished(const CObjectID &generator_id) const + pair[CObjectReference, c_bool] PeekObjectRefStream( + const CObjectID &generator_id) + CObjectID AllocateDynamicReturnId( + const CAddress &owner_address, + const CTaskID &task_id, + optional[ObjectIDIndexType] put_index) + + CJobID GetCurrentJobId() + CTaskID GetCurrentTaskId() + const c_string GetCurrentTaskName() + const c_string GetCurrentTaskFunctionName() + void UpdateTaskIsDebuggerPaused( + const CTaskID &task_id, + const c_bool is_debugger_paused) + int64_t GetCurrentTaskAttemptNumber() + CNodeID GetCurrentNodeId() + int64_t GetTaskDepth() + c_bool GetCurrentTaskRetryExceptions() + CPlacementGroupID GetCurrentPlacementGroupId() + CWorkerID GetWorkerID() + c_bool ShouldCaptureChildTasksInPlacementGroup() + const CActorID &GetActorId() + const c_string GetActorName() + void SetActorTitle(const c_string &title) + void SetActorReprName(const c_string &repr_name) + void SetWebuiDisplay(const c_string &key, const c_string &message) + CTaskID GetCallerId() + const ResourceMappingType &GetResourceIDs() const + void RemoveActorHandleReference(const CActorID &actor_id) + optional[int] GetLocalActorState(const CActorID &actor_id) const + CActorID DeserializeAndRegisterActorHandle(const c_string &bytes, const + CObjectID &outer_object_id, + c_bool add_local_ref) + CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string + *bytes, + CObjectID *c_actor_handle_id) + ActorHandleSharedPtr GetActorHandle(const CActorID &actor_id) const + pair[ActorHandleSharedPtr, CRayStatus] GetNamedActorHandle( + const c_string &name, const c_string &ray_namespace) + pair[c_vector[c_pair[c_string, c_string]], CRayStatus] ListNamedActors( + c_bool all_namespaces) + void AddLocalReference(const CObjectID &object_id) + void RemoveLocalReference(const CObjectID &object_id) + void PutObjectIntoPlasma(const CRayObject &object, + const CObjectID &object_id) + const CAddress &GetRpcAddress() const + CRayStatus GetOwnerAddress(const CObjectID &object_id, + CAddress *owner_address) const + c_vector[CObjectReference] GetObjectRefs( + const c_vector[CObjectID] &object_ids) const + + CRayStatus GetOwnershipInfo(const CObjectID &object_id, + CAddress *owner_address, + c_string *object_status) + void RegisterOwnershipInfoAndResolveFuture( + const CObjectID &object_id, + const CObjectID &outer_object_id, + const CAddress &owner_address, + const c_string &object_status) + + CRayStatus Put(const CRayObject &object, + const c_vector[CObjectID] &contained_object_ids, + CObjectID *object_id) + CRayStatus Put(const CRayObject &object, + const c_vector[CObjectID] &contained_object_ids, + const CObjectID &object_id) + CRayStatus CreateOwnedAndIncrementLocalRef( + c_bool is_mutable, + const shared_ptr[CBuffer] &metadata, + const size_t data_size, + const c_vector[CObjectID] &contained_object_ids, + CObjectID *object_id, shared_ptr[CBuffer] *data, + c_bool created_by_worker, + const unique_ptr[CAddress] &owner_address, + c_bool inline_small_object) + CRayStatus CreateExisting(const shared_ptr[CBuffer] &metadata, + const size_t data_size, + const CObjectID &object_id, + const CAddress &owner_address, + shared_ptr[CBuffer] *data, + c_bool created_by_worker) + CRayStatus ExperimentalChannelWriteAcquire( + const CObjectID &object_id, + const shared_ptr[CBuffer] &metadata, + uint64_t data_size, + int64_t num_readers, + int64_t timeout_ms, + shared_ptr[CBuffer] *data) + CRayStatus ExperimentalChannelWriteRelease( + const CObjectID &object_id) + CRayStatus ExperimentalChannelSetError( + const CObjectID &object_id) + CRayStatus ExperimentalRegisterMutableObjectWriter( + const CObjectID &writer_object_id, + const c_vector[CNodeID] &remote_reader_node_ids) + CRayStatus ExperimentalRegisterMutableObjectReader(const CObjectID &object_id) + CRayStatus ExperimentalRegisterMutableObjectReaderRemote( + const CObjectID &object_id, + const c_vector[CReaderRefInfo] &remote_reader_ref_info) + CRayStatus SealOwned(const CObjectID &object_id, c_bool pin_object, + const unique_ptr[CAddress] &owner_address) + CRayStatus SealExisting(const CObjectID &object_id, c_bool pin_object, + const CObjectID &generator_id, + const unique_ptr[CAddress] &owner_address) + CRayStatus Get(const c_vector[CObjectID] &ids, int64_t timeout_ms, + c_vector[shared_ptr[CRayObject]] results) + CRayStatus GetIfLocal( + const c_vector[CObjectID] &ids, + c_vector[shared_ptr[CRayObject]] *results) + CRayStatus Contains(const CObjectID &object_id, c_bool *has_object, + c_bool *is_in_plasma) + CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects, + int64_t timeout_ms, c_vector[c_bool] *results, + c_bool fetch_local) + CRayStatus Delete(const c_vector[CObjectID] &object_ids, + c_bool local_only) + CRayStatus GetLocalObjectLocations( + const c_vector[CObjectID] &object_ids, + c_vector[optional[CObjectLocation]] *results) + CRayStatus GetLocationFromOwner( + const c_vector[CObjectID] &object_ids, + int64_t timeout_ms, + c_vector[shared_ptr[CObjectLocation]] *results) + CRayStatus TriggerGlobalGC() + CRayStatus ReportGeneratorItemReturns( + const pair[CObjectID, shared_ptr[CRayObject]] &dynamic_return_object, + const CObjectID &generator_id, + const CAddress &caller_address, + int64_t item_index, + uint64_t attempt_number, + shared_ptr[CGeneratorBackpressureWaiter] waiter) + c_string MemoryUsageString() + int GetMemoryStoreSize() + + CWorkerContext &GetWorkerContext() + void YieldCurrentFiber(CFiberEvent &coroutine_done) + + unordered_map[CObjectID, pair[size_t, size_t]] GetAllReferenceCounts() + c_vector[CTaskID] GetPendingChildrenTasks(const CTaskID &task_id) const + + void GetAsync(const CObjectID &object_id, + ray_callback_function success_callback, + void* python_user_callback) + + CRayStatus PushError(const CJobID &job_id, const c_string &type, + const c_string &error_message, double timestamp) + CRayStatus SetResource(const c_string &resource_name, + const double capacity, + const CNodeID &client_Id) + + CJobConfig GetJobConfig() + + int64_t GetNumTasksSubmitted() const + + int64_t GetNumLeasesRequested() const + + int64_t GetLocalMemoryStoreBytesUsed() const + + void RecordTaskLogStart( + const CTaskID &task_id, + int attempt_number, + const c_string& stdout_path, + const c_string& stderr_path, + int64_t stdout_start_offset, + int64_t stderr_start_offset) const + + void RecordTaskLogEnd( + const CTaskID &task_id, + int attempt_number, + int64_t stdout_end_offset, + int64_t stderr_end_offset) const + + void Exit(const CWorkerExitType exit_type, + const c_string &detail, + const shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) + + unordered_map[CLineageReconstructionTask, uint64_t] \ + GetLocalOngoingLineageReconstructionTasks() const + + cdef cppclass CCoreWorkerOptions "ray::core::CoreWorkerOptions": + CWorkerType worker_type + CLanguage language + c_string store_socket + c_string raylet_socket + CJobID job_id + CGcsClientOptions gcs_options + c_bool enable_logging + c_string log_dir + c_bool install_failure_signal_handler + c_bool interactive + c_string node_ip_address + int node_manager_port + c_string raylet_ip_address + c_string driver_name + c_string stdout_file + c_string stderr_file + (CRayStatus( + const CAddress &caller_address, + CTaskType task_type, + const c_string name, + const CRayFunction &ray_function, + const unordered_map[c_string, double] &resources, + const c_vector[shared_ptr[CRayObject]] &args, + const c_vector[CObjectReference] &arg_refs, + const c_string debugger_breakpoint, + const c_string serialized_retry_exception_allowlist, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *dynamic_returns, + c_vector[c_pair[CObjectID, c_bool]] *streaming_generator_returns, + shared_ptr[LocalMemoryBuffer] + &creation_task_exception_pb_bytes, + c_bool *is_retryable_error, + c_string *application_error, + const c_vector[CConcurrencyGroup] &defined_concurrency_groups, + const c_string name_of_concurrency_group_to_execute, + c_bool is_reattempt, + c_bool is_streaming_generator, + c_bool should_retry_exceptions, + int64_t generator_backpressure_num_objects + ) nogil) task_execution_callback + (void(const CWorkerID &) nogil) on_worker_shutdown + (CRayStatus() nogil) check_signals + (void(c_bool) nogil) gc_collect + (c_vector[c_string]( + const c_vector[CObjectReference] &) nogil) spill_objects + (int64_t( + const c_vector[CObjectReference] &, + const c_vector[c_string] &) nogil) restore_spilled_objects + (void( + const c_vector[c_string]&, + CWorkerType) nogil) delete_spilled_objects + (void( + const c_string&, + const c_vector[c_string]&) nogil) run_on_util_worker_handler + (void(const CRayObject&) nogil) unhandled_exception_handler + (void( + const CTaskID &c_task_id, + const CRayFunction &ray_function, + const c_string c_name_of_concurrency_group_to_execute + ) nogil) cancel_async_task + (void(c_string *stack_out) nogil) get_lang_stack + c_bool is_local_mode + int num_workers + (c_bool(const CTaskID &) nogil) kill_main + CCoreWorkerOptions() + (void() nogil) terminate_asyncio_thread + c_string serialized_job_config + int metrics_agent_port + int runtime_env_hash + int startup_token + CClusterID cluster_id + c_string session_name + c_string entrypoint + int64_t worker_launch_time_ms + int64_t worker_launched_time_ms + + cdef cppclass CCoreWorkerProcess "ray::core::CoreWorkerProcess": + @staticmethod + void Initialize(const CCoreWorkerOptions &options) + # Only call this in CoreWorker.__cinit__, + # use CoreWorker.core_worker to access C++ CoreWorker. + + @staticmethod + CCoreWorker &GetCoreWorker() + + @staticmethod + void Shutdown() + + @staticmethod + void RunTaskExecutionLoop() diff --git a/.venv/lib/python3.11/site-packages/ray/includes/metric.pxd b/.venv/lib/python3.11/site-packages/ray/includes/metric.pxd new file mode 100644 index 0000000000000000000000000000000000000000..32c05aea215160c93cc0f2ef3cdbcf7c6c174d8d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/metric.pxd @@ -0,0 +1,45 @@ +from libcpp.string cimport string as c_string +from libcpp.unordered_map cimport unordered_map +from libcpp.vector cimport vector as c_vector + +cdef extern from "opencensus/tags/tag_key.h" nogil: + cdef cppclass CTagKey "opencensus::tags::TagKey": + @staticmethod + CTagKey Register(c_string &name) + const c_string &name() const + +cdef extern from "ray/stats/metric.h" nogil: + cdef cppclass CMetric "ray::stats::Metric": + CMetric(const c_string &name, + const c_string &description, + const c_string &unit, + const c_vector[c_string] &tag_keys) + c_string GetName() const + void Record(double value) + void Record(double value, + unordered_map[c_string, c_string] &tags) + + cdef cppclass CGauge "ray::stats::Gauge": + CGauge(const c_string &name, + const c_string &description, + const c_string &unit, + const c_vector[c_string] &tag_keys) + + cdef cppclass CCount "ray::stats::Count": + CCount(const c_string &name, + const c_string &description, + const c_string &unit, + const c_vector[c_string] &tag_keys) + + cdef cppclass CSum "ray::stats::Sum": + CSum(const c_string &name, + const c_string &description, + const c_string &unit, + const c_vector[c_string] &tag_keys) + + cdef cppclass CHistogram "ray::stats::Histogram": + CHistogram(const c_string &name, + const c_string &description, + const c_string &unit, + const c_vector[double] &boundaries, + const c_vector[c_string] &tag_keys) diff --git a/.venv/lib/python3.11/site-packages/ray/includes/optional.pxd b/.venv/lib/python3.11/site-packages/ray/includes/optional.pxd new file mode 100644 index 0000000000000000000000000000000000000000..a3539824ae73a16df48b1203fbf8b877d00fad42 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/optional.pxd @@ -0,0 +1,36 @@ +# Currently Cython does not support std::optional. +# See: https://github.com/cython/cython/pull/3294 +from libcpp cimport bool + +cdef extern from "" namespace "std" nogil: + cdef cppclass nullopt_t: + nullopt_t() + + cdef nullopt_t nullopt + + cdef cppclass optional[T]: + ctypedef T value_type + optional() + optional(nullopt_t) + optional(optional&) except + + optional(T&) except + + bool has_value() + T& value() + T& value_or[U](U& default_value) + void swap(optional&) + void reset() + T& emplace(...) + T& operator*() + # T* operator->() # Not Supported + optional& operator=(optional&) + optional& operator=[U](U&) + bool operator bool() + bool operator!() + bool operator==[U](optional&, U&) + bool operator!=[U](optional&, U&) + bool operator<[U](optional&, U&) + bool operator>[U](optional&, U&) + bool operator<=[U](optional&, U&) + bool operator>=[U](optional&, U&) + + optional[T] make_optional[T](...) except + diff --git a/.venv/lib/python3.11/site-packages/ray/includes/ray_config.pxd b/.venv/lib/python3.11/site-packages/ray/includes/ray_config.pxd new file mode 100644 index 0000000000000000000000000000000000000000..7189c2b5bd14e3648c6748433ac71191f2629d8e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/ray_config.pxd @@ -0,0 +1,98 @@ +from libcpp cimport bool as c_bool +from libc.stdint cimport int64_t, uint64_t, uint32_t +from libcpp.string cimport string as c_string +from libcpp.unordered_map cimport unordered_map + + +cdef extern from "ray/common/ray_config.h" nogil: + cdef cppclass RayConfig "RayConfig": + @staticmethod + RayConfig &instance() + + void initialize(const c_string& config_list) + + int64_t ray_cookie() const + + int64_t handler_warning_timeout_ms() const + + int64_t debug_dump_period_milliseconds() const + + int64_t object_timeout_milliseconds() const + + int64_t raylet_client_num_connect_attempts() const + + int64_t raylet_client_connect_timeout_milliseconds() const + + int64_t raylet_fetch_timeout_milliseconds() const + + int64_t kill_worker_timeout_milliseconds() const + + int64_t worker_register_timeout_seconds() const + + int64_t redis_db_connect_retries() + + int64_t redis_db_connect_wait_milliseconds() const + + int object_manager_pull_timeout_ms() const + + int object_manager_push_timeout_ms() const + + uint64_t object_manager_default_chunk_size() const + + uint32_t maximum_gcs_deletion_batch_size() const + + int64_t max_direct_call_object_size() const + + int64_t task_rpc_inlined_bytes_limit() const + + uint64_t metrics_report_interval_ms() const + + c_bool enable_timeline() const + + uint32_t max_grpc_message_size() const + + c_bool record_ref_creation_sites() const + + c_string REDIS_CA_CERT() const + + c_string REDIS_CA_PATH() const + + c_string REDIS_CLIENT_CERT() const + + c_string REDIS_CLIENT_KEY() const + + c_string REDIS_SERVER_NAME() const + + int64_t health_check_initial_delay_ms() const + + int64_t health_check_period_ms() const + + int64_t health_check_timeout_ms() const + + int64_t health_check_failure_threshold() const + + uint64_t memory_monitor_refresh_ms() const + + int64_t grpc_keepalive_time_ms() const + + int64_t grpc_keepalive_timeout_ms() const + + int64_t grpc_client_keepalive_time_ms() const + + int64_t grpc_client_keepalive_timeout_ms() const + + c_bool enable_autoscaler_v2() const + + c_string predefined_unit_instance_resources() const + + c_string custom_unit_instance_resources() const + + int64_t nums_py_gcs_reconnect_retry() const + + int64_t py_gcs_connect_timeout_s() const + + int gcs_rpc_server_reconnect_timeout_s() const + + int maximum_gcs_destroyed_actor_cached_count() const + + c_bool record_task_actor_creation_sites() const diff --git a/.venv/lib/python3.11/site-packages/ray/includes/unique_ids.pxd b/.venv/lib/python3.11/site-packages/ray/includes/unique_ids.pxd new file mode 100644 index 0000000000000000000000000000000000000000..84f511ec5107d3ca950d623862348ef9c504b201 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/includes/unique_ids.pxd @@ -0,0 +1,218 @@ +from libcpp cimport bool as c_bool +from libcpp.string cimport string as c_string +from libc.stdint cimport uint8_t, uint32_t, int64_t + +cdef extern from "ray/common/id.h" namespace "ray" nogil: + cdef cppclass CBaseID[T]: + @staticmethod + T FromBinary(const c_string &binary) + + @staticmethod + T FromHex(const c_string &hex_str) + + @staticmethod + const T Nil() + + @staticmethod + size_t Size() + + size_t Hash() const + c_bool IsNil() const + c_bool operator==(const CBaseID &rhs) const + c_bool operator!=(const CBaseID &rhs) const + const uint8_t *data() const + + c_string Binary() const + c_string Hex() const + + cdef cppclass CUniqueID "ray::UniqueID"(CBaseID): + CUniqueID() + + @staticmethod + size_t Size() + + @staticmethod + CUniqueID FromRandom() + + @staticmethod + CUniqueID FromBinary(const c_string &binary) + + @staticmethod + const CUniqueID Nil() + + @staticmethod + size_t Size() + + cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID): + + @staticmethod + CActorClassID FromBinary(const c_string &binary) + + @staticmethod + CActorClassID FromHex(const c_string &hex_str) + + cdef cppclass CActorID "ray::ActorID"(CBaseID[CActorID]): + + @staticmethod + CActorID FromBinary(const c_string &binary) + + @staticmethod + CActorID FromHex(const c_string &hex_str) + + @staticmethod + const CActorID Nil() + + @staticmethod + size_t Size() + + @staticmethod + CActorID Of(CJobID job_id, CTaskID parent_task_id, + int64_t parent_task_counter) + + CJobID JobId() + + cdef cppclass CNodeID "ray::NodeID"(CUniqueID): + + @staticmethod + CNodeID FromBinary(const c_string &binary) + + @staticmethod + CNodeID FromHex(const c_string &hex_str) + + @staticmethod + const CNodeID Nil() + + cdef cppclass CConfigID "ray::ConfigID"(CUniqueID): + + @staticmethod + CConfigID FromBinary(const c_string &binary) + + cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID): + + @staticmethod + CFunctionID FromBinary(const c_string &binary) + + @staticmethod + CFunctionID FromHex(const c_string &hex_str) + + cdef cppclass CJobID "ray::JobID"(CBaseID[CJobID]): + + @staticmethod + CJobID FromBinary(const c_string &binary) + + @staticmethod + CJobID FromHex(const c_string &hex_str) + + @staticmethod + const CJobID Nil() + + @staticmethod + size_t Size() + + @staticmethod + CJobID FromInt(uint32_t value) + + uint32_t ToInt() + + cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]): + + @staticmethod + CTaskID FromBinary(const c_string &binary) + + @staticmethod + CTaskID FromHex(const c_string &hex_str) + + @staticmethod + const CTaskID Nil() + + @staticmethod + size_t Size() + + @staticmethod + CTaskID ForDriverTask(const CJobID &job_id) + + @staticmethod + CTaskID FromRandom(const CJobID &job_id) + + @staticmethod + CTaskID ForActorCreationTask(CActorID actor_id) + + @staticmethod + CTaskID ForActorTask(CJobID job_id, CTaskID parent_task_id, + int64_t parent_task_counter, CActorID actor_id) + + @staticmethod + CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id, + int64_t parent_task_counter) + + CActorID ActorId() const + + CJobID JobId() const + + cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]): + + @staticmethod + int64_t MaxObjectIndex() + + @staticmethod + CObjectID FromBinary(const c_string &binary) + + @staticmethod + CObjectID FromRandom() + + @staticmethod + const CObjectID Nil() + + @staticmethod + CObjectID FromIndex(const CTaskID &task_id, int64_t index) + + @staticmethod + size_t Size() + + c_bool is_put() + + int64_t ObjectIndex() const + + CTaskID TaskId() const + + cdef cppclass CClusterID "ray::ClusterID"(CUniqueID): + + @staticmethod + CClusterID FromBinary(const c_string &binary) + + @staticmethod + CClusterID FromHex(const c_string &hex_str) + + @staticmethod + CClusterID FromRandom() + + @staticmethod + const CClusterID Nil() + + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): + + @staticmethod + CWorkerID FromBinary(const c_string &binary) + + @staticmethod + CWorkerID FromHex(const c_string &hex_str) + + cdef cppclass CPlacementGroupID "ray::PlacementGroupID" \ + (CBaseID[CPlacementGroupID]): + + @staticmethod + CPlacementGroupID FromBinary(const c_string &binary) + + @staticmethod + CPlacementGroupID FromHex(const c_string &hex_str) + + @staticmethod + const CPlacementGroupID Nil() + + @staticmethod + size_t Size() + + @staticmethod + CPlacementGroupID Of(CJobID job_id) + + ctypedef uint32_t ObjectIDIndexType diff --git a/.venv/lib/python3.11/site-packages/ray/runtime_env/__init__.py b/.venv/lib/python3.11/site-packages/ray/runtime_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f401770d0ef075b7c733116b2e7f5ce14f5825e0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/runtime_env/__init__.py @@ -0,0 +1,8 @@ +from ray._private.runtime_env.mpi import mpi_init # noqa: E402,F401 +from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig # noqa: E402,F401 + +__all__ = [ + "RuntimeEnvConfig", + "RuntimeEnv", + "mpi_init", +] diff --git a/.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a273614886590f735ae1909979c5e4055be58874 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/runtime_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/runtime_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c109071d6514a57428c86e098e2132130886a5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/runtime_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/runtime_env/runtime_env.py b/.venv/lib/python3.11/site-packages/ray/runtime_env/runtime_env.py new file mode 100644 index 0000000000000000000000000000000000000000..d04247b75a7773c42a6d867096baad4949bac6ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/runtime_env/runtime_env.py @@ -0,0 +1,662 @@ +import json +import logging +import os +from copy import deepcopy +from dataclasses import asdict, is_dataclass +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import ray +from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS +from ray._private.runtime_env.conda import get_uri as get_conda_uri +from ray._private.runtime_env.default_impl import get_image_uri_plugin_cls +from ray._private.runtime_env.pip import get_uri as get_pip_uri +from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager +from ray._private.runtime_env.uv import get_uri as get_uv_uri +from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN +from ray._private.thirdparty.dacite import from_dict +from ray.core.generated.runtime_env_common_pb2 import ( + RuntimeEnvConfig as ProtoRuntimeEnvConfig, +) +from ray.util.annotations import PublicAPI + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="stable") +class RuntimeEnvConfig(dict): + """Used to specify configuration options for a runtime environment. + + The config is not included when calculating the runtime_env hash, + which means that two runtime_envs with the same options but different + configs are considered the same for caching purposes. + + Args: + setup_timeout_seconds: The timeout of runtime environment + creation, timeout is in seconds. The value `-1` means disable + timeout logic, except `-1`, `setup_timeout_seconds` cannot be + less than or equal to 0. The default value of `setup_timeout_seconds` + is 600 seconds. + eager_install: Indicates whether to install the runtime environment + on the cluster at `ray.init()` time, before the workers are leased. + This flag is set to `True` by default. + """ + + known_fields: Set[str] = {"setup_timeout_seconds", "eager_install", "log_files"} + + _default_config: Dict = { + "setup_timeout_seconds": DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS, + "eager_install": True, + "log_files": [], + } + + def __init__( + self, + setup_timeout_seconds: int = DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS, + eager_install: bool = True, + log_files: Optional[List[str]] = None, + ): + super().__init__() + if not isinstance(setup_timeout_seconds, int): + raise TypeError( + "setup_timeout_seconds must be of type int, " + f"got: {type(setup_timeout_seconds)}" + ) + elif setup_timeout_seconds <= 0 and setup_timeout_seconds != -1: + raise ValueError( + "setup_timeout_seconds must be greater than zero " + f"or equals to -1, got: {setup_timeout_seconds}" + ) + self["setup_timeout_seconds"] = setup_timeout_seconds + + if not isinstance(eager_install, bool): + raise TypeError( + f"eager_install must be a boolean. got {type(eager_install)}" + ) + self["eager_install"] = eager_install + + if log_files is not None: + if not isinstance(log_files, list): + raise TypeError( + "log_files must be a list of strings or None, got " + f"{log_files} with type {type(log_files)}." + ) + for file_name in log_files: + if not isinstance(file_name, str): + raise TypeError("Each item in log_files must be a string.") + else: + log_files = self._default_config["log_files"] + + self["log_files"] = log_files + + @staticmethod + def parse_and_validate_runtime_env_config( + config: Union[Dict, "RuntimeEnvConfig"] + ) -> "RuntimeEnvConfig": + if isinstance(config, RuntimeEnvConfig): + return config + elif isinstance(config, Dict): + unknown_fields = set(config.keys()) - RuntimeEnvConfig.known_fields + if len(unknown_fields): + logger.warning( + "The following unknown entries in the runtime_env_config " + f"dictionary will be ignored: {unknown_fields}." + ) + config_dict = dict() + for field in RuntimeEnvConfig.known_fields: + if field in config: + config_dict[field] = config[field] + return RuntimeEnvConfig(**config_dict) + else: + raise TypeError( + "runtime_env['config'] must be of type dict or RuntimeEnvConfig, " + f"got: {type(config)}" + ) + + @classmethod + def default_config(cls): + return RuntimeEnvConfig(**cls._default_config) + + def build_proto_runtime_env_config(self) -> ProtoRuntimeEnvConfig: + runtime_env_config = ProtoRuntimeEnvConfig() + runtime_env_config.setup_timeout_seconds = self["setup_timeout_seconds"] + runtime_env_config.eager_install = self["eager_install"] + if self["log_files"] is not None: + runtime_env_config.log_files.extend(self["log_files"]) + return runtime_env_config + + @classmethod + def from_proto(cls, runtime_env_config: ProtoRuntimeEnvConfig): + setup_timeout_seconds = runtime_env_config.setup_timeout_seconds + # Cause python class RuntimeEnvConfig has validate to avoid + # setup_timeout_seconds equals zero, so setup_timeout_seconds + # on RuntimeEnvConfig is zero means other Language(except python) + # dosn't assign value to setup_timeout_seconds. So runtime_env_agent + # assign the default value to setup_timeout_seconds. + if setup_timeout_seconds == 0: + setup_timeout_seconds = cls._default_config["setup_timeout_seconds"] + return cls( + setup_timeout_seconds=setup_timeout_seconds, + eager_install=runtime_env_config.eager_install, + log_files=list(runtime_env_config.log_files), + ) + + def to_dict(self) -> Dict: + return dict(deepcopy(self)) + + +# Due to circular reference, field config can only be assigned a value here +OPTION_TO_VALIDATION_FN[ + "config" +] = RuntimeEnvConfig.parse_and_validate_runtime_env_config + + +@PublicAPI +class RuntimeEnv(dict): + """This class is used to define a runtime environment for a job, task, + or actor. + + See :ref:`runtime-environments` for detailed documentation. + + This class can be used interchangeably with an unstructured dictionary + in the relevant API calls. + + Can specify a runtime environment whole job, whether running a script + directly on the cluster, using Ray Job submission, or using Ray Client: + + .. code-block:: python + + from ray.runtime_env import RuntimeEnv + # Starting a single-node local Ray cluster + ray.init(runtime_env=RuntimeEnv(...)) + + .. code-block:: python + + from ray.runtime_env import RuntimeEnv + # Connecting to remote cluster using Ray Client + ray.init("ray://123.456.7.89:10001", runtime_env=RuntimeEnv(...)) + + Can specify different runtime environments per-actor or per-task using + ``.options()`` or the ``@ray.remote`` decorator: + + .. code-block:: python + + from ray.runtime_env import RuntimeEnv + # Invoke a remote task that runs in a specified runtime environment. + f.options(runtime_env=RuntimeEnv(...)).remote() + + # Instantiate an actor that runs in a specified runtime environment. + actor = SomeClass.options(runtime_env=RuntimeEnv(...)).remote() + + # Specify a runtime environment in the task definition. Future invocations via + # `g.remote()` use this runtime environment unless overridden by using + # `.options()` as above. + @ray.remote(runtime_env=RuntimeEnv(...)) + def g(): + pass + + # Specify a runtime environment in the actor definition. Future instantiations + # via `MyClass.remote()` use this runtime environment unless overridden by + # using `.options()` as above. + @ray.remote(runtime_env=RuntimeEnv(...)) + class MyClass: + pass + + Here are some examples of RuntimeEnv initialization: + + .. code-block:: python + + # Example for using conda + RuntimeEnv(conda={ + "channels": ["defaults"], "dependencies": ["codecov"]}) + RuntimeEnv(conda="pytorch_p36") # Found on DLAMIs + + # Example for using container + RuntimeEnv( + container={"image": "anyscale/ray-ml:nightly-py38-cpu", + "run_options": ["--cap-drop SYS_ADMIN","--log-level=debug"]}) + + # Example for set env_vars + RuntimeEnv(env_vars={"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"}) + + # Example for set pip + RuntimeEnv( + pip={"packages":["tensorflow", "requests"], "pip_check": False, + "pip_version": "==22.0.2;python_version=='3.8.11'"}) + + # Example for using image_uri + RuntimeEnv( + image_uri="rayproject/ray:2.39.0-py312-cu123") + + Args: + py_modules: List of URIs (either in the GCS or external + storage), each of which is a zip file that Ray unpacks and + inserts into the PYTHONPATH of the workers. + working_dir: URI (either in the GCS or external storage) of a zip + file that Ray unpacks in the directory of each task/actor. + pip: Either a list of pip packages, a string + containing the path to a pip requirements.txt file, or a Python + dictionary that has three fields: 1) ``packages`` (required, List[str]): a + list of pip packages, 2) ``pip_check`` (optional, bool): whether enable + pip check at the end of pip install, defaults to False. + 3) ``pip_version`` (optional, str): the version of pip, Ray prepends + the package name "pip" in front of the ``pip_version`` to form the final + requirement string, the syntax of a requirement specifier is defined in + full in PEP 508. + uv: Either a list of pip packages, or a Python dictionary that has one field: + 1) ``packages`` (required, List[str]). + conda: Either the conda YAML config, the name of a + local conda env (e.g., "pytorch_p36"), or the path to a conda + environment.yaml file. + Ray automatically injects the dependency into the conda + env to ensure compatibility with the cluster Ray. Ray may automatically + mangle the conda name to avoid conflicts between runtime envs. + This field can't be specified at the same time as the 'pip' field. + To use pip with conda, specify your pip dependencies within + the conda YAML config: + https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#create-env-file-manually + container: Require a given (Docker) container image, + The Ray worker process runs in a container with this image. + This parameter only works alone, or with the ``config`` or + ``env_vars`` parameters. + The `run_options` list spec is here: + https://docs.docker.com/engine/reference/run/ + env_vars: Environment variables to set. + worker_process_setup_hook: (Experimental) The setup hook that's + called after workers start and before Tasks and Actors are scheduled. + A module name (string type) or callable (function) can be passed. + When a module name is passed, Ray worker should be able to access the + module name. When a callable is passed, callable should be serializable. + When a runtime env is specified by job submission API, + only a module name (string) is allowed. + nsight: Dictionary mapping nsight profile option name to it's value. + config: config for runtime environment. Either + a dict or a RuntimeEnvConfig. Field: (1) setup_timeout_seconds, the + timeout of runtime environment creation, timeout is in seconds. + image_uri: URI to a container image. The Ray worker process runs + in a container with this image. This parameter only works alone, + or with the ``config`` or ``env_vars`` parameters. + """ + + known_fields: Set[str] = { + "py_modules", + "java_jars", + "working_dir", + "conda", + "pip", + "uv", + "container", + "excludes", + "env_vars", + "_ray_release", + "_ray_commit", + "_inject_current_ray", + "config", + # TODO(SongGuyang): We add this because the test + # `test_experimental_package_github` set a `docker` + # field which is not supported. We should remove it + # with the test. + "docker", + "worker_process_setup_hook", + "_nsight", + "mpi", + "image_uri", + } + + extensions_fields: Set[str] = { + "_ray_release", + "_ray_commit", + "_inject_current_ray", + } + + def __init__( + self, + *, + py_modules: Optional[List[str]] = None, + working_dir: Optional[str] = None, + pip: Optional[List[str]] = None, + conda: Optional[Union[Dict[str, str], str]] = None, + container: Optional[Dict[str, str]] = None, + env_vars: Optional[Dict[str, str]] = None, + worker_process_setup_hook: Optional[Union[Callable, str]] = None, + nsight: Optional[Union[str, Dict[str, str]]] = None, + config: Optional[Union[Dict, RuntimeEnvConfig]] = None, + _validate: bool = True, + mpi: Optional[Dict] = None, + image_uri: Optional[str] = None, + uv: Optional[List[str]] = None, + **kwargs, + ): + super().__init__() + + runtime_env = kwargs + if py_modules is not None: + runtime_env["py_modules"] = py_modules + if working_dir is not None: + runtime_env["working_dir"] = working_dir + if pip is not None: + runtime_env["pip"] = pip + if uv is not None: + runtime_env["uv"] = uv + if conda is not None: + runtime_env["conda"] = conda + if nsight is not None: + runtime_env["_nsight"] = nsight + if container is not None: + runtime_env["container"] = container + if env_vars is not None: + runtime_env["env_vars"] = env_vars + if config is not None: + runtime_env["config"] = config + if worker_process_setup_hook is not None: + runtime_env["worker_process_setup_hook"] = worker_process_setup_hook + if mpi is not None: + runtime_env["mpi"] = mpi + if image_uri is not None: + runtime_env["image_uri"] = image_uri + if runtime_env.get("java_jars"): + runtime_env["java_jars"] = runtime_env.get("java_jars") + + self.update(runtime_env) + + # Blindly trust that the runtime_env has already been validated. + # This is dangerous and should only be used internally (e.g., on the + # deserialization codepath. + if not _validate: + return + + if (self.get("conda") is not None) + (self.get("pip") is not None) + ( + self.get("uv") is not None + ) > 1: + raise ValueError( + "The 'pip' field, 'uv' field, and 'conda' field of " + "runtime_env cannot be specified at the same time.\n" + f"specified pip field: {self.get('pip')}\n" + f"specified conda field: {self.get('conda')}\n" + f"specified uv field: {self.get('uv')}\n" + "To use pip with conda, please only set the 'conda'" + "field, and specify your pip dependencies within the conda YAML " + "config dict: see https://conda.io/projects/conda/en/latest/" + "user-guide/tasks/manage-environments.html" + "#create-env-file-manually" + ) + + if self.get("container"): + invalid_keys = set(runtime_env.keys()) - {"container", "config", "env_vars"} + if len(invalid_keys): + raise ValueError( + "The 'container' field currently cannot be used " + "together with other fields of runtime_env. " + f"Specified fields: {invalid_keys}" + ) + + logger.warning( + "The `container` runtime environment field is DEPRECATED and will be " + "removed after July 31, 2025. Use `image_uri` instead. See " + "https://docs.ray.io/en/latest/serve/advanced-guides/multi-app-container.html." # noqa + ) + + if self.get("image_uri"): + image_uri_plugin_cls = get_image_uri_plugin_cls() + invalid_keys = ( + set(runtime_env.keys()) - image_uri_plugin_cls.get_compatible_keys() + ) + if len(invalid_keys): + raise ValueError( + "The 'image_uri' field currently cannot be used " + "together with other fields of runtime_env. " + f"Specified fields: {invalid_keys}" + ) + + for option, validate_fn in OPTION_TO_VALIDATION_FN.items(): + option_val = self.get(option) + if option_val is not None: + del self[option] + self[option] = option_val + + if "_ray_commit" not in self: + if self.get("pip") or self.get("conda"): + self["_ray_commit"] = ray.__commit__ + + # Used for testing wheels that have not yet been merged into master. + # If this is set to True, then we do not inject Ray into the conda + # or pip dependencies. + if "_inject_current_ray" not in self: + if "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ: + self["_inject_current_ray"] = True + + # NOTE(architkulkarni): This allows worker caching code in C++ to check + # if a runtime env is empty without deserializing it. This is a catch- + # all; for validated inputs we won't set the key if the value is None. + if all(val is None for val in self.values()): + self.clear() + + def __setitem__(self, key: str, value: Any) -> None: + if is_dataclass(value): + jsonable_type = asdict(value) + else: + jsonable_type = value + RuntimeEnvPluginSchemaManager.validate(key, jsonable_type) + res_value = jsonable_type + if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN: + res_value = OPTION_TO_VALIDATION_FN[key](jsonable_type) + if res_value is None: + return + return super().__setitem__(key, res_value) + + def set(self, name: str, value: Any) -> None: + self.__setitem__(name, value) + + def get(self, name, default=None, data_class=None): + if name not in self: + return default + if not data_class: + return self.__getitem__(name) + else: + return from_dict(data_class=data_class, data=self.__getitem__(name)) + + @classmethod + def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821 + return cls(_validate=False, **json.loads(serialized_runtime_env)) + + def serialize(self) -> str: + # To ensure the accuracy of Proto, `__setitem__` can only guarantee the + # accuracy of a certain field, not the overall accuracy + runtime_env = type(self)(_validate=True, **self) + return json.dumps( + runtime_env, + sort_keys=True, + ) + + def to_dict(self) -> Dict: + runtime_env_dict = dict(deepcopy(self)) + + # Replace strongly-typed RuntimeEnvConfig with a dict to allow the returned + # dict to work properly as a field in a dataclass. Details in issue #26986 + if runtime_env_dict.get("config"): + runtime_env_dict["config"] = runtime_env_dict["config"].to_dict() + + return runtime_env_dict + + def has_working_dir(self) -> bool: + return self.get("working_dir") is not None + + def working_dir_uri(self) -> Optional[str]: + return self.get("working_dir") + + def py_modules_uris(self) -> List[str]: + if "py_modules" in self: + return list(self["py_modules"]) + return [] + + def conda_uri(self) -> Optional[str]: + if "conda" in self: + return get_conda_uri(self) + return None + + def pip_uri(self) -> Optional[str]: + if "pip" in self: + return get_pip_uri(self) + return None + + def uv_uri(self) -> Optional[str]: + if "uv" in self: + return get_uv_uri(self) + return None + + def plugin_uris(self) -> List[str]: + """Not implemented yet, always return a empty list""" + return [] + + def working_dir(self) -> str: + return self.get("working_dir", "") + + def py_modules(self) -> List[str]: + if "py_modules" in self: + return list(self["py_modules"]) + return [] + + def java_jars(self) -> List[str]: + if "java_jars" in self: + return list(self["java_jars"]) + return [] + + def mpi(self) -> Optional[Union[str, Dict[str, str]]]: + return self.get("mpi", None) + + def nsight(self) -> Optional[Union[str, Dict[str, str]]]: + return self.get("_nsight", None) + + def env_vars(self) -> Dict: + return self.get("env_vars", {}) + + def has_conda(self) -> str: + if self.get("conda"): + return True + return False + + def conda_env_name(self) -> str: + if not self.has_conda() or not isinstance(self["conda"], str): + return None + return self["conda"] + + def conda_config(self) -> str: + if not self.has_conda() or not isinstance(self["conda"], dict): + return None + return json.dumps(self["conda"], sort_keys=True) + + def has_pip(self) -> bool: + if self.get("pip"): + return True + return False + + def has_uv(self) -> bool: + if self.get("uv"): + return True + return False + + def virtualenv_name(self) -> Optional[str]: + if not self.has_pip() or not isinstance(self["pip"], str): + return None + return self["pip"] + + def pip_config(self) -> Dict: + if not self.has_pip() or isinstance(self["pip"], str): + return {} + # Parse and validate field pip on method `__setitem__` + self["pip"] = self["pip"] + return self["pip"] + + def uv_config(self) -> Dict: + if not self.has_uv() or isinstance(self["uv"], str): + return {} + # Parse and validate field pip on method `__setitem__` + self["uv"] = self["uv"] + return self["uv"] + + def get_extension(self, key) -> Optional[str]: + if key not in RuntimeEnv.extensions_fields: + raise ValueError( + f"Extension key must be one of {RuntimeEnv.extensions_fields}, " + f"got: {key}" + ) + return self.get(key) + + def has_py_container(self) -> bool: + if self.get("container"): + return True + return False + + def py_container_image(self) -> Optional[str]: + if not self.has_py_container(): + return None + return self["container"].get("image", "") + + def py_container_worker_path(self) -> Optional[str]: + if not self.has_py_container(): + return None + return self["container"].get("worker_path", "") + + def py_container_run_options(self) -> List: + if not self.has_py_container(): + return None + return self["container"].get("run_options", []) + + def image_uri(self) -> Optional[str]: + return self.get("image_uri") + + def plugins(self) -> List[Tuple[str, Any]]: + result = list() + for key, value in self.items(): + if key not in self.known_fields: + result.append((key, value)) + return result + + +def _merge_runtime_env( + parent: Optional[RuntimeEnv], + child: Optional[RuntimeEnv], + override: bool = False, +) -> Optional[RuntimeEnv]: + """Merge the parent and child runtime environments. + + If override = True, the child's runtime env overrides the parent's + runtime env in the event of a conflict. + + Merging happens per key (i.e., "conda", "pip", ...), but + "env_vars" are merged per env var key. + + It returns None if Ray fails to merge runtime environments because + of a conflict and `override = False`. + + Args: + parent: Parent runtime env. + child: Child runtime env. + override: If True, the child's runtime env overrides + conflicting fields. + Returns: + The merged runtime env's if Ray successfully merges them. + None if the runtime env's conflict. Empty dict if + parent and child are both None. + """ + if parent is None: + parent = {} + if child is None: + child = {} + + parent = deepcopy(parent) + child = deepcopy(child) + parent_env_vars = parent.pop("env_vars", {}) + child_env_vars = child.pop("env_vars", {}) + + if not override: + if set(parent.keys()).intersection(set(child.keys())): + return None + if set(parent_env_vars.keys()).intersection(set(child_env_vars.keys())): # noqa + return None + + parent.update(child) + parent_env_vars.update(child_env_vars) + if parent_env_vars: + parent["env_vars"] = parent_env_vars + + return parent diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/__init__.py b/.venv/lib/python3.11/site-packages/ray/widgets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2479501640ba3f2b0be30c3548ee465a13a481a9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/__init__.py @@ -0,0 +1,4 @@ +from ray.widgets.render import Template +from ray.widgets.util import make_table_html_repr + +__all__ = ["Template", "make_table_html_repr"] diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebfcb8da95e44f5cb4d4f64d0c8501ec57fafbe8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/render.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/render.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66672b0828d92bfde11d7cefa54947a43c9076c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/render.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0705bceecbe582ca3e46f649fcc85408404beb10 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/render.py b/.venv/lib/python3.11/site-packages/ray/widgets/render.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e861d39925680c403ff996e9279d4d349bafe5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/render.py @@ -0,0 +1,39 @@ +import pathlib +from typing import List + +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class Template: + """Class which provides basic HTML templating.""" + + def __init__(self, file: str): + with open(pathlib.Path(__file__).parent / "templates" / file, "r") as f: + self.template = f.read() + + def render(self, **kwargs) -> str: + """Render an HTML template with the given data. + + This is done by replacing instances of `{{ key }}` with `value` + from the keyword arguments. + + Returns: + HTML template with the keys of the kwargs replaced with corresponding + values. + """ + rendered = self.template + for key, value in kwargs.items(): + if isinstance(value, List): + value = "".join(value) + rendered = rendered.replace("{{ " + key + " }}", value if value else "") + return rendered + + @staticmethod + def list_templates() -> List[pathlib.Path]: + """List the available HTML templates. + + Returns: + A list of files with .html.j2 extensions inside ../templates/ + """ + return (pathlib.Path(__file__).parent / "templates").glob("*.html.j2") diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/context.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..26cc0ef6c8784c9859151f8553b240d7a94e9360 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context.html.j2 @@ -0,0 +1,6 @@ + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_dashrow.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_dashrow.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..47fbc2fa6f6da7e8a44c88c0e07ae51c6a737dbd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_dashrow.html.j2 @@ -0,0 +1,4 @@ + + Dashboard: + {{ dashboard_url }} + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_logo.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_logo.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..9233fe3a77226c760442d5366b8d5fb65b93abcd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_logo.html.j2 @@ -0,0 +1,13 @@ + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_table.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_table.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..d06822d0c1f56534e5009ea2ff81340fdc6f7850 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/context_table.html.j2 @@ -0,0 +1,11 @@ + + + + + + + + + + {{ dashboard_row }} + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/divider.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/divider.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..b9a04173d7e0918c54a1d950da25938a215be3a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/divider.html.j2 @@ -0,0 +1,9 @@ +
+ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/rendered_html_common.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/rendered_html_common.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..35b4cee0133a3b4105e84ca7a66a2ffeb01f4c5b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/rendered_html_common.html.j2 @@ -0,0 +1,3 @@ + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/run_config.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/run_config.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..4cf392dff519cdd5d7c0629ca3394f47ef705825 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/run_config.html.j2 @@ -0,0 +1,18 @@ +
+
+ {{ settings }} +
+
+ {{ subconfigs }} +
+
+ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/scrollableTable.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/scrollableTable.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..2ec1637b92ee7de3d22bba0a956b8539464c0501 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/scrollableTable.html.j2 @@ -0,0 +1,20 @@ + + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/title_data.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/title_data.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..1731a157d17b4674e3d7cc427164065dafe49de2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/title_data.html.j2 @@ -0,0 +1,11 @@ + + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/title_data_mini.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/title_data_mini.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..bfe654c56346b09a1999dda379d129aaf1e317be --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/title_data_mini.html.j2 @@ -0,0 +1,4 @@ + diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/trial_progress.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/trial_progress.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..f3a323193e7fa242f9cc37a75d22d02841b74da6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/trial_progress.html.j2 @@ -0,0 +1,17 @@ +
+

Trial Progress

+ {{ table }} +
+ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/tune_status.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/tune_status.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..df422f89af4e7d03bc70006414f12652f5898824 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/tune_status.html.j2 @@ -0,0 +1,49 @@ +
+
+
+

Tune Status

+ {{ status_table }} +
+
+
+

System Info

+ {{ sys_info_message }} +
+ {{ messages }} +
+
+
+

Trial Status

+ {{ trial_progress }} +
+
+ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/templates/tune_status_messages.html.j2 b/.venv/lib/python3.11/site-packages/ray/widgets/templates/tune_status_messages.html.j2 new file mode 100644 index 0000000000000000000000000000000000000000..da8e75f5f58d3a5c78274429e7490947aff4c410 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/templates/tune_status_messages.html.j2 @@ -0,0 +1,25 @@ +
+
+

Messages

+ {{ memory_message }} + {{ trial_progress_messages }} + {{ trial_errors }} +
+ diff --git a/.venv/lib/python3.11/site-packages/ray/widgets/util.py b/.venv/lib/python3.11/site-packages/ray/widgets/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2f171c6519cde57fd69033e040a9305ebbbdf53b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/widgets/util.py @@ -0,0 +1,207 @@ +import importlib +import logging +import sys +import textwrap +from functools import wraps +from typing import Any, Callable, Iterable, Optional, TypeVar, Union + +from packaging.version import Version + +from ray._private.thirdparty.tabulate.tabulate import tabulate +from ray.util.annotations import DeveloperAPI +from ray.widgets import Template + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +@DeveloperAPI +def make_table_html_repr( + obj: Any, title: Optional[str] = None, max_height: str = "none" +) -> str: + """Generate a generic html repr using a table. + + Args: + obj: Object for which a repr is to be generated + title: If present, a title for the section is included + max_height: Maximum height of the table; valid values + are given by the max-height CSS property + + Returns: + HTML representation of the object + """ + data = {} + for k, v in vars(obj).items(): + if isinstance(v, (str, bool, int, float)): + data[k] = str(v) + + elif isinstance(v, dict) or hasattr(v, "__dict__"): + data[k] = Template("scrollableTable.html.j2").render( + table=tabulate( + v.items() if isinstance(v, dict) else vars(v).items(), + tablefmt="html", + showindex=False, + headers=["Setting", "Value"], + ), + max_height="none", + ) + + table = Template("scrollableTable.html.j2").render( + table=tabulate( + data.items(), + tablefmt="unsafehtml", + showindex=False, + headers=["Setting", "Value"], + ), + max_height=max_height, + ) + + if title: + content = Template("title_data.html.j2").render(title=title, data=table) + else: + content = table + + return content + + +def _has_missing( + *deps: Iterable[Union[str, Optional[str]]], message: Optional[str] = None +): + """Return a list of missing dependencies. + + Args: + deps: Dependencies to check for + message: Message to be emitted if a dependency isn't found + + Returns: + A list of dependencies which can't be found, if any + """ + missing = [] + for (lib, _) in deps: + if importlib.util.find_spec(lib) is None: + missing.append(lib) + + if missing: + if not message: + message = f"Run `pip install {' '.join(missing)}` for rich notebook output." + + # stacklevel=3: First level is this function, then ensure_notebook_deps, + # then the actual function affected. + logger.info(f"Missing packages: {missing}. {message}", stacklevel=3) + + return missing + + +def _has_outdated( + *deps: Iterable[Union[str, Optional[str]]], message: Optional[str] = None +): + outdated = [] + for (lib, version) in deps: + try: + + module = importlib.import_module(lib) + if version and Version(module.__version__) < Version(version): + outdated.append([lib, version, module.__version__]) + except ImportError: + pass + + if outdated: + outdated_strs = [] + install_args = [] + for lib, version, installed in outdated: + outdated_strs.append(f"{lib}=={installed} found, needs {lib}>={version}") + install_args.append(f"{lib}>={version}") + + outdated_str = textwrap.indent("\n".join(outdated_strs), " ") + install_str = " ".join(install_args) + + if not message: + message = f"Run `pip install -U {install_str}` for rich notebook output." + + # stacklevel=3: First level is this function, then ensure_notebook_deps, + # then the actual function affected. + logger.info(f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3) + + return outdated + + +@DeveloperAPI +def repr_with_fallback( + *notebook_deps: Iterable[Union[str, Optional[str]]] +) -> Callable[[F], F]: + """Decorator which strips rich notebook output from mimebundles in certain cases. + + Fallback to plaintext and don't use rich output in the following cases: + 1. In a notebook environment and the appropriate dependencies are not installed. + 2. In a ipython shell environment. + 3. In Google Colab environment. + See https://github.com/googlecolab/colabtools/ issues/60 for more information + about the status of this issue. + + Args: + notebook_deps: The required dependencies and version for notebook environment. + + Returns: + A function that returns the usual _repr_mimebundle_, unless any of the 3 + conditions above hold, in which case it returns a mimebundle that only contains + a single text/plain mimetype. + """ + message = ( + "Run `pip install -U ipywidgets`, then restart " + "the notebook server for rich notebook output." + ) + if _can_display_ipywidgets(*notebook_deps, message=message): + + def wrapper(func: F) -> F: + @wraps(func) + def wrapped(self, *args, **kwargs): + return func(self, *args, **kwargs) + + return wrapped + + else: + + def wrapper(func: F) -> F: + @wraps(func) + def wrapped(self, *args, **kwargs): + return {"text/plain": repr(self)} + + return wrapped + + return wrapper + + +def _get_ipython_shell_name() -> str: + if "IPython" in sys.modules: + from IPython import get_ipython + + return get_ipython().__class__.__name__ + return "" + + +def _can_display_ipywidgets(*deps, message) -> bool: + # Default to safe behavior: only display widgets if running in a notebook + # that has valid dependencies + if in_notebook() and not ( + _has_missing(*deps, message=message) or _has_outdated(*deps, message=message) + ): + return True + + return False + + +@DeveloperAPI +def in_notebook(shell_name: Optional[str] = None) -> bool: + """Return whether we are in a Jupyter notebook or qtconsole.""" + if not shell_name: + shell_name = _get_ipython_shell_name() + return shell_name == "ZMQInteractiveShell" + + +@DeveloperAPI +def in_ipython_shell(shell_name: Optional[str] = None) -> bool: + """Return whether we are in a terminal running IPython""" + if not shell_name: + shell_name = _get_ipython_shell_name() + return shell_name == "TerminalInteractiveShell" diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ac372d13088eee69b059f6d8e5eed6bfdf17ea4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33299e0c839f0e1f85c24a0b4f6196c0edbd2090 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a622fe847a3c44f824a4ac6cbd34444db2827021 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/event_listener.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/event_listener.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb81ba686369031e2ee342ffb294cd13553e19af Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/event_listener.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4fc4889a08d6cbf0f0620f2ab13056bf9538e72 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/http_event_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/http_event_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06b39a259d7b000c362222d5ddd2c12d386b6e26 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/http_event_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b8fd10f2fde97ab0a00ae05e21c783cb4e097df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/serialization_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/task_executor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/task_executor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cd317a5461a4fa92773171621f0d37bf58d6a67 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/task_executor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_storage.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_storage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06318cb103606a7d30f55ed29805010ee8c36576 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/__pycache__/workflow_storage.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/__init__.py b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7da27ecc89000cb89e01c65d5d70b6b78380f2c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__init__.py @@ -0,0 +1,9 @@ +from ray.workflow.storage.base import Storage +from ray.workflow.storage.base import DataLoadError, DataSaveError, KeyNotFoundError + +__all__ = ( + "Storage", + "DataLoadError", + "DataSaveError", + "KeyNotFoundError", +) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..298949ac6f17498e85f9e6213338872ca096c6da Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff2d62c43ec3405ed99d7a1f2a93c846da2db256 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/debug.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71c96688363a9e8393253dd57ce0c5ff903a539f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/debug.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/filesystem.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/filesystem.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe5e04b9fa1dc4d5b533e4e0bfd916d7fdfaf9f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/workflow/storage/__pycache__/filesystem.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/base.py b/.venv/lib/python3.11/site-packages/ray/workflow/storage/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5323ff516c4cf2450636fee2c0afe568779573d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/storage/base.py @@ -0,0 +1,76 @@ +import abc +from abc import abstractmethod +from typing import Any, List + + +class DataLoadError(Exception): + pass + + +class DataSaveError(Exception): + pass + + +class KeyNotFoundError(KeyError): + pass + + +class Storage(metaclass=abc.ABCMeta): + """Abstract base class for the low-level workflow storage. + This class only provides low level primitives, e.g. save a certain + type of object. + """ + + @abstractmethod + def make_key(self, *names: str) -> str: + """Make key from name sections.""" + + @abstractmethod + async def put(self, key: str, data: Any, is_json: bool = False) -> None: + """Put object into storage. + + Args: + key: The key of the object. + data: The object data. + is_json: True if the object is a json object. + """ + + @abstractmethod + async def get(self, key: str, is_json: bool = False) -> Any: + """Get object from storage. + + Args: + key: The key of the object. + is_json: True if the object is a json object. + + Returns: + The object from storage. + """ + + @abstractmethod + async def delete_prefix(self, key_prefix: str) -> None: + """Delete an object with prefix. + + Args: + key_prefix: The prefix to delete. + """ + + @abstractmethod + async def scan_prefix(self, key_prefix: str) -> List[str]: + """List all keys with the prefix. + + Args: + key_prefix: The prefix of the key. + + Returns: + List of matched keys. + """ + + @property + @abstractmethod + def storage_url(self) -> str: + """Get the URL of the storage.""" + + @abstractmethod + def __reduce__(self): + """Reduce the storage to a serializable object.""" diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/debug.py b/.venv/lib/python3.11/site-packages/ray/workflow/storage/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4f8f04466a9ea350cdeebedfc43f6598335fb8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/storage/debug.py @@ -0,0 +1,190 @@ +import json +from typing import Any, List +from urllib import parse +import pathlib +from filelock import FileLock +from ray.workflow.storage.base import Storage +from ray.workflow.storage.filesystem import FilesystemStorageImpl +import ray.cloudpickle +from ray.workflow import serialization_context + + +class LoggedStorage(FilesystemStorageImpl): + """A storage records all writing to storage sequentially.""" + + def __init__(self, workflow_root_dir: str): + super().__init__(workflow_root_dir) + self._log_dir = self._workflow_root_dir + self._count = self._log_dir / "count.log" + self._op_counter = self._log_dir / "op_counter.pkl" + if not self._log_dir.exists(): + self._log_dir.mkdir() + # only one process initializes the count + with FileLock(str(self._workflow_root_dir / ".lock")): + if not self._count.exists(): + with open(self._count, "x") as f: + f.write("0") + if not self._op_counter.exists(): + with open(self._op_counter, "wb") as f: + ray.cloudpickle.dump({}, f) + + def get_op_counter(self): + with FileLock(str(self._log_dir / ".lock")): + with open(self._op_counter, "rb") as f: + counter = ray.cloudpickle.load(f) + return counter + + def update_count(self, op: str, key): + counter = None + with open(self._op_counter, "rb") as f: + counter = ray.cloudpickle.load(f) + if op not in counter: + counter[op] = [] + counter[op].append(key) + with open(self._op_counter, "wb") as f: + ray.cloudpickle.dump(counter, f) + + async def put(self, key: str, data: Any, is_json: bool = False) -> None: + with FileLock(str(self._log_dir / ".lock")): + self.update_count("put", key) + with open(self._count, "r") as f: + count = int(f.read()) + k1 = self._log_dir / f"{count}.metadata.json" + k2 = self._log_dir / f"{count}.value" + await super().put( + str(k1), + {"operation": "put", "key": key, "is_json": is_json}, + is_json=True, + ) + await super().put(str(k2), data, is_json=is_json) + with open(self._count, "w") as f: + f.write(str(count + 1)) + + async def get(self, key: str, is_json=False) -> None: + with FileLock(str(self._log_dir / ".lock")): + self.update_count("get", key) + + async def delete_prefix(self, key: str) -> None: + with FileLock(str(self._log_dir / ".lock")): + with open(self._count, "r") as f: + count = int(f.read()) + k1 = self._log_dir / f"{count}.metadata.json" + await super().put( + str(k1), {"operation": "delete_prefix", "key": key}, is_json=True + ) + with open(self._count, "w") as f: + f.write(str(count + 1)) + + def get_metadata(self, index: int) -> Any: + with open(self._log_dir / f"{index}.metadata.json") as f: + return json.load(f) + + def get_value(self, index: int, is_json: bool) -> Any: + path = self._log_dir / f"{index}.value" + if is_json: + with open(path) as f: + return json.load(f) + else: + with open(path, "rb") as f: + with serialization_context.workflow_args_keeping_context(): + return ray.cloudpickle.load(f) + + def __len__(self): + with open(self._count, "r") as f: + return int(f.read()) + + +class DebugStorage(Storage): + """A storage for debugging purpose.""" + + def __init__(self, wrapped_storage: "Storage", path: str): + self._log_on = True + self._path = path + self._wrapped_storage = wrapped_storage + log_path = pathlib.Path(path) + parsed = parse.urlparse(wrapped_storage.storage_url) + log_path = ( + log_path + / parsed.scheme.strip("/") + / parsed.netloc.strip("/") + / parsed.path.strip("/") + ) + if not log_path.exists(): + log_path.mkdir(parents=True) + self._logged_storage = LoggedStorage(str(log_path)) + self._op_log_file = log_path / "debug_operations.log" + + def make_key(self, *names: str) -> str: + return self._wrapped_storage.make_key(*names) + + async def get(self, key: str, is_json: bool = False) -> Any: + await self._logged_storage.get(key, is_json) + return await self._wrapped_storage.get(key, is_json) + + async def put(self, key: str, data: Any, is_json: bool = False) -> None: + if self._log_on: + await self._logged_storage.put(key, data, is_json) + await self._wrapped_storage.put(key, data, is_json) + + async def delete_prefix(self, prefix: str) -> None: + if self._log_on: + await self._logged_storage.delete_prefix(prefix) + await self._wrapped_storage.delete_prefix(prefix) + + async def scan_prefix(self, key_prefix: str) -> List[str]: + return await self._wrapped_storage.scan_prefix(key_prefix) + + @property + def storage_url(self) -> str: + store_url = parse.quote_plus(self._wrapped_storage.storage_url) + parsed_url = parse.ParseResult( + scheme="debug", + path=str(pathlib.Path(self._path).absolute()), + netloc="", + params="", + query=f"storage={store_url}", + fragment="", + ) + return parse.urlunparse(parsed_url) + + def __reduce__(self): + return DebugStorage, (self._wrapped_storage, self._path) + + @property + def wrapped_storage(self) -> "Storage": + """Get wrapped storage.""" + return self._wrapped_storage + + async def replay(self, index: int) -> None: + """Replay the a record to the storage. + + Args: + index: The index of the recorded log to replay. + """ + log = self.get_log(index) + op = log["operation"] + if op == "put": + is_json = log["is_json"] + data = self.get_value(index, is_json) + await self._wrapped_storage.put(log["key"], data, is_json) + elif op == "delete_prefix": + await self._wrapped_storage.delete_prefix(log["key"]) + elif op == "get": + pass + else: + raise ValueError(f"Unknown operation '{op}'.") + + def get_log(self, index: int) -> Any: + return self._logged_storage.get_metadata(index) + + def get_value(self, index: int, is_json: bool) -> Any: + return self._logged_storage.get_value(index, is_json) + + def log_off(self): + self._log_on = False + + def log_on(self): + self._log_on = True + + def __len__(self): + return len(self._logged_storage) diff --git a/.venv/lib/python3.11/site-packages/ray/workflow/storage/filesystem.py b/.venv/lib/python3.11/site-packages/ray/workflow/storage/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..8d26c8694d2fab45d81ec31ea98988b9addb0679 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/workflow/storage/filesystem.py @@ -0,0 +1,172 @@ +import os +import contextlib +import json +import shutil +import pathlib +from typing import Any, List +import uuid + +from ray.workflow.storage.base import Storage, KeyNotFoundError + +import ray.cloudpickle + + +@contextlib.contextmanager +def _open_atomic(path: pathlib.Path, mode="r"): + """Open file with atomic file writing support. File reading is also + adapted to atomic file writing (for example, the backup file + is used when an atomic write failed previously.) + + TODO(suquark): race condition like two processes writing the + same file is still not safe. This may not be an issue, because + in our current implementation, we only need to guarantee the + file is either fully written or not existing. + + Args: + path: The file path. + mode: Open mode same as "open()". + + Returns: + File object. + """ + if "a" in mode or "+" in mode: + raise ValueError("Atomic open does not support appending.") + # backup file is hidden by default + backup_path = path.with_name(f".{path.name}.backup") + if "r" in mode: # read mode + if _file_exists(path): + f = open(path, mode) + else: + raise KeyNotFoundError(path) + try: + yield f + finally: + f.close() + elif "x" in mode: # create mode + if path.exists(): + raise FileExistsError(path) + tmp_new_fn = path.with_suffix(f".{path.name}.{uuid.uuid4().hex}") + if not tmp_new_fn.parent.exists(): + tmp_new_fn.parent.mkdir(parents=True) + f = open(tmp_new_fn, mode) + write_ok = True + try: + yield f + except Exception: + write_ok = False + raise + finally: + f.close() + if write_ok: + # "commit" file if writing succeeded + tmp_new_fn.rename(path) + else: + # remove file if writing failed + tmp_new_fn.unlink() + elif "w" in mode: # overwrite mode + # backup existing file + if path.exists(): + # remove an even older backup file + if backup_path.exists(): + backup_path.unlink() + path.rename(backup_path) + tmp_new_fn = path.with_suffix(f".{path.name}.{uuid.uuid4().hex}") + if not tmp_new_fn.parent.exists(): + tmp_new_fn.parent.mkdir(parents=True) + f = open(tmp_new_fn, mode) + write_ok = True + try: + yield f + except Exception: + write_ok = False + raise + finally: + f.close() + if write_ok: + tmp_new_fn.rename(path) + # cleanup the backup file + if backup_path.exists(): + backup_path.unlink() + else: + # remove file if writing failed + tmp_new_fn.unlink() + else: + raise ValueError(f"Unknown file open mode {mode}.") + + +def _file_exists(path: pathlib.Path) -> bool: + """During atomic writing, we backup the original file. If the writing + failed during the middle, then only the backup exists. We consider the + file exists if the file or the backup file exists. We also automatically + restore the backup file to the original path if only backup file exists. + + Args: + path: File path. + + Returns: + True if the file and backup exists. + """ + backup_path = path.with_name(f".{path.name}.backup") + if path.exists(): + return True + elif backup_path.exists(): + backup_path.rename(path) + return True + return False + + +class FilesystemStorageImpl(Storage): + """Filesystem implementation for accessing workflow storage. + + We do not repeat the same comments for abstract methods in the base class. + """ + + def __init__(self, workflow_root_dir: str): + self._workflow_root_dir = pathlib.Path(workflow_root_dir) + if self._workflow_root_dir.exists(): + if not self._workflow_root_dir.is_dir(): + raise ValueError( + f"storage path {workflow_root_dir} must be a directory." + ) + else: + self._workflow_root_dir.mkdir(parents=True) + + def make_key(self, *names: str) -> str: + return os.path.join(str(self._workflow_root_dir), *names) + + async def put(self, key: str, data: Any, is_json: bool = False) -> None: + if is_json: + with _open_atomic(pathlib.Path(key), "w") as f: + return json.dump(data, f) + else: + with _open_atomic(pathlib.Path(key), "wb") as f: + return ray.cloudpickle.dump(data, f) + + async def get(self, key: str, is_json: bool = False) -> Any: + if is_json: + with _open_atomic(pathlib.Path(key)) as f: + return json.load(f) + else: + with _open_atomic(pathlib.Path(key), "rb") as f: + return ray.cloudpickle.load(f) + + async def delete_prefix(self, key_prefix: str) -> None: + path = pathlib.Path(key_prefix) + if path.is_dir(): + shutil.rmtree(str(path)) + else: + path.unlink() + + async def scan_prefix(self, key_prefix: str) -> List[str]: + try: + path = pathlib.Path(key_prefix) + return [p.name for p in path.iterdir()] + except FileNotFoundError: + return [] + + @property + def storage_url(self) -> str: + return "file://" + str(self._workflow_root_dir.absolute()) + + def __reduce__(self): + return FilesystemStorageImpl, (self._workflow_root_dir,)