koichi12 commited on
Commit
98ca408
·
verified ·
1 Parent(s): 80c179b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/data/__init__.py +165 -0
  2. .venv/lib/python3.11/site-packages/ray/data/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/data/__pycache__/aggregate.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/data/__pycache__/block.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/data/__pycache__/context.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/data/__pycache__/exceptions.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/data/__pycache__/grouped_data.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/data/__pycache__/iterator.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/data/__pycache__/preprocessor.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/data/__pycache__/random_access_dataset.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/data/aggregate.py +76 -0
  12. .venv/lib/python3.11/site-packages/ray/data/block.py +561 -0
  13. .venv/lib/python3.11/site-packages/ray/data/context.py +468 -0
  14. .venv/lib/python3.11/site-packages/ray/data/dataset.py +0 -0
  15. .venv/lib/python3.11/site-packages/ray/data/datasource/datasink.py +164 -0
  16. .venv/lib/python3.11/site-packages/ray/data/datasource/datasource.py +243 -0
  17. .venv/lib/python3.11/site-packages/ray/data/datasource/file_based_datasource.py +572 -0
  18. .venv/lib/python3.11/site-packages/ray/data/datasource/file_meta_provider.py +484 -0
  19. .venv/lib/python3.11/site-packages/ray/data/datasource/filename_provider.py +122 -0
  20. .venv/lib/python3.11/site-packages/ray/data/datasource/parquet_meta_provider.py +252 -0
  21. .venv/lib/python3.11/site-packages/ray/data/exceptions.py +91 -0
  22. .venv/lib/python3.11/site-packages/ray/data/grouped_data.py +494 -0
  23. .venv/lib/python3.11/site-packages/ray/data/iterator.py +931 -0
  24. .venv/lib/python3.11/site-packages/ray/data/preprocessor.py +318 -0
  25. .venv/lib/python3.11/site-packages/ray/data/random_access_dataset.py +293 -0
  26. .venv/lib/python3.11/site-packages/ray/data/read_api.py +0 -0
  27. .venv/lib/python3.11/site-packages/ray/includes/__init__.pxd +0 -0
  28. .venv/lib/python3.11/site-packages/ray/includes/common.pxd +749 -0
  29. .venv/lib/python3.11/site-packages/ray/includes/function_descriptor.pxd +80 -0
  30. .venv/lib/python3.11/site-packages/ray/includes/global_state_accessor.pxd +144 -0
  31. .venv/lib/python3.11/site-packages/ray/includes/libcoreworker.pxd +457 -0
  32. .venv/lib/python3.11/site-packages/ray/includes/metric.pxd +45 -0
  33. .venv/lib/python3.11/site-packages/ray/includes/optional.pxd +36 -0
  34. .venv/lib/python3.11/site-packages/ray/includes/ray_config.pxd +98 -0
  35. .venv/lib/python3.11/site-packages/ray/includes/unique_ids.pxd +218 -0
  36. .venv/lib/python3.11/site-packages/ray/runtime_env/__init__.py +8 -0
  37. .venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/runtime_env.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/runtime_env/runtime_env.py +662 -0
  40. .venv/lib/python3.11/site-packages/ray/widgets/__init__.py +4 -0
  41. .venv/lib/python3.11/site-packages/ray/widgets/__pycache__/__init__.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/widgets/__pycache__/render.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/widgets/__pycache__/util.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/widgets/render.py +39 -0
  45. .venv/lib/python3.11/site-packages/ray/widgets/templates/context.html.j2 +6 -0
  46. .venv/lib/python3.11/site-packages/ray/widgets/templates/context_dashrow.html.j2 +4 -0
  47. .venv/lib/python3.11/site-packages/ray/widgets/templates/context_logo.html.j2 +13 -0
  48. .venv/lib/python3.11/site-packages/ray/widgets/templates/context_table.html.j2 +11 -0
  49. .venv/lib/python3.11/site-packages/ray/widgets/templates/divider.html.j2 +9 -0
  50. .venv/lib/python3.11/site-packages/ray/widgets/templates/rendered_html_common.html.j2 +3 -0
.venv/lib/python3.11/site-packages/ray/data/__init__.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Short term workaround for https://github.com/ray-project/ray/issues/32435
2
+ # Dataset has a hard dependency on pandas, so it doesn't need to be delayed.
3
+ import pandas # noqa
4
+ from packaging.version import parse as parse_version
5
+
6
+ from ray._private.utils import _get_pyarrow_version
7
+ from ray.data._internal.compute import ActorPoolStrategy
8
+ from ray.data._internal.datasource.tfrecords_datasource import TFXReadOptions
9
+ from ray.data._internal.execution.interfaces import (
10
+ ExecutionOptions,
11
+ ExecutionResources,
12
+ NodeIdStr,
13
+ )
14
+ from ray.data._internal.logging import configure_logging
15
+ from ray.data.context import DataContext, DatasetContext
16
+ from ray.data.dataset import Dataset, Schema
17
+ from ray.data.datasource import (
18
+ BlockBasedFileDatasink,
19
+ Datasink,
20
+ Datasource,
21
+ FileShuffleConfig,
22
+ ReadTask,
23
+ RowBasedFileDatasink,
24
+ )
25
+ from ray.data.iterator import DataIterator, DatasetIterator
26
+ from ray.data.preprocessor import Preprocessor
27
+ from ray.data.read_api import ( # noqa: F401
28
+ from_arrow,
29
+ from_arrow_refs,
30
+ from_blocks,
31
+ from_dask,
32
+ from_huggingface,
33
+ from_items,
34
+ from_mars,
35
+ from_modin,
36
+ from_numpy,
37
+ from_numpy_refs,
38
+ from_pandas,
39
+ from_pandas_refs,
40
+ from_spark,
41
+ from_tf,
42
+ from_torch,
43
+ range,
44
+ range_tensor,
45
+ read_audio,
46
+ read_avro,
47
+ read_bigquery,
48
+ read_binary_files,
49
+ read_clickhouse,
50
+ read_csv,
51
+ read_databricks_tables,
52
+ read_datasource,
53
+ read_delta_sharing_tables,
54
+ read_hudi,
55
+ read_iceberg,
56
+ read_images,
57
+ read_json,
58
+ read_lance,
59
+ read_mongo,
60
+ read_numpy,
61
+ read_parquet,
62
+ read_parquet_bulk,
63
+ read_sql,
64
+ read_text,
65
+ read_tfrecords,
66
+ read_videos,
67
+ read_webdataset,
68
+ )
69
+
70
+ # Module-level cached global functions for callable classes. It needs to be defined here
71
+ # since it has to be process-global across cloudpickled funcs.
72
+ _map_actor_context = None
73
+
74
+ configure_logging()
75
+
76
+ try:
77
+ import pyarrow as pa
78
+
79
+ # https://github.com/apache/arrow/pull/38608 deprecated `PyExtensionType`, and
80
+ # disabled it's deserialization by default. To ensure that users can load data
81
+ # written with earlier version of Ray Data, we enable auto-loading of serialized
82
+ # tensor extensions.
83
+ pyarrow_version = _get_pyarrow_version()
84
+ if not isinstance(pyarrow_version, str):
85
+ # PyArrow is mocked in documentation builds. In this case, we don't need to do
86
+ # anything.
87
+ pass
88
+ else:
89
+ from ray._private.ray_constants import env_bool
90
+
91
+ RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE = env_bool(
92
+ "RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE", False
93
+ )
94
+
95
+ if (
96
+ parse_version(pyarrow_version) >= parse_version("14.0.1")
97
+ and RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE
98
+ ):
99
+ pa.PyExtensionType.set_auto_load(True)
100
+ # Import these arrow extension types to ensure that they are registered.
101
+ from ray.air.util.tensor_extensions.arrow import ( # noqa
102
+ ArrowTensorType,
103
+ ArrowVariableShapedTensorType,
104
+ )
105
+ except ModuleNotFoundError:
106
+ pass
107
+
108
+
109
+ __all__ = [
110
+ "ActorPoolStrategy",
111
+ "BlockBasedFileDatasink",
112
+ "Dataset",
113
+ "DataContext",
114
+ "DatasetContext", # Backwards compatibility alias.
115
+ "DataIterator",
116
+ "DatasetIterator", # Backwards compatibility alias.
117
+ "Datasink",
118
+ "Datasource",
119
+ "ExecutionOptions",
120
+ "ExecutionResources",
121
+ "FileShuffleConfig",
122
+ "NodeIdStr",
123
+ "ReadTask",
124
+ "RowBasedFileDatasink",
125
+ "Schema",
126
+ "from_dask",
127
+ "from_items",
128
+ "from_arrow",
129
+ "from_arrow_refs",
130
+ "from_mars",
131
+ "from_modin",
132
+ "from_numpy",
133
+ "from_numpy_refs",
134
+ "from_pandas",
135
+ "from_pandas_refs",
136
+ "from_spark",
137
+ "from_tf",
138
+ "from_torch",
139
+ "from_huggingface",
140
+ "range",
141
+ "range_tensor",
142
+ "read_audio",
143
+ "read_avro",
144
+ "read_text",
145
+ "read_binary_files",
146
+ "read_clickhouse",
147
+ "read_csv",
148
+ "read_datasource",
149
+ "read_delta_sharing_tables",
150
+ "read_hudi",
151
+ "read_iceberg",
152
+ "read_images",
153
+ "read_json",
154
+ "read_lance",
155
+ "read_numpy",
156
+ "read_mongo",
157
+ "read_parquet",
158
+ "read_parquet_bulk",
159
+ "read_sql",
160
+ "read_tfrecords",
161
+ "read_videos",
162
+ "read_webdataset",
163
+ "Preprocessor",
164
+ "TFXReadOptions",
165
+ ]
.venv/lib/python3.11/site-packages/ray/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.97 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/aggregate.cpython-311.pyc ADDED
Binary file (4.36 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/block.cpython-311.pyc ADDED
Binary file (25.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/context.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/exceptions.cpython-311.pyc ADDED
Binary file (4.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/grouped_data.cpython-311.pyc ADDED
Binary file (24.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/iterator.cpython-311.pyc ADDED
Binary file (46.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/preprocessor.cpython-311.pyc ADDED
Binary file (14.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/__pycache__/random_access_dataset.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/aggregate.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Callable, Optional, Union
2
+
3
+ from ray.data.block import AggType, Block, BlockAccessor, KeyType, T, U
4
+ from ray.util.annotations import PublicAPI
5
+
6
+ if TYPE_CHECKING:
7
+ import pyarrow as pa
8
+
9
+
10
+ @PublicAPI
11
+ class AggregateFn:
12
+ """Defines an aggregate function in the accumulator style.
13
+
14
+ Aggregates a collection of inputs of type T into
15
+ a single output value of type U.
16
+ See https://www.sigops.org/s/conferences/sosp/2009/papers/yu-sosp09.pdf
17
+ for more details about accumulator-based aggregation.
18
+
19
+ Args:
20
+ init: This is called once for each group to return the empty accumulator.
21
+ For example, an empty accumulator for a sum would be 0.
22
+ merge: This may be called multiple times, each time to merge
23
+ two accumulators into one.
24
+ name: The name of the aggregation. This will be used as the column name
25
+ in the output Dataset.
26
+ accumulate_row: This is called once per row of the same group.
27
+ This combines the accumulator and the row, returns the updated
28
+ accumulator. Exactly one of accumulate_row and accumulate_block must
29
+ be provided.
30
+ accumulate_block: This is used to calculate the aggregation for a
31
+ single block, and is vectorized alternative to accumulate_row. This will
32
+ be given a base accumulator and the entire block, allowing for
33
+ vectorized accumulation of the block. Exactly one of accumulate_row and
34
+ accumulate_block must be provided.
35
+ finalize: This is called once to compute the final aggregation
36
+ result from the fully merged accumulator.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ init: Callable[[KeyType], AggType],
42
+ merge: Callable[[AggType, AggType], AggType],
43
+ name: str,
44
+ accumulate_row: Callable[[AggType, T], AggType] = None,
45
+ accumulate_block: Callable[[AggType, Block], AggType] = None,
46
+ finalize: Optional[Callable[[AggType], U]] = None,
47
+ ):
48
+ if (accumulate_row is None and accumulate_block is None) or (
49
+ accumulate_row is not None and accumulate_block is not None
50
+ ):
51
+ raise ValueError(
52
+ "Exactly one of accumulate_row or accumulate_block must be provided."
53
+ )
54
+ if accumulate_block is None:
55
+
56
+ def accumulate_block(a: AggType, block: Block) -> AggType:
57
+ block_acc = BlockAccessor.for_block(block)
58
+ for r in block_acc.iter_rows(public_row_format=False):
59
+ a = accumulate_row(a, r)
60
+ return a
61
+
62
+ if not isinstance(name, str):
63
+ raise TypeError("`name` must be provided.")
64
+
65
+ if finalize is None:
66
+ finalize = lambda a: a # noqa: E731
67
+
68
+ self.init = init
69
+ self.merge = merge
70
+ self.name = name
71
+ self.accumulate_block = accumulate_block
72
+ self.finalize = finalize
73
+
74
+ def _validate(self, schema: Optional[Union[type, "pa.lib.Schema"]]) -> None:
75
+ """Raise an error if this cannot be applied to the given schema."""
76
+ pass
.venv/lib/python3.11/site-packages/ray/data/block.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import os
4
+ import time
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Callable,
11
+ Dict,
12
+ Iterator,
13
+ List,
14
+ Literal,
15
+ Optional,
16
+ Protocol,
17
+ Tuple,
18
+ TypeVar,
19
+ Union,
20
+ )
21
+
22
+ import numpy as np
23
+
24
+ import ray
25
+ from ray import DynamicObjectRefGenerator
26
+ from ray.air.util.tensor_extensions.arrow import ArrowConversionError
27
+ from ray.data._internal.util import _check_pyarrow_version, _truncated_repr
28
+ from ray.types import ObjectRef
29
+ from ray.util import log_once
30
+ from ray.util.annotations import DeveloperAPI
31
+
32
+ import psutil
33
+
34
+ try:
35
+ import resource
36
+ except ImportError:
37
+ resource = None
38
+
39
+ if TYPE_CHECKING:
40
+ import pandas
41
+ import pyarrow
42
+
43
+ from ray.data._internal.block_builder import BlockBuilder
44
+ from ray.data._internal.planner.exchange.sort_task_spec import SortKey
45
+ from ray.data.aggregate import AggregateFn
46
+
47
+
48
+ T = TypeVar("T", contravariant=True)
49
+ U = TypeVar("U", covariant=True)
50
+
51
+ KeyType = TypeVar("KeyType")
52
+ AggType = TypeVar("AggType")
53
+
54
+
55
+ # Represents a batch of records to be stored in the Ray object store.
56
+ #
57
+ # Block data can be accessed in a uniform way via ``BlockAccessors`` like`
58
+ # ``ArrowBlockAccessor``.
59
+ Block = Union["pyarrow.Table", "pandas.DataFrame"]
60
+
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+
65
+ @DeveloperAPI
66
+ class BlockType(Enum):
67
+ ARROW = "arrow"
68
+ PANDAS = "pandas"
69
+
70
+
71
+ # User-facing data batch type. This is the data type for data that is supplied to and
72
+ # returned from batch UDFs.
73
+ DataBatch = Union["pyarrow.Table", "pandas.DataFrame", Dict[str, np.ndarray]]
74
+
75
+ # User-facing data column type. This is the data type for data that is supplied to and
76
+ # returned from column UDFs.
77
+ DataBatchColumn = Union[
78
+ "pyarrow.ChunkedArray", "pyarrow.Array", "pandas.Series", np.ndarray
79
+ ]
80
+
81
+
82
+ # A class type that implements __call__.
83
+ CallableClass = type
84
+
85
+
86
+ class _CallableClassProtocol(Protocol[T, U]):
87
+ def __call__(self, __arg: T) -> Union[U, Iterator[U]]:
88
+ ...
89
+
90
+
91
+ # A user defined function passed to map, map_batches, ec.
92
+ UserDefinedFunction = Union[
93
+ Callable[[T], U],
94
+ Callable[[T], Iterator[U]],
95
+ "_CallableClassProtocol",
96
+ ]
97
+
98
+ # A list of block references pending computation by a single task. For example,
99
+ # this may be the output of a task reading a file.
100
+ BlockPartition = List[Tuple[ObjectRef[Block], "BlockMetadata"]]
101
+
102
+ # The metadata that describes the output of a BlockPartition. This has the
103
+ # same type as the metadata that describes each block in the partition.
104
+ BlockPartitionMetadata = List["BlockMetadata"]
105
+
106
+ # TODO(ekl/chengsu): replace this with just
107
+ # `DynamicObjectRefGenerator` once block splitting
108
+ # is on by default. When block splitting is off, the type is a plain block.
109
+ MaybeBlockPartition = Union[Block, DynamicObjectRefGenerator]
110
+
111
+ VALID_BATCH_FORMATS = ["pandas", "pyarrow", "numpy", None]
112
+ DEFAULT_BATCH_FORMAT = "numpy"
113
+
114
+
115
+ def _apply_batch_format(given_batch_format: Optional[str]) -> str:
116
+ if given_batch_format == "default":
117
+ given_batch_format = DEFAULT_BATCH_FORMAT
118
+ if given_batch_format not in VALID_BATCH_FORMATS:
119
+ raise ValueError(
120
+ f"The given batch format {given_batch_format} isn't allowed (must be one of"
121
+ f" {VALID_BATCH_FORMATS})."
122
+ )
123
+ return given_batch_format
124
+
125
+
126
+ def _apply_batch_size(
127
+ given_batch_size: Optional[Union[int, Literal["default"]]]
128
+ ) -> Optional[int]:
129
+ if given_batch_size == "default":
130
+ return ray.data.context.DEFAULT_BATCH_SIZE
131
+ else:
132
+ return given_batch_size
133
+
134
+
135
+ @DeveloperAPI
136
+ class BlockExecStats:
137
+ """Execution stats for this block.
138
+
139
+ Attributes:
140
+ wall_time_s: The wall-clock time it took to compute this block.
141
+ cpu_time_s: The CPU time it took to compute this block.
142
+ node_id: A unique id for the node that computed this block.
143
+ """
144
+
145
+ def __init__(self):
146
+ self.start_time_s: Optional[float] = None
147
+ self.end_time_s: Optional[float] = None
148
+ self.wall_time_s: Optional[float] = None
149
+ self.udf_time_s: Optional[float] = 0
150
+ self.cpu_time_s: Optional[float] = None
151
+ self.node_id = ray.runtime_context.get_runtime_context().get_node_id()
152
+ # Max memory usage. May be an overestimate since we do not
153
+ # differentiate from previous tasks on the same worker.
154
+ self.max_rss_bytes: int = 0
155
+ self.task_idx: Optional[int] = None
156
+
157
+ @staticmethod
158
+ def builder() -> "_BlockExecStatsBuilder":
159
+ return _BlockExecStatsBuilder()
160
+
161
+ def __repr__(self):
162
+ return repr(
163
+ {
164
+ "wall_time_s": self.wall_time_s,
165
+ "cpu_time_s": self.cpu_time_s,
166
+ "udf_time_s": self.udf_time_s,
167
+ "node_id": self.node_id,
168
+ }
169
+ )
170
+
171
+
172
+ class _BlockExecStatsBuilder:
173
+ """Helper class for building block stats.
174
+
175
+ When this class is created, we record the start time. When build() is
176
+ called, the time delta is saved as part of the stats.
177
+ """
178
+
179
+ def __init__(self):
180
+ self.start_time = time.perf_counter()
181
+ self.start_cpu = time.process_time()
182
+
183
+ def build(self) -> "BlockExecStats":
184
+ self.end_time = time.perf_counter()
185
+ self.end_cpu = time.process_time()
186
+
187
+ stats = BlockExecStats()
188
+ stats.start_time_s = self.start_time
189
+ stats.end_time_s = self.end_time
190
+ stats.wall_time_s = self.end_time - self.start_time
191
+ stats.cpu_time_s = self.end_cpu - self.start_cpu
192
+ if resource is None:
193
+ # NOTE(swang): resource package is not supported on Windows. This
194
+ # is only the memory usage at the end of the task, not the peak
195
+ # memory.
196
+ process = psutil.Process(os.getpid())
197
+ stats.max_rss_bytes = int(process.memory_info().rss)
198
+ else:
199
+ stats.max_rss_bytes = int(
200
+ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1e3
201
+ )
202
+ return stats
203
+
204
+
205
+ @DeveloperAPI
206
+ @dataclass
207
+ class BlockMetadata:
208
+ """Metadata about the block."""
209
+
210
+ #: The number of rows contained in this block, or None.
211
+ num_rows: Optional[int]
212
+ #: The approximate size in bytes of this block, or None.
213
+ size_bytes: Optional[int]
214
+ #: The pyarrow schema or types of the block elements, or None.
215
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]]
216
+ #: The list of file paths used to generate this block, or
217
+ #: the empty list if indeterminate.
218
+ input_files: Optional[List[str]]
219
+ #: Execution stats for this block.
220
+ exec_stats: Optional[BlockExecStats]
221
+
222
+ def __post_init__(self):
223
+ if self.input_files is None:
224
+ self.input_files = []
225
+ if self.size_bytes is not None:
226
+ # Require size_bytes to be int, ray.util.metrics objects
227
+ # will not take other types like numpy.int64
228
+ assert isinstance(self.size_bytes, int)
229
+
230
+
231
+ @DeveloperAPI
232
+ class BlockAccessor:
233
+ """Provides accessor methods for a specific block.
234
+
235
+ Ideally, we wouldn't need a separate accessor classes for blocks. However,
236
+ this is needed if we want to support storing ``pyarrow.Table`` directly
237
+ as a top-level Ray object, without a wrapping class (issue #17186).
238
+ """
239
+
240
+ def num_rows(self) -> int:
241
+ """Return the number of rows contained in this block."""
242
+ raise NotImplementedError
243
+
244
+ def iter_rows(self, public_row_format: bool) -> Iterator[T]:
245
+ """Iterate over the rows of this block.
246
+
247
+ Args:
248
+ public_row_format: Whether to cast rows into the public Dict row
249
+ format (this incurs extra copy conversions).
250
+ """
251
+ raise NotImplementedError
252
+
253
+ def slice(self, start: int, end: int, copy: bool) -> Block:
254
+ """Return a slice of this block.
255
+
256
+ Args:
257
+ start: The starting index of the slice (inclusive).
258
+ end: The ending index of the slice (exclusive).
259
+ copy: Whether to perform a data copy for the slice.
260
+
261
+ Returns:
262
+ The sliced block result.
263
+ """
264
+ raise NotImplementedError
265
+
266
+ def take(self, indices: List[int]) -> Block:
267
+ """Return a new block containing the provided row indices.
268
+
269
+ Args:
270
+ indices: The row indices to return.
271
+
272
+ Returns:
273
+ A new block containing the provided row indices.
274
+ """
275
+ raise NotImplementedError
276
+
277
+ def select(self, columns: List[Optional[str]]) -> Block:
278
+ """Return a new block containing the provided columns."""
279
+ raise NotImplementedError
280
+
281
+ def rename_columns(self, columns_rename: Dict[str, str]) -> Block:
282
+ """Return the block reflecting the renamed columns."""
283
+ raise NotImplementedError
284
+
285
+ def random_shuffle(self, random_seed: Optional[int]) -> Block:
286
+ """Randomly shuffle this block."""
287
+ raise NotImplementedError
288
+
289
+ def to_pandas(self) -> "pandas.DataFrame":
290
+ """Convert this block into a Pandas dataframe."""
291
+ raise NotImplementedError
292
+
293
+ def to_numpy(
294
+ self, columns: Optional[Union[str, List[str]]] = None
295
+ ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
296
+ """Convert this block (or columns of block) into a NumPy ndarray.
297
+
298
+ Args:
299
+ columns: Name of columns to convert, or None if converting all columns.
300
+ """
301
+ raise NotImplementedError
302
+
303
+ def to_arrow(self) -> "pyarrow.Table":
304
+ """Convert this block into an Arrow table."""
305
+ raise NotImplementedError
306
+
307
+ def to_block(self) -> Block:
308
+ """Return the base block that this accessor wraps."""
309
+ raise NotImplementedError
310
+
311
+ def to_default(self) -> Block:
312
+ """Return the default data format for this accessor."""
313
+ return self.to_block()
314
+
315
+ def to_batch_format(self, batch_format: Optional[str]) -> DataBatch:
316
+ """Convert this block into the provided batch format.
317
+
318
+ Args:
319
+ batch_format: The batch format to convert this block to.
320
+
321
+ Returns:
322
+ This block formatted as the provided batch format.
323
+ """
324
+ if batch_format is None:
325
+ return self.to_block()
326
+ elif batch_format == "default" or batch_format == "native":
327
+ return self.to_default()
328
+ elif batch_format == "pandas":
329
+ return self.to_pandas()
330
+ elif batch_format == "pyarrow":
331
+ return self.to_arrow()
332
+ elif batch_format == "numpy":
333
+ return self.to_numpy()
334
+ else:
335
+ raise ValueError(
336
+ f"The batch format must be one of {VALID_BATCH_FORMATS}, got: "
337
+ f"{batch_format}"
338
+ )
339
+
340
+ def size_bytes(self) -> int:
341
+ """Return the approximate size in bytes of this block."""
342
+ raise NotImplementedError
343
+
344
+ def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
345
+ """Return the Python type or pyarrow schema of this block."""
346
+ raise NotImplementedError
347
+
348
+ def get_metadata(
349
+ self,
350
+ input_files: Optional[List[str]] = None,
351
+ exec_stats: Optional[BlockExecStats] = None,
352
+ ) -> BlockMetadata:
353
+ """Create a metadata object from this block."""
354
+ return BlockMetadata(
355
+ num_rows=self.num_rows(),
356
+ size_bytes=self.size_bytes(),
357
+ schema=self.schema(),
358
+ input_files=input_files,
359
+ exec_stats=exec_stats,
360
+ )
361
+
362
+ def zip(self, other: "Block") -> "Block":
363
+ """Zip this block with another block of the same type and size."""
364
+ raise NotImplementedError
365
+
366
+ @staticmethod
367
+ def builder() -> "BlockBuilder":
368
+ """Create a builder for this block type."""
369
+ raise NotImplementedError
370
+
371
+ @classmethod
372
+ def batch_to_block(
373
+ cls,
374
+ batch: DataBatch,
375
+ block_type: Optional[BlockType] = None,
376
+ ) -> Block:
377
+ """Create a block from user-facing data formats."""
378
+
379
+ if isinstance(batch, np.ndarray):
380
+ raise ValueError(
381
+ f"Error validating {_truncated_repr(batch)}: "
382
+ "Standalone numpy arrays are not "
383
+ "allowed in Ray 2.5. Return a dict of field -> array, "
384
+ "e.g., `{'data': array}` instead of `array`."
385
+ )
386
+
387
+ elif isinstance(batch, collections.abc.Mapping):
388
+ if block_type is None or block_type == BlockType.ARROW:
389
+ try:
390
+ return cls.batch_to_arrow_block(batch)
391
+ except ArrowConversionError as e:
392
+ if log_once("_fallback_to_pandas_block_warning"):
393
+ logger.warning(
394
+ f"Failed to convert batch to Arrow due to: {e}; "
395
+ f"falling back to Pandas block"
396
+ )
397
+
398
+ if block_type is None:
399
+ return cls.batch_to_pandas_block(batch)
400
+ else:
401
+ raise e
402
+ else:
403
+ assert block_type == BlockType.PANDAS
404
+ return cls.batch_to_pandas_block(batch)
405
+ return batch
406
+
407
+ @classmethod
408
+ def batch_to_arrow_block(cls, batch: Dict[str, Any]) -> Block:
409
+ """Create an Arrow block from user-facing data formats."""
410
+ from ray.data._internal.arrow_block import ArrowBlockBuilder
411
+
412
+ return ArrowBlockBuilder._table_from_pydict(batch)
413
+
414
+ @classmethod
415
+ def batch_to_pandas_block(cls, batch: Dict[str, Any]) -> Block:
416
+ """Create a Pandas block from user-facing data formats."""
417
+ from ray.data._internal.pandas_block import PandasBlockAccessor
418
+
419
+ return PandasBlockAccessor.numpy_to_block(batch)
420
+
421
+ @staticmethod
422
+ def for_block(block: Block) -> "BlockAccessor[T]":
423
+ """Create a block accessor for the given block."""
424
+ _check_pyarrow_version()
425
+ import pandas
426
+ import pyarrow
427
+
428
+ if isinstance(block, pyarrow.Table):
429
+ from ray.data._internal.arrow_block import ArrowBlockAccessor
430
+
431
+ return ArrowBlockAccessor(block)
432
+ elif isinstance(block, pandas.DataFrame):
433
+ from ray.data._internal.pandas_block import PandasBlockAccessor
434
+
435
+ return PandasBlockAccessor(block)
436
+ elif isinstance(block, bytes):
437
+ from ray.data._internal.arrow_block import ArrowBlockAccessor
438
+
439
+ return ArrowBlockAccessor.from_bytes(block)
440
+ elif isinstance(block, list):
441
+ raise ValueError(
442
+ f"Error validating {_truncated_repr(block)}: "
443
+ "Standalone Python objects are not "
444
+ "allowed in Ray 2.5. To use Python objects in a dataset, "
445
+ "wrap them in a dict of numpy arrays, e.g., "
446
+ "return `{'item': batch}` instead of just `batch`."
447
+ )
448
+ else:
449
+ raise TypeError("Not a block type: {} ({})".format(block, type(block)))
450
+
451
+ def sample(self, n_samples: int, sort_key: "SortKey") -> "Block":
452
+ """Return a random sample of items from this block."""
453
+ raise NotImplementedError
454
+
455
+ def sort_and_partition(
456
+ self, boundaries: List[T], sort_key: "SortKey"
457
+ ) -> List["Block"]:
458
+ """Return a list of sorted partitions of this block."""
459
+ raise NotImplementedError
460
+
461
+ def combine(self, key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block:
462
+ """Combine rows with the same key into an accumulator."""
463
+ raise NotImplementedError
464
+
465
+ @staticmethod
466
+ def merge_sorted_blocks(
467
+ blocks: List["Block"], sort_key: "SortKey"
468
+ ) -> Tuple[Block, BlockMetadata]:
469
+ """Return a sorted block by merging a list of sorted blocks."""
470
+ raise NotImplementedError
471
+
472
+ @staticmethod
473
+ def aggregate_combined_blocks(
474
+ blocks: List[Block], sort_key: "SortKey", aggs: Tuple["AggregateFn"]
475
+ ) -> Tuple[Block, BlockMetadata]:
476
+ """Aggregate partially combined and sorted blocks."""
477
+ raise NotImplementedError
478
+
479
+ def block_type(self) -> BlockType:
480
+ """Return the block type of this block."""
481
+ raise NotImplementedError
482
+
483
+
484
+ def _get_block_boundaries(columns: list[np.ndarray]) -> np.ndarray:
485
+ """Compute boundaries of the groups within a block, which is represented
486
+ by a list of 1D numpy arrays for each column. In each column,
487
+ NaNs/None are considered to be the same group.
488
+
489
+ Args:
490
+ columns: a list of 1D numpy arrays. This is generally given by the
491
+ dictionary values of ``BlockAccessor.to_numpy()``.
492
+
493
+ Returns:
494
+ A list of starting indices of each group and an end index of the last
495
+ group, i.e., there are ``num_groups + 1`` entries and the first and last
496
+ entries are 0 and ``len(array)`` respectively.
497
+ """
498
+
499
+ # There are 3 categories: general, numerics with NaN, and categorical with None.
500
+ # We only needed to check the last element for NaNs/None, as they are assumed to
501
+ # be sorted.
502
+ general_arrays = []
503
+ num_arrays_with_nan = []
504
+ cat_arrays_with_none = []
505
+ for arr in columns:
506
+ if np.issubdtype(arr.dtype, np.number) and np.isnan(arr[-1]):
507
+ num_arrays_with_nan.append(arr)
508
+ elif not np.issubdtype(arr.dtype, np.number) and arr[-1] is None:
509
+ cat_arrays_with_none.append(arr)
510
+ else:
511
+ general_arrays.append(arr)
512
+
513
+ # Compute the difference between each pair of elements. Handle the cases
514
+ # where neighboring elements are both NaN or None. Output as a list of
515
+ # boolean arrays.
516
+ diffs = []
517
+ if len(general_arrays) > 0:
518
+ diffs.append(
519
+ np.vstack([arr[1:] != arr[:-1] for arr in general_arrays]).any(axis=0)
520
+ )
521
+ if len(num_arrays_with_nan) > 0:
522
+ # Two neighboring numeric elements belong to the same group when they are
523
+ # 1) both finite and equal
524
+ # or 2) both np.nan
525
+ diffs.append(
526
+ np.vstack(
527
+ [
528
+ (arr[1:] != arr[:-1])
529
+ & (np.isfinite(arr[1:]) | np.isfinite(arr[:-1]))
530
+ for arr in num_arrays_with_nan
531
+ ]
532
+ ).any(axis=0)
533
+ )
534
+ if len(cat_arrays_with_none) > 0:
535
+ # Two neighboring str/object elements belong to the same group when they are
536
+ # 1) both finite and equal
537
+ # or 2) both None
538
+ diffs.append(
539
+ np.vstack(
540
+ [
541
+ (arr[1:] != arr[:-1])
542
+ & ~(np.equal(arr[1:], None) & np.equal(arr[:-1], None))
543
+ for arr in cat_arrays_with_none
544
+ ]
545
+ ).any(axis=0)
546
+ )
547
+
548
+ # A series of vectorized operations to compute the boundaries:
549
+ # - column_stack: stack the bool arrays into a single 2D bool array
550
+ # - any() and nonzero(): find the indices where any of the column diffs are True
551
+ # - add 1 to get the index of the first element of the next group
552
+ # - hstack(): include the 0 and last indices to the boundaries
553
+ boundaries = np.hstack(
554
+ [
555
+ [0],
556
+ (np.column_stack(diffs).any(axis=1).nonzero()[0] + 1),
557
+ [len(columns[0])],
558
+ ]
559
+ ).astype(int)
560
+
561
+ return boundaries
.venv/lib/python3.11/site-packages/ray/data/context.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import threading
4
+ import warnings
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
7
+
8
+ import ray
9
+ from ray._private.ray_constants import env_bool, env_integer
10
+ from ray._private.worker import WORKER_MODE
11
+ from ray.util.annotations import DeveloperAPI
12
+ from ray.util.debug import log_once
13
+ from ray.util.scheduling_strategies import SchedulingStrategyT
14
+
15
+ if TYPE_CHECKING:
16
+ from ray.data._internal.execution.interfaces import ExecutionOptions
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # The context singleton on this process.
21
+ _default_context: "Optional[DataContext]" = None
22
+ _context_lock = threading.Lock()
23
+
24
+
25
+ # We chose 128MiB for default: With streaming execution and num_cpus many concurrent
26
+ # tasks, the memory footprint will be about 2 * num_cpus * target_max_block_size ~= RAM
27
+ # * DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION * 0.3 (default object store memory
28
+ # fraction set by Ray core), assuming typical memory:core ratio of 4:1.
29
+ DEFAULT_TARGET_MAX_BLOCK_SIZE = 128 * 1024 * 1024
30
+
31
+ # We set a higher target block size because we have to materialize
32
+ # all input blocks anyway, so there is no performance advantage to having
33
+ # smaller blocks. Setting a larger block size allows avoiding overhead from an
34
+ # excessive number of partitions.
35
+ # We choose 1GiB as 4x less than the typical memory:core ratio (4:1).
36
+ DEFAULT_SHUFFLE_TARGET_MAX_BLOCK_SIZE = 1024 * 1024 * 1024
37
+
38
+ # We will attempt to slice blocks whose size exceeds this factor *
39
+ # target_max_block_size. We will warn the user if slicing fails and we produce
40
+ # blocks larger than this threshold.
41
+ MAX_SAFE_BLOCK_SIZE_FACTOR = 1.5
42
+
43
+ DEFAULT_TARGET_MIN_BLOCK_SIZE = 1 * 1024 * 1024
44
+
45
+ # This default appears to work well with most file sizes on remote storage systems,
46
+ # which is very sensitive to the buffer size.
47
+ DEFAULT_STREAMING_READ_BUFFER_SIZE = 32 * 1024 * 1024
48
+
49
+ DEFAULT_ENABLE_PANDAS_BLOCK = True
50
+
51
+ DEFAULT_READ_OP_MIN_NUM_BLOCKS = 200
52
+
53
+ DEFAULT_ACTOR_PREFETCHER_ENABLED = False
54
+
55
+ DEFAULT_USE_PUSH_BASED_SHUFFLE = bool(
56
+ os.environ.get("RAY_DATA_PUSH_BASED_SHUFFLE", None)
57
+ )
58
+
59
+ DEFAULT_SCHEDULING_STRATEGY = "SPREAD"
60
+
61
+ # This default enables locality-based scheduling in Ray for tasks where arg data
62
+ # transfer is a bottleneck.
63
+ DEFAULT_SCHEDULING_STRATEGY_LARGE_ARGS = "DEFAULT"
64
+
65
+ DEFAULT_LARGE_ARGS_THRESHOLD = 50 * 1024 * 1024
66
+
67
+ DEFAULT_USE_POLARS = False
68
+
69
+ DEFAULT_EAGER_FREE = bool(int(os.environ.get("RAY_DATA_EAGER_FREE", "1")))
70
+
71
+ DEFAULT_DECODING_SIZE_ESTIMATION_ENABLED = True
72
+
73
+ DEFAULT_MIN_PARALLELISM = 200
74
+
75
+ DEFAULT_ENABLE_TENSOR_EXTENSION_CASTING = True
76
+
77
+ # NOTE: V1 tensor type format only supports tensors of no more than 2Gb in
78
+ # total cumulative size (due to it internally utilizing int32 offsets)
79
+ #
80
+ # V2 in turn relies on int64 offsets, therefore having a limit of ~9Eb (exabytes)
81
+ DEFAULT_USE_ARROW_TENSOR_V2 = env_bool("RAY_DATA_USE_ARROW_TENSOR_V2", True)
82
+
83
+ DEFAULT_AUTO_LOG_STATS = False
84
+
85
+ DEFAULT_VERBOSE_STATS_LOG = False
86
+
87
+ DEFAULT_TRACE_ALLOCATIONS = bool(int(os.environ.get("RAY_DATA_TRACE_ALLOCATIONS", "0")))
88
+
89
+ DEFAULT_LOG_INTERNAL_STACK_TRACE_TO_STDOUT = env_bool(
90
+ "RAY_DATA_LOG_INTERNAL_STACK_TRACE_TO_STDOUT", False
91
+ )
92
+
93
+ DEFAULT_RAY_DATA_RAISE_ORIGINAL_MAP_EXCEPTION = env_bool(
94
+ "RAY_DATA_RAISE_ORIGINAL_MAP_EXCEPTION", False
95
+ )
96
+
97
+ DEFAULT_USE_RAY_TQDM = bool(int(os.environ.get("RAY_TQDM", "1")))
98
+
99
+ # Globally enable or disable all progress bars.
100
+ # If this is False, both the global and operator-level progress bars are disabled.
101
+ DEFAULT_ENABLE_PROGRESS_BARS = not bool(
102
+ env_integer("RAY_DATA_DISABLE_PROGRESS_BARS", 0)
103
+ )
104
+ DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = env_bool(
105
+ "RAY_DATA_ENABLE_PROGRESS_BAR_NAME_TRUNCATION", True
106
+ )
107
+
108
+ DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS = False
109
+
110
+
111
+ # `write_file_retry_on_errors` is deprecated in favor of `retried_io_errors`. You
112
+ # shouldn't need to modify `DEFAULT_WRITE_FILE_RETRY_ON_ERRORS`.
113
+ DEFAULT_WRITE_FILE_RETRY_ON_ERRORS = (
114
+ "AWS Error INTERNAL_FAILURE",
115
+ "AWS Error NETWORK_CONNECTION",
116
+ "AWS Error SLOW_DOWN",
117
+ "AWS Error UNKNOWN (HTTP status 503)",
118
+ )
119
+
120
+ DEFAULT_RETRIED_IO_ERRORS = (
121
+ "AWS Error INTERNAL_FAILURE",
122
+ "AWS Error NETWORK_CONNECTION",
123
+ "AWS Error SLOW_DOWN",
124
+ "AWS Error UNKNOWN (HTTP status 503)",
125
+ "AWS Error SERVICE_UNAVAILABLE",
126
+ )
127
+
128
+ DEFAULT_WARN_ON_DRIVER_MEMORY_USAGE_BYTES = 2 * 1024 * 1024 * 1024
129
+
130
+ DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS = False
131
+
132
+ DEFAULT_ENABLE_OP_RESOURCE_RESERVATION = env_bool(
133
+ "RAY_DATA_ENABLE_OP_RESOURCE_RESERVATION", True
134
+ )
135
+
136
+ DEFAULT_OP_RESOURCE_RESERVATION_RATIO = float(
137
+ os.environ.get("RAY_DATA_OP_RESERVATION_RATIO", "0.5")
138
+ )
139
+
140
+ DEFAULT_MAX_ERRORED_BLOCKS = 0
141
+
142
+ # Use this to prefix important warning messages for the user.
143
+ WARN_PREFIX = "⚠️ "
144
+
145
+ # Use this to prefix important success messages for the user.
146
+ OK_PREFIX = "✔️ "
147
+
148
+ # Default batch size for batch transformations.
149
+ DEFAULT_BATCH_SIZE = 1024
150
+
151
+ # Default value of the max number of blocks that can be buffered at the
152
+ # streaming generator of each `DataOpTask`.
153
+ # Note, if this value is too large, we'll need to allocate more memory
154
+ # buffer for the pending task outputs, which may lead to bad performance
155
+ # as we may not have enough memory buffer for the operator outputs.
156
+ # If the value is too small, the task may be frequently blocked due to
157
+ # streaming generator backpressure.
158
+ DEFAULT_MAX_NUM_BLOCKS_IN_STREAMING_GEN_BUFFER = 2
159
+
160
+ # Default value for whether or not to try to create directories for write
161
+ # calls if the URI is an S3 URI.
162
+ DEFAULT_S3_TRY_CREATE_DIR = False
163
+
164
+ DEFAULT_WAIT_FOR_MIN_ACTORS_S = env_integer(
165
+ "RAY_DATA_DEFAULT_WAIT_FOR_MIN_ACTORS_S", 60 * 10
166
+ )
167
+
168
+
169
+ def _execution_options_factory() -> "ExecutionOptions":
170
+ # Lazily import to avoid circular dependencies.
171
+ from ray.data._internal.execution.interfaces import ExecutionOptions
172
+
173
+ return ExecutionOptions()
174
+
175
+
176
+ @DeveloperAPI
177
+ @dataclass
178
+ class DataContext:
179
+ """Global settings for Ray Data.
180
+
181
+ Configure this class to enable advanced features and tune performance.
182
+
183
+ .. warning::
184
+ Apply changes before creating a :class:`~ray.data.Dataset`. Changes made after
185
+ won't take effect.
186
+
187
+ .. note::
188
+ This object is automatically propagated to workers. Access it from the driver
189
+ and remote workers with :meth:`DataContext.get_current()`.
190
+
191
+ Examples:
192
+ >>> from ray.data import DataContext
193
+ >>> DataContext.get_current().enable_progress_bars = False
194
+
195
+ Args:
196
+ target_max_block_size: The max target block size in bytes for reads and
197
+ transformations.
198
+ target_shuffle_max_block_size: The max target block size in bytes for shuffle
199
+ ops like ``random_shuffle``, ``sort``, and ``repartition``.
200
+ target_min_block_size: Ray Data avoids creating blocks smaller than this
201
+ size in bytes on read. This takes precedence over
202
+ ``read_op_min_num_blocks``.
203
+ streaming_read_buffer_size: Buffer size when doing streaming reads from local or
204
+ remote storage.
205
+ enable_pandas_block: Whether pandas block format is enabled.
206
+ actor_prefetcher_enabled: Whether to use actor based block prefetcher.
207
+ use_push_based_shuffle: Whether to use push-based shuffle.
208
+ pipeline_push_based_shuffle_reduce_tasks:
209
+ scheduling_strategy: The global scheduling strategy. For tasks with large args,
210
+ ``scheduling_strategy_large_args`` takes precedence.
211
+ scheduling_strategy_large_args: Scheduling strategy for tasks with large args.
212
+ large_args_threshold: Size in bytes after which point task arguments are
213
+ considered large. Choose a value so that the data transfer overhead is
214
+ significant in comparison to task scheduling (i.e., low tens of ms).
215
+ use_polars: Whether to use Polars for tabular dataset sorts, groupbys, and
216
+ aggregations.
217
+ eager_free: Whether to eagerly free memory.
218
+ decoding_size_estimation: Whether to estimate in-memory decoding data size for
219
+ data source.
220
+ min_parallelism: This setting is deprecated. Use ``read_op_min_num_blocks``
221
+ instead.
222
+ read_op_min_num_blocks: Minimum number of read output blocks for a dataset.
223
+ enable_tensor_extension_casting: Whether to automatically cast NumPy ndarray
224
+ columns in Pandas DataFrames to tensor extension columns.
225
+ use_arrow_tensor_v2: Config enabling V2 version of ArrowTensorArray supporting
226
+ tensors > 2Gb in size (off by default)
227
+ enable_fallback_to_arrow_object_ext_type: Enables fallback to serialize column
228
+ values not suppported by Arrow natively (like user-defined custom Python
229
+ classes for ex, etc) using `ArrowPythonObjectType` (simply serializing
230
+ these as bytes)
231
+ enable_auto_log_stats: Whether to automatically log stats after execution. If
232
+ disabled, you can still manually print stats with ``Dataset.stats()``.
233
+ verbose_stats_logs: Whether stats logs should be verbose. This includes fields
234
+ such as `extra_metrics` in the stats output, which are excluded by default.
235
+ trace_allocations: Whether to trace allocations / eager free. This adds
236
+ significant performance overheads and should only be used for debugging.
237
+ execution_options: The
238
+ :class:`~ray.data._internal.execution.interfaces.execution_options.ExecutionOptions`
239
+ to use.
240
+ use_ray_tqdm: Whether to enable distributed tqdm.
241
+ enable_progress_bars: Whether to enable progress bars.
242
+ enable_progress_bar_name_truncation: If True, the name of the progress bar
243
+ (often the operator name) will be truncated if it exceeds
244
+ `ProgressBar.MAX_NAME_LENGTH`. Otherwise, the full operator name is shown.
245
+ enable_get_object_locations_for_metrics: Whether to enable
246
+ ``get_object_locations`` for metrics.
247
+ write_file_retry_on_errors: A list of substrings of error messages that should
248
+ trigger a retry when writing files. This is useful for handling transient
249
+ errors when writing to remote storage systems.
250
+ warn_on_driver_memory_usage_bytes: If driver memory exceeds this threshold,
251
+ Ray Data warns you. For now, this only applies to shuffle ops because most
252
+ other ops are unlikely to use as much driver memory.
253
+ actor_task_retry_on_errors: The application-level errors that actor task should
254
+ retry. This follows same format as :ref:`retry_exceptions <task-retries>` in
255
+ Ray Core. Default to `False` to not retry on any errors. Set to `True` to
256
+ retry all errors, or set to a list of errors to retry.
257
+ enable_op_resource_reservation: Whether to reserve resources for each operator.
258
+ op_resource_reservation_ratio: The ratio of the total resources to reserve for
259
+ each operator.
260
+ max_errored_blocks: Max number of blocks that are allowed to have errors,
261
+ unlimited if negative. This option allows application-level exceptions in
262
+ block processing tasks. These exceptions may be caused by UDFs (e.g., due to
263
+ corrupted data samples) or IO errors. Data in the failed blocks are dropped.
264
+ This option can be useful to prevent a long-running job from failing due to
265
+ a small number of bad blocks.
266
+ log_internal_stack_trace_to_stdout: Whether to include internal Ray Data/Ray
267
+ Core code stack frames when logging to stdout. The full stack trace is
268
+ always written to the Ray Data log file.
269
+ raise_original_map_exception: Whether to raise the original exception
270
+ encountered in map UDF instead of wrapping it in a `UserCodeException`.
271
+ print_on_execution_start: If ``True``, print execution information when
272
+ execution starts.
273
+ s3_try_create_dir: If ``True``, try to create directories on S3 when a write
274
+ call is made with a S3 URI.
275
+ wait_for_min_actors_s: The default time to wait for minimum requested
276
+ actors to start before raising a timeout, in seconds.
277
+ retried_io_errors: A list of substrings of error messages that should
278
+ trigger a retry when reading or writing files. This is useful for handling
279
+ transient errors when reading from remote storage systems.
280
+ """
281
+
282
+ target_max_block_size: int = DEFAULT_TARGET_MAX_BLOCK_SIZE
283
+ target_shuffle_max_block_size: int = DEFAULT_SHUFFLE_TARGET_MAX_BLOCK_SIZE
284
+ target_min_block_size: int = DEFAULT_TARGET_MIN_BLOCK_SIZE
285
+ streaming_read_buffer_size: int = DEFAULT_STREAMING_READ_BUFFER_SIZE
286
+ enable_pandas_block: bool = DEFAULT_ENABLE_PANDAS_BLOCK
287
+ actor_prefetcher_enabled: bool = DEFAULT_ACTOR_PREFETCHER_ENABLED
288
+ use_push_based_shuffle: bool = DEFAULT_USE_PUSH_BASED_SHUFFLE
289
+ pipeline_push_based_shuffle_reduce_tasks: bool = True
290
+ scheduling_strategy: SchedulingStrategyT = DEFAULT_SCHEDULING_STRATEGY
291
+ scheduling_strategy_large_args: SchedulingStrategyT = (
292
+ DEFAULT_SCHEDULING_STRATEGY_LARGE_ARGS
293
+ )
294
+ large_args_threshold: int = DEFAULT_LARGE_ARGS_THRESHOLD
295
+ use_polars: bool = DEFAULT_USE_POLARS
296
+ eager_free: bool = DEFAULT_EAGER_FREE
297
+ decoding_size_estimation: bool = DEFAULT_DECODING_SIZE_ESTIMATION_ENABLED
298
+ min_parallelism: int = DEFAULT_MIN_PARALLELISM
299
+ read_op_min_num_blocks: int = DEFAULT_READ_OP_MIN_NUM_BLOCKS
300
+ enable_tensor_extension_casting: bool = DEFAULT_ENABLE_TENSOR_EXTENSION_CASTING
301
+ use_arrow_tensor_v2: bool = DEFAULT_USE_ARROW_TENSOR_V2
302
+ enable_fallback_to_arrow_object_ext_type: Optional[bool] = None
303
+ enable_auto_log_stats: bool = DEFAULT_AUTO_LOG_STATS
304
+ verbose_stats_logs: bool = DEFAULT_VERBOSE_STATS_LOG
305
+ trace_allocations: bool = DEFAULT_TRACE_ALLOCATIONS
306
+ execution_options: "ExecutionOptions" = field(
307
+ default_factory=_execution_options_factory
308
+ )
309
+ use_ray_tqdm: bool = DEFAULT_USE_RAY_TQDM
310
+ enable_progress_bars: bool = DEFAULT_ENABLE_PROGRESS_BARS
311
+ # By default, enable the progress bar for operator-level progress.
312
+ # In __post_init__(), we disable operator-level progress
313
+ # bars when running in a Ray job.
314
+ enable_operator_progress_bars: bool = True
315
+ enable_progress_bar_name_truncation: bool = (
316
+ DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION
317
+ )
318
+ enable_get_object_locations_for_metrics: bool = (
319
+ DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS
320
+ )
321
+ write_file_retry_on_errors: List[str] = DEFAULT_WRITE_FILE_RETRY_ON_ERRORS
322
+ warn_on_driver_memory_usage_bytes: int = DEFAULT_WARN_ON_DRIVER_MEMORY_USAGE_BYTES
323
+ actor_task_retry_on_errors: Union[
324
+ bool, List[BaseException]
325
+ ] = DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS
326
+ op_resource_reservation_enabled: bool = DEFAULT_ENABLE_OP_RESOURCE_RESERVATION
327
+ op_resource_reservation_ratio: float = DEFAULT_OP_RESOURCE_RESERVATION_RATIO
328
+ max_errored_blocks: int = DEFAULT_MAX_ERRORED_BLOCKS
329
+ log_internal_stack_trace_to_stdout: bool = (
330
+ DEFAULT_LOG_INTERNAL_STACK_TRACE_TO_STDOUT
331
+ )
332
+ raise_original_map_exception: bool = DEFAULT_RAY_DATA_RAISE_ORIGINAL_MAP_EXCEPTION
333
+ print_on_execution_start: bool = True
334
+ s3_try_create_dir: bool = DEFAULT_S3_TRY_CREATE_DIR
335
+ wait_for_min_actors_s: int = DEFAULT_WAIT_FOR_MIN_ACTORS_S
336
+ retried_io_errors: List[str] = field(
337
+ default_factory=lambda: list(DEFAULT_RETRIED_IO_ERRORS)
338
+ )
339
+
340
+ override_object_store_memory_limit_fraction: float = None
341
+
342
+ def __post_init__(self):
343
+ # The additonal ray remote args that should be added to
344
+ # the task-pool-based data tasks.
345
+ self._task_pool_data_task_remote_args: Dict[str, Any] = {}
346
+ # The extra key-value style configs.
347
+ # These configs are managed by individual components or plugins via
348
+ # `set_config`, `get_config` and `remove_config`.
349
+ # The reason why we use a dict instead of individual fields is to decouple
350
+ # the DataContext from the plugin implementations, as well as to avoid
351
+ # circular dependencies.
352
+ self._kv_configs: Dict[str, Any] = {}
353
+ self._max_num_blocks_in_streaming_gen_buffer = (
354
+ DEFAULT_MAX_NUM_BLOCKS_IN_STREAMING_GEN_BUFFER
355
+ )
356
+
357
+ is_ray_job = os.environ.get("RAY_JOB_ID") is not None
358
+ if is_ray_job:
359
+ is_driver = ray.get_runtime_context().worker.mode != WORKER_MODE
360
+ if is_driver and log_once(
361
+ "ray_data_disable_operator_progress_bars_in_ray_jobs"
362
+ ):
363
+ logger.info(
364
+ "Disabling operator-level progress bars by default in Ray Jobs. "
365
+ "To enable progress bars for all operators, set "
366
+ "`ray.data.DataContext.get_current()"
367
+ ".enable_operator_progress_bars = True`."
368
+ )
369
+ # Disable operator-level progress bars by default in Ray jobs.
370
+ # The global progress bar for the overall Dataset execution will
371
+ # still be enabled, unless the user also sets
372
+ # `ray.data.DataContext.get_current().enable_progress_bars = False`.
373
+ self.enable_operator_progress_bars = False
374
+ else:
375
+ # When not running in Ray job, operator-level progress
376
+ # bars are enabled by default.
377
+ self.enable_operator_progress_bars = True
378
+
379
+ def __setattr__(self, name: str, value: Any) -> None:
380
+ if (
381
+ name == "write_file_retry_on_errors"
382
+ and value != DEFAULT_WRITE_FILE_RETRY_ON_ERRORS
383
+ ):
384
+ warnings.warn(
385
+ "`write_file_retry_on_errors` is deprecated. Configure "
386
+ "`retried_io_errors` instead.",
387
+ DeprecationWarning,
388
+ )
389
+
390
+ super().__setattr__(name, value)
391
+
392
+ @staticmethod
393
+ def get_current() -> "DataContext":
394
+ """Get or create the current DataContext.
395
+
396
+ When a Dataset is created, the current DataContext will be sealed.
397
+ Changes to `DataContext.get_current()` will not impact existing Datasets.
398
+
399
+ Examples:
400
+
401
+ .. testcode::
402
+ import ray
403
+
404
+ context = ray.data.DataContext.get_current()
405
+
406
+ context.target_max_block_size = 100 * 1024 ** 2
407
+ ds1 = ray.data.range(1)
408
+ context.target_max_block_size = 1 * 1024 ** 2
409
+ ds2 = ray.data.range(1)
410
+
411
+ # ds1's target_max_block_size will be 100MB
412
+ ds1.take_all()
413
+ # ds2's target_max_block_size will be 1MB
414
+ ds2.take_all()
415
+
416
+ Developer notes: Avoid using `DataContext.get_current()` in data
417
+ internal components, use the DataContext object captured in the
418
+ Dataset and pass it around as arguments.
419
+ """
420
+
421
+ global _default_context
422
+
423
+ with _context_lock:
424
+ if _default_context is None:
425
+ _default_context = DataContext()
426
+
427
+ return _default_context
428
+
429
+ @staticmethod
430
+ def _set_current(context: "DataContext") -> None:
431
+ """Set the current context in a remote worker.
432
+
433
+ This is used internally by Dataset to propagate the driver context to
434
+ remote workers used for parallelization.
435
+ """
436
+ global _default_context
437
+ _default_context = context
438
+
439
+ def get_config(self, key: str, default: Any = None) -> Any:
440
+ """Get the value for a key-value style config.
441
+
442
+ Args:
443
+ key: The key of the config.
444
+ default: The default value to return if the key is not found.
445
+ Returns: The value for the key, or the default value if the key is not found.
446
+ """
447
+ return self._kv_configs.get(key, default)
448
+
449
+ def set_config(self, key: str, value: Any) -> None:
450
+ """Set the value for a key-value style config.
451
+
452
+ Args:
453
+ key: The key of the config.
454
+ value: The value of the config.
455
+ """
456
+ self._kv_configs[key] = value
457
+
458
+ def remove_config(self, key: str) -> None:
459
+ """Remove a key-value style config.
460
+
461
+ Args:
462
+ key: The key of the config.
463
+ """
464
+ self._kv_configs.pop(key, None)
465
+
466
+
467
+ # Backwards compatibility alias.
468
+ DatasetContext = DataContext
.venv/lib/python3.11/site-packages/ray/data/dataset.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/ray/data/datasource/datasink.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Generic, Iterable, List, Optional, TypeVar
4
+
5
+ import ray
6
+ from ray.data._internal.execution.interfaces import TaskContext
7
+ from ray.data.block import Block, BlockAccessor
8
+ from ray.util.annotations import DeveloperAPI
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ WriteReturnType = TypeVar("WriteReturnType")
14
+ """Generic type for the return value of `Datasink.write`."""
15
+
16
+
17
+ @dataclass
18
+ @DeveloperAPI
19
+ class WriteResult(Generic[WriteReturnType]):
20
+ """Aggregated result of the Datasink write operations."""
21
+
22
+ # Total number of written rows.
23
+ num_rows: int
24
+ # Total size in bytes of written data.
25
+ size_bytes: int
26
+ # All returned values of `Datasink.write`.
27
+ write_returns: List[WriteReturnType]
28
+
29
+
30
+ @DeveloperAPI
31
+ class Datasink(Generic[WriteReturnType]):
32
+ """Interface for defining write-related logic.
33
+
34
+ If you want to write data to something that isn't built-in, subclass this class
35
+ and call :meth:`~ray.data.Dataset.write_datasink`.
36
+ """
37
+
38
+ def on_write_start(self) -> None:
39
+ """Callback for when a write job starts.
40
+
41
+ Use this method to perform setup for write tasks. For example, creating a
42
+ staging bucket in S3.
43
+ """
44
+ pass
45
+
46
+ def write(
47
+ self,
48
+ blocks: Iterable[Block],
49
+ ctx: TaskContext,
50
+ ) -> WriteReturnType:
51
+ """Write blocks. This is used by a single write task.
52
+
53
+ Args:
54
+ blocks: Generator of data blocks.
55
+ ctx: ``TaskContext`` for the write task.
56
+
57
+ Returns:
58
+ Result of this write task. When the entire write operator finishes,
59
+ All returned values will be passed as `WriteResult.write_returns`
60
+ to `Datasink.on_write_complete`.
61
+ """
62
+ raise NotImplementedError
63
+
64
+ def on_write_complete(self, write_result: WriteResult[WriteReturnType]):
65
+ """Callback for when a write job completes.
66
+
67
+ This can be used to "commit" a write output. This method must
68
+ succeed prior to ``write_datasink()`` returning to the user. If this
69
+ method fails, then ``on_write_failed()`` is called.
70
+
71
+ Args:
72
+ write_result: Aggregated result of the
73
+ the Write operator, containing write results and stats.
74
+ """
75
+ pass
76
+
77
+ def on_write_failed(self, error: Exception) -> None:
78
+ """Callback for when a write job fails.
79
+
80
+ This is called on a best-effort basis on write failures.
81
+
82
+ Args:
83
+ error: The first error encountered.
84
+ """
85
+ pass
86
+
87
+ def get_name(self) -> str:
88
+ """Return a human-readable name for this datasink.
89
+
90
+ This is used as the names of the write tasks.
91
+ """
92
+ name = type(self).__name__
93
+ datasink_suffix = "Datasink"
94
+ if name.startswith("_"):
95
+ name = name[1:]
96
+ if name.endswith(datasink_suffix):
97
+ name = name[: -len(datasink_suffix)]
98
+ return name
99
+
100
+ @property
101
+ def supports_distributed_writes(self) -> bool:
102
+ """If ``False``, only launch write tasks on the driver's node."""
103
+ return True
104
+
105
+ @property
106
+ def min_rows_per_write(self) -> Optional[int]:
107
+ """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call.
108
+
109
+ If ``None``, Ray Data passes a system-chosen number of rows.
110
+ """
111
+ return None
112
+
113
+
114
+ @DeveloperAPI
115
+ class DummyOutputDatasink(Datasink[None]):
116
+ """An example implementation of a writable datasource for testing.
117
+ Examples:
118
+ >>> import ray
119
+ >>> from ray.data.datasource import DummyOutputDatasink
120
+ >>> output = DummyOutputDatasink()
121
+ >>> ray.data.range(10).write_datasink(output)
122
+ >>> assert output.num_ok == 1
123
+ """
124
+
125
+ def __init__(self):
126
+ ctx = ray.data.DataContext.get_current()
127
+
128
+ # Setup a dummy actor to send the data. In a real datasource, write
129
+ # tasks would send data to an external system instead of a Ray actor.
130
+ @ray.remote(scheduling_strategy=ctx.scheduling_strategy)
131
+ class DataSink:
132
+ def __init__(self):
133
+ self.rows_written = 0
134
+ self.enabled = True
135
+
136
+ def write(self, block: Block) -> None:
137
+ block = BlockAccessor.for_block(block)
138
+ self.rows_written += block.num_rows()
139
+
140
+ def get_rows_written(self):
141
+ return self.rows_written
142
+
143
+ self.data_sink = DataSink.remote()
144
+ self.num_ok = 0
145
+ self.num_failed = 0
146
+ self.enabled = True
147
+
148
+ def write(
149
+ self,
150
+ blocks: Iterable[Block],
151
+ ctx: TaskContext,
152
+ ) -> None:
153
+ tasks = []
154
+ if not self.enabled:
155
+ raise ValueError("disabled")
156
+ for b in blocks:
157
+ tasks.append(self.data_sink.write.remote(b))
158
+ ray.get(tasks)
159
+
160
+ def on_write_complete(self, write_result: WriteResult[None]):
161
+ self.num_ok += 1
162
+
163
+ def on_write_failed(self, error: Exception) -> None:
164
+ self.num_failed += 1
.venv/lib/python3.11/site-packages/ray/data/datasource/datasource.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Iterable, List, Optional
2
+
3
+ import numpy as np
4
+
5
+ from ray.data._internal.util import _check_pyarrow_version
6
+ from ray.data.block import Block, BlockMetadata
7
+ from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
8
+
9
+
10
+ @PublicAPI
11
+ class Datasource:
12
+ """Interface for defining a custom :class:`~ray.data.Dataset` datasource.
13
+
14
+ To read a datasource into a dataset, use :meth:`~ray.data.read_datasource`.
15
+ """ # noqa: E501
16
+
17
+ @Deprecated
18
+ def create_reader(self, **read_args) -> "Reader":
19
+ """
20
+ Deprecated: Implement :meth:`~ray.data.Datasource.get_read_tasks` and
21
+ :meth:`~ray.data.Datasource.estimate_inmemory_data_size` instead.
22
+ """
23
+ return _LegacyDatasourceReader(self, **read_args)
24
+
25
+ @Deprecated
26
+ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask"]:
27
+ """
28
+ Deprecated: Implement :meth:`~ray.data.Datasource.get_read_tasks` and
29
+ :meth:`~ray.data.Datasource.estimate_inmemory_data_size` instead.
30
+ """
31
+ raise NotImplementedError
32
+
33
+ def get_name(self) -> str:
34
+ """Return a human-readable name for this datasource.
35
+ This will be used as the names of the read tasks.
36
+ """
37
+ name = type(self).__name__
38
+ datasource_suffix = "Datasource"
39
+ if name.endswith(datasource_suffix):
40
+ name = name[: -len(datasource_suffix)]
41
+ return name
42
+
43
+ def estimate_inmemory_data_size(self) -> Optional[int]:
44
+ """Return an estimate of the in-memory data size, or None if unknown.
45
+
46
+ Note that the in-memory data size may be larger than the on-disk data size.
47
+ """
48
+ raise NotImplementedError
49
+
50
+ def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
51
+ """Execute the read and return read tasks.
52
+
53
+ Args:
54
+ parallelism: The requested read parallelism. The number of read
55
+ tasks should equal to this value if possible.
56
+
57
+ Returns:
58
+ A list of read tasks that can be executed to read blocks from the
59
+ datasource in parallel.
60
+ """
61
+ raise NotImplementedError
62
+
63
+ @property
64
+ def should_create_reader(self) -> bool:
65
+ has_implemented_get_read_tasks = (
66
+ type(self).get_read_tasks is not Datasource.get_read_tasks
67
+ )
68
+ has_implemented_estimate_inmemory_data_size = (
69
+ type(self).estimate_inmemory_data_size
70
+ is not Datasource.estimate_inmemory_data_size
71
+ )
72
+ return (
73
+ not has_implemented_get_read_tasks
74
+ or not has_implemented_estimate_inmemory_data_size
75
+ )
76
+
77
+ @property
78
+ def supports_distributed_reads(self) -> bool:
79
+ """If ``False``, only launch read tasks on the driver's node."""
80
+ return True
81
+
82
+
83
+ @Deprecated
84
+ class Reader:
85
+ """A bound read operation for a :class:`~ray.data.Datasource`.
86
+
87
+ This is a stateful class so that reads can be prepared in multiple stages.
88
+ For example, it is useful for :class:`Datasets <ray.data.Dataset>` to know the
89
+ in-memory size of the read prior to executing it.
90
+ """
91
+
92
+ def estimate_inmemory_data_size(self) -> Optional[int]:
93
+ """Return an estimate of the in-memory data size, or None if unknown.
94
+
95
+ Note that the in-memory data size may be larger than the on-disk data size.
96
+ """
97
+ raise NotImplementedError
98
+
99
+ def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
100
+ """Execute the read and return read tasks.
101
+
102
+ Args:
103
+ parallelism: The requested read parallelism. The number of read
104
+ tasks should equal to this value if possible.
105
+ read_args: Additional kwargs to pass to the datasource impl.
106
+
107
+ Returns:
108
+ A list of read tasks that can be executed to read blocks from the
109
+ datasource in parallel.
110
+ """
111
+ raise NotImplementedError
112
+
113
+
114
+ class _LegacyDatasourceReader(Reader):
115
+ def __init__(self, datasource: Datasource, **read_args):
116
+ self._datasource = datasource
117
+ self._read_args = read_args
118
+
119
+ def estimate_inmemory_data_size(self) -> Optional[int]:
120
+ return None
121
+
122
+ def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
123
+ return self._datasource.prepare_read(parallelism, **self._read_args)
124
+
125
+
126
+ @DeveloperAPI
127
+ class ReadTask(Callable[[], Iterable[Block]]):
128
+ """A function used to read blocks from the :class:`~ray.data.Dataset`.
129
+
130
+ Read tasks are generated by :meth:`~ray.data.Datasource.get_read_tasks`,
131
+ and return a list of ``ray.data.Block`` when called. Initial metadata about the read
132
+ operation can be retrieved via the ``metadata`` attribute prior to executing the
133
+ read. Final metadata is returned after the read along with the blocks.
134
+
135
+ Ray will execute read tasks in remote functions to parallelize execution.
136
+ Note that the number of blocks returned can vary at runtime. For example,
137
+ if a task is reading a single large file it can return multiple blocks to
138
+ avoid running out of memory during the read.
139
+
140
+ The initial metadata should reflect all the blocks returned by the read,
141
+ e.g., if the metadata says ``num_rows=1000``, the read can return a single
142
+ block of 1000 rows, or multiple blocks with 1000 rows altogether.
143
+
144
+ The final metadata (returned with the actual block) reflects the exact
145
+ contents of the block itself.
146
+ """
147
+
148
+ def __init__(self, read_fn: Callable[[], Iterable[Block]], metadata: BlockMetadata):
149
+ self._metadata = metadata
150
+ self._read_fn = read_fn
151
+
152
+ @property
153
+ def metadata(self) -> BlockMetadata:
154
+ return self._metadata
155
+
156
+ @property
157
+ def read_fn(self) -> Callable[[], Iterable[Block]]:
158
+ return self._read_fn
159
+
160
+ def __call__(self) -> Iterable[Block]:
161
+ result = self._read_fn()
162
+ if not hasattr(result, "__iter__"):
163
+ DeprecationWarning(
164
+ "Read function must return Iterable[Block], got {}. "
165
+ "Probably you need to return `[block]` instead of "
166
+ "`block`.".format(result)
167
+ )
168
+ yield from result
169
+
170
+
171
+ @DeveloperAPI
172
+ class RandomIntRowDatasource(Datasource):
173
+ """An example datasource that generates rows with random int64 columns.
174
+
175
+ Examples:
176
+ >>> import ray
177
+ >>> from ray.data.datasource import RandomIntRowDatasource
178
+ >>> source = RandomIntRowDatasource() # doctest: +SKIP
179
+ >>> ray.data.read_datasource( # doctest: +SKIP
180
+ ... source, n=10, num_columns=2).take()
181
+ {'c_0': 1717767200176864416, 'c_1': 999657309586757214}
182
+ {'c_0': 4983608804013926748, 'c_1': 1160140066899844087}
183
+ """
184
+
185
+ def __init__(self, n: int, num_columns: int):
186
+ self._n = n
187
+ self._num_columns = num_columns
188
+
189
+ def estimate_inmemory_data_size(self) -> Optional[int]:
190
+ return self._n * self._num_columns * 8
191
+
192
+ def get_read_tasks(
193
+ self,
194
+ parallelism: int,
195
+ ) -> List[ReadTask]:
196
+ _check_pyarrow_version()
197
+ import pyarrow
198
+
199
+ read_tasks: List[ReadTask] = []
200
+ n = self._n
201
+ num_columns = self._num_columns
202
+ block_size = max(1, n // parallelism)
203
+
204
+ def make_block(count: int, num_columns: int) -> Block:
205
+ return pyarrow.Table.from_arrays(
206
+ np.random.randint(
207
+ np.iinfo(np.int64).max, size=(num_columns, count), dtype=np.int64
208
+ ),
209
+ names=[f"c_{i}" for i in range(num_columns)],
210
+ )
211
+
212
+ schema = pyarrow.Table.from_pydict(
213
+ {f"c_{i}": [0] for i in range(num_columns)}
214
+ ).schema
215
+
216
+ i = 0
217
+ while i < n:
218
+ count = min(block_size, n - i)
219
+ meta = BlockMetadata(
220
+ num_rows=count,
221
+ size_bytes=8 * count * num_columns,
222
+ schema=schema,
223
+ input_files=None,
224
+ exec_stats=None,
225
+ )
226
+ read_tasks.append(
227
+ ReadTask(
228
+ lambda count=count, num_columns=num_columns: [
229
+ make_block(count, num_columns)
230
+ ],
231
+ meta,
232
+ )
233
+ )
234
+ i += block_size
235
+
236
+ return read_tasks
237
+
238
+ def get_name(self) -> str:
239
+ """Return a human-readable name for this datasource.
240
+ This will be used as the names of the read tasks.
241
+ Note: overrides the base `Datasource` method.
242
+ """
243
+ return "RandomInt"
.venv/lib/python3.11/site-packages/ray/data/datasource/file_based_datasource.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Callable,
8
+ Dict,
9
+ Iterable,
10
+ Iterator,
11
+ List,
12
+ Literal,
13
+ Optional,
14
+ Union,
15
+ )
16
+
17
+ import numpy as np
18
+
19
+ import ray
20
+ from ray.data._internal.util import (
21
+ _check_pyarrow_version,
22
+ _is_local_scheme,
23
+ call_with_retry,
24
+ make_async_gen,
25
+ )
26
+ from ray.data.block import Block, BlockAccessor
27
+ from ray.data.context import DataContext
28
+ from ray.data.datasource.datasource import Datasource, ReadTask
29
+ from ray.data.datasource.file_meta_provider import (
30
+ BaseFileMetadataProvider,
31
+ DefaultFileMetadataProvider,
32
+ )
33
+ from ray.data.datasource.partitioning import (
34
+ Partitioning,
35
+ PathPartitionFilter,
36
+ PathPartitionParser,
37
+ )
38
+ from ray.data.datasource.path_util import (
39
+ _has_file_extension,
40
+ _resolve_paths_and_filesystem,
41
+ )
42
+ from ray.util.annotations import DeveloperAPI
43
+
44
+ if TYPE_CHECKING:
45
+ import pandas as pd
46
+ import pyarrow
47
+
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ # We should parallelize file size fetch operations beyond this threshold.
53
+ FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD = 16
54
+
55
+ # 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem.
56
+ PATHS_PER_FILE_SIZE_FETCH_TASK = 16
57
+
58
+ # The max retry backoff in seconds for opening file.
59
+ OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS = 32
60
+
61
+ # The max number of attempts for opening file.
62
+ OPEN_FILE_MAX_ATTEMPTS = 10
63
+
64
+
65
+ @DeveloperAPI
66
+ @dataclass
67
+ class FileShuffleConfig:
68
+ """Configuration for file shuffling.
69
+
70
+ This configuration object controls how files are shuffled while reading file-based
71
+ datasets.
72
+
73
+ .. note::
74
+ Even if you provided a seed, you might still observe a non-deterministic row
75
+ order. This is because tasks are executed in parallel and their completion
76
+ order might vary. If you need to preserve the order of rows, set
77
+ `DataContext.get_current().execution_options.preserve_order`.
78
+
79
+ Args:
80
+ seed: An optional integer seed for the file shuffler. If provided, Ray Data
81
+ shuffles files deterministically based on this seed.
82
+
83
+ Example:
84
+ >>> import ray
85
+ >>> from ray.data import FileShuffleConfig
86
+ >>> shuffle = FileShuffleConfig(seed=42)
87
+ >>> ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea", shuffle=shuffle)
88
+ """ # noqa: E501
89
+
90
+ seed: Optional[int] = None
91
+
92
+ def __post_init__(self):
93
+ """Ensure that the seed is either None or an integer."""
94
+ if self.seed is not None and not isinstance(self.seed, int):
95
+ raise ValueError("Seed must be an integer or None.")
96
+
97
+
98
+ @DeveloperAPI
99
+ class FileBasedDatasource(Datasource):
100
+ """File-based datasource for reading files.
101
+
102
+ Don't use this class directly. Instead, subclass it and implement `_read_stream()`.
103
+ """
104
+
105
+ # If `_WRITE_FILE_PER_ROW` is `True`, this datasource calls `_write_row` and writes
106
+ # each row to a file. Otherwise, this datasource calls `_write_block` and writes
107
+ # each block to a file.
108
+ _WRITE_FILE_PER_ROW = False
109
+ _FILE_EXTENSIONS: Optional[Union[str, List[str]]] = None
110
+ # Number of threads for concurrent reading within each read task.
111
+ # If zero or negative, reading will be performed in the main thread.
112
+ _NUM_THREADS_PER_TASK = 0
113
+
114
+ def __init__(
115
+ self,
116
+ paths: Union[str, List[str]],
117
+ *,
118
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
119
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
120
+ open_stream_args: Optional[Dict[str, Any]] = None,
121
+ meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(),
122
+ partition_filter: PathPartitionFilter = None,
123
+ partitioning: Partitioning = None,
124
+ ignore_missing_paths: bool = False,
125
+ shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None,
126
+ include_paths: bool = False,
127
+ file_extensions: Optional[List[str]] = None,
128
+ ):
129
+ _check_pyarrow_version()
130
+
131
+ self._supports_distributed_reads = not _is_local_scheme(paths)
132
+ if not self._supports_distributed_reads and ray.util.client.ray.is_connected():
133
+ raise ValueError(
134
+ "Because you're using Ray Client, read tasks scheduled on the Ray "
135
+ "cluster can't access your local files. To fix this issue, store "
136
+ "files in cloud storage or a distributed filesystem like NFS."
137
+ )
138
+
139
+ self._schema = schema
140
+ self._open_stream_args = open_stream_args
141
+ self._meta_provider = meta_provider
142
+ self._partition_filter = partition_filter
143
+ self._partitioning = partitioning
144
+ self._ignore_missing_paths = ignore_missing_paths
145
+ self._include_paths = include_paths
146
+ paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem)
147
+ paths, file_sizes = map(
148
+ list,
149
+ zip(
150
+ *meta_provider.expand_paths(
151
+ paths,
152
+ self._filesystem,
153
+ partitioning,
154
+ ignore_missing_paths=ignore_missing_paths,
155
+ )
156
+ ),
157
+ )
158
+
159
+ if ignore_missing_paths and len(paths) == 0:
160
+ raise ValueError(
161
+ "None of the provided paths exist. "
162
+ "The 'ignore_missing_paths' field is set to True."
163
+ )
164
+
165
+ if self._partition_filter is not None:
166
+ # Use partition filter to skip files which are not needed.
167
+ path_to_size = dict(zip(paths, file_sizes))
168
+ paths = self._partition_filter(paths)
169
+ file_sizes = [path_to_size[p] for p in paths]
170
+ if len(paths) == 0:
171
+ raise ValueError(
172
+ "No input files found to read. Please double check that "
173
+ "'partition_filter' field is set properly."
174
+ )
175
+
176
+ if file_extensions is not None:
177
+ path_to_size = dict(zip(paths, file_sizes))
178
+ paths = [p for p in paths if _has_file_extension(p, file_extensions)]
179
+ file_sizes = [path_to_size[p] for p in paths]
180
+ if len(paths) == 0:
181
+ raise ValueError(
182
+ "No input files found to read with the following file extensions: "
183
+ f"{file_extensions}. Please double check that "
184
+ "'file_extensions' field is set properly."
185
+ )
186
+
187
+ _validate_shuffle_arg(shuffle)
188
+ self._file_metadata_shuffler = None
189
+ if shuffle == "files":
190
+ self._file_metadata_shuffler = np.random.default_rng()
191
+ elif isinstance(shuffle, FileShuffleConfig):
192
+ # Create a NumPy random generator with a fixed seed if provided
193
+ self._file_metadata_shuffler = np.random.default_rng(shuffle.seed)
194
+
195
+ # Read tasks serialize `FileBasedDatasource` instances, and the list of paths
196
+ # can be large. To avoid slow serialization speeds, we store a reference to
197
+ # the paths rather than the paths themselves.
198
+ self._paths_ref = ray.put(paths)
199
+ self._file_sizes_ref = ray.put(file_sizes)
200
+
201
+ def _paths(self) -> List[str]:
202
+ return ray.get(self._paths_ref)
203
+
204
+ def _file_sizes(self) -> List[float]:
205
+ return ray.get(self._file_sizes_ref)
206
+
207
+ def estimate_inmemory_data_size(self) -> Optional[int]:
208
+ total_size = 0
209
+ for sz in self._file_sizes():
210
+ if sz is not None:
211
+ total_size += sz
212
+ return total_size
213
+
214
+ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
215
+ import numpy as np
216
+
217
+ ctx = DataContext.get_current()
218
+ open_stream_args = self._open_stream_args
219
+ partitioning = self._partitioning
220
+
221
+ paths = self._paths()
222
+ file_sizes = self._file_sizes()
223
+
224
+ if self._file_metadata_shuffler is not None:
225
+ files_metadata = list(zip(paths, file_sizes))
226
+ shuffled_files_metadata = [
227
+ files_metadata[i]
228
+ for i in self._file_metadata_shuffler.permutation(len(files_metadata))
229
+ ]
230
+ paths, file_sizes = list(map(list, zip(*shuffled_files_metadata)))
231
+
232
+ read_stream = self._read_stream
233
+ filesystem = _wrap_s3_serialization_workaround(self._filesystem)
234
+
235
+ if open_stream_args is None:
236
+ open_stream_args = {}
237
+
238
+ open_input_source = self._open_input_source
239
+
240
+ def read_files(
241
+ read_paths: Iterable[str],
242
+ ) -> Iterable[Block]:
243
+ nonlocal filesystem, open_stream_args, partitioning
244
+
245
+ DataContext._set_current(ctx)
246
+ fs = _unwrap_s3_serialization_workaround(filesystem)
247
+ for read_path in read_paths:
248
+ partitions: Dict[str, str] = {}
249
+ if partitioning is not None:
250
+ parse = PathPartitionParser(partitioning)
251
+ partitions = parse(read_path)
252
+
253
+ with _open_file_with_retry(
254
+ read_path,
255
+ lambda read_path=read_path: open_input_source(
256
+ fs, read_path, **open_stream_args
257
+ ),
258
+ ) as f:
259
+ for block in read_stream(f, read_path):
260
+ if partitions:
261
+ block = _add_partitions(block, partitions)
262
+ if self._include_paths:
263
+ block_accessor = BlockAccessor.for_block(block)
264
+ block = block_accessor.append_column(
265
+ "path", [read_path] * block_accessor.num_rows()
266
+ )
267
+ yield block
268
+
269
+ def create_read_task_fn(read_paths, num_threads):
270
+ def read_task_fn():
271
+ nonlocal num_threads, read_paths
272
+
273
+ # TODO: We should refactor the code so that we can get the results in
274
+ # order even when using multiple threads.
275
+ if ctx.execution_options.preserve_order:
276
+ num_threads = 0
277
+
278
+ if num_threads > 0:
279
+ if len(read_paths) < num_threads:
280
+ num_threads = len(read_paths)
281
+
282
+ logger.debug(
283
+ f"Reading {len(read_paths)} files with {num_threads} threads."
284
+ )
285
+
286
+ yield from make_async_gen(
287
+ iter(read_paths),
288
+ read_files,
289
+ num_workers=num_threads,
290
+ )
291
+ else:
292
+ logger.debug(f"Reading {len(read_paths)} files.")
293
+ yield from read_files(read_paths)
294
+
295
+ return read_task_fn
296
+
297
+ # fix https://github.com/ray-project/ray/issues/24296
298
+ parallelism = min(parallelism, len(paths))
299
+
300
+ read_tasks = []
301
+ split_paths = np.array_split(paths, parallelism)
302
+ split_file_sizes = np.array_split(file_sizes, parallelism)
303
+
304
+ for read_paths, file_sizes in zip(split_paths, split_file_sizes):
305
+ if len(read_paths) <= 0:
306
+ continue
307
+
308
+ meta = self._meta_provider(
309
+ read_paths,
310
+ self._schema,
311
+ rows_per_file=self._rows_per_file(),
312
+ file_sizes=file_sizes,
313
+ )
314
+
315
+ read_task_fn = create_read_task_fn(read_paths, self._NUM_THREADS_PER_TASK)
316
+
317
+ read_task = ReadTask(read_task_fn, meta)
318
+
319
+ read_tasks.append(read_task)
320
+
321
+ return read_tasks
322
+
323
+ def _open_input_source(
324
+ self,
325
+ filesystem: "pyarrow.fs.FileSystem",
326
+ path: str,
327
+ **open_args,
328
+ ) -> "pyarrow.NativeFile":
329
+ """Opens a source path for reading and returns the associated Arrow NativeFile.
330
+
331
+ The default implementation opens the source path as a sequential input stream,
332
+ using ctx.streaming_read_buffer_size as the buffer size if none is given by the
333
+ caller.
334
+
335
+ Implementations that do not support streaming reads (e.g. that require random
336
+ access) should override this method.
337
+ """
338
+ import pyarrow as pa
339
+ from pyarrow.fs import HadoopFileSystem
340
+
341
+ ctx = DataContext.get_current()
342
+
343
+ compression = open_args.get("compression", None)
344
+ if compression is None:
345
+ try:
346
+ # If no compression manually given, try to detect
347
+ # compression codec from path.
348
+ compression = pa.Codec.detect(path).name
349
+ except (ValueError, TypeError):
350
+ # Arrow's compression inference on the file path
351
+ # doesn't work for Snappy, so we double-check ourselves.
352
+ import pathlib
353
+
354
+ suffix = pathlib.Path(path).suffix
355
+ if suffix and suffix[1:] == "snappy":
356
+ compression = "snappy"
357
+ else:
358
+ compression = None
359
+
360
+ buffer_size = open_args.pop("buffer_size", None)
361
+ if buffer_size is None:
362
+ buffer_size = ctx.streaming_read_buffer_size
363
+
364
+ if compression == "snappy":
365
+ # Arrow doesn't support streaming Snappy decompression since the canonical
366
+ # C++ Snappy library doesn't natively support streaming decompression. We
367
+ # works around this by manually decompressing the file with python-snappy.
368
+ open_args["compression"] = None
369
+ else:
370
+ open_args["compression"] = compression
371
+
372
+ file = call_with_retry(
373
+ lambda: filesystem.open_input_stream(
374
+ path, buffer_size=buffer_size, **open_args
375
+ ),
376
+ description=f"open file {path}",
377
+ match=ctx.retried_io_errors,
378
+ )
379
+
380
+ if compression == "snappy":
381
+ import snappy
382
+
383
+ stream = io.BytesIO()
384
+ if isinstance(filesystem, HadoopFileSystem):
385
+ snappy.hadoop_snappy.stream_decompress(src=file, dst=stream)
386
+ else:
387
+ snappy.stream_decompress(src=file, dst=stream)
388
+ stream.seek(0)
389
+
390
+ file = pa.PythonFile(stream, mode="r")
391
+
392
+ return file
393
+
394
+ def _rows_per_file(self):
395
+ """Returns the number of rows per file, or None if unknown."""
396
+ return None
397
+
398
+ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
399
+ """Streaming read a single file.
400
+
401
+ This method should be implemented by subclasses.
402
+ """
403
+ raise NotImplementedError(
404
+ "Subclasses of FileBasedDatasource must implement _read_stream()."
405
+ )
406
+
407
+ @property
408
+ def supports_distributed_reads(self) -> bool:
409
+ return self._supports_distributed_reads
410
+
411
+
412
+ def _add_partitions(
413
+ data: Union["pyarrow.Table", "pd.DataFrame"], partitions: Dict[str, Any]
414
+ ) -> Union["pyarrow.Table", "pd.DataFrame"]:
415
+ import pandas as pd
416
+ import pyarrow as pa
417
+
418
+ assert isinstance(data, (pa.Table, pd.DataFrame))
419
+ if isinstance(data, pa.Table):
420
+ return _add_partitions_to_table(data, partitions)
421
+ if isinstance(data, pd.DataFrame):
422
+ return _add_partitions_to_dataframe(data, partitions)
423
+
424
+
425
+ def _add_partitions_to_table(
426
+ table: "pyarrow.Table", partitions: Dict[str, Any]
427
+ ) -> "pyarrow.Table":
428
+ import pyarrow as pa
429
+ import pyarrow.compute as pc
430
+
431
+ column_names = set(table.column_names)
432
+ for field, value in partitions.items():
433
+ column = pa.array([value] * len(table))
434
+ if field in column_names:
435
+ # TODO: Handle cast error.
436
+ column_type = table.schema.field(field).type
437
+ column = column.cast(column_type)
438
+
439
+ values_are_equal = pc.all(pc.equal(column, table[field]))
440
+ values_are_equal = values_are_equal.as_py()
441
+
442
+ if not values_are_equal:
443
+ raise ValueError(
444
+ f"Partition column {field} exists in table data, but partition "
445
+ f"value '{value}' is different from in-data values: "
446
+ f"{table[field].unique().to_pylist()}."
447
+ )
448
+
449
+ i = table.schema.get_field_index(field)
450
+ table = table.set_column(i, field, column)
451
+ else:
452
+ table = table.append_column(field, column)
453
+
454
+ return table
455
+
456
+
457
+ def _add_partitions_to_dataframe(
458
+ df: "pd.DataFrame", partitions: Dict[str, Any]
459
+ ) -> "pd.DataFrame":
460
+ import pandas as pd
461
+
462
+ for field, value in partitions.items():
463
+ column = pd.Series(data=[value] * len(df), name=field)
464
+
465
+ if field in df:
466
+ column = column.astype(df[field].dtype)
467
+ mask = df[field].notna()
468
+ if not df[field][mask].equals(column[mask]):
469
+ raise ValueError(
470
+ f"Partition column {field} exists in table data, but partition "
471
+ f"value '{value}' is different from in-data values: "
472
+ f"{list(df[field].unique())}."
473
+ )
474
+
475
+ df[field] = column
476
+
477
+ return df
478
+
479
+
480
+ def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"):
481
+ # This is needed because pa.fs.S3FileSystem assumes pa.fs is already
482
+ # imported before deserialization. See #17085.
483
+ import pyarrow as pa
484
+ import pyarrow.fs
485
+
486
+ if isinstance(filesystem, pa.fs.S3FileSystem):
487
+ return _S3FileSystemWrapper(filesystem)
488
+ return filesystem
489
+
490
+
491
+ def _unwrap_s3_serialization_workaround(
492
+ filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"]
493
+ ):
494
+ if isinstance(filesystem, _S3FileSystemWrapper):
495
+ return filesystem.unwrap()
496
+ else:
497
+ return filesystem
498
+
499
+
500
+ class _S3FileSystemWrapper:
501
+ def __init__(self, fs: "pyarrow.fs.S3FileSystem"):
502
+ self._fs = fs
503
+
504
+ def unwrap(self):
505
+ return self._fs
506
+
507
+ @classmethod
508
+ def _reconstruct(cls, fs_reconstruct, fs_args):
509
+ # Implicitly trigger S3 subsystem initialization by importing
510
+ # pyarrow.fs.
511
+ import pyarrow.fs # noqa: F401
512
+
513
+ return cls(fs_reconstruct(*fs_args))
514
+
515
+ def __reduce__(self):
516
+ return _S3FileSystemWrapper._reconstruct, self._fs.__reduce__()
517
+
518
+
519
+ def _wrap_arrow_serialization_workaround(kwargs: dict) -> dict:
520
+ if "filesystem" in kwargs:
521
+ kwargs["filesystem"] = _wrap_s3_serialization_workaround(kwargs["filesystem"])
522
+
523
+ return kwargs
524
+
525
+
526
+ def _unwrap_arrow_serialization_workaround(kwargs: dict) -> dict:
527
+ if isinstance(kwargs.get("filesystem"), _S3FileSystemWrapper):
528
+ kwargs["filesystem"] = kwargs["filesystem"].unwrap()
529
+ return kwargs
530
+
531
+
532
+ def _resolve_kwargs(
533
+ kwargs_fn: Callable[[], Dict[str, Any]], **kwargs
534
+ ) -> Dict[str, Any]:
535
+ if kwargs_fn:
536
+ kwarg_overrides = kwargs_fn()
537
+ kwargs.update(kwarg_overrides)
538
+ return kwargs
539
+
540
+
541
+ def _open_file_with_retry(
542
+ file_path: str,
543
+ open_file: Callable[[], "pyarrow.NativeFile"],
544
+ ) -> "pyarrow.NativeFile":
545
+ """Open file with an exponential backoff retry strategy.
546
+
547
+ This is to avoid transient task failure with remote storage (such as S3),
548
+ when the remote storage throttles the requests.
549
+ """
550
+ if OPEN_FILE_MAX_ATTEMPTS < 1:
551
+ raise ValueError(
552
+ "OPEN_FILE_MAX_ATTEMPTS cannot be negative or 0. Get: "
553
+ f"{OPEN_FILE_MAX_ATTEMPTS}"
554
+ )
555
+
556
+ return call_with_retry(
557
+ open_file,
558
+ description=f"open file {file_path}",
559
+ match=DataContext.get_current().retried_io_errors,
560
+ max_attempts=OPEN_FILE_MAX_ATTEMPTS,
561
+ max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS,
562
+ )
563
+
564
+
565
+ def _validate_shuffle_arg(shuffle: Optional[str]) -> None:
566
+ if not (
567
+ shuffle is None or shuffle == "files" or isinstance(shuffle, FileShuffleConfig)
568
+ ):
569
+ raise ValueError(
570
+ f"Invalid value for 'shuffle': {shuffle}. "
571
+ "Valid values are None, 'files', `FileShuffleConfig`."
572
+ )
.venv/lib/python3.11/site-packages/ray/data/datasource/file_meta_provider.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Callable,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ TypeVar,
14
+ Union,
15
+ )
16
+
17
+ import numpy as np
18
+
19
+ import ray
20
+ from ray.data._internal.progress_bar import ProgressBar
21
+ from ray.data._internal.remote_fn import cached_remote_fn
22
+ from ray.data._internal.util import call_with_retry
23
+ from ray.data.block import BlockMetadata
24
+ from ray.data.datasource.partitioning import Partitioning
25
+ from ray.util.annotations import DeveloperAPI
26
+
27
+ if TYPE_CHECKING:
28
+ import pyarrow
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @DeveloperAPI
35
+ class FileMetadataProvider:
36
+ """Abstract callable that provides metadata for the files of a single dataset block.
37
+
38
+ Current subclasses:
39
+ - :class:`BaseFileMetadataProvider`
40
+ - :class:`ParquetMetadataProvider`
41
+ """
42
+
43
+ def _get_block_metadata(
44
+ self,
45
+ paths: List[str],
46
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]],
47
+ **kwargs,
48
+ ) -> BlockMetadata:
49
+ """Resolves and returns block metadata for files in the given paths.
50
+
51
+ All file paths provided should belong to a single dataset block.
52
+
53
+ Args:
54
+ paths: The file paths for a single dataset block.
55
+ schema: The user-provided or inferred schema for the given paths,
56
+ if any.
57
+
58
+ Returns:
59
+ BlockMetadata aggregated across the given paths.
60
+ """
61
+ raise NotImplementedError
62
+
63
+ def __call__(
64
+ self,
65
+ paths: List[str],
66
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]],
67
+ **kwargs,
68
+ ) -> BlockMetadata:
69
+ return self._get_block_metadata(paths, schema, **kwargs)
70
+
71
+
72
+ @DeveloperAPI
73
+ class BaseFileMetadataProvider(FileMetadataProvider):
74
+ """Abstract callable that provides metadata for
75
+ :class:`~ray.data.datasource.file_based_datasource.FileBasedDatasource`
76
+ implementations that reuse the base :meth:`~ray.data.Datasource.prepare_read`
77
+ method.
78
+
79
+ Also supports file and file size discovery in input directory paths.
80
+
81
+ Current subclasses:
82
+ - :class:`DefaultFileMetadataProvider`
83
+ """
84
+
85
+ def _get_block_metadata(
86
+ self,
87
+ paths: List[str],
88
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]],
89
+ *,
90
+ rows_per_file: Optional[int],
91
+ file_sizes: List[Optional[int]],
92
+ ) -> BlockMetadata:
93
+ """Resolves and returns block metadata for files of a single dataset block.
94
+
95
+ Args:
96
+ paths: The file paths for a single dataset block. These
97
+ paths will always be a subset of those previously returned from
98
+ :meth:`.expand_paths`.
99
+ schema: The user-provided or inferred schema for the given file
100
+ paths, if any.
101
+ rows_per_file: The fixed number of rows per input file, or None.
102
+ file_sizes: Optional file size per input file previously returned
103
+ from :meth:`.expand_paths`, where `file_sizes[i]` holds the size of
104
+ the file at `paths[i]`.
105
+
106
+ Returns:
107
+ BlockMetadata aggregated across the given file paths.
108
+ """
109
+ raise NotImplementedError
110
+
111
+ def expand_paths(
112
+ self,
113
+ paths: List[str],
114
+ filesystem: Optional["pyarrow.fs.FileSystem"],
115
+ partitioning: Optional[Partitioning] = None,
116
+ ignore_missing_paths: bool = False,
117
+ ) -> Iterator[Tuple[str, int]]:
118
+ """Expands all paths into concrete file paths by walking directories.
119
+
120
+ Also returns a sidecar of file sizes.
121
+
122
+ The input paths must be normalized for compatibility with the input
123
+ filesystem prior to invocation.
124
+
125
+ Args:
126
+ paths: A list of file and/or directory paths compatible with the
127
+ given filesystem.
128
+ filesystem: The filesystem implementation that should be used for
129
+ expanding all paths and reading their files.
130
+ ignore_missing_paths: If True, ignores any file paths in ``paths`` that
131
+ are not found. Defaults to False.
132
+
133
+ Returns:
134
+ An iterator of `(file_path, file_size)` pairs. None may be returned for the
135
+ file size if it is either unknown or will be fetched later by
136
+ `_get_block_metadata()`, but the length of
137
+ both lists must be equal.
138
+ """
139
+ raise NotImplementedError
140
+
141
+
142
+ @DeveloperAPI
143
+ class DefaultFileMetadataProvider(BaseFileMetadataProvider):
144
+ """Default metadata provider for
145
+ :class:`~ray.data.datasource.file_based_datasource.FileBasedDatasource`
146
+ implementations that reuse the base `prepare_read` method.
147
+
148
+ Calculates block size in bytes as the sum of its constituent file sizes,
149
+ and assumes a fixed number of rows per file.
150
+ """
151
+
152
+ def _get_block_metadata(
153
+ self,
154
+ paths: List[str],
155
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]],
156
+ *,
157
+ rows_per_file: Optional[int],
158
+ file_sizes: List[Optional[int]],
159
+ ) -> BlockMetadata:
160
+ if rows_per_file is None:
161
+ num_rows = None
162
+ else:
163
+ num_rows = len(paths) * rows_per_file
164
+ return BlockMetadata(
165
+ num_rows=num_rows,
166
+ size_bytes=None if None in file_sizes else int(sum(file_sizes)),
167
+ schema=schema,
168
+ input_files=paths,
169
+ exec_stats=None,
170
+ ) # Exec stats filled in later.
171
+
172
+ def expand_paths(
173
+ self,
174
+ paths: List[str],
175
+ filesystem: "pyarrow.fs.FileSystem",
176
+ partitioning: Optional[Partitioning] = None,
177
+ ignore_missing_paths: bool = False,
178
+ ) -> Iterator[Tuple[str, int]]:
179
+ yield from _expand_paths(paths, filesystem, partitioning, ignore_missing_paths)
180
+
181
+
182
+ @DeveloperAPI
183
+ class FastFileMetadataProvider(DefaultFileMetadataProvider):
184
+ """Fast Metadata provider for
185
+ :class:`~ray.data.datasource.file_based_datasource.FileBasedDatasource`
186
+ implementations.
187
+
188
+ Offers improved performance vs.
189
+ :class:`DefaultFileMetadataProvider`
190
+ by skipping directory path expansion and file size collection.
191
+ While this performance improvement may be negligible for local filesystems,
192
+ it can be substantial for cloud storage service providers.
193
+
194
+ This should only be used when all input paths exist and are known to be files.
195
+ """
196
+
197
+ def expand_paths(
198
+ self,
199
+ paths: List[str],
200
+ filesystem: "pyarrow.fs.FileSystem",
201
+ partitioning: Optional[Partitioning] = None,
202
+ ignore_missing_paths: bool = False,
203
+ ) -> Iterator[Tuple[str, int]]:
204
+ if ignore_missing_paths:
205
+ raise ValueError(
206
+ "`ignore_missing_paths` cannot be set when used with "
207
+ "`FastFileMetadataProvider`. All paths must exist when "
208
+ "using `FastFileMetadataProvider`."
209
+ )
210
+
211
+ logger.warning(
212
+ f"Skipping expansion of {len(paths)} path(s). If your paths contain "
213
+ f"directories or if file size collection is required, try rerunning this "
214
+ f"read with `meta_provider=DefaultFileMetadataProvider()`."
215
+ )
216
+
217
+ yield from zip(paths, itertools.repeat(None, len(paths)))
218
+
219
+
220
+ def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str:
221
+ # NOTE: this is not comprehensive yet, and should be extended as more errors arise.
222
+ # NOTE: The latter patterns are raised in Arrow 10+, while the former is raised in
223
+ # Arrow < 10.
224
+ aws_error_pattern = (
225
+ r"^(?:(.*)AWS Error \[code \d+\]: No response body\.(.*))|"
226
+ r"(?:(.*)AWS Error UNKNOWN \(HTTP status 400\) during HeadObject operation: "
227
+ r"No response body\.(.*))|"
228
+ r"(?:(.*)AWS Error ACCESS_DENIED during HeadObject operation: No response "
229
+ r"body\.(.*))$"
230
+ )
231
+ if re.match(aws_error_pattern, str(error)):
232
+ # Specially handle AWS error when reading files, to give a clearer error
233
+ # message to avoid confusing users. The real issue is most likely that the AWS
234
+ # S3 file credentials have not been properly configured yet.
235
+ if isinstance(paths, str):
236
+ # Quote to highlight single file path in error message for better
237
+ # readability. List of file paths will be shown up as ['foo', 'boo'],
238
+ # so only quote single file path here.
239
+ paths = f'"{paths}"'
240
+ raise OSError(
241
+ (
242
+ f"Failing to read AWS S3 file(s): {paths}. "
243
+ "Please check that file exists and has properly configured access. "
244
+ "You can also run AWS CLI command to get more detailed error message "
245
+ "(e.g., aws s3 ls <file-name>). "
246
+ "See https://awscli.amazonaws.com/v2/documentation/api/latest/reference/s3/index.html " # noqa
247
+ "and https://docs.ray.io/en/latest/data/creating-datasets.html#reading-from-remote-storage " # noqa
248
+ "for more information."
249
+ )
250
+ )
251
+ else:
252
+ raise error
253
+
254
+
255
+ def _expand_paths(
256
+ paths: List[str],
257
+ filesystem: "pyarrow.fs.FileSystem",
258
+ partitioning: Optional[Partitioning],
259
+ ignore_missing_paths: bool = False,
260
+ ) -> Iterator[Tuple[str, int]]:
261
+ """Get the file sizes for all provided file paths."""
262
+ from pyarrow.fs import LocalFileSystem
263
+
264
+ from ray.data.datasource.file_based_datasource import (
265
+ FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD,
266
+ )
267
+ from ray.data.datasource.path_util import _unwrap_protocol
268
+
269
+ # We break down our processing paths into a few key cases:
270
+ # 1. If len(paths) < threshold, fetch the file info for the individual files/paths
271
+ # serially.
272
+ # 2. If all paths are contained under the same parent directory (or base directory,
273
+ # if using partitioning), fetch all file infos at this prefix and filter to the
274
+ # provided paths on the client; this should be a single file info request.
275
+ # 3. If more than threshold requests required, parallelize them via Ray tasks.
276
+ # 1. Small # of paths case.
277
+ if (
278
+ len(paths) < FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD
279
+ # Local file systems are very fast to hit.
280
+ or isinstance(filesystem, LocalFileSystem)
281
+ ):
282
+ yield from _get_file_infos_serial(paths, filesystem, ignore_missing_paths)
283
+ else:
284
+ # 2. Common path prefix case.
285
+ # Get longest common path of all paths.
286
+ common_path = os.path.commonpath(paths)
287
+ # If parent directory (or base directory, if using partitioning) is common to
288
+ # all paths, fetch all file infos at that prefix and filter the response to the
289
+ # provided paths.
290
+ if (
291
+ partitioning is not None
292
+ and common_path == _unwrap_protocol(partitioning.base_dir)
293
+ ) or all(str(pathlib.Path(path).parent) == common_path for path in paths):
294
+ yield from _get_file_infos_common_path_prefix(
295
+ paths, common_path, filesystem, ignore_missing_paths
296
+ )
297
+ # 3. Parallelization case.
298
+ else:
299
+ # Parallelize requests via Ray tasks.
300
+ yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths)
301
+
302
+
303
+ def _get_file_infos_serial(
304
+ paths: List[str],
305
+ filesystem: "pyarrow.fs.FileSystem",
306
+ ignore_missing_paths: bool = False,
307
+ ) -> Iterator[Tuple[str, int]]:
308
+ for path in paths:
309
+ yield from _get_file_infos(path, filesystem, ignore_missing_paths)
310
+
311
+
312
+ def _get_file_infos_common_path_prefix(
313
+ paths: List[str],
314
+ common_path: str,
315
+ filesystem: "pyarrow.fs.FileSystem",
316
+ ignore_missing_paths: bool = False,
317
+ ) -> Iterator[Tuple[str, int]]:
318
+ path_to_size = {path: None for path in paths}
319
+ for path, file_size in _get_file_infos(
320
+ common_path, filesystem, ignore_missing_paths
321
+ ):
322
+ if path in path_to_size:
323
+ path_to_size[path] = file_size
324
+
325
+ # Check if all `paths` have file size metadata.
326
+ # If any of paths has no file size, fall back to get files metadata in parallel.
327
+ # This can happen when path is a directory, but not a file.
328
+ have_missing_path = False
329
+ for path in paths:
330
+ if path_to_size[path] is None:
331
+ logger.debug(
332
+ f"Finding path {path} not have file size metadata. "
333
+ "Fall back to get files metadata in parallel for all paths."
334
+ )
335
+ have_missing_path = True
336
+ break
337
+
338
+ if have_missing_path:
339
+ # Parallelize requests via Ray tasks.
340
+ yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths)
341
+ else:
342
+ # Iterate over `paths` to yield each path in original order.
343
+ # NOTE: do not iterate over `path_to_size` because the dictionary skips
344
+ # duplicated path, while `paths` might contain duplicated path if one wants
345
+ # to read same file multiple times.
346
+ for path in paths:
347
+ yield path, path_to_size[path]
348
+
349
+
350
+ def _get_file_infos_parallel(
351
+ paths: List[str],
352
+ filesystem: "pyarrow.fs.FileSystem",
353
+ ignore_missing_paths: bool = False,
354
+ ) -> Iterator[Tuple[str, int]]:
355
+ from ray.data.datasource.file_based_datasource import (
356
+ PATHS_PER_FILE_SIZE_FETCH_TASK,
357
+ _unwrap_s3_serialization_workaround,
358
+ _wrap_s3_serialization_workaround,
359
+ )
360
+
361
+ logger.warning(
362
+ f"Expanding {len(paths)} path(s). This may be a HIGH LATENCY "
363
+ f"operation on some cloud storage services. Moving all the "
364
+ "paths to a common parent directory will lead to faster "
365
+ "metadata fetching."
366
+ )
367
+
368
+ # Capture the filesystem in the fetcher func closure, but wrap it in our
369
+ # serialization workaround to make sure that the pickle roundtrip works as expected.
370
+ filesystem = _wrap_s3_serialization_workaround(filesystem)
371
+
372
+ def _file_infos_fetcher(paths: List[str]) -> List[Tuple[str, int]]:
373
+ fs = _unwrap_s3_serialization_workaround(filesystem)
374
+ return list(
375
+ itertools.chain.from_iterable(
376
+ _get_file_infos(path, fs, ignore_missing_paths) for path in paths
377
+ )
378
+ )
379
+
380
+ yield from _fetch_metadata_parallel(
381
+ paths, _file_infos_fetcher, PATHS_PER_FILE_SIZE_FETCH_TASK
382
+ )
383
+
384
+
385
+ Uri = TypeVar("Uri")
386
+ Meta = TypeVar("Meta")
387
+
388
+
389
+ def _fetch_metadata_parallel(
390
+ uris: List[Uri],
391
+ fetch_func: Callable[[List[Uri]], List[Meta]],
392
+ desired_uris_per_task: int,
393
+ **ray_remote_args,
394
+ ) -> Iterator[Meta]:
395
+ """Fetch file metadata in parallel using Ray tasks."""
396
+ remote_fetch_func = cached_remote_fn(fetch_func)
397
+ if ray_remote_args:
398
+ remote_fetch_func = remote_fetch_func.options(**ray_remote_args)
399
+ # Choose a parallelism that results in a # of metadata fetches per task that
400
+ # dominates the Ray task overhead while ensuring good parallelism.
401
+ # Always launch at least 2 parallel fetch tasks.
402
+ parallelism = max(len(uris) // desired_uris_per_task, 2)
403
+ metadata_fetch_bar = ProgressBar(
404
+ "Metadata Fetch Progress", total=parallelism, unit="task"
405
+ )
406
+ fetch_tasks = []
407
+ for uri_chunk in np.array_split(uris, parallelism):
408
+ if len(uri_chunk) == 0:
409
+ continue
410
+ fetch_tasks.append(remote_fetch_func.remote(uri_chunk))
411
+ results = metadata_fetch_bar.fetch_until_complete(fetch_tasks)
412
+ yield from itertools.chain.from_iterable(results)
413
+
414
+
415
+ def _get_file_infos(
416
+ path: str, filesystem: "pyarrow.fs.FileSystem", ignore_missing_path: bool = False
417
+ ) -> List[Tuple[str, int]]:
418
+ """Get the file info for all files at or under the provided path."""
419
+ from pyarrow.fs import FileType
420
+
421
+ file_infos = []
422
+ try:
423
+ ctx = ray.data.DataContext.get_current()
424
+ file_info = call_with_retry(
425
+ lambda: filesystem.get_file_info(path),
426
+ description="get file info",
427
+ match=ctx.retried_io_errors,
428
+ )
429
+ except OSError as e:
430
+ _handle_read_os_error(e, path)
431
+ if file_info.type == FileType.Directory:
432
+ for file_path, file_size in _expand_directory(path, filesystem):
433
+ file_infos.append((file_path, file_size))
434
+ elif file_info.type == FileType.File:
435
+ file_infos.append((path, file_info.size))
436
+ elif file_info.type == FileType.NotFound and ignore_missing_path:
437
+ pass
438
+ else:
439
+ raise FileNotFoundError(path)
440
+
441
+ return file_infos
442
+
443
+
444
+ def _expand_directory(
445
+ path: str,
446
+ filesystem: "pyarrow.fs.FileSystem",
447
+ exclude_prefixes: Optional[List[str]] = None,
448
+ ignore_missing_path: bool = False,
449
+ ) -> List[Tuple[str, int]]:
450
+ """
451
+ Expand the provided directory path to a list of file paths.
452
+
453
+ Args:
454
+ path: The directory path to expand.
455
+ filesystem: The filesystem implementation that should be used for
456
+ reading these files.
457
+ exclude_prefixes: The file relative path prefixes that should be
458
+ excluded from the returned file set. Default excluded prefixes are
459
+ "." and "_".
460
+
461
+ Returns:
462
+ An iterator of (file_path, file_size) tuples.
463
+ """
464
+ if exclude_prefixes is None:
465
+ exclude_prefixes = [".", "_"]
466
+
467
+ from pyarrow.fs import FileSelector
468
+
469
+ selector = FileSelector(path, recursive=True, allow_not_found=ignore_missing_path)
470
+ files = filesystem.get_file_info(selector)
471
+ base_path = selector.base_dir
472
+ out = []
473
+ for file_ in files:
474
+ if not file_.is_file:
475
+ continue
476
+ file_path = file_.path
477
+ if not file_path.startswith(base_path):
478
+ continue
479
+ relative = file_path[len(base_path) :]
480
+ if any(relative.startswith(prefix) for prefix in exclude_prefixes):
481
+ continue
482
+ out.append((file_path, file_.size))
483
+ # We sort the paths to guarantee a stable order.
484
+ return sorted(out)
.venv/lib/python3.11/site-packages/ray/data/datasource/filename_provider.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ from ray.data.block import Block
4
+ from ray.util.annotations import PublicAPI
5
+
6
+
7
+ @PublicAPI(stability="alpha")
8
+ class FilenameProvider:
9
+ """Generates filenames when you write a :class:`~ray.data.Dataset`.
10
+
11
+ Use this class to customize the filenames used when writing a Dataset.
12
+
13
+ Some methods write each row to a separate file, while others write each block to a
14
+ separate file. For example, :meth:`ray.data.Dataset.write_images` writes individual
15
+ rows, and :func:`ray.data.Dataset.write_parquet` writes blocks of data. For more
16
+ information about blocks, see :ref:`Data internals <datasets_scheduling>`.
17
+
18
+ If you're writing each row to a separate file, implement
19
+ :meth:`~FilenameProvider.get_filename_for_row`. Otherwise, implement
20
+ :meth:`~FilenameProvider.get_filename_for_block`.
21
+
22
+ Example:
23
+
24
+ This snippet shows you how to encode labels in written files. For example, if
25
+ `"cat"` is a label, you might write a file named `cat_000000_000000_000000.png`.
26
+
27
+ .. testcode::
28
+
29
+ import ray
30
+ from ray.data.datasource import FilenameProvider
31
+
32
+ class ImageFilenameProvider(FilenameProvider):
33
+
34
+ def __init__(self, file_format: str):
35
+ self.file_format = file_format
36
+
37
+ def get_filename_for_row(self, row, task_index, block_index, row_index):
38
+ return (
39
+ f"{row['label']}_{task_index:06}_{block_index:06}"
40
+ f"_{row_index:06}.{self.file_format}"
41
+ )
42
+
43
+ ds = ray.data.read_parquet("s3://anonymous@ray-example-data/images.parquet")
44
+ ds.write_images(
45
+ "/tmp/results",
46
+ column="image",
47
+ filename_provider=ImageFilenameProvider("png")
48
+ )
49
+ """ # noqa: E501
50
+
51
+ def get_filename_for_block(
52
+ self, block: Block, task_index: int, block_index: int
53
+ ) -> str:
54
+ """Generate a filename for a block of data.
55
+
56
+ .. note::
57
+ Filenames must be unique and deterministic for a given task and block index.
58
+
59
+ A block consists of multiple rows and corresponds to a single output file.
60
+ Each task might produce a different number of blocks.
61
+
62
+ Args:
63
+ block: The block that will be written to a file.
64
+ task_index: The index of the the write task.
65
+ block_index: The index of the block *within* the write task.
66
+ """
67
+ raise NotImplementedError
68
+
69
+ def get_filename_for_row(
70
+ self, row: Dict[str, Any], task_index: int, block_index: int, row_index: int
71
+ ) -> str:
72
+ """Generate a filename for a row.
73
+
74
+ .. note::
75
+ Filenames must be unique and deterministic for a given task, block, and row
76
+ index.
77
+
78
+ A block consists of multiple rows, and each row corresponds to a single
79
+ output file. Each task might produce a different number of blocks, and each
80
+ block might contain a different number of rows.
81
+
82
+ .. tip::
83
+ If you require a contiguous row index into the global dataset, use
84
+ :meth:`~ray.data.Dataset.iter_rows`. This method is single-threaded and
85
+ isn't recommended for large datasets.
86
+
87
+ Args:
88
+ row: The row that will be written to a file.
89
+ task_index: The index of the the write task.
90
+ block_index: The index of the block *within* the write task.
91
+ row_index: The index of the row *within* the block.
92
+ """
93
+ raise NotImplementedError
94
+
95
+
96
+ class _DefaultFilenameProvider(FilenameProvider):
97
+ def __init__(
98
+ self, dataset_uuid: Optional[str] = None, file_format: Optional[str] = None
99
+ ):
100
+ self._dataset_uuid = dataset_uuid
101
+ self._file_format = file_format
102
+
103
+ def get_filename_for_block(
104
+ self, block: Block, task_index: int, block_index: int
105
+ ) -> str:
106
+ file_id = f"{task_index:06}_{block_index:06}"
107
+ return self._generate_filename(file_id)
108
+
109
+ def get_filename_for_row(
110
+ self, row: Dict[str, Any], task_index: int, block_index: int, row_index: int
111
+ ) -> str:
112
+ file_id = f"{task_index:06}_{block_index:06}_{row_index:06}"
113
+ return self._generate_filename(file_id)
114
+
115
+ def _generate_filename(self, file_id: str) -> str:
116
+ filename = ""
117
+ if self._dataset_uuid is not None:
118
+ filename += f"{self._dataset_uuid}_"
119
+ filename += file_id
120
+ if self._file_format is not None:
121
+ filename += f".{self._file_format}"
122
+ return filename
.venv/lib/python3.11/site-packages/ray/data/datasource/parquet_meta_provider.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, List, Optional, Union
2
+
3
+ import ray.cloudpickle as cloudpickle
4
+ from ray.data._internal.util import call_with_retry
5
+ from ray.data.block import BlockMetadata
6
+ from ray.data.datasource.file_meta_provider import (
7
+ FileMetadataProvider,
8
+ _fetch_metadata_parallel,
9
+ )
10
+ from ray.util.annotations import DeveloperAPI
11
+
12
+ if TYPE_CHECKING:
13
+ import pyarrow
14
+
15
+ from ray.data._internal.datasource.parquet_datasource import SerializedFragment
16
+
17
+
18
+ FRAGMENTS_PER_META_FETCH = 6
19
+ PARALLELIZE_META_FETCH_THRESHOLD = 24
20
+
21
+ # The application-level exceptions to retry for metadata prefetching task.
22
+ # Default to retry on access denied and read timeout errors because AWS S3 would throw
23
+ # these transient errors when load is too high.
24
+ RETRY_EXCEPTIONS_FOR_META_FETCH_TASK = ["AWS Error ACCESS_DENIED", "Timeout"]
25
+ # Maximum number of retries for metadata prefetching task due to transient errors.
26
+ RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK = 32
27
+ # Maximum retry back-off interval in seconds for failed metadata prefetching task.
28
+ RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK = 64
29
+
30
+
31
+ class _ParquetFileFragmentMetaData:
32
+ """Class to store metadata of a Parquet file fragment. This includes
33
+ all attributes from `pyarrow.parquet.FileMetaData` except for `schema`,
34
+ which is stored in `self.schema_pickled` as a pickled object from
35
+ `cloudpickle.loads()`, used in deduplicating schemas across multiple fragments."""
36
+
37
+ def __init__(self, fragment_metadata: "pyarrow.parquet.FileMetaData"):
38
+ self.created_by = fragment_metadata.created_by
39
+ self.format_version = fragment_metadata.format_version
40
+ self.num_columns = fragment_metadata.num_columns
41
+ self.num_row_groups = fragment_metadata.num_row_groups
42
+ self.num_rows = fragment_metadata.num_rows
43
+ self.serialized_size = fragment_metadata.serialized_size
44
+ # This is a pickled schema object, to be set later with
45
+ # `self.set_schema_pickled()`. To get the underlying schema, use
46
+ # `cloudpickle.loads(self.schema_pickled)`.
47
+ self.schema_pickled = None
48
+
49
+ # Calculate the total byte size of the file fragment using the original
50
+ # object, as it is not possible to access row groups from this class.
51
+ self.total_byte_size = 0
52
+ for row_group_idx in range(fragment_metadata.num_row_groups):
53
+ row_group_metadata = fragment_metadata.row_group(row_group_idx)
54
+ self.total_byte_size += row_group_metadata.total_byte_size
55
+
56
+ def set_schema_pickled(self, schema_pickled: bytes):
57
+ """Note: to get the underlying schema, use
58
+ `cloudpickle.loads(self.schema_pickled)`."""
59
+ self.schema_pickled = schema_pickled
60
+
61
+
62
+ @DeveloperAPI
63
+ class ParquetMetadataProvider(FileMetadataProvider):
64
+ """Provides block metadata for Arrow Parquet file fragments."""
65
+
66
+ def _get_block_metadata(
67
+ self,
68
+ paths: List[str],
69
+ schema: Optional[Union[type, "pyarrow.lib.Schema"]],
70
+ *,
71
+ num_fragments: int,
72
+ prefetched_metadata: Optional[List["_ParquetFileFragmentMetaData"]],
73
+ ) -> BlockMetadata:
74
+ """Resolves and returns block metadata for files of a single dataset block.
75
+
76
+ Args:
77
+ paths: The file paths for a single dataset block.
78
+ schema: The user-provided or inferred schema for the given file
79
+ paths, if any.
80
+ num_fragments: The number of Parquet file fragments derived from the input
81
+ file paths.
82
+ prefetched_metadata: Metadata previously returned from
83
+ `prefetch_file_metadata()` for each file fragment, where
84
+ `prefetched_metadata[i]` contains the metadata for `fragments[i]`.
85
+
86
+ Returns:
87
+ BlockMetadata aggregated across the given file paths.
88
+ """
89
+ if (
90
+ prefetched_metadata is not None
91
+ and len(prefetched_metadata) == num_fragments
92
+ and all(m is not None for m in prefetched_metadata)
93
+ ):
94
+ # Fragment metadata was available, construct a normal
95
+ # BlockMetadata.
96
+ block_metadata = BlockMetadata(
97
+ num_rows=sum(m.num_rows for m in prefetched_metadata),
98
+ size_bytes=sum(m.total_byte_size for m in prefetched_metadata),
99
+ schema=schema,
100
+ input_files=paths,
101
+ exec_stats=None,
102
+ ) # Exec stats filled in later.
103
+ else:
104
+ # Fragment metadata was not available, construct an empty
105
+ # BlockMetadata.
106
+ block_metadata = BlockMetadata(
107
+ num_rows=None,
108
+ size_bytes=None,
109
+ schema=schema,
110
+ input_files=paths,
111
+ exec_stats=None,
112
+ )
113
+ return block_metadata
114
+
115
+ def prefetch_file_metadata(
116
+ self,
117
+ fragments: List["pyarrow.dataset.ParquetFileFragment"],
118
+ **ray_remote_args,
119
+ ) -> Optional[List[_ParquetFileFragmentMetaData]]:
120
+ """Pre-fetches file metadata for all Parquet file fragments in a single batch.
121
+
122
+ Subsets of the metadata returned will be provided as input to subsequent calls
123
+ to ``_get_block_metadata`` together with their corresponding Parquet file
124
+ fragments.
125
+
126
+ Args:
127
+ fragments: The Parquet file fragments to fetch metadata for.
128
+
129
+ Returns:
130
+ Metadata resolved for each input file fragment, or `None`. Metadata
131
+ must be returned in the same order as all input file fragments, such
132
+ that `metadata[i]` always contains the metadata for `fragments[i]`.
133
+ """
134
+ from ray.data._internal.datasource.parquet_datasource import SerializedFragment
135
+
136
+ if len(fragments) > PARALLELIZE_META_FETCH_THRESHOLD:
137
+ # Wrap Parquet fragments in serialization workaround.
138
+ fragments = [SerializedFragment(fragment) for fragment in fragments]
139
+ # Fetch Parquet metadata in parallel using Ray tasks.
140
+
141
+ def fetch_func(fragments):
142
+ return _fetch_metadata_serialization_wrapper(
143
+ fragments,
144
+ # Ensure that retry settings are propagated to remote tasks.
145
+ retry_match=RETRY_EXCEPTIONS_FOR_META_FETCH_TASK,
146
+ retry_max_attempts=RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK,
147
+ retry_max_interval=RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK,
148
+ )
149
+
150
+ raw_metadata = list(
151
+ _fetch_metadata_parallel(
152
+ fragments,
153
+ fetch_func,
154
+ FRAGMENTS_PER_META_FETCH,
155
+ **ray_remote_args,
156
+ )
157
+ )
158
+ else:
159
+ raw_metadata = _fetch_metadata(fragments)
160
+
161
+ return _dedupe_metadata(raw_metadata)
162
+
163
+
164
+ def _fetch_metadata_serialization_wrapper(
165
+ fragments: List["SerializedFragment"],
166
+ retry_match: Optional[List[str]],
167
+ retry_max_attempts: int,
168
+ retry_max_interval: int,
169
+ ) -> List["pyarrow.parquet.FileMetaData"]:
170
+ from ray.data._internal.datasource.parquet_datasource import (
171
+ _deserialize_fragments_with_retry,
172
+ )
173
+
174
+ deserialized_fragments = _deserialize_fragments_with_retry(fragments)
175
+ try:
176
+ metadata = call_with_retry(
177
+ lambda: _fetch_metadata(deserialized_fragments),
178
+ description="fetch metdata",
179
+ match=retry_match,
180
+ max_attempts=retry_max_attempts,
181
+ max_backoff_s=retry_max_interval,
182
+ )
183
+ except OSError as e:
184
+ raise RuntimeError(
185
+ f"Exceeded maximum number of attempts ({retry_max_attempts}) to retry "
186
+ "metadata fetching task. Metadata fetching tasks can fail due to transient "
187
+ "errors like rate limiting.\n"
188
+ "\n"
189
+ "To increase the maximum number of attempts, configure "
190
+ "`RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK`. For example:\n"
191
+ "```\n"
192
+ "ray.data._internal.datasource.parquet_datasource.RETRY_MAX_ATTEMPTS_FOR_META_FETCH_TASK = 64\n" # noqa: E501
193
+ "```\n"
194
+ "To increase the maximum retry backoff interval, configure "
195
+ "`RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK`. For example:\n"
196
+ "```\n"
197
+ "ray.data._internal.datasource.parquet_datasource.RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK = 128\n" # noqa: E501
198
+ "```\n"
199
+ "If the error continues to occur, you can also try decresasing the "
200
+ "concurency of metadata fetching tasks by setting "
201
+ "`NUM_CPUS_FOR_META_FETCH_TASK` to a larger value. For example:\n"
202
+ "```\n"
203
+ "ray.data._internal.datasource.parquet_datasource.NUM_CPUS_FOR_META_FETCH_TASK = 4.\n" # noqa: E501
204
+ "```\n"
205
+ "To change which exceptions to retry on, set "
206
+ "`RETRY_EXCEPTIONS_FOR_META_FETCH_TASK` to a list of error messages. For "
207
+ "example:\n"
208
+ "```\n"
209
+ 'ray.data._internal.datasource.parquet_datasource.RETRY_EXCEPTIONS_FOR_META_FETCH_TASK = ["AWS Error ACCESS_DENIED", "Timeout"]\n' # noqa: E501
210
+ "```"
211
+ ) from e
212
+ return metadata
213
+
214
+
215
+ def _fetch_metadata(
216
+ fragments: List["pyarrow.dataset.ParquetFileFragment"],
217
+ ) -> List["pyarrow.parquet.FileMetaData"]:
218
+ fragment_metadata = []
219
+ for f in fragments:
220
+ try:
221
+ fragment_metadata.append(f.metadata)
222
+ except AttributeError:
223
+ break
224
+ return fragment_metadata
225
+
226
+
227
+ def _dedupe_metadata(
228
+ raw_metadatas: List["pyarrow.parquet.FileMetaData"],
229
+ ) -> List[_ParquetFileFragmentMetaData]:
230
+ """For datasets with a large number of columns, the FileMetaData
231
+ (in particular the schema) can be very large. We can reduce the
232
+ memory usage by only keeping unique schema objects across all
233
+ file fragments. This method deduplicates the schemas and returns
234
+ a list of `_ParquetFileFragmentMetaData` objects."""
235
+ schema_to_id = {} # schema_id -> serialized_schema
236
+ id_to_schema = {} # serialized_schema -> schema_id
237
+ stripped_metadatas = []
238
+ for fragment_metadata in raw_metadatas:
239
+ stripped_md = _ParquetFileFragmentMetaData(fragment_metadata)
240
+
241
+ schema_ser = cloudpickle.dumps(fragment_metadata.schema.to_arrow_schema())
242
+ if schema_ser not in schema_to_id:
243
+ schema_id = len(schema_to_id)
244
+ schema_to_id[schema_ser] = schema_id
245
+ id_to_schema[schema_id] = schema_ser
246
+ stripped_md.set_schema_pickled(schema_ser)
247
+ else:
248
+ schema_id = schema_to_id.get(schema_ser)
249
+ existing_schema_ser = id_to_schema[schema_id]
250
+ stripped_md.set_schema_pickled(existing_schema_ser)
251
+ stripped_metadatas.append(stripped_md)
252
+ return stripped_metadatas
.venv/lib/python3.11/site-packages/ray/data/exceptions.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Callable
3
+
4
+ from ray.data._internal.logging import get_log_directory
5
+ from ray.data.context import DataContext
6
+ from ray.exceptions import UserCodeException
7
+ from ray.util import log_once
8
+ from ray.util.annotations import DeveloperAPI
9
+ from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @DeveloperAPI
15
+ class RayDataUserCodeException(UserCodeException):
16
+ """Represents an Exception originating from user code, e.g.
17
+ user-specified UDF used in a Ray Data transformation.
18
+
19
+ By default, the frames corresponding to Ray Data internal files are
20
+ omitted from the stack trace logged to stdout, but will still be
21
+ emitted to the Ray Data specific log file. To emit all stack frames to stdout,
22
+ set `DataContext.log_internal_stack_trace_to_stdout` to True."""
23
+
24
+ pass
25
+
26
+
27
+ @DeveloperAPI
28
+ class SystemException(Exception):
29
+ """Represents an Exception originating from Ray Data internal code
30
+ or Ray Core private code paths, as opposed to user code. When
31
+ Exceptions of this form are raised, it likely indicates a bug
32
+ in Ray Data or Ray Core."""
33
+
34
+ pass
35
+
36
+
37
+ @DeveloperAPI
38
+ def omit_traceback_stdout(fn: Callable) -> Callable:
39
+ """Decorator which runs the function, and if there is an exception raised,
40
+ drops the stack trace before re-raising the exception. The original exception,
41
+ including the full unmodified stack trace, is always written to the Ray Data
42
+ log file at `data_exception_logger._log_path`.
43
+
44
+ This is useful for stripping long stack traces of internal Ray Data code,
45
+ which can otherwise obfuscate user code errors."""
46
+
47
+ def handle_trace(*args, **kwargs):
48
+ try:
49
+ return fn(*args, **kwargs)
50
+ except Exception as e:
51
+ # Only log the full internal stack trace to stdout when configured
52
+ # via DataContext, or when the Ray Debugger is enabled.
53
+ # The full stack trace will always be emitted to the Ray Data log file.
54
+ log_to_stdout = DataContext.get_current().log_internal_stack_trace_to_stdout
55
+ if _is_ray_debugger_post_mortem_enabled():
56
+ logger.exception("Full stack trace:")
57
+ raise e
58
+
59
+ is_user_code_exception = isinstance(e, UserCodeException)
60
+ if is_user_code_exception:
61
+ # Exception has occurred in user code.
62
+ if not log_to_stdout and log_once("ray_data_exception_internal_hidden"):
63
+ logger.error(
64
+ "Exception occurred in user code, with the abbreviated stack "
65
+ "trace below. By default, the Ray Data internal stack trace "
66
+ "is omitted from stdout, and only written to the Ray Data log "
67
+ f"files at {get_log_directory()}. To "
68
+ "output the full stack trace to stdout, set "
69
+ "`DataContext.log_internal_stack_trace_to_stdout` to True."
70
+ )
71
+ else:
72
+ # Exception has occurred in internal Ray Data / Ray Core code.
73
+ logger.error(
74
+ "Exception occurred in Ray Data or Ray Core internal code. "
75
+ "If you continue to see this error, please open an issue on "
76
+ "the Ray project GitHub page with the full stack trace below: "
77
+ "https://github.com/ray-project/ray/issues/new/choose"
78
+ )
79
+
80
+ should_hide_traceback = is_user_code_exception and not log_to_stdout
81
+ logger.exception(
82
+ "Full stack trace:",
83
+ exc_info=True,
84
+ extra={"hide": should_hide_traceback},
85
+ )
86
+ if is_user_code_exception:
87
+ raise e.with_traceback(None)
88
+ else:
89
+ raise e.with_traceback(None) from SystemException()
90
+
91
+ return handle_trace
.venv/lib/python3.11/site-packages/ray/data/grouped_data.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
3
+
4
+ from ray.data._internal.aggregate import Count, Max, Mean, Min, Std, Sum
5
+ from ray.data._internal.compute import ComputeStrategy
6
+ from ray.data._internal.logical.interfaces import LogicalPlan
7
+ from ray.data._internal.logical.operators.all_to_all_operator import Aggregate
8
+ from ray.data.aggregate import AggregateFn
9
+ from ray.data.block import (
10
+ BlockAccessor,
11
+ CallableClass,
12
+ DataBatch,
13
+ UserDefinedFunction,
14
+ _get_block_boundaries,
15
+ )
16
+ from ray.data.dataset import Dataset
17
+ from ray.util.annotations import PublicAPI
18
+
19
+ CDS_API_GROUP = "Computations or Descriptive Stats"
20
+ FA_API_GROUP = "Function Application"
21
+
22
+
23
+ class GroupedData:
24
+ """Represents a grouped dataset created by calling ``Dataset.groupby()``.
25
+
26
+ The actual groupby is deferred until an aggregation is applied.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset: Dataset,
32
+ key: Optional[Union[str, List[str]]],
33
+ ):
34
+ """Construct a dataset grouped by key (internal API).
35
+
36
+ The constructor is not part of the GroupedData API.
37
+ Use the ``Dataset.groupby()`` method to construct one.
38
+ """
39
+ self._dataset = dataset
40
+ self._key = key
41
+
42
+ def __repr__(self) -> str:
43
+ return (
44
+ f"{self.__class__.__name__}(dataset={self._dataset}, " f"key={self._key!r})"
45
+ )
46
+
47
+ @PublicAPI(api_group=FA_API_GROUP)
48
+ def aggregate(self, *aggs: AggregateFn) -> Dataset:
49
+ """Implements an accumulator-based aggregation.
50
+
51
+ Args:
52
+ aggs: Aggregations to do.
53
+
54
+ Returns:
55
+ The output is an dataset of ``n + 1`` columns where the first column
56
+ is the groupby key and the second through ``n + 1`` columns are the
57
+ results of the aggregations.
58
+ If groupby key is ``None`` then the key part of return is omitted.
59
+ """
60
+
61
+ plan = self._dataset._plan.copy()
62
+ op = Aggregate(
63
+ self._dataset._logical_plan.dag,
64
+ key=self._key,
65
+ aggs=aggs,
66
+ )
67
+ logical_plan = LogicalPlan(op, self._dataset.context)
68
+ return Dataset(
69
+ plan,
70
+ logical_plan,
71
+ )
72
+
73
+ def _aggregate_on(
74
+ self,
75
+ agg_cls: type,
76
+ on: Union[str, List[str]],
77
+ *args,
78
+ **kwargs,
79
+ ):
80
+ """Helper for aggregating on a particular subset of the dataset.
81
+
82
+ This validates the `on` argument, and converts a list of column names
83
+ to a multi-aggregation. A null `on` results in a
84
+ multi-aggregation on all columns for an Arrow Dataset, and a single
85
+ aggregation on the entire row for a simple Dataset.
86
+ """
87
+ aggs = self._dataset._build_multicolumn_aggs(
88
+ agg_cls, on, *args, skip_cols=self._key, **kwargs
89
+ )
90
+ return self.aggregate(*aggs)
91
+
92
+ @PublicAPI(api_group=FA_API_GROUP)
93
+ def map_groups(
94
+ self,
95
+ fn: UserDefinedFunction[DataBatch, DataBatch],
96
+ *,
97
+ compute: Union[str, ComputeStrategy] = None,
98
+ batch_format: Optional[str] = "default",
99
+ fn_args: Optional[Iterable[Any]] = None,
100
+ fn_kwargs: Optional[Dict[str, Any]] = None,
101
+ fn_constructor_args: Optional[Iterable[Any]] = None,
102
+ fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
103
+ num_cpus: Optional[float] = None,
104
+ num_gpus: Optional[float] = None,
105
+ concurrency: Optional[Union[int, Tuple[int, int]]] = None,
106
+ **ray_remote_args,
107
+ ) -> "Dataset":
108
+ """Apply the given function to each group of records of this dataset.
109
+
110
+ While map_groups() is very flexible, note that it comes with downsides:
111
+ * It may be slower than using more specific methods such as min(), max().
112
+ * It requires that each group fits in memory on a single node.
113
+
114
+ In general, prefer to use aggregate() instead of map_groups().
115
+
116
+ .. warning::
117
+ Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental,
118
+ and may result in scheduling or stability issues. Please
119
+ `report any issues <https://github.com/ray-project/ray/issues/new/choose>`_
120
+ to the Ray team.
121
+
122
+ Examples:
123
+ >>> # Return a single record per group (list of multiple records in,
124
+ >>> # list of a single record out).
125
+ >>> import ray
126
+ >>> import pandas as pd
127
+ >>> import numpy as np
128
+ >>> # Get first value per group.
129
+ >>> ds = ray.data.from_items([ # doctest: +SKIP
130
+ ... {"group": 1, "value": 1},
131
+ ... {"group": 1, "value": 2},
132
+ ... {"group": 2, "value": 3},
133
+ ... {"group": 2, "value": 4}])
134
+ >>> ds.groupby("group").map_groups( # doctest: +SKIP
135
+ ... lambda g: {"result": np.array([g["value"][0]])})
136
+
137
+ >>> # Return multiple records per group (dataframe in, dataframe out).
138
+ >>> df = pd.DataFrame(
139
+ ... {"A": ["a", "a", "b"], "B": [1, 1, 3], "C": [4, 6, 5]}
140
+ ... )
141
+ >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
142
+ >>> grouped = ds.groupby("A") # doctest: +SKIP
143
+ >>> grouped.map_groups( # doctest: +SKIP
144
+ ... lambda g: g.apply(
145
+ ... lambda c: c / g[c.name].sum() if c.name in ["B", "C"] else c
146
+ ... )
147
+ ... ) # doctest: +SKIP
148
+
149
+ Args:
150
+ fn: The function to apply to each group of records, or a class type
151
+ that can be instantiated to create such a callable. It takes as
152
+ input a batch of all records from a single group, and returns a
153
+ batch of zero or more records, similar to map_batches().
154
+ compute: The compute strategy, either "tasks" (default) to use Ray
155
+ tasks, ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed-size actor
156
+ pool, or ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` for an
157
+ autoscaling actor pool.
158
+ batch_format: Specify ``"default"`` to use the default block format
159
+ (NumPy), ``"pandas"`` to select ``pandas.DataFrame``, "pyarrow" to
160
+ select ``pyarrow.Table``, or ``"numpy"`` to select
161
+ ``Dict[str, numpy.ndarray]``, or None to return the underlying block
162
+ exactly as is with no additional formatting.
163
+ fn_args: Arguments to `fn`.
164
+ fn_kwargs: Keyword arguments to `fn`.
165
+ fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
166
+ You can only provide this if ``fn`` is a callable class. These arguments
167
+ are top-level arguments in the underlying Ray actor construction task.
168
+ fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor.
169
+ This can only be provided if ``fn`` is a callable class. These arguments
170
+ are top-level arguments in the underlying Ray actor construction task.
171
+ num_cpus: The number of CPUs to reserve for each parallel map worker.
172
+ num_gpus: The number of GPUs to reserve for each parallel map worker. For
173
+ example, specify `num_gpus=1` to request 1 GPU for each parallel map
174
+ worker.
175
+ ray_remote_args: Additional resource requirements to request from
176
+ Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
177
+ :func:`ray.remote` for details.
178
+
179
+ Returns:
180
+ The return type is determined by the return type of ``fn``, and the return
181
+ value is combined from results of all groups.
182
+ """
183
+ # Globally sort records by key.
184
+ # Note that sort() will ensure that records of the same key partitioned
185
+ # into the same block.
186
+ if self._key is not None:
187
+ sorted_ds = self._dataset.sort(self._key)
188
+ else:
189
+ sorted_ds = self._dataset.repartition(1)
190
+
191
+ # The batch is the entire block, because we have batch_size=None for
192
+ # map_batches() below.
193
+ def apply_udf_to_groups(udf, batch, *args, **kwargs):
194
+ block = BlockAccessor.batch_to_block(batch)
195
+ block_accessor = BlockAccessor.for_block(block)
196
+
197
+ # Get the list of boundaries including first start and last end indices
198
+ if self._key:
199
+ projected_block = block_accessor.to_numpy(self._key)
200
+
201
+ # get_block_boundaries() expects a list of arrays
202
+ if isinstance(self._key, str):
203
+ projected_block = [projected_block]
204
+ else:
205
+ # projected_block is a dict of arrays
206
+ projected_block = list(projected_block.values())
207
+
208
+ boundaries = _get_block_boundaries(projected_block)
209
+ else:
210
+ boundaries = [0, block_accessor.num_rows()]
211
+
212
+ for start, end in zip(boundaries[:-1], boundaries[1:]):
213
+ group_block = block_accessor.slice(start, end, copy=False)
214
+ group_block_accessor = BlockAccessor.for_block(group_block)
215
+ # Convert block of each group to batch format here, because the
216
+ # block format here can be different from batch format
217
+ # (e.g. block is Arrow format, and batch is NumPy format).
218
+ group_batch = group_block_accessor.to_batch_format(batch_format)
219
+ applied = udf(group_batch, *args, **kwargs)
220
+ yield applied
221
+
222
+ if isinstance(fn, CallableClass):
223
+
224
+ class wrapped_fn:
225
+ def __init__(self, *args, **kwargs):
226
+ self.fn = fn(*args, **kwargs)
227
+
228
+ def __call__(self, batch, *args, **kwargs):
229
+ yield from apply_udf_to_groups(self.fn, batch, *args, **kwargs)
230
+
231
+ else:
232
+
233
+ def wrapped_fn(batch, *args, **kwargs):
234
+ yield from apply_udf_to_groups(fn, batch, *args, **kwargs)
235
+
236
+ # Change the name of the wrapped function so that users see the name of their
237
+ # function rather than `wrapped_fn` in the progress bar.
238
+ if isinstance(fn, partial):
239
+ wrapped_fn.__name__ = fn.func.__name__
240
+ else:
241
+ wrapped_fn.__name__ = fn.__name__
242
+
243
+ # Note we set batch_size=None here, so it will use the entire block as a batch,
244
+ # which ensures that each group will be contained within a batch in entirety.
245
+ return sorted_ds._map_batches_without_batch_size_validation(
246
+ wrapped_fn,
247
+ batch_size=None,
248
+ compute=compute,
249
+ batch_format=batch_format,
250
+ zero_copy_batch=False,
251
+ fn_args=fn_args,
252
+ fn_kwargs=fn_kwargs,
253
+ fn_constructor_args=fn_constructor_args,
254
+ fn_constructor_kwargs=fn_constructor_kwargs,
255
+ num_cpus=num_cpus,
256
+ num_gpus=num_gpus,
257
+ concurrency=concurrency,
258
+ ray_remote_args_fn=None,
259
+ **ray_remote_args,
260
+ )
261
+
262
+ @PublicAPI(api_group=CDS_API_GROUP)
263
+ def count(self) -> Dataset:
264
+ """Compute count aggregation.
265
+
266
+ Examples:
267
+ >>> import ray
268
+ >>> ray.data.from_items([ # doctest: +SKIP
269
+ ... {"A": x % 3, "B": x} for x in range(100)]).groupby( # doctest: +SKIP
270
+ ... "A").count() # doctest: +SKIP
271
+
272
+ Returns:
273
+ A dataset of ``[k, v]`` columns where ``k`` is the groupby key and
274
+ ``v`` is the number of rows with that key.
275
+ If groupby key is ``None`` then the key part of return is omitted.
276
+ """
277
+ return self.aggregate(Count())
278
+
279
+ @PublicAPI(api_group=CDS_API_GROUP)
280
+ def sum(
281
+ self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
282
+ ) -> Dataset:
283
+ r"""Compute grouped sum aggregation.
284
+
285
+ Examples:
286
+ >>> import ray
287
+ >>> ray.data.from_items([ # doctest: +SKIP
288
+ ... (i % 3, i, i**2) # doctest: +SKIP
289
+ ... for i in range(100)]) \ # doctest: +SKIP
290
+ ... .groupby(lambda x: x[0] % 3) \ # doctest: +SKIP
291
+ ... .sum(lambda x: x[2]) # doctest: +SKIP
292
+ >>> ray.data.range(100).groupby("id").sum() # doctest: +SKIP
293
+ >>> ray.data.from_items([ # doctest: +SKIP
294
+ ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP
295
+ ... for i in range(100)]) \ # doctest: +SKIP
296
+ ... .groupby("A") \ # doctest: +SKIP
297
+ ... .sum(["B", "C"]) # doctest: +SKIP
298
+
299
+ Args:
300
+ on: a column name or a list of column names to aggregate.
301
+ ignore_nulls: Whether to ignore null values. If ``True``, null
302
+ values will be ignored when computing the sum; if ``False``,
303
+ if a null value is encountered, the output will be null.
304
+ We consider np.nan, None, and pd.NaT to be null values.
305
+ Default is ``True``.
306
+
307
+ Returns:
308
+ The sum result.
309
+
310
+ For different values of ``on``, the return varies:
311
+
312
+ - ``on=None``: a dataset containing a groupby key column,
313
+ ``"k"``, and a column-wise sum column for each original column
314
+ in the dataset.
315
+ - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1``
316
+ columns where the first column is the groupby key and the second
317
+ through ``n + 1`` columns are the results of the aggregations.
318
+
319
+ If groupby key is ``None`` then the key part of return is omitted.
320
+ """
321
+ return self._aggregate_on(Sum, on, ignore_nulls=ignore_nulls)
322
+
323
+ @PublicAPI(api_group=CDS_API_GROUP)
324
+ def min(
325
+ self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
326
+ ) -> Dataset:
327
+ """Compute grouped min aggregation.
328
+
329
+ Examples:
330
+ >>> import ray
331
+ >>> ray.data.le(100).groupby("value").min() # doctest: +SKIP
332
+ >>> ray.data.from_items([ # doctest: +SKIP
333
+ ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP
334
+ ... for i in range(100)]) \ # doctest: +SKIP
335
+ ... .groupby("A") \ # doctest: +SKIP
336
+ ... .min(["B", "C"]) # doctest: +SKIP
337
+
338
+ Args:
339
+ on: a column name or a list of column names to aggregate.
340
+ ignore_nulls: Whether to ignore null values. If ``True``, null
341
+ values will be ignored when computing the min; if ``False``,
342
+ if a null value is encountered, the output will be null.
343
+ We consider np.nan, None, and pd.NaT to be null values.
344
+ Default is ``True``.
345
+
346
+ Returns:
347
+ The min result.
348
+
349
+ For different values of ``on``, the return varies:
350
+
351
+ - ``on=None``: a dataset containing a groupby key column,
352
+ ``"k"``, and a column-wise min column for each original column in
353
+ the dataset.
354
+ - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1``
355
+ columns where the first column is the groupby key and the second
356
+ through ``n + 1`` columns are the results of the aggregations.
357
+
358
+ If groupby key is ``None`` then the key part of return is omitted.
359
+ """
360
+ return self._aggregate_on(Min, on, ignore_nulls=ignore_nulls)
361
+
362
+ @PublicAPI(api_group=CDS_API_GROUP)
363
+ def max(
364
+ self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
365
+ ) -> Dataset:
366
+ """Compute grouped max aggregation.
367
+
368
+ Examples:
369
+ >>> import ray
370
+ >>> ray.data.le(100).groupby("value").max() # doctest: +SKIP
371
+ >>> ray.data.from_items([ # doctest: +SKIP
372
+ ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP
373
+ ... for i in range(100)]) \ # doctest: +SKIP
374
+ ... .groupby("A") \ # doctest: +SKIP
375
+ ... .max(["B", "C"]) # doctest: +SKIP
376
+
377
+ Args:
378
+ on: a column name or a list of column names to aggregate.
379
+ ignore_nulls: Whether to ignore null values. If ``True``, null
380
+ values will be ignored when computing the max; if ``False``,
381
+ if a null value is encountered, the output will be null.
382
+ We consider np.nan, None, and pd.NaT to be null values.
383
+ Default is ``True``.
384
+
385
+ Returns:
386
+ The max result.
387
+
388
+ For different values of ``on``, the return varies:
389
+
390
+ - ``on=None``: a dataset containing a groupby key column,
391
+ ``"k"``, and a column-wise max column for each original column in
392
+ the dataset.
393
+ - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1``
394
+ columns where the first column is the groupby key and the second
395
+ through ``n + 1`` columns are the results of the aggregations.
396
+
397
+ If groupby key is ``None`` then the key part of return is omitted.
398
+ """
399
+ return self._aggregate_on(Max, on, ignore_nulls=ignore_nulls)
400
+
401
+ @PublicAPI(api_group=CDS_API_GROUP)
402
+ def mean(
403
+ self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
404
+ ) -> Dataset:
405
+ """Compute grouped mean aggregation.
406
+
407
+ Examples:
408
+ >>> import ray
409
+ >>> ray.data.le(100).groupby("value").mean() # doctest: +SKIP
410
+ >>> ray.data.from_items([ # doctest: +SKIP
411
+ ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP
412
+ ... for i in range(100)]) \ # doctest: +SKIP
413
+ ... .groupby("A") \ # doctest: +SKIP
414
+ ... .mean(["B", "C"]) # doctest: +SKIP
415
+
416
+ Args:
417
+ on: a column name or a list of column names to aggregate.
418
+ ignore_nulls: Whether to ignore null values. If ``True``, null
419
+ values will be ignored when computing the mean; if ``False``,
420
+ if a null value is encountered, the output will be null.
421
+ We consider np.nan, None, and pd.NaT to be null values.
422
+ Default is ``True``.
423
+
424
+ Returns:
425
+ The mean result.
426
+
427
+ For different values of ``on``, the return varies:
428
+
429
+ - ``on=None``: a dataset containing a groupby key column,
430
+ ``"k"``, and a column-wise mean column for each original column
431
+ in the dataset.
432
+ - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1``
433
+ columns where the first column is the groupby key and the second
434
+ through ``n + 1`` columns are the results of the aggregations.
435
+
436
+ If groupby key is ``None`` then the key part of return is omitted.
437
+ """
438
+ return self._aggregate_on(Mean, on, ignore_nulls=ignore_nulls)
439
+
440
+ @PublicAPI(api_group=CDS_API_GROUP)
441
+ def std(
442
+ self,
443
+ on: Union[str, List[str]] = None,
444
+ ddof: int = 1,
445
+ ignore_nulls: bool = True,
446
+ ) -> Dataset:
447
+ """Compute grouped standard deviation aggregation.
448
+
449
+ Examples:
450
+ >>> import ray
451
+ >>> ray.data.range(100).groupby("id").std(ddof=0) # doctest: +SKIP
452
+ >>> ray.data.from_items([ # doctest: +SKIP
453
+ ... {"A": i % 3, "B": i, "C": i**2} # doctest: +SKIP
454
+ ... for i in range(100)]) \ # doctest: +SKIP
455
+ ... .groupby("A") \ # doctest: +SKIP
456
+ ... .std(["B", "C"]) # doctest: +SKIP
457
+
458
+ NOTE: This uses Welford's online method for an accumulator-style
459
+ computation of the standard deviation. This method was chosen due to
460
+ it's numerical stability, and it being computable in a single pass.
461
+ This may give different (but more accurate) results than NumPy, Pandas,
462
+ and sklearn, which use a less numerically stable two-pass algorithm.
463
+ See
464
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
465
+
466
+ Args:
467
+ on: a column name or a list of column names to aggregate.
468
+ ddof: Delta Degrees of Freedom. The divisor used in calculations
469
+ is ``N - ddof``, where ``N`` represents the number of elements.
470
+ ignore_nulls: Whether to ignore null values. If ``True``, null
471
+ values will be ignored when computing the std; if ``False``,
472
+ if a null value is encountered, the output will be null.
473
+ We consider np.nan, None, and pd.NaT to be null values.
474
+ Default is ``True``.
475
+
476
+ Returns:
477
+ The standard deviation result.
478
+
479
+ For different values of ``on``, the return varies:
480
+
481
+ - ``on=None``: a dataset containing a groupby key column,
482
+ ``"k"``, and a column-wise std column for each original column in
483
+ the dataset.
484
+ - ``on=["col_1", ..., "col_n"]``: a dataset of ``n + 1``
485
+ columns where the first column is the groupby key and the second
486
+ through ``n + 1`` columns are the results of the aggregations.
487
+
488
+ If groupby key is ``None`` then the key part of return is omitted.
489
+ """
490
+ return self._aggregate_on(Std, on, ignore_nulls=ignore_nulls, ddof=ddof)
491
+
492
+
493
+ # Backwards compatibility alias.
494
+ GroupedDataset = GroupedData
.venv/lib/python3.11/site-packages/ray/data/iterator.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import time
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ Iterable,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ TypeVar,
14
+ Union,
15
+ )
16
+
17
+ import numpy as np
18
+
19
+ from ray.data._internal.block_batching.iter_batches import iter_batches
20
+ from ray.data._internal.execution.interfaces import RefBundle
21
+ from ray.data._internal.logical.operators.input_data_operator import InputData
22
+ from ray.data._internal.logical.optimizers import LogicalPlan
23
+ from ray.data._internal.plan import ExecutionPlan
24
+ from ray.data._internal.stats import DatasetStats, StatsManager
25
+ from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format
26
+ from ray.util.annotations import PublicAPI
27
+
28
+ if TYPE_CHECKING:
29
+ import tensorflow as tf
30
+ import torch
31
+
32
+ from ray.data.dataset import (
33
+ CollatedData,
34
+ MaterializedDataset,
35
+ Schema,
36
+ TensorFlowTensorBatchType,
37
+ TorchBatchType,
38
+ )
39
+
40
+
41
+ T = TypeVar("T")
42
+
43
+
44
+ class _IterableFromIterator(Iterable[T]):
45
+ def __init__(self, iterator_gen: Callable[[], Iterator[T]]):
46
+ """Constructs an Iterable from an iterator generator.
47
+
48
+ Args:
49
+ iterator_gen: A function that returns an iterator each time it
50
+ is called. For example, this can be a generator function.
51
+ """
52
+ self.iterator_gen = iterator_gen
53
+
54
+ def __iter__(self):
55
+ return self.iterator_gen()
56
+
57
+
58
+ @PublicAPI
59
+ class DataIterator(abc.ABC):
60
+ """An iterator for reading records from a :class:`~Dataset`.
61
+
62
+ For Datasets, each iteration call represents a complete read of all items in the
63
+ Dataset.
64
+
65
+ If using Ray Train, each trainer actor should get its own iterator by calling
66
+ :meth:`ray.train.get_dataset_shard("train")
67
+ <ray.train.get_dataset_shard>`.
68
+
69
+ Examples:
70
+ >>> import ray
71
+ >>> ds = ray.data.range(5)
72
+ >>> ds
73
+ Dataset(num_rows=5, schema={id: int64})
74
+ >>> ds.iterator()
75
+ DataIterator(Dataset(num_rows=5, schema={id: int64}))
76
+ """
77
+
78
+ @abc.abstractmethod
79
+ def _to_ref_bundle_iterator(
80
+ self,
81
+ ) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]:
82
+ """Returns the iterator to use for `iter_batches`.
83
+
84
+ Returns:
85
+ A tuple. The first item of the tuple is an iterator over RefBundles.
86
+ The second item of the tuple is a DatasetStats object used for recording
87
+ stats during iteration.
88
+ The third item is a boolean indicating if the blocks can be safely cleared
89
+ after use.
90
+ """
91
+ raise NotImplementedError
92
+
93
+ @PublicAPI
94
+ def iter_batches(
95
+ self,
96
+ *,
97
+ prefetch_batches: int = 1,
98
+ batch_size: int = 256,
99
+ batch_format: Optional[str] = "default",
100
+ drop_last: bool = False,
101
+ local_shuffle_buffer_size: Optional[int] = None,
102
+ local_shuffle_seed: Optional[int] = None,
103
+ _collate_fn: Optional[Callable[[DataBatch], "CollatedData"]] = None,
104
+ _finalize_fn: Optional[Callable[[Any], Any]] = None,
105
+ ) -> Iterable[DataBatch]:
106
+ """Return a batched iterable over the dataset.
107
+
108
+ Examples:
109
+ >>> import ray
110
+ >>> for batch in ray.data.range(
111
+ ... 1000000
112
+ ... ).iterator().iter_batches(): # doctest: +SKIP
113
+ ... print(batch) # doctest: +SKIP
114
+
115
+ Time complexity: O(1)
116
+
117
+ Args:
118
+ prefetch_batches: The number of batches to fetch ahead of the current batch
119
+ to fetch. If set to greater than 0, a separate threadpool will be used
120
+ to fetch the objects to the local node, format the batches, and apply
121
+ the collate_fn. Defaults to 1.
122
+ batch_size: The number of rows in each batch, or None to use entire blocks
123
+ as batches (blocks may contain different number of rows).
124
+ The final batch may include fewer than ``batch_size`` rows if
125
+ ``drop_last`` is ``False``. Defaults to 256.
126
+ batch_format: Specify ``"default"`` to use the default block format
127
+ (NumPy), ``"pandas"`` to select ``pandas.DataFrame``, "pyarrow" to
128
+ select ``pyarrow.Table``, or ``"numpy"`` to select
129
+ ``Dict[str, numpy.ndarray]``, or None to return the underlying block
130
+ exactly as is with no additional formatting.
131
+ drop_last: Whether to drop the last batch if it's incomplete.
132
+ local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
133
+ using a local in-memory shuffle buffer, and this value will serve as the
134
+ minimum number of rows that must be in the local in-memory shuffle
135
+ buffer in order to yield a batch. When there are no more rows to add to
136
+ the buffer, the remaining rows in the buffer will be drained.
137
+ local_shuffle_seed: The seed to use for the local random shuffle.
138
+
139
+ Returns:
140
+ An iterable over record batches.
141
+ """
142
+ batch_format = _apply_batch_format(batch_format)
143
+
144
+ def _create_iterator() -> Iterator[DataBatch]:
145
+ time_start = time.perf_counter()
146
+ # Iterate through the dataset from the start each time
147
+ # _iterator_gen is called.
148
+ # This allows multiple iterations of the dataset without
149
+ # needing to explicitly call `iter_batches()` multiple times.
150
+ (
151
+ ref_bundles_iterator,
152
+ stats,
153
+ blocks_owned_by_consumer,
154
+ ) = self._to_ref_bundle_iterator()
155
+
156
+ iterator = iter(
157
+ iter_batches(
158
+ ref_bundles_iterator,
159
+ stats=stats,
160
+ clear_block_after_read=blocks_owned_by_consumer,
161
+ batch_size=batch_size,
162
+ batch_format=batch_format,
163
+ drop_last=drop_last,
164
+ collate_fn=_collate_fn,
165
+ finalize_fn=_finalize_fn,
166
+ shuffle_buffer_min_size=local_shuffle_buffer_size,
167
+ shuffle_seed=local_shuffle_seed,
168
+ prefetch_batches=prefetch_batches,
169
+ )
170
+ )
171
+
172
+ dataset_tag = self._get_dataset_tag()
173
+
174
+ if stats:
175
+ stats.iter_initialize_s.add(time.perf_counter() - time_start)
176
+
177
+ for batch in iterator:
178
+ yield batch
179
+ StatsManager.update_iteration_metrics(stats, dataset_tag)
180
+ StatsManager.clear_iteration_metrics(dataset_tag)
181
+
182
+ if stats:
183
+ stats.iter_total_s.add(time.perf_counter() - time_start)
184
+
185
+ return _IterableFromIterator(_create_iterator)
186
+
187
+ def _get_dataset_tag(self) -> str:
188
+ return "unknown_dataset"
189
+
190
+ @PublicAPI
191
+ def iter_rows(self) -> Iterable[Dict[str, Any]]:
192
+ """Return a local row iterable over the dataset.
193
+
194
+ If the dataset is a tabular dataset (Arrow/Pandas blocks), dicts
195
+ are yielded for each row by the iterator. If the dataset is not tabular,
196
+ the raw row is yielded.
197
+
198
+ Examples:
199
+ >>> import ray
200
+ >>> dataset = ray.data.range(10)
201
+ >>> next(iter(dataset.iterator().iter_rows()))
202
+ {'id': 0}
203
+
204
+ Time complexity: O(1)
205
+
206
+ Returns:
207
+ An iterable over rows of the dataset.
208
+ """
209
+ batch_iterable = self.iter_batches(
210
+ batch_size=None, batch_format=None, prefetch_batches=1
211
+ )
212
+
213
+ def _wrapped_iterator():
214
+ for batch in batch_iterable:
215
+ batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch))
216
+ for row in batch.iter_rows(public_row_format=True):
217
+ yield row
218
+
219
+ return _IterableFromIterator(_wrapped_iterator)
220
+
221
+ @abc.abstractmethod
222
+ @PublicAPI
223
+ def stats(self) -> str:
224
+ """Returns a string containing execution timing information."""
225
+ raise NotImplementedError
226
+
227
+ @abc.abstractmethod
228
+ def schema(self) -> "Schema":
229
+ """Return the schema of the dataset iterated over."""
230
+ raise NotImplementedError
231
+
232
+ @PublicAPI
233
+ def iter_torch_batches(
234
+ self,
235
+ *,
236
+ prefetch_batches: int = 1,
237
+ batch_size: Optional[int] = 256,
238
+ dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
239
+ device: str = "auto",
240
+ collate_fn: Optional[Callable[[Dict[str, np.ndarray]], "CollatedData"]] = None,
241
+ drop_last: bool = False,
242
+ local_shuffle_buffer_size: Optional[int] = None,
243
+ local_shuffle_seed: Optional[int] = None,
244
+ ) -> Iterable["TorchBatchType"]:
245
+ """Return a batched iterable of Torch Tensors over the dataset.
246
+
247
+ This iterable yields a dictionary of column-tensors. If you are looking for
248
+ more flexibility in the tensor conversion (e.g. casting dtypes) or the batch
249
+ format, try using :meth:`~ray.data.DataIterator.iter_batches` directly.
250
+
251
+ Examples:
252
+ >>> import ray
253
+ >>> for batch in ray.data.range(
254
+ ... 12,
255
+ ... ).iterator().iter_torch_batches(batch_size=4):
256
+ ... print(batch)
257
+ {'id': tensor([0, 1, 2, 3])}
258
+ {'id': tensor([4, 5, 6, 7])}
259
+ {'id': tensor([ 8, 9, 10, 11])}
260
+
261
+ Use the ``collate_fn`` to customize how the tensor batch is created.
262
+
263
+ >>> from typing import Any, Dict
264
+ >>> import torch
265
+ >>> import numpy as np
266
+ >>> import ray
267
+ >>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
268
+ ... return torch.stack(
269
+ ... [torch.as_tensor(array) for array in batch.values()],
270
+ ... axis=1
271
+ ... )
272
+ >>> iterator = ray.data.from_items([
273
+ ... {"col_1": 1, "col_2": 2},
274
+ ... {"col_1": 3, "col_2": 4}]).iterator()
275
+ >>> for batch in iterator.iter_torch_batches(collate_fn=collate_fn):
276
+ ... print(batch)
277
+ tensor([[1, 2],
278
+ [3, 4]])
279
+
280
+ Time complexity: O(1)
281
+
282
+ Args:
283
+ prefetch_batches: The number of batches to fetch ahead of the current batch
284
+ to fetch. If set to greater than 0, a separate threadpool will be used
285
+ to fetch the objects to the local node, format the batches, and apply
286
+ the collate_fn. Defaults to 1.
287
+ batch_size: The number of rows in each batch, or None to use entire blocks
288
+ as batches (blocks may contain different number of rows).
289
+ The final batch may include fewer than ``batch_size`` rows if
290
+ ``drop_last`` is ``False``. Defaults to 256.
291
+ dtypes: The Torch dtype(s) for the created tensor(s); if None, the dtype
292
+ will be inferred from the tensor data. You can't use this parameter
293
+ with ``collate_fn``.
294
+ device: The device on which the tensor should be placed. Defaults to
295
+ "auto" which moves the tensors to the appropriate device when the
296
+ Dataset is passed to Ray Train and ``collate_fn`` is not provided.
297
+ Otherwise, defaults to CPU. You can't use this parameter with
298
+ ``collate_fn``.
299
+ collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
300
+ When this parameter is specified, the user should manually handle the
301
+ host to device data transfer outside of ``collate_fn``.
302
+ This is useful for further processing the data after it has been
303
+ batched. Potential use cases include collating along a dimension other
304
+ than the first, padding sequences of various lengths, or generally
305
+ handling batches of different length tensors. If not provided, the
306
+ default collate function is used which simply converts the batch of
307
+ numpy arrays to a batch of PyTorch tensors. This API is still
308
+ experimental and is subject to change. You can't use this parameter in
309
+ conjunction with ``dtypes`` or ``device``.
310
+ drop_last: Whether to drop the last batch if it's incomplete.
311
+ local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
312
+ using a local in-memory shuffle buffer, and this value will serve as the
313
+ minimum number of rows that must be in the local in-memory shuffle
314
+ buffer in order to yield a batch. When there are no more rows to add to
315
+ the buffer, the remaining rows in the buffer will be drained. This
316
+ buffer size must be greater than or equal to ``batch_size``, and
317
+ therefore ``batch_size`` must also be specified when using local
318
+ shuffling.
319
+ local_shuffle_seed: The seed to use for the local random shuffle.
320
+
321
+ Returns:
322
+ An iterable over Torch Tensor batches.
323
+ """
324
+
325
+ from ray.air._internal.torch_utils import (
326
+ convert_ndarray_batch_to_torch_tensor_batch,
327
+ )
328
+ from ray.train.torch import get_device
329
+
330
+ if collate_fn is not None and (dtypes is not None or device != "auto"):
331
+ raise ValueError(
332
+ "collate_fn cannot be used with dtypes and device."
333
+ "You should manually move the output Torch tensors to the"
334
+ "desired dtype and device outside of collate_fn."
335
+ )
336
+
337
+ if device == "auto":
338
+ # Use the appropriate device for Ray Train, or falls back to CPU if
339
+ # Ray Train is not being used.
340
+ device = get_device()
341
+
342
+ if collate_fn is None:
343
+ # The default collate_fn handles formatting and Tensor creation.
344
+ # Here, we set device=None to defer host to device data transfer
345
+ # to the subsequent finalize_fn.
346
+ def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]):
347
+ return convert_ndarray_batch_to_torch_tensor_batch(
348
+ batch,
349
+ dtypes=dtypes,
350
+ device=None,
351
+ )
352
+
353
+ # The default finalize_fn handles the host to device data transfer.
354
+ # This is executed in a 1-thread pool separately from collate_fn
355
+ # to allow independent parallelism of these steps.
356
+ def finalize_fn(batch: Union["torch.Tensor", Dict[str, "torch.Tensor"]]):
357
+ if device is not None:
358
+ if isinstance(batch, dict):
359
+ for k, t in batch.items():
360
+ batch[k] = t.to(device=device)
361
+ else:
362
+ batch = batch.to(device=device)
363
+ return batch
364
+
365
+ else:
366
+ finalize_fn = None
367
+
368
+ return self.iter_batches(
369
+ prefetch_batches=prefetch_batches,
370
+ batch_size=batch_size,
371
+ drop_last=drop_last,
372
+ local_shuffle_buffer_size=local_shuffle_buffer_size,
373
+ local_shuffle_seed=local_shuffle_seed,
374
+ _collate_fn=collate_fn,
375
+ _finalize_fn=finalize_fn,
376
+ )
377
+
378
+ def iter_tf_batches(
379
+ self,
380
+ *,
381
+ prefetch_batches: int = 1,
382
+ batch_size: Optional[int] = 256,
383
+ dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None,
384
+ drop_last: bool = False,
385
+ local_shuffle_buffer_size: Optional[int] = None,
386
+ local_shuffle_seed: Optional[int] = None,
387
+ ) -> Iterable["TensorFlowTensorBatchType"]:
388
+ """Return a batched iterable of TensorFlow Tensors over the dataset.
389
+
390
+ This iterable will yield single-tensor batches of the underlying dataset
391
+ consists of a single column; otherwise, it will yield a dictionary of
392
+ column-tensors.
393
+
394
+ .. tip::
395
+ If you don't need the additional flexibility provided by this method,
396
+ consider using :meth:`~ray.data.Dataset.to_tf` instead. It's easier
397
+ to use.
398
+
399
+ Examples:
400
+ >>> import ray
401
+ >>> for batch in ray.data.range( # doctest: +SKIP
402
+ ... 12,
403
+ ... ).iter_tf_batches(batch_size=4):
404
+ ... print(batch.shape) # doctest: +SKIP
405
+ (4, 1)
406
+ (4, 1)
407
+ (4, 1)
408
+
409
+ Time complexity: O(1)
410
+
411
+ Args:
412
+ prefetch_batches: The number of batches to fetch ahead of the current batch
413
+ to fetch. If set to greater than 0, a separate threadpool will be used
414
+ to fetch the objects to the local node, format the batches, and apply
415
+ the collate_fn. Defaults to 1.
416
+ batch_size: The number of rows in each batch, or None to use entire blocks
417
+ as batches (blocks may contain different number of rows).
418
+ The final batch may include fewer than ``batch_size`` rows if
419
+ ``drop_last`` is ``False``. Defaults to 256.
420
+ dtypes: The TensorFlow dtype(s) for the created tensor(s); if None, the
421
+ dtype will be inferred from the tensor data.
422
+ drop_last: Whether to drop the last batch if it's incomplete.
423
+ local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
424
+ using a local in-memory shuffle buffer, and this value will serve as the
425
+ minimum number of rows that must be in the local in-memory shuffle
426
+ buffer in order to yield a batch. When there are no more rows to add to
427
+ the buffer, the remaining rows in the buffer will be drained. This
428
+ buffer size must be greater than or equal to ``batch_size``, and
429
+ therefore ``batch_size`` must also be specified when using local
430
+ shuffling.
431
+ local_shuffle_seed: The seed to use for the local random shuffle.
432
+
433
+ Returns:
434
+ An iterator over TensorFlow Tensor batches.
435
+ """
436
+ from ray.air._internal.tensorflow_utils import (
437
+ convert_ndarray_batch_to_tf_tensor_batch,
438
+ )
439
+
440
+ batch_iterable = self.iter_batches(
441
+ prefetch_batches=prefetch_batches,
442
+ batch_size=batch_size,
443
+ drop_last=drop_last,
444
+ local_shuffle_buffer_size=local_shuffle_buffer_size,
445
+ local_shuffle_seed=local_shuffle_seed,
446
+ )
447
+ mapped_iterable = map(
448
+ lambda batch: convert_ndarray_batch_to_tf_tensor_batch(
449
+ batch, dtypes=dtypes
450
+ ),
451
+ batch_iterable,
452
+ )
453
+
454
+ return mapped_iterable
455
+
456
+ def to_torch(
457
+ self,
458
+ *,
459
+ label_column: Optional[str] = None,
460
+ feature_columns: Optional[
461
+ Union[List[str], List[List[str]], Dict[str, List[str]]]
462
+ ] = None,
463
+ label_column_dtype: Optional["torch.dtype"] = None,
464
+ feature_column_dtypes: Optional[
465
+ Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]]
466
+ ] = None,
467
+ batch_size: int = 1,
468
+ prefetch_batches: int = 1,
469
+ drop_last: bool = False,
470
+ local_shuffle_buffer_size: Optional[int] = None,
471
+ local_shuffle_seed: Optional[int] = None,
472
+ unsqueeze_label_tensor: bool = True,
473
+ unsqueeze_feature_tensors: bool = True,
474
+ ) -> "torch.utils.data.IterableDataset":
475
+ """Return a Torch IterableDataset over this dataset.
476
+
477
+ This is only supported for datasets convertible to Arrow records.
478
+
479
+ It is recommended to use the returned ``IterableDataset`` directly
480
+ instead of passing it into a torch ``DataLoader``.
481
+
482
+ Each element in IterableDataset will be a tuple consisting of 2
483
+ elements. The first item contains the feature tensor(s), and the
484
+ second item is the label tensor. Those can take on different
485
+ forms, depending on the specified arguments.
486
+
487
+ For the features tensor (N is the ``batch_size`` and n, m, k
488
+ are the number of features per tensor):
489
+
490
+ * If ``feature_columns`` is a ``List[str]``, the features will be
491
+ a tensor of shape (N, n), with columns corresponding to
492
+ ``feature_columns``
493
+
494
+ * If ``feature_columns`` is a ``List[List[str]]``, the features will be
495
+ a list of tensors of shape [(N, m),...,(N, k)], with columns of each
496
+ tensor corresponding to the elements of ``feature_columns``
497
+
498
+ * If ``feature_columns`` is a ``Dict[str, List[str]]``, the features
499
+ will be a dict of key-tensor pairs of shape
500
+ {key1: (N, m),..., keyN: (N, k)}, with columns of each
501
+ tensor corresponding to the value of ``feature_columns`` under the
502
+ key.
503
+
504
+ If ``unsqueeze_label_tensor=True`` (default), the label tensor will be
505
+ of shape (N, 1). Otherwise, it will be of shape (N,).
506
+ If ``label_column`` is specified as ``None``, then no column from the
507
+ ``Dataset`` will be treated as the label, and the output label tensor
508
+ will be ``None``.
509
+
510
+ Note that you probably want to call ``.split()`` on this dataset if
511
+ there are to be multiple Torch workers consuming the data.
512
+
513
+ Time complexity: O(1)
514
+
515
+ Args:
516
+ label_column: The name of the column used as the
517
+ label (second element of the output list). Can be None for
518
+ prediction, in which case the second element of returned
519
+ tuple will also be None.
520
+ feature_columns: The names of the columns
521
+ to use as the features. Can be a list of lists or
522
+ a dict of string-list pairs for multi-tensor output.
523
+ If None, then use all columns except the label column as
524
+ the features.
525
+ label_column_dtype: The torch dtype to
526
+ use for the label column. If None, then automatically infer
527
+ the dtype.
528
+ feature_column_dtypes: The dtypes to use for the feature
529
+ tensors. This should match the format of ``feature_columns``,
530
+ or be a single dtype, in which case it will be applied to
531
+ all tensors. If None, then automatically infer the dtype.
532
+ batch_size: How many samples per batch to yield at a time.
533
+ Defaults to 1.
534
+ prefetch_batches: The number of batches to fetch ahead of the current batch
535
+ to fetch. If set to greater than 0, a separate threadpool will be used
536
+ to fetch the objects to the local node, format the batches, and apply
537
+ the collate_fn. Defaults to 1.
538
+ drop_last: Set to True to drop the last incomplete batch,
539
+ if the dataset size is not divisible by the batch size. If
540
+ False and the size of dataset is not divisible by the batch
541
+ size, then the last batch will be smaller. Defaults to False.
542
+ local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
543
+ using a local in-memory shuffle buffer, and this value will serve as the
544
+ minimum number of rows that must be in the local in-memory shuffle
545
+ buffer in order to yield a batch. When there are no more rows to add to
546
+ the buffer, the remaining rows in the buffer will be drained. This
547
+ buffer size must be greater than or equal to ``batch_size``, and
548
+ therefore ``batch_size`` must also be specified when using local
549
+ shuffling.
550
+ local_shuffle_seed: The seed to use for the local random shuffle.
551
+ unsqueeze_label_tensor: If set to True, the label tensor
552
+ will be unsqueezed (reshaped to (N, 1)). Otherwise, it will
553
+ be left as is, that is (N, ). In general, regression loss
554
+ functions expect an unsqueezed tensor, while classification
555
+ loss functions expect a squeezed one. Defaults to True.
556
+ unsqueeze_feature_tensors: If set to True, the features tensors
557
+ will be unsqueezed (reshaped to (N, 1)) before being concatenated into
558
+ the final features tensor. Otherwise, they will be left as is, that is
559
+ (N, ). Defaults to True.
560
+
561
+ Returns:
562
+ A torch IterableDataset.
563
+ """
564
+ import torch
565
+
566
+ from ray.air._internal.torch_utils import convert_pandas_to_torch_tensor
567
+ from ray.data._internal.torch_iterable_dataset import TorchIterableDataset
568
+
569
+ # If an empty collection is passed in, treat it the same as None
570
+ if not feature_columns:
571
+ feature_columns = None
572
+
573
+ if feature_column_dtypes and not isinstance(feature_column_dtypes, torch.dtype):
574
+ if isinstance(feature_columns, dict):
575
+ if not isinstance(feature_column_dtypes, dict):
576
+ raise TypeError(
577
+ "If `feature_columns` is a dict, "
578
+ "`feature_column_dtypes` must be None, `torch.dtype`,"
579
+ f" or dict, got {type(feature_column_dtypes)}."
580
+ )
581
+ if set(feature_columns) != set(feature_column_dtypes):
582
+ raise ValueError(
583
+ "`feature_columns` and `feature_column_dtypes` "
584
+ "must have the same keys."
585
+ )
586
+ if any(not subcolumns for subcolumns in feature_columns.values()):
587
+ raise ValueError("column list may not be empty")
588
+ elif isinstance(feature_columns[0], (list, tuple)):
589
+ if not isinstance(feature_column_dtypes, (list, tuple)):
590
+ raise TypeError(
591
+ "If `feature_columns` is a list of lists, "
592
+ "`feature_column_dtypes` must be None, `torch.dtype`,"
593
+ f" or a sequence, got {type(feature_column_dtypes)}."
594
+ )
595
+ if len(feature_columns) != len(feature_column_dtypes):
596
+ raise ValueError(
597
+ "`feature_columns` and `feature_column_dtypes` "
598
+ "must have the same length."
599
+ )
600
+ if any(not subcolumns for subcolumns in feature_columns):
601
+ raise ValueError("column list may not be empty")
602
+
603
+ def make_generator():
604
+ for batch in self.iter_batches(
605
+ batch_size=batch_size,
606
+ batch_format="pandas",
607
+ prefetch_batches=prefetch_batches,
608
+ drop_last=drop_last,
609
+ local_shuffle_buffer_size=local_shuffle_buffer_size,
610
+ local_shuffle_seed=local_shuffle_seed,
611
+ ):
612
+ if label_column:
613
+ label_tensor = convert_pandas_to_torch_tensor(
614
+ batch,
615
+ [label_column],
616
+ label_column_dtype,
617
+ unsqueeze=unsqueeze_label_tensor,
618
+ )
619
+ batch.pop(label_column)
620
+ else:
621
+ label_tensor = None
622
+
623
+ if isinstance(feature_columns, dict):
624
+ features_tensor = {
625
+ key: convert_pandas_to_torch_tensor(
626
+ batch,
627
+ feature_columns[key],
628
+ (
629
+ feature_column_dtypes[key]
630
+ if isinstance(feature_column_dtypes, dict)
631
+ else feature_column_dtypes
632
+ ),
633
+ unsqueeze=unsqueeze_feature_tensors,
634
+ )
635
+ for key in feature_columns
636
+ }
637
+ else:
638
+ features_tensor = convert_pandas_to_torch_tensor(
639
+ batch,
640
+ columns=feature_columns,
641
+ column_dtypes=feature_column_dtypes,
642
+ unsqueeze=unsqueeze_feature_tensors,
643
+ )
644
+
645
+ yield (features_tensor, label_tensor)
646
+
647
+ return TorchIterableDataset(make_generator)
648
+
649
+ @PublicAPI
650
+ def to_tf(
651
+ self,
652
+ feature_columns: Union[str, List[str]],
653
+ label_columns: Union[str, List[str]],
654
+ *,
655
+ additional_columns: Union[Optional[str], Optional[List[str]]] = None,
656
+ prefetch_batches: int = 1,
657
+ batch_size: int = 1,
658
+ drop_last: bool = False,
659
+ local_shuffle_buffer_size: Optional[int] = None,
660
+ local_shuffle_seed: Optional[int] = None,
661
+ feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
662
+ label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
663
+ additional_type_spec: Union[
664
+ Optional["tf.TypeSpec"], Optional[Dict[str, "tf.TypeSpec"]]
665
+ ] = None,
666
+ ) -> "tf.data.Dataset":
667
+ """Return a TF Dataset over this dataset.
668
+
669
+ .. warning::
670
+ If your dataset contains ragged tensors, this method errors. To prevent
671
+ errors, :ref:`resize your tensors <transforming_tensors>`.
672
+
673
+ Examples:
674
+ >>> import ray
675
+ >>> ds = ray.data.read_csv(
676
+ ... "s3://anonymous@air-example-data/iris.csv"
677
+ ... )
678
+ >>> it = ds.iterator(); it
679
+ DataIterator(Dataset(
680
+ num_rows=?,
681
+ schema={
682
+ sepal length (cm): double,
683
+ sepal width (cm): double,
684
+ petal length (cm): double,
685
+ petal width (cm): double,
686
+ target: int64
687
+ }
688
+ ))
689
+
690
+ If your model accepts a single tensor as input, specify a single feature column.
691
+
692
+ >>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target")
693
+ <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
694
+
695
+ If your model accepts a dictionary as input, specify a list of feature columns.
696
+
697
+ >>> it.to_tf(["sepal length (cm)", "sepal width (cm)"], "target")
698
+ <_OptionsDataset element_spec=({'sepal length (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), 'sepal width (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal width (cm)')}, TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
699
+
700
+ If your dataset contains multiple features but your model accepts a single
701
+ tensor as input, combine features with
702
+ :class:`~ray.data.preprocessors.Concatenator`.
703
+
704
+ >>> from ray.data.preprocessors import Concatenator
705
+ >>> columns_to_concat = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
706
+ >>> preprocessor = Concatenator(columns=columns_to_concat, output_column_name="features")
707
+ >>> it = preprocessor.transform(ds).iterator()
708
+ >>> it
709
+ DataIterator(Concatenator
710
+ +- Dataset(
711
+ num_rows=?,
712
+ schema={
713
+ sepal length (cm): double,
714
+ sepal width (cm): double,
715
+ petal length (cm): double,
716
+ petal width (cm): double,
717
+ target: int64
718
+ }
719
+ ))
720
+ >>> it.to_tf("features", "target")
721
+ <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
722
+
723
+ If your model accepts different types, shapes, or names of tensors as input, specify the type spec.
724
+ If type specs are not specified, they are automatically inferred from the schema of the iterator.
725
+
726
+ >>> import tensorflow as tf
727
+ >>> it.to_tf(
728
+ ... feature_columns="features",
729
+ ... label_columns="target",
730
+ ... feature_type_spec=tf.TensorSpec(shape=(None, 4), dtype=tf.float32, name="features"),
731
+ ... label_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="label")
732
+ ... )
733
+ <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name='features'), TensorSpec(shape=(None,), dtype=tf.float32, name='label'))>
734
+
735
+ If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns.
736
+ A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``.
737
+
738
+ >>> import pandas as pd
739
+ >>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df)))
740
+ >>> it = ds.iterator()
741
+ >>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target", additional_columns="sample weights")
742
+ <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))>
743
+
744
+ If your model accepts different types, shapes, or names for the additional metadata, specify the type spec of the additional column.
745
+
746
+ >>> it.to_tf(
747
+ ... feature_columns="sepal length (cm)",
748
+ ... label_columns="target",
749
+ ... additional_columns="sample weights",
750
+ ... additional_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="weight")
751
+ ... )
752
+ <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.float32, name='weight'))>
753
+
754
+ Args:
755
+ feature_columns: Columns that correspond to model inputs. If this is a
756
+ string, the input data is a tensor. If this is a list, the input data
757
+ is a ``dict`` that maps column names to their tensor representation.
758
+ label_columns: Columns that correspond to model targets. If this is a
759
+ string, the target data is a tensor. If this is a list, the target data
760
+ is a ``dict`` that maps column names to their tensor representation.
761
+ additional_columns: Columns that correspond to sample weights or other metadata.
762
+ If this is a string, the weight data is a tensor. If this is a list, the
763
+ weight data is a ``dict`` that maps column names to their tensor representation.
764
+ prefetch_batches: The number of batches to fetch ahead of the current batch
765
+ to fetch. If set to greater than 0, a separate threadpool will be used
766
+ to fetch the objects to the local node, format the batches, and apply
767
+ the collate_fn. Defaults to 1.
768
+ batch_size: Record batch size. Defaults to 1.
769
+ drop_last: Set to True to drop the last incomplete batch,
770
+ if the dataset size is not divisible by the batch size. If
771
+ False and the size of dataset is not divisible by the batch
772
+ size, then the last batch will be smaller. Defaults to False.
773
+ local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
774
+ using a local in-memory shuffle buffer, and this value will serve as the
775
+ minimum number of rows that must be in the local in-memory shuffle
776
+ buffer in order to yield a batch. When there are no more rows to add to
777
+ the buffer, the remaining rows in the buffer will be drained. This
778
+ buffer size must be greater than or equal to ``batch_size``, and
779
+ therefore ``batch_size`` must also be specified when using local
780
+ shuffling.
781
+ local_shuffle_seed: The seed to use for the local random shuffle.
782
+ feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is
783
+ only one column, specify a `tf.TypeSpec`. If there are multiple columns,
784
+ specify a ``dict`` that maps column names to their `tf.TypeSpec`.
785
+ Default is `None` to automatically infer the type of each column.
786
+ label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is
787
+ only one column, specify a `tf.TypeSpec`. If there are multiple columns,
788
+ specify a ``dict`` that maps column names to their `tf.TypeSpec`.
789
+ Default is `None` to automatically infer the type of each column.
790
+ additional_type_spec: The `tf.TypeSpec` of `additional_columns`. If there
791
+ is only one column, specify a `tf.TypeSpec`. If there are multiple
792
+ columns, specify a ``dict`` that maps column names to their `tf.TypeSpec`.
793
+ Default is `None` to automatically infer the type of each column.
794
+
795
+ Returns:
796
+ A ``tf.data.Dataset`` that yields inputs and targets.
797
+ """ # noqa: E501
798
+
799
+ from ray.air._internal.tensorflow_utils import (
800
+ convert_ndarray_to_tf_tensor,
801
+ get_type_spec,
802
+ )
803
+
804
+ try:
805
+ import tensorflow as tf
806
+ except ImportError:
807
+ raise ValueError("tensorflow must be installed!")
808
+
809
+ def validate_column(column: str) -> None:
810
+ if column not in valid_columns:
811
+ raise ValueError(
812
+ f"You specified '{column}' in `feature_columns`, "
813
+ f"`label_columns`, or `additional_columns`, but there's no "
814
+ f"column named '{column}' in the dataset. "
815
+ f"Valid column names are: {valid_columns}."
816
+ )
817
+
818
+ def validate_columns(columns: Union[str, List]) -> None:
819
+ if isinstance(columns, list):
820
+ for column in columns:
821
+ validate_column(column)
822
+ else:
823
+ validate_column(columns)
824
+
825
+ def convert_batch_to_tensors(
826
+ batch: Dict[str, np.ndarray],
827
+ *,
828
+ columns: Union[str, List[str]],
829
+ type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]],
830
+ ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
831
+ if isinstance(columns, str):
832
+ return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec)
833
+ return {
834
+ column: convert_ndarray_to_tf_tensor(
835
+ batch[column], type_spec=type_spec[column]
836
+ )
837
+ for column in columns
838
+ }
839
+
840
+ def generator():
841
+ for batch in self.iter_batches(
842
+ prefetch_batches=prefetch_batches,
843
+ batch_size=batch_size,
844
+ drop_last=drop_last,
845
+ local_shuffle_buffer_size=local_shuffle_buffer_size,
846
+ local_shuffle_seed=local_shuffle_seed,
847
+ ):
848
+ assert isinstance(batch, dict)
849
+ features = convert_batch_to_tensors(
850
+ batch, columns=feature_columns, type_spec=feature_type_spec
851
+ )
852
+ labels = convert_batch_to_tensors(
853
+ batch, columns=label_columns, type_spec=label_type_spec
854
+ )
855
+
856
+ if additional_columns is None:
857
+ yield features, labels
858
+ else:
859
+ additional_metadata = convert_batch_to_tensors(
860
+ batch,
861
+ columns=additional_columns,
862
+ type_spec=additional_type_spec,
863
+ )
864
+ yield features, labels, additional_metadata
865
+
866
+ if feature_type_spec is None or label_type_spec is None:
867
+ schema = self.schema()
868
+ valid_columns = set(schema.names)
869
+ validate_columns(feature_columns)
870
+ validate_columns(label_columns)
871
+ feature_type_spec = get_type_spec(schema, columns=feature_columns)
872
+ label_type_spec = get_type_spec(schema, columns=label_columns)
873
+
874
+ if additional_columns is not None and additional_type_spec is None:
875
+ schema = self.schema()
876
+ valid_columns = set(schema.names)
877
+ validate_columns(additional_columns)
878
+ additional_type_spec = get_type_spec(schema, columns=additional_columns)
879
+
880
+ if additional_columns is not None:
881
+ dataset = tf.data.Dataset.from_generator(
882
+ generator,
883
+ output_signature=(
884
+ feature_type_spec,
885
+ label_type_spec,
886
+ additional_type_spec,
887
+ ),
888
+ )
889
+ else:
890
+ dataset = tf.data.Dataset.from_generator(
891
+ generator, output_signature=(feature_type_spec, label_type_spec)
892
+ )
893
+
894
+ options = tf.data.Options()
895
+ options.experimental_distribute.auto_shard_policy = (
896
+ tf.data.experimental.AutoShardPolicy.OFF
897
+ )
898
+ return dataset.with_options(options)
899
+
900
+ @PublicAPI
901
+ def materialize(self) -> "MaterializedDataset":
902
+ """Execute and materialize this data iterator into object store memory.
903
+
904
+ .. note::
905
+ This method triggers the execution and materializes all blocks
906
+ of the iterator, returning its contents as a
907
+ :class:`~ray.data.dataset.MaterializedDataset` for further processing.
908
+ """
909
+
910
+ from ray.data.dataset import MaterializedDataset
911
+
912
+ ref_bundles_iter, stats, _ = self._to_ref_bundle_iterator()
913
+
914
+ ref_bundles = list(ref_bundles_iter)
915
+ execution_plan = ExecutionPlan(stats)
916
+ logical_plan = LogicalPlan(
917
+ InputData(input_data=ref_bundles),
918
+ execution_plan._context,
919
+ )
920
+ return MaterializedDataset(
921
+ execution_plan,
922
+ logical_plan,
923
+ )
924
+
925
+ def __del__(self):
926
+ # Clear metrics on deletion in case the iterator was not fully consumed.
927
+ StatsManager.clear_iteration_metrics(self._get_dataset_tag())
928
+
929
+
930
+ # Backwards compatibility alias.
931
+ DatasetIterator = DataIterator
.venv/lib/python3.11/site-packages/ray/data/preprocessor.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import base64
3
+ import collections
4
+ import pickle
5
+ import warnings
6
+ from enum import Enum
7
+ from typing import TYPE_CHECKING, Any, Dict, Union
8
+
9
+ from ray.air.util.data_batch_conversion import BatchFormat
10
+ from ray.util.annotations import DeveloperAPI, PublicAPI
11
+
12
+ if TYPE_CHECKING:
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ from ray.air.data_batch_type import DataBatchType
17
+ from ray.data import Dataset
18
+
19
+
20
+ @PublicAPI(stability="beta")
21
+ class PreprocessorNotFittedException(RuntimeError):
22
+ """Error raised when the preprocessor needs to be fitted first."""
23
+
24
+ pass
25
+
26
+
27
+ @PublicAPI(stability="beta")
28
+ class Preprocessor(abc.ABC):
29
+ """Implements an ML preprocessing operation.
30
+
31
+ Preprocessors are stateful objects that can be fitted against a Dataset and used
32
+ to transform both local data batches and distributed data. For example, a
33
+ Normalization preprocessor may calculate the mean and stdev of a field during
34
+ fitting, and uses these attributes to implement its normalization transform.
35
+
36
+ Preprocessors can also be stateless and transform data without needed to be fitted.
37
+ For example, a preprocessor may simply remove a column, which does not require
38
+ any state to be fitted.
39
+
40
+ If you are implementing your own Preprocessor sub-class, you should override the
41
+ following:
42
+
43
+ * ``_fit`` if your preprocessor is stateful. Otherwise, set
44
+ ``_is_fittable=False``.
45
+ * ``_transform_pandas`` and/or ``_transform_numpy`` for best performance,
46
+ implement both. Otherwise, the data will be converted to the match the
47
+ implemented method.
48
+ """
49
+
50
+ class FitStatus(str, Enum):
51
+ """The fit status of preprocessor."""
52
+
53
+ NOT_FITTABLE = "NOT_FITTABLE"
54
+ NOT_FITTED = "NOT_FITTED"
55
+ # Only meaningful for Chain preprocessors.
56
+ # At least one contained preprocessor in the chain preprocessor
57
+ # is fitted and at least one that can be fitted is not fitted yet.
58
+ # This is a state that show up if caller only interacts
59
+ # with the chain preprocessor through intended Preprocessor APIs.
60
+ PARTIALLY_FITTED = "PARTIALLY_FITTED"
61
+ FITTED = "FITTED"
62
+
63
+ # Preprocessors that do not need to be fitted must override this.
64
+ _is_fittable = True
65
+
66
+ def _check_has_fitted_state(self):
67
+ """Checks if the Preprocessor has fitted state.
68
+
69
+ This is also used as an indiciation if the Preprocessor has been fit, following
70
+ convention from Ray versions prior to 2.6.
71
+ This allows preprocessors that have been fit in older versions of Ray to be
72
+ used to transform data in newer versions.
73
+ """
74
+
75
+ fitted_vars = [v for v in vars(self) if v.endswith("_")]
76
+ return bool(fitted_vars)
77
+
78
+ def fit_status(self) -> "Preprocessor.FitStatus":
79
+ if not self._is_fittable:
80
+ return Preprocessor.FitStatus.NOT_FITTABLE
81
+ elif (
82
+ hasattr(self, "_fitted") and self._fitted
83
+ ) or self._check_has_fitted_state():
84
+ return Preprocessor.FitStatus.FITTED
85
+ else:
86
+ return Preprocessor.FitStatus.NOT_FITTED
87
+
88
+ def fit(self, ds: "Dataset") -> "Preprocessor":
89
+ """Fit this Preprocessor to the Dataset.
90
+
91
+ Fitted state attributes will be directly set in the Preprocessor.
92
+
93
+ Calling it more than once will overwrite all previously fitted state:
94
+ ``preprocessor.fit(A).fit(B)`` is equivalent to ``preprocessor.fit(B)``.
95
+
96
+ Args:
97
+ ds: Input dataset.
98
+
99
+ Returns:
100
+ Preprocessor: The fitted Preprocessor with state attributes.
101
+ """
102
+ fit_status = self.fit_status()
103
+ if fit_status == Preprocessor.FitStatus.NOT_FITTABLE:
104
+ # No-op as there is no state to be fitted.
105
+ return self
106
+
107
+ if fit_status in (
108
+ Preprocessor.FitStatus.FITTED,
109
+ Preprocessor.FitStatus.PARTIALLY_FITTED,
110
+ ):
111
+ warnings.warn(
112
+ "`fit` has already been called on the preprocessor (or at least one "
113
+ "contained preprocessors if this is a chain). "
114
+ "All previously fitted state will be overwritten!"
115
+ )
116
+
117
+ fitted_ds = self._fit(ds)
118
+ self._fitted = True
119
+ return fitted_ds
120
+
121
+ def fit_transform(self, ds: "Dataset") -> "Dataset":
122
+ """Fit this Preprocessor to the Dataset and then transform the Dataset.
123
+
124
+ Calling it more than once will overwrite all previously fitted state:
125
+ ``preprocessor.fit_transform(A).fit_transform(B)``
126
+ is equivalent to ``preprocessor.fit_transform(B)``.
127
+
128
+ Args:
129
+ ds: Input Dataset.
130
+
131
+ Returns:
132
+ ray.data.Dataset: The transformed Dataset.
133
+ """
134
+ self.fit(ds)
135
+ return self.transform(ds)
136
+
137
+ def transform(self, ds: "Dataset") -> "Dataset":
138
+ """Transform the given dataset.
139
+
140
+ Args:
141
+ ds: Input Dataset.
142
+
143
+ Returns:
144
+ ray.data.Dataset: The transformed Dataset.
145
+
146
+ Raises:
147
+ PreprocessorNotFittedException: if ``fit`` is not called yet.
148
+ """
149
+ fit_status = self.fit_status()
150
+ if fit_status in (
151
+ Preprocessor.FitStatus.PARTIALLY_FITTED,
152
+ Preprocessor.FitStatus.NOT_FITTED,
153
+ ):
154
+ raise PreprocessorNotFittedException(
155
+ "`fit` must be called before `transform`, "
156
+ "or simply use fit_transform() to run both steps"
157
+ )
158
+ transformed_ds = self._transform(ds)
159
+ return transformed_ds
160
+
161
+ def transform_batch(self, data: "DataBatchType") -> "DataBatchType":
162
+ """Transform a single batch of data.
163
+
164
+ The data will be converted to the format supported by the Preprocessor,
165
+ based on which ``_transform_*`` methods are implemented.
166
+
167
+ Args:
168
+ data: Input data batch.
169
+
170
+ Returns:
171
+ DataBatchType:
172
+ The transformed data batch. This may differ
173
+ from the input type depending on which ``_transform_*`` methods
174
+ are implemented.
175
+ """
176
+ fit_status = self.fit_status()
177
+ if fit_status in (
178
+ Preprocessor.FitStatus.PARTIALLY_FITTED,
179
+ Preprocessor.FitStatus.NOT_FITTED,
180
+ ):
181
+ raise PreprocessorNotFittedException(
182
+ "`fit` must be called before `transform_batch`."
183
+ )
184
+ return self._transform_batch(data)
185
+
186
+ @DeveloperAPI
187
+ def _fit(self, ds: "Dataset") -> "Preprocessor":
188
+ """Sub-classes should override this instead of fit()."""
189
+ raise NotImplementedError()
190
+
191
+ def _determine_transform_to_use(self) -> BatchFormat:
192
+ """Determine which batch format to use based on Preprocessor implementation.
193
+
194
+ * If only `_transform_pandas` is implemented, then use ``pandas`` batch format.
195
+ * If only `_transform_numpy` is implemented, then use ``numpy`` batch format.
196
+ * If both are implemented, then use the Preprocessor defined preferred batch
197
+ format.
198
+ """
199
+
200
+ has_transform_pandas = (
201
+ self.__class__._transform_pandas != Preprocessor._transform_pandas
202
+ )
203
+ has_transform_numpy = (
204
+ self.__class__._transform_numpy != Preprocessor._transform_numpy
205
+ )
206
+
207
+ if has_transform_numpy and has_transform_pandas:
208
+ return self.preferred_batch_format()
209
+ elif has_transform_numpy:
210
+ return BatchFormat.NUMPY
211
+ elif has_transform_pandas:
212
+ return BatchFormat.PANDAS
213
+ else:
214
+ raise NotImplementedError(
215
+ "None of `_transform_numpy` or `_transform_pandas` are implemented. "
216
+ "At least one of these transform functions must be implemented "
217
+ "for Preprocessor transforms."
218
+ )
219
+
220
+ def _transform(self, ds: "Dataset") -> "Dataset":
221
+ # TODO(matt): Expose `batch_size` or similar configurability.
222
+ # The default may be too small for some datasets and too large for others.
223
+ transform_type = self._determine_transform_to_use()
224
+
225
+ # Our user-facing batch format should only be pandas or NumPy, other
226
+ # formats {arrow, simple} are internal.
227
+ kwargs = self._get_transform_config()
228
+ if transform_type == BatchFormat.PANDAS:
229
+ return ds.map_batches(
230
+ self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs
231
+ )
232
+ elif transform_type == BatchFormat.NUMPY:
233
+ return ds.map_batches(
234
+ self._transform_numpy, batch_format=BatchFormat.NUMPY, **kwargs
235
+ )
236
+ else:
237
+ raise ValueError(
238
+ "Invalid transform type returned from _determine_transform_to_use; "
239
+ f'"pandas" and "numpy" allowed, but got: {transform_type}'
240
+ )
241
+
242
+ def _get_transform_config(self) -> Dict[str, Any]:
243
+ """Returns kwargs to be passed to :meth:`ray.data.Dataset.map_batches`.
244
+
245
+ This can be implemented by subclassing preprocessors.
246
+ """
247
+ return {}
248
+
249
+ def _transform_batch(self, data: "DataBatchType") -> "DataBatchType":
250
+ # For minimal install to locally import air modules
251
+ import numpy as np
252
+ import pandas as pd
253
+
254
+ from ray.air.util.data_batch_conversion import (
255
+ _convert_batch_type_to_numpy,
256
+ _convert_batch_type_to_pandas,
257
+ )
258
+
259
+ try:
260
+ import pyarrow
261
+ except ImportError:
262
+ pyarrow = None
263
+
264
+ if not isinstance(
265
+ data, (pd.DataFrame, pyarrow.Table, collections.abc.Mapping, np.ndarray)
266
+ ):
267
+ raise ValueError(
268
+ "`transform_batch` is currently only implemented for Pandas "
269
+ "DataFrames, pyarrow Tables, NumPy ndarray and dictionary of "
270
+ f"ndarray. Got {type(data)}."
271
+ )
272
+
273
+ transform_type = self._determine_transform_to_use()
274
+
275
+ if transform_type == BatchFormat.PANDAS:
276
+ return self._transform_pandas(_convert_batch_type_to_pandas(data))
277
+ elif transform_type == BatchFormat.NUMPY:
278
+ return self._transform_numpy(_convert_batch_type_to_numpy(data))
279
+
280
+ @DeveloperAPI
281
+ def _transform_pandas(self, df: "pd.DataFrame") -> "pd.DataFrame":
282
+ """Run the transformation on a data batch in a Pandas DataFrame format."""
283
+ raise NotImplementedError()
284
+
285
+ @DeveloperAPI
286
+ def _transform_numpy(
287
+ self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]]
288
+ ) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
289
+ """Run the transformation on a data batch in a NumPy ndarray format."""
290
+ raise NotImplementedError()
291
+
292
+ @classmethod
293
+ @DeveloperAPI
294
+ def preferred_batch_format(cls) -> BatchFormat:
295
+ """Batch format hint for upstream producers to try yielding best block format.
296
+
297
+ The preferred batch format to use if both `_transform_pandas` and
298
+ `_transform_numpy` are implemented. Defaults to Pandas.
299
+
300
+ Can be overriden by Preprocessor classes depending on which transform
301
+ path is the most optimal.
302
+ """
303
+ return BatchFormat.PANDAS
304
+
305
+ @DeveloperAPI
306
+ def serialize(self) -> str:
307
+ """Return this preprocessor serialized as a string.
308
+ Note: this is not a stable serialization format as it uses `pickle`.
309
+ """
310
+ # Convert it to a plain string so that it can be included as JSON metadata
311
+ # in Trainer checkpoints.
312
+ return base64.b64encode(pickle.dumps(self)).decode("ascii")
313
+
314
+ @staticmethod
315
+ @DeveloperAPI
316
+ def deserialize(serialized: str) -> "Preprocessor":
317
+ """Load the original preprocessor serialized via `self.serialize()`."""
318
+ return pickle.loads(base64.b64decode(serialized))
.venv/lib/python3.11/site-packages/ray/data/random_access_dataset.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import logging
3
+ import random
4
+ import time
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING, Any, List, Optional
7
+
8
+ import numpy as np
9
+
10
+ import ray
11
+ from ray.data._internal.execution.interfaces.ref_bundle import (
12
+ _ref_bundles_iterator_to_block_refs_list,
13
+ )
14
+ from ray.data._internal.remote_fn import cached_remote_fn
15
+ from ray.data.block import BlockAccessor
16
+ from ray.data.context import DataContext
17
+ from ray.types import ObjectRef
18
+ from ray.util.annotations import PublicAPI
19
+
20
+ try:
21
+ import pyarrow as pa
22
+ except ImportError:
23
+ pa = None
24
+
25
+ if TYPE_CHECKING:
26
+ from ray.data import Dataset
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @PublicAPI(stability="alpha")
32
+ class RandomAccessDataset:
33
+ """A class that provides distributed, random access to a Dataset.
34
+
35
+ See: ``Dataset.to_random_access_dataset()``.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ ds: "Dataset",
41
+ key: str,
42
+ num_workers: int,
43
+ ):
44
+ """Construct a RandomAccessDataset (internal API).
45
+
46
+ The constructor is a private API. Use ``ds.to_random_access_dataset()``
47
+ to construct a RandomAccessDataset.
48
+ """
49
+ schema = ds.schema(fetch_if_missing=True)
50
+ if schema is None or isinstance(schema, type):
51
+ raise ValueError("RandomAccessDataset only supports Arrow-format blocks.")
52
+
53
+ start = time.perf_counter()
54
+ logger.info("[setup] Indexing dataset by sort key.")
55
+ sorted_ds = ds.sort(key)
56
+ get_bounds = cached_remote_fn(_get_bounds)
57
+ bundles = sorted_ds.iter_internal_ref_bundles()
58
+ blocks = _ref_bundles_iterator_to_block_refs_list(bundles)
59
+
60
+ logger.info("[setup] Computing block range bounds.")
61
+ bounds = ray.get([get_bounds.remote(b, key) for b in blocks])
62
+ self._non_empty_blocks = []
63
+ self._lower_bound = None
64
+ self._upper_bounds = []
65
+ for i, b in enumerate(bounds):
66
+ if b:
67
+ self._non_empty_blocks.append(blocks[i])
68
+ if self._lower_bound is None:
69
+ self._lower_bound = b[0]
70
+ self._upper_bounds.append(b[1])
71
+
72
+ logger.info("[setup] Creating {} random access workers.".format(num_workers))
73
+ ctx = DataContext.get_current()
74
+ scheduling_strategy = ctx.scheduling_strategy
75
+ self._workers = [
76
+ _RandomAccessWorker.options(scheduling_strategy=scheduling_strategy).remote(
77
+ key
78
+ )
79
+ for _ in range(num_workers)
80
+ ]
81
+ (
82
+ self._block_to_workers_map,
83
+ self._worker_to_blocks_map,
84
+ ) = self._compute_block_to_worker_assignments()
85
+
86
+ logger.info(
87
+ "[setup] Worker to blocks assignment: {}".format(self._worker_to_blocks_map)
88
+ )
89
+ ray.get(
90
+ [
91
+ w.assign_blocks.remote(
92
+ {
93
+ i: self._non_empty_blocks[i]
94
+ for i in self._worker_to_blocks_map[w]
95
+ }
96
+ )
97
+ for w in self._workers
98
+ ]
99
+ )
100
+
101
+ logger.info("[setup] Finished assigning blocks to workers.")
102
+ self._build_time = time.perf_counter() - start
103
+
104
+ def _compute_block_to_worker_assignments(self):
105
+ # Return values.
106
+ block_to_workers: dict[int, List["ray.ActorHandle"]] = defaultdict(list)
107
+ worker_to_blocks: dict["ray.ActorHandle", List[int]] = defaultdict(list)
108
+
109
+ # Aux data structures.
110
+ loc_to_workers: dict[str, List["ray.ActorHandle"]] = defaultdict(list)
111
+ locs = ray.get([w.ping.remote() for w in self._workers])
112
+ for i, loc in enumerate(locs):
113
+ loc_to_workers[loc].append(self._workers[i])
114
+ block_locs = ray.experimental.get_object_locations(self._non_empty_blocks)
115
+
116
+ # First, try to assign all blocks to all workers at its location.
117
+ for block_idx, block in enumerate(self._non_empty_blocks):
118
+ block_info = block_locs[block]
119
+ locs = block_info.get("node_ids", [])
120
+ for loc in locs:
121
+ for worker in loc_to_workers[loc]:
122
+ block_to_workers[block_idx].append(worker)
123
+ worker_to_blocks[worker].append(block_idx)
124
+
125
+ # Randomly assign any leftover blocks to at least one worker.
126
+ # TODO: the load balancing here could be improved.
127
+ for block_idx, block in enumerate(self._non_empty_blocks):
128
+ if len(block_to_workers[block_idx]) == 0:
129
+ worker = random.choice(self._workers)
130
+ block_to_workers[block_idx].append(worker)
131
+ worker_to_blocks[worker].append(block_idx)
132
+
133
+ return block_to_workers, worker_to_blocks
134
+
135
+ def get_async(self, key: Any) -> ObjectRef[Any]:
136
+ """Asynchronously finds the record for a single key.
137
+
138
+ Args:
139
+ key: The key of the record to find.
140
+
141
+ Returns:
142
+ ObjectRef containing the record (in pydict form), or None if not found.
143
+ """
144
+ block_index = self._find_le(key)
145
+ if block_index is None:
146
+ return ray.put(None)
147
+ return self._worker_for(block_index).get.remote(block_index, key)
148
+
149
+ def multiget(self, keys: List[Any]) -> List[Optional[Any]]:
150
+ """Synchronously find the records for a list of keys.
151
+
152
+ Args:
153
+ keys: List of keys to find the records for.
154
+
155
+ Returns:
156
+ List of found records (in pydict form), or None for missing records.
157
+ """
158
+ batches = defaultdict(list)
159
+ for k in keys:
160
+ batches[self._find_le(k)].append(k)
161
+ futures = {}
162
+ for index, keybatch in batches.items():
163
+ if index is None:
164
+ continue
165
+ fut = self._worker_for(index).multiget.remote(
166
+ [index] * len(keybatch), keybatch
167
+ )
168
+ futures[index] = fut
169
+ results = {}
170
+ for i, fut in futures.items():
171
+ keybatch = batches[i]
172
+ values = ray.get(fut)
173
+ for k, v in zip(keybatch, values):
174
+ results[k] = v
175
+ return [results.get(k) for k in keys]
176
+
177
+ def stats(self) -> str:
178
+ """Returns a string containing access timing information."""
179
+ stats = ray.get([w.stats.remote() for w in self._workers])
180
+ total_time = sum(s["total_time"] for s in stats)
181
+ accesses = [s["num_accesses"] for s in stats]
182
+ blocks = [s["num_blocks"] for s in stats]
183
+ msg = "RandomAccessDataset:\n"
184
+ msg += "- Build time: {}s\n".format(round(self._build_time, 2))
185
+ msg += "- Num workers: {}\n".format(len(stats))
186
+ msg += "- Blocks per worker: {} min, {} max, {} mean\n".format(
187
+ min(blocks), max(blocks), int(sum(blocks) / len(blocks))
188
+ )
189
+ msg += "- Accesses per worker: {} min, {} max, {} mean\n".format(
190
+ min(accesses), max(accesses), int(sum(accesses) / len(accesses))
191
+ )
192
+ msg += "- Mean access time: {}us\n".format(
193
+ int(total_time / (1 + sum(accesses)) * 1e6)
194
+ )
195
+ return msg
196
+
197
+ def _worker_for(self, block_index: int):
198
+ return random.choice(self._block_to_workers_map[block_index])
199
+
200
+ def _find_le(self, x: Any) -> int:
201
+ i = bisect.bisect_left(self._upper_bounds, x)
202
+ if i >= len(self._upper_bounds) or x < self._lower_bound:
203
+ return None
204
+ return i
205
+
206
+
207
+ @ray.remote(num_cpus=0)
208
+ class _RandomAccessWorker:
209
+ def __init__(self, key_field):
210
+ self.blocks = None
211
+ self.key_field = key_field
212
+ self.num_accesses = 0
213
+ self.total_time = 0
214
+
215
+ def assign_blocks(self, block_ref_dict):
216
+ self.blocks = {k: ray.get(ref) for k, ref in block_ref_dict.items()}
217
+
218
+ def get(self, block_index, key):
219
+ start = time.perf_counter()
220
+ result = self._get(block_index, key)
221
+ self.total_time += time.perf_counter() - start
222
+ self.num_accesses += 1
223
+ return result
224
+
225
+ def multiget(self, block_indices, keys):
226
+ start = time.perf_counter()
227
+ block = self.blocks[block_indices[0]]
228
+ if len(set(block_indices)) == 1 and isinstance(
229
+ self.blocks[block_indices[0]], pa.Table
230
+ ):
231
+ # Fast path: use np.searchsorted for vectorized search on a single block.
232
+ # This is ~3x faster than the naive case.
233
+ block = self.blocks[block_indices[0]]
234
+ col = block[self.key_field]
235
+ indices = np.searchsorted(col, keys)
236
+ acc = BlockAccessor.for_block(block)
237
+ result = [acc._get_row(i) for i in indices]
238
+ # assert result == [self._get(i, k) for i, k in zip(block_indices, keys)]
239
+ else:
240
+ result = [self._get(i, k) for i, k in zip(block_indices, keys)]
241
+ self.total_time += time.perf_counter() - start
242
+ self.num_accesses += 1
243
+ return result
244
+
245
+ def ping(self):
246
+ return ray.get_runtime_context().get_node_id()
247
+
248
+ def stats(self) -> dict:
249
+ return {
250
+ "num_blocks": len(self.blocks),
251
+ "num_accesses": self.num_accesses,
252
+ "total_time": self.total_time,
253
+ }
254
+
255
+ def _get(self, block_index, key):
256
+ if block_index is None:
257
+ return None
258
+ block = self.blocks[block_index]
259
+ column = block[self.key_field]
260
+ if isinstance(block, pa.Table):
261
+ column = _ArrowListWrapper(column)
262
+ i = _binary_search_find(column, key)
263
+ if i is None:
264
+ return None
265
+ acc = BlockAccessor.for_block(block)
266
+ return acc._get_row(i)
267
+
268
+
269
+ def _binary_search_find(column, x):
270
+ i = bisect.bisect_left(column, x)
271
+ if i != len(column) and column[i] == x:
272
+ return i
273
+ return None
274
+
275
+
276
+ class _ArrowListWrapper:
277
+ def __init__(self, arrow_col):
278
+ self.arrow_col = arrow_col
279
+
280
+ def __getitem__(self, i):
281
+ return self.arrow_col[i].as_py()
282
+
283
+ def __len__(self):
284
+ return len(self.arrow_col)
285
+
286
+
287
+ def _get_bounds(block, key):
288
+ if len(block) == 0:
289
+ return None
290
+ b = (block[key][0], block[key][len(block) - 1])
291
+ if isinstance(block, pa.Table):
292
+ b = (b[0].as_py(), b[1].as_py())
293
+ return b
.venv/lib/python3.11/site-packages/ray/data/read_api.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/ray/includes/__init__.pxd ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/includes/common.pxd ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libcpp cimport bool as c_bool
2
+ from libcpp.memory cimport shared_ptr, unique_ptr
3
+ from libcpp.string cimport string as c_string
4
+
5
+ from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t, uint32_t
6
+ from libcpp.unordered_map cimport unordered_map
7
+ from libcpp.vector cimport vector as c_vector
8
+ from libcpp.pair cimport pair as c_pair
9
+ from ray.includes.optional cimport (
10
+ optional,
11
+ )
12
+ from ray.includes.unique_ids cimport (
13
+ CActorID,
14
+ CJobID,
15
+ CClusterID,
16
+ CWorkerID,
17
+ CObjectID,
18
+ CTaskID,
19
+ CPlacementGroupID,
20
+ CNodeID,
21
+ )
22
+ from ray.includes.function_descriptor cimport (
23
+ CFunctionDescriptor,
24
+ )
25
+
26
+
27
+ cdef extern from * namespace "polyfill" nogil:
28
+ """
29
+ namespace polyfill {
30
+
31
+ template <typename T>
32
+ inline typename std::remove_reference<T>::type&& move(T& t) {
33
+ return std::move(t);
34
+ }
35
+
36
+ template <typename T>
37
+ inline typename std::remove_reference<T>::type&& move(T&& t) {
38
+ return std::move(t);
39
+ }
40
+
41
+ } // namespace polyfill
42
+ """
43
+ cdef T move[T](T)
44
+
45
+
46
+ cdef extern from "ray/common/status.h" namespace "ray" nogil:
47
+ # TODO(ryw) in Cython 3.x we can directly use `cdef enum class CStatusCode`
48
+ cdef cppclass CStatusCode "ray::StatusCode":
49
+ pass
50
+ cdef CStatusCode CStatusCode_OK "ray::StatusCode::OK"
51
+ c_bool operator==(CStatusCode lhs, CStatusCode rhs)
52
+
53
+ cdef cppclass CRayStatus "ray::Status":
54
+ CRayStatus()
55
+ CRayStatus(CStatusCode code, const c_string &msg)
56
+ CRayStatus(CStatusCode code, const c_string &msg, int rpc_code)
57
+ CRayStatus(const CRayStatus &s)
58
+
59
+ @staticmethod
60
+ CRayStatus OK()
61
+
62
+ @staticmethod
63
+ CRayStatus OutOfMemory(const c_string &msg)
64
+
65
+ @staticmethod
66
+ CRayStatus KeyError(const c_string &msg)
67
+
68
+ @staticmethod
69
+ CRayStatus Invalid(const c_string &msg)
70
+
71
+ @staticmethod
72
+ CRayStatus IOError(const c_string &msg)
73
+
74
+ @staticmethod
75
+ CRayStatus TypeError(const c_string &msg)
76
+
77
+ @staticmethod
78
+ CRayStatus UnknownError(const c_string &msg)
79
+
80
+ @staticmethod
81
+ CRayStatus NotImplemented(const c_string &msg)
82
+
83
+ @staticmethod
84
+ CRayStatus ObjectStoreFull(const c_string &msg)
85
+
86
+ @staticmethod
87
+ CRayStatus RedisError(const c_string &msg)
88
+
89
+ @staticmethod
90
+ CRayStatus TimedOut(const c_string &msg)
91
+
92
+ @staticmethod
93
+ CRayStatus InvalidArgument(const c_string &msg)
94
+
95
+ @staticmethod
96
+ CRayStatus Interrupted(const c_string &msg)
97
+
98
+ @staticmethod
99
+ CRayStatus IntentionalSystemExit(const c_string &msg)
100
+
101
+ @staticmethod
102
+ CRayStatus UnexpectedSystemExit(const c_string &msg)
103
+
104
+ @staticmethod
105
+ CRayStatus CreationTaskError(const c_string &msg)
106
+
107
+ @staticmethod
108
+ CRayStatus NotFound()
109
+
110
+ @staticmethod
111
+ CRayStatus ObjectRefEndOfStream()
112
+
113
+ c_bool ok()
114
+ c_bool IsOutOfMemory()
115
+ c_bool IsKeyError()
116
+ c_bool IsInvalid()
117
+ c_bool IsIOError()
118
+ c_bool IsTypeError()
119
+ c_bool IsUnknownError()
120
+ c_bool IsNotImplemented()
121
+ c_bool IsObjectStoreFull()
122
+ c_bool IsAlreadyExists()
123
+ c_bool IsOutOfDisk()
124
+ c_bool IsRedisError()
125
+ c_bool IsTimedOut()
126
+ c_bool IsInvalidArgument()
127
+ c_bool IsInterrupted()
128
+ c_bool ShouldExitWorker()
129
+ c_bool IsObjectNotFound()
130
+ c_bool IsNotFound()
131
+ c_bool IsObjectUnknownOwner()
132
+ c_bool IsRpcError()
133
+ c_bool IsOutOfResource()
134
+ c_bool IsObjectRefEndOfStream()
135
+ c_bool IsIntentionalSystemExit()
136
+ c_bool IsUnexpectedSystemExit()
137
+ c_bool IsChannelError()
138
+ c_bool IsChannelTimeoutError()
139
+
140
+ c_string ToString()
141
+ c_string CodeAsString()
142
+ CStatusCode code()
143
+ c_string message()
144
+ int rpc_code()
145
+
146
+ # We can later add more of the common status factory methods as needed
147
+ cdef CRayStatus RayStatus_OK "Status::OK"()
148
+ cdef CRayStatus RayStatus_Invalid "Status::Invalid"()
149
+ cdef CRayStatus RayStatus_NotImplemented "Status::NotImplemented"()
150
+
151
+
152
+ cdef extern from "ray/common/id.h" namespace "ray" nogil:
153
+ const CTaskID GenerateTaskId(const CJobID &job_id,
154
+ const CTaskID &parent_task_id,
155
+ int parent_task_counter)
156
+
157
+
158
+ cdef extern from "src/ray/protobuf/common.pb.h" nogil:
159
+ cdef cppclass CLanguage "Language":
160
+ pass
161
+ cdef cppclass CWorkerType "ray::core::WorkerType":
162
+ pass
163
+ cdef cppclass CWorkerExitType "ray::rpc::WorkerExitType":
164
+ pass
165
+ cdef cppclass CTaskType "ray::TaskType":
166
+ pass
167
+ cdef cppclass CPlacementStrategy "ray::core::PlacementStrategy":
168
+ pass
169
+ cdef cppclass CDefaultSchedulingStrategy "ray::rpc::DefaultSchedulingStrategy": # noqa: E501
170
+ CDefaultSchedulingStrategy()
171
+ cdef cppclass CSpreadSchedulingStrategy "ray::rpc::SpreadSchedulingStrategy": # noqa: E501
172
+ CSpreadSchedulingStrategy()
173
+ cdef cppclass CPlacementGroupSchedulingStrategy "ray::rpc::PlacementGroupSchedulingStrategy": # noqa: E501
174
+ CPlacementGroupSchedulingStrategy()
175
+ void set_placement_group_id(const c_string& placement_group_id)
176
+ void set_placement_group_bundle_index(int64_t placement_group_bundle_index) # noqa: E501
177
+ void set_placement_group_capture_child_tasks(c_bool placement_group_capture_child_tasks) # noqa: E501
178
+ cdef cppclass CNodeAffinitySchedulingStrategy "ray::rpc::NodeAffinitySchedulingStrategy": # noqa: E501
179
+ CNodeAffinitySchedulingStrategy()
180
+ void set_node_id(const c_string& node_id)
181
+ void set_soft(c_bool soft)
182
+ void set_spill_on_unavailable(c_bool spill_on_unavailable)
183
+ void set_fail_on_unavailable(c_bool fail_on_unavailable)
184
+ cdef cppclass CSchedulingStrategy "ray::rpc::SchedulingStrategy":
185
+ CSchedulingStrategy()
186
+ void clear_scheduling_strategy()
187
+ CSpreadSchedulingStrategy* mutable_spread_scheduling_strategy()
188
+ CDefaultSchedulingStrategy* mutable_default_scheduling_strategy()
189
+ CPlacementGroupSchedulingStrategy* mutable_placement_group_scheduling_strategy() # noqa: E501
190
+ CNodeAffinitySchedulingStrategy* mutable_node_affinity_scheduling_strategy()
191
+ CNodeLabelSchedulingStrategy* mutable_node_label_scheduling_strategy()
192
+ cdef cppclass CAddress "ray::rpc::Address":
193
+ CAddress()
194
+ const c_string &SerializeAsString() const
195
+ void ParseFromString(const c_string &serialized)
196
+ void CopyFrom(const CAddress& address)
197
+ const c_string &worker_id()
198
+ cdef cppclass CObjectReference "ray::rpc::ObjectReference":
199
+ CObjectReference()
200
+ CAddress owner_address() const
201
+ const c_string &object_id() const
202
+ const c_string &call_site() const
203
+ cdef cppclass CNodeLabelSchedulingStrategy "ray::rpc::NodeLabelSchedulingStrategy": # noqa: E501
204
+ CNodeLabelSchedulingStrategy()
205
+ CLabelMatchExpressions* mutable_hard()
206
+ CLabelMatchExpressions* mutable_soft()
207
+ cdef cppclass CLabelMatchExpressions "ray::rpc::LabelMatchExpressions": # noqa: E501
208
+ CLabelMatchExpressions()
209
+ CLabelMatchExpression* add_expressions()
210
+ cdef cppclass CLabelMatchExpression "ray::rpc::LabelMatchExpression": # noqa: E501
211
+ CLabelMatchExpression()
212
+ void set_key(const c_string &key)
213
+ CLabelOperator* mutable_operator_()
214
+ cdef cppclass CLabelIn "ray::rpc::LabelIn": # noqa: E501
215
+ CLabelIn()
216
+ void add_values(const c_string &value)
217
+ cdef cppclass CLabelNotIn "ray::rpc::LabelNotIn": # noqa: E501
218
+ CLabelNotIn()
219
+ void add_values(const c_string &value)
220
+ cdef cppclass CLabelExists "ray::rpc::LabelExists": # noqa: E501
221
+ CLabelExists()
222
+ cdef cppclass CLabelDoesNotExist "ray::rpc::LabelDoesNotExist": # noqa: E501
223
+ CLabelDoesNotExist()
224
+ cdef cppclass CLabelNotIn "ray::rpc::LabelNotIn": # noqa: E501
225
+ CLabelNotIn()
226
+ void add_values(const c_string &value)
227
+ cdef cppclass CLabelOperator "ray::rpc::LabelOperator": # noqa: E501
228
+ CLabelOperator()
229
+ CLabelIn* mutable_label_in()
230
+ CLabelNotIn* mutable_label_not_in()
231
+ CLabelExists* mutable_label_exists()
232
+ CLabelDoesNotExist* mutable_label_does_not_exist()
233
+ cdef cppclass CLineageReconstructionTask "ray::rpc::LineageReconstructionTask":
234
+ CLineageReconstructionTask()
235
+ const c_string &SerializeAsString() const
236
+
237
+
238
+ # This is a workaround for C++ enum class since Cython has no corresponding
239
+ # representation.
240
+ cdef extern from "src/ray/protobuf/common.pb.h" nogil:
241
+ cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON"
242
+ cdef CLanguage LANGUAGE_CPP "Language::CPP"
243
+ cdef CLanguage LANGUAGE_JAVA "Language::JAVA"
244
+
245
+ cdef extern from "src/ray/protobuf/common.pb.h" nogil:
246
+ cdef CWorkerType WORKER_TYPE_WORKER "ray::core::WorkerType::WORKER"
247
+ cdef CWorkerType WORKER_TYPE_DRIVER "ray::core::WorkerType::DRIVER"
248
+ cdef CWorkerType WORKER_TYPE_SPILL_WORKER "ray::core::WorkerType::SPILL_WORKER" # noqa: E501
249
+ cdef CWorkerType WORKER_TYPE_RESTORE_WORKER "ray::core::WorkerType::RESTORE_WORKER" # noqa: E501
250
+ cdef CWorkerType WORKER_TYPE_UTIL_WORKER "ray::core::WorkerType::UTIL_WORKER" # noqa: E501
251
+ cdef CWorkerExitType WORKER_EXIT_TYPE_USER_ERROR "ray::rpc::WorkerExitType::USER_ERROR" # noqa: E501
252
+ cdef CWorkerExitType WORKER_EXIT_TYPE_SYSTEM_ERROR "ray::rpc::WorkerExitType::SYSTEM_ERROR" # noqa: E501
253
+ cdef CWorkerExitType WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM_ERROR "ray::rpc::WorkerExitType::INTENDED_SYSTEM_EXIT" # noqa: E501
254
+
255
+ cdef extern from "src/ray/protobuf/common.pb.h" nogil:
256
+ cdef CTaskType TASK_TYPE_NORMAL_TASK "ray::TaskType::NORMAL_TASK"
257
+ cdef CTaskType TASK_TYPE_ACTOR_CREATION_TASK "ray::TaskType::ACTOR_CREATION_TASK" # noqa: E501
258
+ cdef CTaskType TASK_TYPE_ACTOR_TASK "ray::TaskType::ACTOR_TASK"
259
+
260
+ cdef extern from "src/ray/protobuf/common.pb.h" nogil:
261
+ cdef CPlacementStrategy PLACEMENT_STRATEGY_PACK \
262
+ "ray::core::PlacementStrategy::PACK"
263
+ cdef CPlacementStrategy PLACEMENT_STRATEGY_SPREAD \
264
+ "ray::core::PlacementStrategy::SPREAD"
265
+ cdef CPlacementStrategy PLACEMENT_STRATEGY_STRICT_PACK \
266
+ "ray::core::PlacementStrategy::STRICT_PACK"
267
+ cdef CPlacementStrategy PLACEMENT_STRATEGY_STRICT_SPREAD \
268
+ "ray::core::PlacementStrategy::STRICT_SPREAD"
269
+
270
+ cdef extern from "ray/common/buffer.h" namespace "ray" nogil:
271
+ cdef cppclass CBuffer "ray::Buffer":
272
+ uint8_t *Data() const
273
+ size_t Size() const
274
+ c_bool IsPlasmaBuffer() const
275
+
276
+ cdef cppclass LocalMemoryBuffer(CBuffer):
277
+ LocalMemoryBuffer(uint8_t *data, size_t size, c_bool copy_data)
278
+ LocalMemoryBuffer(size_t size)
279
+
280
+ cdef cppclass SharedMemoryBuffer(CBuffer):
281
+ SharedMemoryBuffer(
282
+ const shared_ptr[CBuffer] &buffer,
283
+ int64_t offset,
284
+ int64_t size)
285
+ c_bool IsPlasmaBuffer() const
286
+
287
+ cdef extern from "ray/common/ray_object.h" nogil:
288
+ cdef cppclass CRayObject "ray::RayObject":
289
+ CRayObject(const shared_ptr[CBuffer] &data,
290
+ const shared_ptr[CBuffer] &metadata,
291
+ const c_vector[CObjectReference] &nested_refs)
292
+ c_bool HasData() const
293
+ c_bool HasMetadata() const
294
+ const size_t DataSize() const
295
+ const shared_ptr[CBuffer] &GetData()
296
+ const shared_ptr[CBuffer] &GetMetadata() const
297
+ c_bool IsInPlasmaError() const
298
+
299
+ cdef extern from "ray/core_worker/common.h" nogil:
300
+ cdef cppclass CRayFunction "ray::core::RayFunction":
301
+ CRayFunction()
302
+ CRayFunction(CLanguage language,
303
+ const CFunctionDescriptor &function_descriptor)
304
+ CLanguage GetLanguage()
305
+ const CFunctionDescriptor GetFunctionDescriptor()
306
+
307
+ cdef cppclass CTaskArg "ray::TaskArg":
308
+ pass
309
+
310
+ cdef cppclass CTaskArgByReference "ray::TaskArgByReference":
311
+ CTaskArgByReference(const CObjectID &object_id,
312
+ const CAddress &owner_address,
313
+ const c_string &call_site)
314
+
315
+ cdef cppclass CTaskArgByValue "ray::TaskArgByValue":
316
+ CTaskArgByValue(const shared_ptr[CRayObject] &data)
317
+
318
+ cdef cppclass CTaskOptions "ray::core::TaskOptions":
319
+ CTaskOptions()
320
+ CTaskOptions(c_string name, int num_returns,
321
+ unordered_map[c_string, double] &resources,
322
+ c_string concurrency_group_name,
323
+ int64_t generator_backpressure_num_objects)
324
+ CTaskOptions(c_string name, int num_returns,
325
+ unordered_map[c_string, double] &resources,
326
+ c_string concurrency_group_name,
327
+ int64_t generator_backpressure_num_objects,
328
+ c_string serialized_runtime_env)
329
+ CTaskOptions(c_string name, int num_returns,
330
+ unordered_map[c_string, double] &resources,
331
+ c_string concurrency_group_name,
332
+ int64_t generator_backpressure_num_objects,
333
+ c_string serialized_runtime_env, c_bool enable_task_events,
334
+ const unordered_map[c_string, c_string] &labels)
335
+
336
+ cdef cppclass CActorCreationOptions "ray::core::ActorCreationOptions":
337
+ CActorCreationOptions()
338
+ CActorCreationOptions(
339
+ int64_t max_restarts,
340
+ int64_t max_task_retries,
341
+ int32_t max_concurrency,
342
+ const unordered_map[c_string, double] &resources,
343
+ const unordered_map[c_string, double] &placement_resources,
344
+ const c_vector[c_string] &dynamic_worker_options,
345
+ optional[c_bool] is_detached, c_string &name, c_string &ray_namespace,
346
+ c_bool is_asyncio,
347
+ const CSchedulingStrategy &scheduling_strategy,
348
+ c_string serialized_runtime_env,
349
+ const c_vector[CConcurrencyGroup] &concurrency_groups,
350
+ c_bool execute_out_of_order,
351
+ int32_t max_pending_calls,
352
+ c_bool enable_task_events,
353
+ const unordered_map[c_string, c_string] &labels)
354
+
355
+ cdef cppclass CPlacementGroupCreationOptions \
356
+ "ray::core::PlacementGroupCreationOptions":
357
+ CPlacementGroupCreationOptions()
358
+ CPlacementGroupCreationOptions(
359
+ const c_string &name,
360
+ CPlacementStrategy strategy,
361
+ const c_vector[unordered_map[c_string, double]] &bundles,
362
+ c_bool is_detached,
363
+ double max_cpu_fraction_per_node,
364
+ CNodeID soft_target_node_id,
365
+ )
366
+
367
+ cdef cppclass CObjectLocation "ray::core::ObjectLocation":
368
+ const CNodeID &GetPrimaryNodeID() const
369
+ const int64_t GetObjectSize() const
370
+ const c_vector[CNodeID] &GetNodeIDs() const
371
+ c_bool IsSpilled() const
372
+ const c_string &GetSpilledURL() const
373
+ const CNodeID &GetSpilledNodeID() const
374
+ const c_bool GetDidSpill() const
375
+
376
+ cdef extern from "ray/gcs/gcs_client/python_callbacks.h" namespace "ray::gcs":
377
+ cdef cppclass MultiItemPyCallback[T]:
378
+ MultiItemPyCallback(
379
+ object (*)(CRayStatus, c_vector[T] &&) nogil,
380
+ void (object, object) nogil,
381
+ object) nogil
382
+
383
+ cdef cppclass OptionalItemPyCallback[T]:
384
+ OptionalItemPyCallback(
385
+ object (*)(CRayStatus, const optional[T]&) nogil,
386
+ void (object, object) nogil,
387
+ object) nogil
388
+
389
+ cdef cppclass StatusPyCallback:
390
+ StatusPyCallback(
391
+ object (*)(CRayStatus) nogil,
392
+ void (object, object) nogil,
393
+ object) nogil
394
+
395
+ cdef extern from "ray/gcs/gcs_client/accessor.h" nogil:
396
+ cdef cppclass CActorInfoAccessor "ray::gcs::ActorInfoAccessor":
397
+ CRayStatus AsyncGetAllByFilter(
398
+ const optional[CActorID] &actor_id,
399
+ const optional[CJobID] &job_id,
400
+ const optional[c_string] &actor_state_name,
401
+ const MultiItemPyCallback[CActorTableData] &callback,
402
+ int64_t timeout_ms)
403
+
404
+ CRayStatus AsyncKillActor(const CActorID &actor_id,
405
+ c_bool force_kill,
406
+ c_bool no_restart,
407
+ const StatusPyCallback &callback,
408
+ int64_t timeout_ms)
409
+
410
+ cdef cppclass CJobInfoAccessor "ray::gcs::JobInfoAccessor":
411
+ CRayStatus GetAll(
412
+ const optional[c_string] &job_or_submission_id,
413
+ c_bool skip_submission_job_info_field,
414
+ c_bool skip_is_running_tasks_field,
415
+ c_vector[CJobTableData] &result,
416
+ int64_t timeout_ms)
417
+
418
+ CRayStatus AsyncGetAll(
419
+ const optional[c_string] &job_or_submission_id,
420
+ c_bool skip_submission_job_info_field,
421
+ c_bool skip_is_running_tasks_field,
422
+ const MultiItemPyCallback[CJobTableData] &callback,
423
+ int64_t timeout_ms)
424
+
425
+ cdef cppclass CNodeInfoAccessor "ray::gcs::NodeInfoAccessor":
426
+ CRayStatus CheckAlive(
427
+ const c_vector[c_string] &raylet_addresses,
428
+ int64_t timeout_ms,
429
+ c_vector[c_bool] &result)
430
+
431
+ CRayStatus AsyncCheckAlive(
432
+ const c_vector[c_string] &raylet_addresses,
433
+ int64_t timeout_ms,
434
+ const MultiItemPyCallback[c_bool] &callback)
435
+
436
+ CRayStatus DrainNodes(
437
+ const c_vector[CNodeID] &node_ids,
438
+ int64_t timeout_ms,
439
+ c_vector[c_string] &drained_node_ids)
440
+
441
+ CRayStatus GetAllNoCache(
442
+ int64_t timeout_ms,
443
+ c_vector[CGcsNodeInfo] &result)
444
+
445
+ CRayStatus AsyncGetAll(
446
+ const MultiItemPyCallback[CGcsNodeInfo] &callback,
447
+ int64_t timeout_ms,
448
+ optional[CNodeID] node_id)
449
+
450
+ cdef cppclass CNodeResourceInfoAccessor "ray::gcs::NodeResourceInfoAccessor":
451
+ CRayStatus GetAllResourceUsage(
452
+ int64_t timeout_ms,
453
+ CGetAllResourceUsageReply &serialized_reply)
454
+
455
+ cdef cppclass CInternalKVAccessor "ray::gcs::InternalKVAccessor":
456
+ CRayStatus Keys(
457
+ const c_string &ns,
458
+ const c_string &prefix,
459
+ int64_t timeout_ms,
460
+ c_vector[c_string] &value)
461
+
462
+ CRayStatus Put(
463
+ const c_string &ns,
464
+ const c_string &key,
465
+ const c_string &value,
466
+ c_bool overwrite,
467
+ int64_t timeout_ms,
468
+ c_bool &added)
469
+
470
+ CRayStatus Get(
471
+ const c_string &ns,
472
+ const c_string &key,
473
+ int64_t timeout_ms,
474
+ c_string &value)
475
+
476
+ CRayStatus MultiGet(
477
+ const c_string &ns,
478
+ const c_vector[c_string] &keys,
479
+ int64_t timeout_ms,
480
+ unordered_map[c_string, c_string] &values)
481
+
482
+ CRayStatus Del(
483
+ const c_string &ns,
484
+ const c_string &key,
485
+ c_bool del_by_prefix,
486
+ int64_t timeout_ms,
487
+ int& num_deleted)
488
+
489
+ CRayStatus Exists(
490
+ const c_string &ns,
491
+ const c_string &key,
492
+ int64_t timeout_ms,
493
+ c_bool &exists)
494
+
495
+ CRayStatus AsyncInternalKVKeys(
496
+ const c_string &ns,
497
+ const c_string &prefix,
498
+ int64_t timeout_ms,
499
+ const OptionalItemPyCallback[c_vector[c_string]] &callback)
500
+
501
+ CRayStatus AsyncInternalKVGet(
502
+ const c_string &ns,
503
+ const c_string &key,
504
+ int64_t timeout_ms,
505
+ const OptionalItemPyCallback[c_string] &callback)
506
+
507
+ CRayStatus AsyncInternalKVMultiGet(
508
+ const c_string &ns,
509
+ const c_vector[c_string] &keys,
510
+ int64_t timeout_ms,
511
+ const OptionalItemPyCallback[unordered_map[c_string, c_string]] &callback)
512
+
513
+ CRayStatus AsyncInternalKVPut(
514
+ const c_string &ns,
515
+ const c_string &key,
516
+ const c_string &value,
517
+ c_bool overwrite,
518
+ int64_t timeout_ms,
519
+ const OptionalItemPyCallback[c_bool] &callback)
520
+
521
+ CRayStatus AsyncInternalKVExists(
522
+ const c_string &ns,
523
+ const c_string &key,
524
+ int64_t timeout_ms,
525
+ const OptionalItemPyCallback[c_bool] &callback)
526
+
527
+ CRayStatus AsyncInternalKVDel(
528
+ const c_string &ns,
529
+ const c_string &key,
530
+ c_bool del_by_prefix,
531
+ int64_t timeout_ms,
532
+ const OptionalItemPyCallback[int] &callback)
533
+
534
+ cdef cppclass CRuntimeEnvAccessor "ray::gcs::RuntimeEnvAccessor":
535
+ CRayStatus PinRuntimeEnvUri(
536
+ const c_string &uri,
537
+ int expiration_s,
538
+ int64_t timeout_ms)
539
+
540
+ cdef cppclass CAutoscalerStateAccessor "ray::gcs::AutoscalerStateAccessor":
541
+
542
+ CRayStatus RequestClusterResourceConstraint(
543
+ int64_t timeout_ms,
544
+ const c_vector[unordered_map[c_string, double]] &bundles,
545
+ const c_vector[int64_t] &count_array
546
+ )
547
+
548
+ CRayStatus GetClusterResourceState(
549
+ int64_t timeout_ms,
550
+ c_string &serialized_reply
551
+ )
552
+
553
+ CRayStatus GetClusterStatus(
554
+ int64_t timeout_ms,
555
+ c_string &serialized_reply
556
+ )
557
+
558
+ CRayStatus ReportAutoscalingState(
559
+ int64_t timeout_ms,
560
+ const c_string &serialized_state
561
+ )
562
+
563
+ CRayStatus ReportClusterConfig(
564
+ int64_t timeout_ms,
565
+ const c_string &serialized_cluster_config
566
+ )
567
+
568
+ CRayStatus DrainNode(
569
+ const c_string &node_id,
570
+ int32_t reason,
571
+ const c_string &reason_message,
572
+ int64_t deadline_timestamp_ms,
573
+ int64_t timeout_ms,
574
+ c_bool &is_accepted,
575
+ c_string &rejection_reason_message
576
+ )
577
+
578
+
579
+ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
580
+ cdef enum CGrpcStatusCode "grpc::StatusCode":
581
+ UNAVAILABLE "grpc::StatusCode::UNAVAILABLE",
582
+ UNKNOWN "grpc::StatusCode::UNKNOWN",
583
+ DEADLINE_EXCEEDED "grpc::StatusCode::DEADLINE_EXCEEDED",
584
+ RESOURCE_EXHAUSTED "grpc::StatusCode::RESOURCE_EXHAUSTED",
585
+ UNIMPLEMENTED "grpc::StatusCode::UNIMPLEMENTED",
586
+
587
+ cdef cppclass CGcsClientOptions "ray::gcs::GcsClientOptions":
588
+ CGcsClientOptions(
589
+ const c_string &gcs_address, int port, CClusterID cluster_id,
590
+ c_bool allow_cluster_id_nil, c_bool fetch_cluster_id_if_nil)
591
+
592
+ cdef cppclass CGcsClient "ray::gcs::GcsClient":
593
+ CGcsClient(const CGcsClientOptions &options)
594
+
595
+ c_pair[c_string, int] GetGcsServerAddress() const
596
+ CClusterID GetClusterId() const
597
+
598
+ CActorInfoAccessor& Actors()
599
+ CJobInfoAccessor& Jobs()
600
+ CInternalKVAccessor& InternalKV()
601
+ CNodeInfoAccessor& Nodes()
602
+ CNodeResourceInfoAccessor& NodeResources()
603
+ CRuntimeEnvAccessor& RuntimeEnvs()
604
+ CAutoscalerStateAccessor& Autoscaler()
605
+
606
+ cdef CRayStatus ConnectOnSingletonIoContext(CGcsClient &gcs_client, int timeout_ms)
607
+
608
+ cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil:
609
+ unordered_map[c_string, double] PythonGetResourcesTotal(
610
+ const CGcsNodeInfo& node_info)
611
+
612
+ cdef extern from "ray/gcs/pubsub/gcs_pub_sub.h" nogil:
613
+
614
+ cdef cppclass CPythonGcsPublisher "ray::gcs::PythonGcsPublisher":
615
+
616
+ CPythonGcsPublisher(const c_string& gcs_address)
617
+
618
+ CRayStatus Connect()
619
+
620
+ CRayStatus PublishError(
621
+ const c_string &key_id, const CErrorTableData &data, int64_t num_retries)
622
+
623
+ CRayStatus PublishLogs(const c_string &key_id, const CLogBatch &data)
624
+
625
+ cdef cppclass CPythonGcsSubscriber "ray::gcs::PythonGcsSubscriber":
626
+
627
+ CPythonGcsSubscriber(
628
+ const c_string& gcs_address, int gcs_port, CChannelType channel_type,
629
+ const c_string& subscriber_id, const c_string& worker_id)
630
+
631
+ CRayStatus Subscribe()
632
+
633
+ int64_t last_batch_size()
634
+
635
+ CRayStatus PollError(
636
+ c_string* key_id, int64_t timeout_ms, CErrorTableData* data)
637
+
638
+ CRayStatus PollLogs(
639
+ c_string* key_id, int64_t timeout_ms, CLogBatch* data)
640
+
641
+ CRayStatus PollActor(
642
+ c_string* key_id, int64_t timeout_ms, CActorTableData* data)
643
+
644
+ CRayStatus Close()
645
+
646
+ cdef extern from "ray/gcs/pubsub/gcs_pub_sub.h" namespace "ray::gcs" nogil:
647
+ c_vector[c_string] PythonGetLogBatchLines(const CLogBatch& log_batch)
648
+
649
+ cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil:
650
+ unordered_map[c_string, c_string] PythonGetNodeLabels(
651
+ const CGcsNodeInfo& node_info)
652
+
653
+ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
654
+ cdef enum CChannelType "ray::rpc::ChannelType":
655
+ RAY_ERROR_INFO_CHANNEL "ray::rpc::ChannelType::RAY_ERROR_INFO_CHANNEL",
656
+ RAY_LOG_CHANNEL "ray::rpc::ChannelType::RAY_LOG_CHANNEL",
657
+ GCS_ACTOR_CHANNEL "ray::rpc::ChannelType::GCS_ACTOR_CHANNEL",
658
+
659
+ cdef cppclass CJobConfig "ray::rpc::JobConfig":
660
+ c_string ray_namespace() const
661
+ const c_string &SerializeAsString() const
662
+
663
+ cdef cppclass CNodeDeathInfo "ray::rpc::NodeDeathInfo":
664
+ int reason() const
665
+ c_string reason_message() const
666
+
667
+ cdef cppclass CGcsNodeInfo "ray::rpc::GcsNodeInfo":
668
+ c_string node_id() const
669
+ c_string node_name() const
670
+ int state() const
671
+ c_string node_manager_address() const
672
+ c_string node_manager_hostname() const
673
+ int node_manager_port() const
674
+ int object_manager_port() const
675
+ c_string object_store_socket_name() const
676
+ c_string raylet_socket_name() const
677
+ int metrics_export_port() const
678
+ int runtime_env_agent_port() const
679
+ CNodeDeathInfo death_info() const
680
+ void ParseFromString(const c_string &serialized)
681
+ const c_string& SerializeAsString() const
682
+
683
+ cdef enum CGcsNodeState "ray::rpc::GcsNodeInfo_GcsNodeState":
684
+ ALIVE "ray::rpc::GcsNodeInfo_GcsNodeState_ALIVE",
685
+
686
+ cdef cppclass CJobTableData "ray::rpc::JobTableData":
687
+ c_string job_id() const
688
+ c_bool is_dead() const
689
+ CJobConfig config() const
690
+ const c_string &SerializeAsString() const
691
+
692
+ cdef cppclass CGetAllResourceUsageReply "ray::rpc::GetAllResourceUsageReply":
693
+ const c_string& SerializeAsString() const
694
+
695
+ cdef cppclass CPythonFunction "ray::rpc::PythonFunction":
696
+ void set_key(const c_string &key)
697
+ c_string key() const
698
+
699
+ cdef cppclass CErrorTableData "ray::rpc::ErrorTableData":
700
+ c_string job_id() const
701
+ c_string type() const
702
+ c_string error_message() const
703
+ double timestamp() const
704
+
705
+ void set_job_id(const c_string &job_id)
706
+ void set_type(const c_string &type)
707
+ void set_error_message(const c_string &error_message)
708
+ void set_timestamp(double timestamp)
709
+
710
+ cdef cppclass CLogBatch "ray::rpc::LogBatch":
711
+ c_string ip() const
712
+ c_string pid() const
713
+ c_string job_id() const
714
+ c_bool is_error() const
715
+ c_string actor_name() const
716
+ c_string task_name() const
717
+
718
+ void set_ip(const c_string &ip)
719
+ void set_pid(const c_string &pid)
720
+ void set_job_id(const c_string &job_id)
721
+ void set_is_error(c_bool is_error)
722
+ void add_lines(const c_string &line)
723
+ void set_actor_name(const c_string &actor_name)
724
+ void set_task_name(const c_string &task_name)
725
+
726
+ cdef cppclass CActorTableData "ray::rpc::ActorTableData":
727
+ CAddress address() const
728
+ void ParseFromString(const c_string &serialized)
729
+ const c_string &SerializeAsString() const
730
+
731
+ cdef extern from "ray/common/task/task_spec.h" nogil:
732
+ cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
733
+ CConcurrencyGroup(
734
+ const c_string &name,
735
+ uint32_t max_concurrency,
736
+ const c_vector[CFunctionDescriptor] &c_fds)
737
+ CConcurrencyGroup()
738
+ c_string GetName() const
739
+ uint32_t GetMaxConcurrency() const
740
+ c_vector[CFunctionDescriptor] GetFunctionDescriptors() const
741
+
742
+ cdef extern from "ray/common/constants.h" nogil:
743
+ cdef const char[] kWorkerSetupHookKeyName
744
+ cdef int kResourceUnitScaling
745
+ cdef const char[] kImplicitResourcePrefix
746
+ cdef int kStreamingGeneratorReturn
747
+ cdef const char[] kGcsAutoscalerStateNamespace
748
+ cdef const char[] kGcsAutoscalerV2EnabledKey
749
+ cdef const char[] kGcsAutoscalerClusterConfigKey
.venv/lib/python3.11/site-packages/ray/includes/function_descriptor.pxd ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libc.stdint cimport uint8_t, uint64_t
2
+ from libcpp cimport bool as c_bool
3
+ from libcpp.memory cimport unique_ptr, shared_ptr
4
+ from libcpp.string cimport string as c_string
5
+ from libcpp.unordered_map cimport unordered_map
6
+ from libcpp.vector cimport vector as c_vector
7
+
8
+ from ray.includes.common cimport (
9
+ CLanguage,
10
+ )
11
+ from ray.includes.unique_ids cimport (
12
+ CActorID,
13
+ CJobID,
14
+ CObjectID,
15
+ CTaskID,
16
+ )
17
+
18
+ cdef extern from "src/ray/protobuf/common.pb.h" nogil:
19
+ cdef cppclass CFunctionDescriptorType \
20
+ "ray::FunctionDescriptorType":
21
+ pass
22
+
23
+ cdef CFunctionDescriptorType EmptyFunctionDescriptorType \
24
+ "ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET"
25
+ cdef CFunctionDescriptorType JavaFunctionDescriptorType \
26
+ "ray::FunctionDescriptorType::kJavaFunctionDescriptor"
27
+ cdef CFunctionDescriptorType PythonFunctionDescriptorType \
28
+ "ray::FunctionDescriptorType::kPythonFunctionDescriptor"
29
+ cdef CFunctionDescriptorType CppFunctionDescriptorType \
30
+ "ray::FunctionDescriptorType::kCppFunctionDescriptor"
31
+
32
+
33
+ cdef extern from "ray/common/function_descriptor.h" nogil:
34
+ cdef cppclass CFunctionDescriptorInterface \
35
+ "ray::FunctionDescriptorInterface":
36
+ CFunctionDescriptorType Type()
37
+ c_string ToString()
38
+ c_string Serialize()
39
+
40
+ ctypedef shared_ptr[CFunctionDescriptorInterface] CFunctionDescriptor \
41
+ "ray::FunctionDescriptor"
42
+
43
+ cdef cppclass CFunctionDescriptorBuilder "ray::FunctionDescriptorBuilder":
44
+ @staticmethod
45
+ CFunctionDescriptor Empty()
46
+
47
+ @staticmethod
48
+ CFunctionDescriptor BuildJava(const c_string &class_name,
49
+ const c_string &function_name,
50
+ const c_string &signature)
51
+
52
+ @staticmethod
53
+ CFunctionDescriptor BuildPython(const c_string &module_name,
54
+ const c_string &class_name,
55
+ const c_string &function_name,
56
+ const c_string &function_source_hash)
57
+
58
+ @staticmethod
59
+ CFunctionDescriptor BuildCpp(const c_string &function_name,
60
+ const c_string &caller,
61
+ const c_string &class_name)
62
+
63
+ @staticmethod
64
+ CFunctionDescriptor Deserialize(const c_string &serialized_binary)
65
+
66
+ cdef cppclass CJavaFunctionDescriptor "ray::JavaFunctionDescriptor":
67
+ c_string ClassName()
68
+ c_string FunctionName()
69
+ c_string Signature()
70
+
71
+ cdef cppclass CPythonFunctionDescriptor "ray::PythonFunctionDescriptor":
72
+ c_string ModuleName()
73
+ c_string ClassName()
74
+ c_string FunctionName()
75
+ c_string FunctionHash()
76
+
77
+ cdef cppclass CCppFunctionDescriptor "ray::CppFunctionDescriptor":
78
+ c_string FunctionName()
79
+ c_string Caller()
80
+ c_string ClassName()
.venv/lib/python3.11/site-packages/ray/includes/global_state_accessor.pxd ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libcpp.string cimport string as c_string
2
+ from libcpp cimport bool as c_bool
3
+ from libcpp.vector cimport vector as c_vector
4
+ from libcpp.unordered_map cimport unordered_map
5
+ from libcpp.memory cimport unique_ptr
6
+ from libc.stdint cimport (
7
+ int32_t as c_int32_t,
8
+ uint32_t as c_uint32_t,
9
+ int64_t as c_int64_t,
10
+ )
11
+ from ray.includes.unique_ids cimport (
12
+ CActorID,
13
+ CJobID,
14
+ CNodeID,
15
+ CObjectID,
16
+ CWorkerID,
17
+ CPlacementGroupID,
18
+ )
19
+ from ray.includes.common cimport (
20
+ CRayStatus,
21
+ CGcsClientOptions,
22
+ )
23
+ from ray.includes.optional cimport (
24
+ optional
25
+ )
26
+
27
+ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
28
+ cdef cppclass CGlobalStateAccessor "ray::gcs::GlobalStateAccessor":
29
+ CGlobalStateAccessor(const CGcsClientOptions&)
30
+ c_bool Connect()
31
+ void Disconnect()
32
+ c_vector[c_string] GetAllJobInfo(
33
+ c_bool skip_submission_job_info_field, c_bool skip_is_running_tasks_field)
34
+ CJobID GetNextJobID()
35
+ c_vector[c_string] GetAllNodeInfo()
36
+ c_vector[c_string] GetAllAvailableResources()
37
+ c_vector[c_string] GetAllTotalResources()
38
+ unordered_map[CNodeID, c_int64_t] GetDrainingNodes()
39
+ unique_ptr[c_string] GetInternalKV(
40
+ const c_string &namespace, const c_string &key)
41
+ c_vector[c_string] GetAllTaskEvents()
42
+ unique_ptr[c_string] GetObjectInfo(const CObjectID &object_id)
43
+ unique_ptr[c_string] GetAllResourceUsage()
44
+ c_vector[c_string] GetAllActorInfo(
45
+ optional[CActorID], optional[CJobID], optional[c_string])
46
+ unique_ptr[c_string] GetActorInfo(const CActorID &actor_id)
47
+ unique_ptr[c_string] GetWorkerInfo(const CWorkerID &worker_id)
48
+ c_vector[c_string] GetAllWorkerInfo()
49
+ c_bool AddWorkerInfo(const c_string &serialized_string)
50
+ c_bool UpdateWorkerDebuggerPort(const CWorkerID &worker_id,
51
+ const c_uint32_t debuger_port)
52
+ c_bool UpdateWorkerNumPausedThreads(const CWorkerID &worker_id,
53
+ const c_int32_t num_paused_threads_delta)
54
+ c_uint32_t GetWorkerDebuggerPort(const CWorkerID &worker_id)
55
+ unique_ptr[c_string] GetPlacementGroupInfo(
56
+ const CPlacementGroupID &placement_group_id)
57
+ unique_ptr[c_string] GetPlacementGroupByName(
58
+ const c_string &placement_group_name,
59
+ const c_string &ray_namespace,
60
+ )
61
+ c_vector[c_string] GetAllPlacementGroupInfo()
62
+ c_string GetSystemConfig()
63
+ CRayStatus GetNodeToConnectForDriver(
64
+ const c_string &node_ip_address,
65
+ c_string *node_to_connect)
66
+ CRayStatus GetNode(
67
+ const c_string &node_id_hex_str,
68
+ c_string *node_info)
69
+
70
+ cdef extern from * namespace "ray::gcs" nogil:
71
+ """
72
+ #include <thread>
73
+ #include "ray/gcs/gcs_server/store_client_kv.h"
74
+ namespace ray {
75
+ namespace gcs {
76
+
77
+ bool RedisGetKeySync(const std::string& host,
78
+ int32_t port,
79
+ const std::string& username,
80
+ const std::string& password,
81
+ bool use_ssl,
82
+ const std::string& config,
83
+ const std::string& key,
84
+ std::string* data) {
85
+ // Logging default value see class `RayLog`.
86
+ InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog,
87
+ ray::RayLog::ShutDownRayLog,
88
+ "ray_init",
89
+ ray::RayLogLevel::WARNING,
90
+ /*log_filepath=*/"",
91
+ /*log_rotation_max_size=*/1ULL << 29,
92
+ /*log_rotation_file_num=*/10);
93
+
94
+ RedisClientOptions options(host, port, username, password, use_ssl);
95
+
96
+ std::string config_list;
97
+ RAY_CHECK(absl::Base64Unescape(config, &config_list));
98
+ RayConfig::instance().initialize(config_list);
99
+
100
+ instrumented_io_context io_service;
101
+
102
+ auto redis_client = std::make_shared<RedisClient>(options);
103
+ auto status = redis_client->Connect(io_service);
104
+ RAY_CHECK_OK(status) << "Failed to connect to redis.";
105
+
106
+ auto cli = std::make_unique<StoreClientInternalKV>(
107
+ std::make_unique<RedisStoreClient>(std::move(redis_client)));
108
+
109
+ bool ret_val = false;
110
+ cli->Get("session", key, {[&](std::optional<std::string> result) {
111
+ if (result.has_value()) {
112
+ *data = result.value();
113
+ ret_val = true;
114
+ } else {
115
+ RAY_LOG(INFO) << "Failed to retrieve the key " << key
116
+ << " from persistent storage.";
117
+ ret_val = false;
118
+ }
119
+ }, io_service});
120
+ io_service.run_for(std::chrono::milliseconds(1000));
121
+
122
+ return ret_val;
123
+ }
124
+
125
+ }
126
+ }
127
+ """
128
+ c_bool RedisGetKeySync(const c_string& host,
129
+ c_int32_t port,
130
+ const c_string& username,
131
+ const c_string& password,
132
+ c_bool use_ssl,
133
+ const c_string& config,
134
+ const c_string& key,
135
+ c_string* data)
136
+
137
+
138
+ cdef extern from * namespace "ray::gcs" nogil:
139
+ c_bool RedisDelKeyPrefixSync(const c_string& host,
140
+ c_int32_t port,
141
+ const c_string& username,
142
+ const c_string& password,
143
+ c_bool use_ssl,
144
+ const c_string& key_prefix)
.venv/lib/python3.11/site-packages/ray/includes/libcoreworker.pxd ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cython: profile = False
2
+ # distutils: language = c++
3
+ # cython: embedsignature = True
4
+
5
+ from libc.stdint cimport int64_t, uint64_t
6
+ from libcpp cimport bool as c_bool
7
+ from libcpp.memory cimport shared_ptr, unique_ptr
8
+ from libcpp.pair cimport pair as c_pair
9
+ from libcpp.string cimport string as c_string
10
+ from libcpp.unordered_map cimport unordered_map
11
+ from libcpp.utility cimport pair
12
+ from libcpp.vector cimport vector as c_vector
13
+
14
+ from ray.includes.unique_ids cimport (
15
+ CActorID,
16
+ CClusterID,
17
+ CNodeID,
18
+ CJobID,
19
+ CTaskID,
20
+ CObjectID,
21
+ CPlacementGroupID,
22
+ CWorkerID,
23
+ ObjectIDIndexType,
24
+ )
25
+
26
+ from ray.includes.common cimport (
27
+ CAddress,
28
+ CObjectReference,
29
+ CActorCreationOptions,
30
+ CBuffer,
31
+ CPlacementGroupCreationOptions,
32
+ CObjectLocation,
33
+ CObjectReference,
34
+ CRayFunction,
35
+ CRayObject,
36
+ CRayStatus,
37
+ CTaskArg,
38
+ CTaskOptions,
39
+ CTaskType,
40
+ CWorkerType,
41
+ CLanguage,
42
+ CGcsClientOptions,
43
+ LocalMemoryBuffer,
44
+ CJobConfig,
45
+ CConcurrencyGroup,
46
+ CSchedulingStrategy,
47
+ CWorkerExitType,
48
+ CLineageReconstructionTask,
49
+ )
50
+ from ray.includes.function_descriptor cimport (
51
+ CFunctionDescriptor,
52
+ )
53
+
54
+ from ray.includes.optional cimport (
55
+ optional,
56
+ )
57
+
58
+ ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \
59
+ ResourceMappingType
60
+
61
+ ctypedef void (*ray_callback_function) \
62
+ (shared_ptr[CRayObject] result_object,
63
+ CObjectID object_id, void* user_data)
64
+
65
+ ctypedef void (*plasma_callback_function) \
66
+ (CObjectID object_id, int64_t data_size, int64_t metadata_size)
67
+
68
+ # NOTE: This ctypedef is needed, because Cython doesn't compile
69
+ # "pair[shared_ptr[const CActorHandle], CRayStatus]".
70
+ # This is a bug of cython: https://github.com/cython/cython/issues/3967.
71
+ ctypedef shared_ptr[const CActorHandle] ActorHandleSharedPtr
72
+
73
+
74
+ cdef extern from "ray/core_worker/profile_event.h" nogil:
75
+ cdef cppclass CProfileEvent "ray::core::worker::ProfileEvent":
76
+ void SetExtraData(const c_string &extra_data)
77
+
78
+ cdef extern from "ray/core_worker/fiber.h" nogil:
79
+ cdef cppclass CFiberEvent "ray::core::FiberEvent":
80
+ CFiberEvent()
81
+ void Wait()
82
+ void Notify()
83
+
84
+ cdef extern from "ray/core_worker/experimental_mutable_object_manager.h" nogil:
85
+ cdef cppclass CReaderRefInfo "ray::experimental::ReaderRefInfo":
86
+ CReaderRefInfo()
87
+ CObjectID reader_ref_id
88
+ CActorID owner_reader_actor_id
89
+ int64_t num_reader_actors
90
+
91
+
92
+ cdef extern from "ray/core_worker/context.h" nogil:
93
+ cdef cppclass CWorkerContext "ray::core::WorkerContext":
94
+ c_bool CurrentActorIsAsync()
95
+ const c_string &GetCurrentSerializedRuntimeEnv()
96
+ int CurrentActorMaxConcurrency()
97
+ const CActorID &GetRootDetachedActorID()
98
+
99
+ cdef extern from "ray/core_worker/generator_waiter.h" nogil:
100
+ cdef cppclass CGeneratorBackpressureWaiter "ray::core::GeneratorBackpressureWaiter": # noqa
101
+ CGeneratorBackpressureWaiter(
102
+ int64_t generator_backpressure_num_objects,
103
+ (CRayStatus() nogil) check_signals)
104
+ CRayStatus WaitAllObjectsReported()
105
+
106
+ cdef extern from "ray/core_worker/core_worker.h" nogil:
107
+ cdef cppclass CActorHandle "ray::core::ActorHandle":
108
+ CActorID GetActorID() const
109
+ CJobID CreationJobID() const
110
+ CLanguage ActorLanguage() const
111
+ CFunctionDescriptor ActorCreationTaskFunctionDescriptor() const
112
+ c_string ExtensionData() const
113
+ int MaxPendingCalls() const
114
+ int MaxTaskRetries() const
115
+ c_bool EnableTaskEvents() const
116
+
117
+ cdef cppclass CCoreWorker "ray::core::CoreWorker":
118
+ CWorkerType GetWorkerType()
119
+ CLanguage GetLanguage()
120
+
121
+ c_vector[CObjectReference] SubmitTask(
122
+ const CRayFunction &function,
123
+ const c_vector[unique_ptr[CTaskArg]] &args,
124
+ const CTaskOptions &options,
125
+ int max_retries,
126
+ c_bool retry_exceptions,
127
+ const CSchedulingStrategy &scheduling_strategy,
128
+ c_string debugger_breakpoint,
129
+ c_string serialized_retry_exception_allowlist,
130
+ c_string call_site,
131
+ const CTaskID current_task_id)
132
+ CRayStatus CreateActor(
133
+ const CRayFunction &function,
134
+ const c_vector[unique_ptr[CTaskArg]] &args,
135
+ const CActorCreationOptions &options,
136
+ const c_string &extension_data,
137
+ c_string call_site,
138
+ CActorID *actor_id)
139
+ CRayStatus CreatePlacementGroup(
140
+ const CPlacementGroupCreationOptions &options,
141
+ CPlacementGroupID *placement_group_id)
142
+ CRayStatus RemovePlacementGroup(
143
+ const CPlacementGroupID &placement_group_id)
144
+ CRayStatus WaitPlacementGroupReady(
145
+ const CPlacementGroupID &placement_group_id, int64_t timeout_seconds)
146
+ CRayStatus SubmitActorTask(
147
+ const CActorID &actor_id, const CRayFunction &function,
148
+ const c_vector[unique_ptr[CTaskArg]] &args,
149
+ const CTaskOptions &options,
150
+ int max_retries,
151
+ c_bool retry_exceptions,
152
+ c_string serialized_retry_exception_allowlist,
153
+ c_string call_site,
154
+ c_vector[CObjectReference] &task_returns,
155
+ const CTaskID current_task_id)
156
+ CRayStatus KillActor(
157
+ const CActorID &actor_id, c_bool force_kill,
158
+ c_bool no_restart)
159
+ CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill,
160
+ c_bool recursive)
161
+
162
+ unique_ptr[CProfileEvent] CreateProfileEvent(
163
+ const c_string &event_type)
164
+ CRayStatus AllocateReturnObject(
165
+ const CObjectID &object_id,
166
+ const size_t &data_size,
167
+ const shared_ptr[CBuffer] &metadata,
168
+ const c_vector[CObjectID] &contained_object_id,
169
+ const CAddress &caller_address,
170
+ int64_t *task_output_inlined_bytes,
171
+ shared_ptr[CRayObject] *return_object)
172
+ CRayStatus SealReturnObject(
173
+ const CObjectID &return_id,
174
+ const shared_ptr[CRayObject] &return_object,
175
+ const CObjectID &generator_id,
176
+ const CAddress &caller_address
177
+ )
178
+ c_bool PinExistingReturnObject(
179
+ const CObjectID &return_id,
180
+ shared_ptr[CRayObject] *return_object,
181
+ const CObjectID &generator_id,
182
+ const CAddress &caller_address)
183
+ void AsyncDelObjectRefStream(const CObjectID &generator_id)
184
+ CRayStatus TryReadObjectRefStream(
185
+ const CObjectID &generator_id,
186
+ CObjectReference *object_ref_out)
187
+ c_bool StreamingGeneratorIsFinished(const CObjectID &generator_id) const
188
+ pair[CObjectReference, c_bool] PeekObjectRefStream(
189
+ const CObjectID &generator_id)
190
+ CObjectID AllocateDynamicReturnId(
191
+ const CAddress &owner_address,
192
+ const CTaskID &task_id,
193
+ optional[ObjectIDIndexType] put_index)
194
+
195
+ CJobID GetCurrentJobId()
196
+ CTaskID GetCurrentTaskId()
197
+ const c_string GetCurrentTaskName()
198
+ const c_string GetCurrentTaskFunctionName()
199
+ void UpdateTaskIsDebuggerPaused(
200
+ const CTaskID &task_id,
201
+ const c_bool is_debugger_paused)
202
+ int64_t GetCurrentTaskAttemptNumber()
203
+ CNodeID GetCurrentNodeId()
204
+ int64_t GetTaskDepth()
205
+ c_bool GetCurrentTaskRetryExceptions()
206
+ CPlacementGroupID GetCurrentPlacementGroupId()
207
+ CWorkerID GetWorkerID()
208
+ c_bool ShouldCaptureChildTasksInPlacementGroup()
209
+ const CActorID &GetActorId()
210
+ const c_string GetActorName()
211
+ void SetActorTitle(const c_string &title)
212
+ void SetActorReprName(const c_string &repr_name)
213
+ void SetWebuiDisplay(const c_string &key, const c_string &message)
214
+ CTaskID GetCallerId()
215
+ const ResourceMappingType &GetResourceIDs() const
216
+ void RemoveActorHandleReference(const CActorID &actor_id)
217
+ optional[int] GetLocalActorState(const CActorID &actor_id) const
218
+ CActorID DeserializeAndRegisterActorHandle(const c_string &bytes, const
219
+ CObjectID &outer_object_id,
220
+ c_bool add_local_ref)
221
+ CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
222
+ *bytes,
223
+ CObjectID *c_actor_handle_id)
224
+ ActorHandleSharedPtr GetActorHandle(const CActorID &actor_id) const
225
+ pair[ActorHandleSharedPtr, CRayStatus] GetNamedActorHandle(
226
+ const c_string &name, const c_string &ray_namespace)
227
+ pair[c_vector[c_pair[c_string, c_string]], CRayStatus] ListNamedActors(
228
+ c_bool all_namespaces)
229
+ void AddLocalReference(const CObjectID &object_id)
230
+ void RemoveLocalReference(const CObjectID &object_id)
231
+ void PutObjectIntoPlasma(const CRayObject &object,
232
+ const CObjectID &object_id)
233
+ const CAddress &GetRpcAddress() const
234
+ CRayStatus GetOwnerAddress(const CObjectID &object_id,
235
+ CAddress *owner_address) const
236
+ c_vector[CObjectReference] GetObjectRefs(
237
+ const c_vector[CObjectID] &object_ids) const
238
+
239
+ CRayStatus GetOwnershipInfo(const CObjectID &object_id,
240
+ CAddress *owner_address,
241
+ c_string *object_status)
242
+ void RegisterOwnershipInfoAndResolveFuture(
243
+ const CObjectID &object_id,
244
+ const CObjectID &outer_object_id,
245
+ const CAddress &owner_address,
246
+ const c_string &object_status)
247
+
248
+ CRayStatus Put(const CRayObject &object,
249
+ const c_vector[CObjectID] &contained_object_ids,
250
+ CObjectID *object_id)
251
+ CRayStatus Put(const CRayObject &object,
252
+ const c_vector[CObjectID] &contained_object_ids,
253
+ const CObjectID &object_id)
254
+ CRayStatus CreateOwnedAndIncrementLocalRef(
255
+ c_bool is_mutable,
256
+ const shared_ptr[CBuffer] &metadata,
257
+ const size_t data_size,
258
+ const c_vector[CObjectID] &contained_object_ids,
259
+ CObjectID *object_id, shared_ptr[CBuffer] *data,
260
+ c_bool created_by_worker,
261
+ const unique_ptr[CAddress] &owner_address,
262
+ c_bool inline_small_object)
263
+ CRayStatus CreateExisting(const shared_ptr[CBuffer] &metadata,
264
+ const size_t data_size,
265
+ const CObjectID &object_id,
266
+ const CAddress &owner_address,
267
+ shared_ptr[CBuffer] *data,
268
+ c_bool created_by_worker)
269
+ CRayStatus ExperimentalChannelWriteAcquire(
270
+ const CObjectID &object_id,
271
+ const shared_ptr[CBuffer] &metadata,
272
+ uint64_t data_size,
273
+ int64_t num_readers,
274
+ int64_t timeout_ms,
275
+ shared_ptr[CBuffer] *data)
276
+ CRayStatus ExperimentalChannelWriteRelease(
277
+ const CObjectID &object_id)
278
+ CRayStatus ExperimentalChannelSetError(
279
+ const CObjectID &object_id)
280
+ CRayStatus ExperimentalRegisterMutableObjectWriter(
281
+ const CObjectID &writer_object_id,
282
+ const c_vector[CNodeID] &remote_reader_node_ids)
283
+ CRayStatus ExperimentalRegisterMutableObjectReader(const CObjectID &object_id)
284
+ CRayStatus ExperimentalRegisterMutableObjectReaderRemote(
285
+ const CObjectID &object_id,
286
+ const c_vector[CReaderRefInfo] &remote_reader_ref_info)
287
+ CRayStatus SealOwned(const CObjectID &object_id, c_bool pin_object,
288
+ const unique_ptr[CAddress] &owner_address)
289
+ CRayStatus SealExisting(const CObjectID &object_id, c_bool pin_object,
290
+ const CObjectID &generator_id,
291
+ const unique_ptr[CAddress] &owner_address)
292
+ CRayStatus Get(const c_vector[CObjectID] &ids, int64_t timeout_ms,
293
+ c_vector[shared_ptr[CRayObject]] results)
294
+ CRayStatus GetIfLocal(
295
+ const c_vector[CObjectID] &ids,
296
+ c_vector[shared_ptr[CRayObject]] *results)
297
+ CRayStatus Contains(const CObjectID &object_id, c_bool *has_object,
298
+ c_bool *is_in_plasma)
299
+ CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects,
300
+ int64_t timeout_ms, c_vector[c_bool] *results,
301
+ c_bool fetch_local)
302
+ CRayStatus Delete(const c_vector[CObjectID] &object_ids,
303
+ c_bool local_only)
304
+ CRayStatus GetLocalObjectLocations(
305
+ const c_vector[CObjectID] &object_ids,
306
+ c_vector[optional[CObjectLocation]] *results)
307
+ CRayStatus GetLocationFromOwner(
308
+ const c_vector[CObjectID] &object_ids,
309
+ int64_t timeout_ms,
310
+ c_vector[shared_ptr[CObjectLocation]] *results)
311
+ CRayStatus TriggerGlobalGC()
312
+ CRayStatus ReportGeneratorItemReturns(
313
+ const pair[CObjectID, shared_ptr[CRayObject]] &dynamic_return_object,
314
+ const CObjectID &generator_id,
315
+ const CAddress &caller_address,
316
+ int64_t item_index,
317
+ uint64_t attempt_number,
318
+ shared_ptr[CGeneratorBackpressureWaiter] waiter)
319
+ c_string MemoryUsageString()
320
+ int GetMemoryStoreSize()
321
+
322
+ CWorkerContext &GetWorkerContext()
323
+ void YieldCurrentFiber(CFiberEvent &coroutine_done)
324
+
325
+ unordered_map[CObjectID, pair[size_t, size_t]] GetAllReferenceCounts()
326
+ c_vector[CTaskID] GetPendingChildrenTasks(const CTaskID &task_id) const
327
+
328
+ void GetAsync(const CObjectID &object_id,
329
+ ray_callback_function success_callback,
330
+ void* python_user_callback)
331
+
332
+ CRayStatus PushError(const CJobID &job_id, const c_string &type,
333
+ const c_string &error_message, double timestamp)
334
+ CRayStatus SetResource(const c_string &resource_name,
335
+ const double capacity,
336
+ const CNodeID &client_Id)
337
+
338
+ CJobConfig GetJobConfig()
339
+
340
+ int64_t GetNumTasksSubmitted() const
341
+
342
+ int64_t GetNumLeasesRequested() const
343
+
344
+ int64_t GetLocalMemoryStoreBytesUsed() const
345
+
346
+ void RecordTaskLogStart(
347
+ const CTaskID &task_id,
348
+ int attempt_number,
349
+ const c_string& stdout_path,
350
+ const c_string& stderr_path,
351
+ int64_t stdout_start_offset,
352
+ int64_t stderr_start_offset) const
353
+
354
+ void RecordTaskLogEnd(
355
+ const CTaskID &task_id,
356
+ int attempt_number,
357
+ int64_t stdout_end_offset,
358
+ int64_t stderr_end_offset) const
359
+
360
+ void Exit(const CWorkerExitType exit_type,
361
+ const c_string &detail,
362
+ const shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes)
363
+
364
+ unordered_map[CLineageReconstructionTask, uint64_t] \
365
+ GetLocalOngoingLineageReconstructionTasks() const
366
+
367
+ cdef cppclass CCoreWorkerOptions "ray::core::CoreWorkerOptions":
368
+ CWorkerType worker_type
369
+ CLanguage language
370
+ c_string store_socket
371
+ c_string raylet_socket
372
+ CJobID job_id
373
+ CGcsClientOptions gcs_options
374
+ c_bool enable_logging
375
+ c_string log_dir
376
+ c_bool install_failure_signal_handler
377
+ c_bool interactive
378
+ c_string node_ip_address
379
+ int node_manager_port
380
+ c_string raylet_ip_address
381
+ c_string driver_name
382
+ c_string stdout_file
383
+ c_string stderr_file
384
+ (CRayStatus(
385
+ const CAddress &caller_address,
386
+ CTaskType task_type,
387
+ const c_string name,
388
+ const CRayFunction &ray_function,
389
+ const unordered_map[c_string, double] &resources,
390
+ const c_vector[shared_ptr[CRayObject]] &args,
391
+ const c_vector[CObjectReference] &arg_refs,
392
+ const c_string debugger_breakpoint,
393
+ const c_string serialized_retry_exception_allowlist,
394
+ c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns,
395
+ c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *dynamic_returns,
396
+ c_vector[c_pair[CObjectID, c_bool]] *streaming_generator_returns,
397
+ shared_ptr[LocalMemoryBuffer]
398
+ &creation_task_exception_pb_bytes,
399
+ c_bool *is_retryable_error,
400
+ c_string *application_error,
401
+ const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
402
+ const c_string name_of_concurrency_group_to_execute,
403
+ c_bool is_reattempt,
404
+ c_bool is_streaming_generator,
405
+ c_bool should_retry_exceptions,
406
+ int64_t generator_backpressure_num_objects
407
+ ) nogil) task_execution_callback
408
+ (void(const CWorkerID &) nogil) on_worker_shutdown
409
+ (CRayStatus() nogil) check_signals
410
+ (void(c_bool) nogil) gc_collect
411
+ (c_vector[c_string](
412
+ const c_vector[CObjectReference] &) nogil) spill_objects
413
+ (int64_t(
414
+ const c_vector[CObjectReference] &,
415
+ const c_vector[c_string] &) nogil) restore_spilled_objects
416
+ (void(
417
+ const c_vector[c_string]&,
418
+ CWorkerType) nogil) delete_spilled_objects
419
+ (void(
420
+ const c_string&,
421
+ const c_vector[c_string]&) nogil) run_on_util_worker_handler
422
+ (void(const CRayObject&) nogil) unhandled_exception_handler
423
+ (void(
424
+ const CTaskID &c_task_id,
425
+ const CRayFunction &ray_function,
426
+ const c_string c_name_of_concurrency_group_to_execute
427
+ ) nogil) cancel_async_task
428
+ (void(c_string *stack_out) nogil) get_lang_stack
429
+ c_bool is_local_mode
430
+ int num_workers
431
+ (c_bool(const CTaskID &) nogil) kill_main
432
+ CCoreWorkerOptions()
433
+ (void() nogil) terminate_asyncio_thread
434
+ c_string serialized_job_config
435
+ int metrics_agent_port
436
+ int runtime_env_hash
437
+ int startup_token
438
+ CClusterID cluster_id
439
+ c_string session_name
440
+ c_string entrypoint
441
+ int64_t worker_launch_time_ms
442
+ int64_t worker_launched_time_ms
443
+
444
+ cdef cppclass CCoreWorkerProcess "ray::core::CoreWorkerProcess":
445
+ @staticmethod
446
+ void Initialize(const CCoreWorkerOptions &options)
447
+ # Only call this in CoreWorker.__cinit__,
448
+ # use CoreWorker.core_worker to access C++ CoreWorker.
449
+
450
+ @staticmethod
451
+ CCoreWorker &GetCoreWorker()
452
+
453
+ @staticmethod
454
+ void Shutdown()
455
+
456
+ @staticmethod
457
+ void RunTaskExecutionLoop()
.venv/lib/python3.11/site-packages/ray/includes/metric.pxd ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libcpp.string cimport string as c_string
2
+ from libcpp.unordered_map cimport unordered_map
3
+ from libcpp.vector cimport vector as c_vector
4
+
5
+ cdef extern from "opencensus/tags/tag_key.h" nogil:
6
+ cdef cppclass CTagKey "opencensus::tags::TagKey":
7
+ @staticmethod
8
+ CTagKey Register(c_string &name)
9
+ const c_string &name() const
10
+
11
+ cdef extern from "ray/stats/metric.h" nogil:
12
+ cdef cppclass CMetric "ray::stats::Metric":
13
+ CMetric(const c_string &name,
14
+ const c_string &description,
15
+ const c_string &unit,
16
+ const c_vector[c_string] &tag_keys)
17
+ c_string GetName() const
18
+ void Record(double value)
19
+ void Record(double value,
20
+ unordered_map[c_string, c_string] &tags)
21
+
22
+ cdef cppclass CGauge "ray::stats::Gauge":
23
+ CGauge(const c_string &name,
24
+ const c_string &description,
25
+ const c_string &unit,
26
+ const c_vector[c_string] &tag_keys)
27
+
28
+ cdef cppclass CCount "ray::stats::Count":
29
+ CCount(const c_string &name,
30
+ const c_string &description,
31
+ const c_string &unit,
32
+ const c_vector[c_string] &tag_keys)
33
+
34
+ cdef cppclass CSum "ray::stats::Sum":
35
+ CSum(const c_string &name,
36
+ const c_string &description,
37
+ const c_string &unit,
38
+ const c_vector[c_string] &tag_keys)
39
+
40
+ cdef cppclass CHistogram "ray::stats::Histogram":
41
+ CHistogram(const c_string &name,
42
+ const c_string &description,
43
+ const c_string &unit,
44
+ const c_vector[double] &boundaries,
45
+ const c_vector[c_string] &tag_keys)
.venv/lib/python3.11/site-packages/ray/includes/optional.pxd ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Currently Cython does not support std::optional.
2
+ # See: https://github.com/cython/cython/pull/3294
3
+ from libcpp cimport bool
4
+
5
+ cdef extern from "<optional>" namespace "std" nogil:
6
+ cdef cppclass nullopt_t:
7
+ nullopt_t()
8
+
9
+ cdef nullopt_t nullopt
10
+
11
+ cdef cppclass optional[T]:
12
+ ctypedef T value_type
13
+ optional()
14
+ optional(nullopt_t)
15
+ optional(optional&) except +
16
+ optional(T&) except +
17
+ bool has_value()
18
+ T& value()
19
+ T& value_or[U](U& default_value)
20
+ void swap(optional&)
21
+ void reset()
22
+ T& emplace(...)
23
+ T& operator*()
24
+ # T* operator->() # Not Supported
25
+ optional& operator=(optional&)
26
+ optional& operator=[U](U&)
27
+ bool operator bool()
28
+ bool operator!()
29
+ bool operator==[U](optional&, U&)
30
+ bool operator!=[U](optional&, U&)
31
+ bool operator<[U](optional&, U&)
32
+ bool operator>[U](optional&, U&)
33
+ bool operator<=[U](optional&, U&)
34
+ bool operator>=[U](optional&, U&)
35
+
36
+ optional[T] make_optional[T](...) except +
.venv/lib/python3.11/site-packages/ray/includes/ray_config.pxd ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libcpp cimport bool as c_bool
2
+ from libc.stdint cimport int64_t, uint64_t, uint32_t
3
+ from libcpp.string cimport string as c_string
4
+ from libcpp.unordered_map cimport unordered_map
5
+
6
+
7
+ cdef extern from "ray/common/ray_config.h" nogil:
8
+ cdef cppclass RayConfig "RayConfig":
9
+ @staticmethod
10
+ RayConfig &instance()
11
+
12
+ void initialize(const c_string& config_list)
13
+
14
+ int64_t ray_cookie() const
15
+
16
+ int64_t handler_warning_timeout_ms() const
17
+
18
+ int64_t debug_dump_period_milliseconds() const
19
+
20
+ int64_t object_timeout_milliseconds() const
21
+
22
+ int64_t raylet_client_num_connect_attempts() const
23
+
24
+ int64_t raylet_client_connect_timeout_milliseconds() const
25
+
26
+ int64_t raylet_fetch_timeout_milliseconds() const
27
+
28
+ int64_t kill_worker_timeout_milliseconds() const
29
+
30
+ int64_t worker_register_timeout_seconds() const
31
+
32
+ int64_t redis_db_connect_retries()
33
+
34
+ int64_t redis_db_connect_wait_milliseconds() const
35
+
36
+ int object_manager_pull_timeout_ms() const
37
+
38
+ int object_manager_push_timeout_ms() const
39
+
40
+ uint64_t object_manager_default_chunk_size() const
41
+
42
+ uint32_t maximum_gcs_deletion_batch_size() const
43
+
44
+ int64_t max_direct_call_object_size() const
45
+
46
+ int64_t task_rpc_inlined_bytes_limit() const
47
+
48
+ uint64_t metrics_report_interval_ms() const
49
+
50
+ c_bool enable_timeline() const
51
+
52
+ uint32_t max_grpc_message_size() const
53
+
54
+ c_bool record_ref_creation_sites() const
55
+
56
+ c_string REDIS_CA_CERT() const
57
+
58
+ c_string REDIS_CA_PATH() const
59
+
60
+ c_string REDIS_CLIENT_CERT() const
61
+
62
+ c_string REDIS_CLIENT_KEY() const
63
+
64
+ c_string REDIS_SERVER_NAME() const
65
+
66
+ int64_t health_check_initial_delay_ms() const
67
+
68
+ int64_t health_check_period_ms() const
69
+
70
+ int64_t health_check_timeout_ms() const
71
+
72
+ int64_t health_check_failure_threshold() const
73
+
74
+ uint64_t memory_monitor_refresh_ms() const
75
+
76
+ int64_t grpc_keepalive_time_ms() const
77
+
78
+ int64_t grpc_keepalive_timeout_ms() const
79
+
80
+ int64_t grpc_client_keepalive_time_ms() const
81
+
82
+ int64_t grpc_client_keepalive_timeout_ms() const
83
+
84
+ c_bool enable_autoscaler_v2() const
85
+
86
+ c_string predefined_unit_instance_resources() const
87
+
88
+ c_string custom_unit_instance_resources() const
89
+
90
+ int64_t nums_py_gcs_reconnect_retry() const
91
+
92
+ int64_t py_gcs_connect_timeout_s() const
93
+
94
+ int gcs_rpc_server_reconnect_timeout_s() const
95
+
96
+ int maximum_gcs_destroyed_actor_cached_count() const
97
+
98
+ c_bool record_task_actor_creation_sites() const
.venv/lib/python3.11/site-packages/ray/includes/unique_ids.pxd ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libcpp cimport bool as c_bool
2
+ from libcpp.string cimport string as c_string
3
+ from libc.stdint cimport uint8_t, uint32_t, int64_t
4
+
5
+ cdef extern from "ray/common/id.h" namespace "ray" nogil:
6
+ cdef cppclass CBaseID[T]:
7
+ @staticmethod
8
+ T FromBinary(const c_string &binary)
9
+
10
+ @staticmethod
11
+ T FromHex(const c_string &hex_str)
12
+
13
+ @staticmethod
14
+ const T Nil()
15
+
16
+ @staticmethod
17
+ size_t Size()
18
+
19
+ size_t Hash() const
20
+ c_bool IsNil() const
21
+ c_bool operator==(const CBaseID &rhs) const
22
+ c_bool operator!=(const CBaseID &rhs) const
23
+ const uint8_t *data() const
24
+
25
+ c_string Binary() const
26
+ c_string Hex() const
27
+
28
+ cdef cppclass CUniqueID "ray::UniqueID"(CBaseID):
29
+ CUniqueID()
30
+
31
+ @staticmethod
32
+ size_t Size()
33
+
34
+ @staticmethod
35
+ CUniqueID FromRandom()
36
+
37
+ @staticmethod
38
+ CUniqueID FromBinary(const c_string &binary)
39
+
40
+ @staticmethod
41
+ const CUniqueID Nil()
42
+
43
+ @staticmethod
44
+ size_t Size()
45
+
46
+ cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID):
47
+
48
+ @staticmethod
49
+ CActorClassID FromBinary(const c_string &binary)
50
+
51
+ @staticmethod
52
+ CActorClassID FromHex(const c_string &hex_str)
53
+
54
+ cdef cppclass CActorID "ray::ActorID"(CBaseID[CActorID]):
55
+
56
+ @staticmethod
57
+ CActorID FromBinary(const c_string &binary)
58
+
59
+ @staticmethod
60
+ CActorID FromHex(const c_string &hex_str)
61
+
62
+ @staticmethod
63
+ const CActorID Nil()
64
+
65
+ @staticmethod
66
+ size_t Size()
67
+
68
+ @staticmethod
69
+ CActorID Of(CJobID job_id, CTaskID parent_task_id,
70
+ int64_t parent_task_counter)
71
+
72
+ CJobID JobId()
73
+
74
+ cdef cppclass CNodeID "ray::NodeID"(CUniqueID):
75
+
76
+ @staticmethod
77
+ CNodeID FromBinary(const c_string &binary)
78
+
79
+ @staticmethod
80
+ CNodeID FromHex(const c_string &hex_str)
81
+
82
+ @staticmethod
83
+ const CNodeID Nil()
84
+
85
+ cdef cppclass CConfigID "ray::ConfigID"(CUniqueID):
86
+
87
+ @staticmethod
88
+ CConfigID FromBinary(const c_string &binary)
89
+
90
+ cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID):
91
+
92
+ @staticmethod
93
+ CFunctionID FromBinary(const c_string &binary)
94
+
95
+ @staticmethod
96
+ CFunctionID FromHex(const c_string &hex_str)
97
+
98
+ cdef cppclass CJobID "ray::JobID"(CBaseID[CJobID]):
99
+
100
+ @staticmethod
101
+ CJobID FromBinary(const c_string &binary)
102
+
103
+ @staticmethod
104
+ CJobID FromHex(const c_string &hex_str)
105
+
106
+ @staticmethod
107
+ const CJobID Nil()
108
+
109
+ @staticmethod
110
+ size_t Size()
111
+
112
+ @staticmethod
113
+ CJobID FromInt(uint32_t value)
114
+
115
+ uint32_t ToInt()
116
+
117
+ cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]):
118
+
119
+ @staticmethod
120
+ CTaskID FromBinary(const c_string &binary)
121
+
122
+ @staticmethod
123
+ CTaskID FromHex(const c_string &hex_str)
124
+
125
+ @staticmethod
126
+ const CTaskID Nil()
127
+
128
+ @staticmethod
129
+ size_t Size()
130
+
131
+ @staticmethod
132
+ CTaskID ForDriverTask(const CJobID &job_id)
133
+
134
+ @staticmethod
135
+ CTaskID FromRandom(const CJobID &job_id)
136
+
137
+ @staticmethod
138
+ CTaskID ForActorCreationTask(CActorID actor_id)
139
+
140
+ @staticmethod
141
+ CTaskID ForActorTask(CJobID job_id, CTaskID parent_task_id,
142
+ int64_t parent_task_counter, CActorID actor_id)
143
+
144
+ @staticmethod
145
+ CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id,
146
+ int64_t parent_task_counter)
147
+
148
+ CActorID ActorId() const
149
+
150
+ CJobID JobId() const
151
+
152
+ cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]):
153
+
154
+ @staticmethod
155
+ int64_t MaxObjectIndex()
156
+
157
+ @staticmethod
158
+ CObjectID FromBinary(const c_string &binary)
159
+
160
+ @staticmethod
161
+ CObjectID FromRandom()
162
+
163
+ @staticmethod
164
+ const CObjectID Nil()
165
+
166
+ @staticmethod
167
+ CObjectID FromIndex(const CTaskID &task_id, int64_t index)
168
+
169
+ @staticmethod
170
+ size_t Size()
171
+
172
+ c_bool is_put()
173
+
174
+ int64_t ObjectIndex() const
175
+
176
+ CTaskID TaskId() const
177
+
178
+ cdef cppclass CClusterID "ray::ClusterID"(CUniqueID):
179
+
180
+ @staticmethod
181
+ CClusterID FromBinary(const c_string &binary)
182
+
183
+ @staticmethod
184
+ CClusterID FromHex(const c_string &hex_str)
185
+
186
+ @staticmethod
187
+ CClusterID FromRandom()
188
+
189
+ @staticmethod
190
+ const CClusterID Nil()
191
+
192
+ cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID):
193
+
194
+ @staticmethod
195
+ CWorkerID FromBinary(const c_string &binary)
196
+
197
+ @staticmethod
198
+ CWorkerID FromHex(const c_string &hex_str)
199
+
200
+ cdef cppclass CPlacementGroupID "ray::PlacementGroupID" \
201
+ (CBaseID[CPlacementGroupID]):
202
+
203
+ @staticmethod
204
+ CPlacementGroupID FromBinary(const c_string &binary)
205
+
206
+ @staticmethod
207
+ CPlacementGroupID FromHex(const c_string &hex_str)
208
+
209
+ @staticmethod
210
+ const CPlacementGroupID Nil()
211
+
212
+ @staticmethod
213
+ size_t Size()
214
+
215
+ @staticmethod
216
+ CPlacementGroupID Of(CJobID job_id)
217
+
218
+ ctypedef uint32_t ObjectIDIndexType
.venv/lib/python3.11/site-packages/ray/runtime_env/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ray._private.runtime_env.mpi import mpi_init # noqa: E402,F401
2
+ from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig # noqa: E402,F401
3
+
4
+ __all__ = [
5
+ "RuntimeEnvConfig",
6
+ "RuntimeEnv",
7
+ "mpi_init",
8
+ ]
.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (423 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/runtime_env/__pycache__/runtime_env.cpython-311.pyc ADDED
Binary file (31.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/runtime_env/runtime_env.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from copy import deepcopy
5
+ from dataclasses import asdict, is_dataclass
6
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
7
+
8
+ import ray
9
+ from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
10
+ from ray._private.runtime_env.conda import get_uri as get_conda_uri
11
+ from ray._private.runtime_env.default_impl import get_image_uri_plugin_cls
12
+ from ray._private.runtime_env.pip import get_uri as get_pip_uri
13
+ from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager
14
+ from ray._private.runtime_env.uv import get_uri as get_uv_uri
15
+ from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
16
+ from ray._private.thirdparty.dacite import from_dict
17
+ from ray.core.generated.runtime_env_common_pb2 import (
18
+ RuntimeEnvConfig as ProtoRuntimeEnvConfig,
19
+ )
20
+ from ray.util.annotations import PublicAPI
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @PublicAPI(stability="stable")
26
+ class RuntimeEnvConfig(dict):
27
+ """Used to specify configuration options for a runtime environment.
28
+
29
+ The config is not included when calculating the runtime_env hash,
30
+ which means that two runtime_envs with the same options but different
31
+ configs are considered the same for caching purposes.
32
+
33
+ Args:
34
+ setup_timeout_seconds: The timeout of runtime environment
35
+ creation, timeout is in seconds. The value `-1` means disable
36
+ timeout logic, except `-1`, `setup_timeout_seconds` cannot be
37
+ less than or equal to 0. The default value of `setup_timeout_seconds`
38
+ is 600 seconds.
39
+ eager_install: Indicates whether to install the runtime environment
40
+ on the cluster at `ray.init()` time, before the workers are leased.
41
+ This flag is set to `True` by default.
42
+ """
43
+
44
+ known_fields: Set[str] = {"setup_timeout_seconds", "eager_install", "log_files"}
45
+
46
+ _default_config: Dict = {
47
+ "setup_timeout_seconds": DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS,
48
+ "eager_install": True,
49
+ "log_files": [],
50
+ }
51
+
52
+ def __init__(
53
+ self,
54
+ setup_timeout_seconds: int = DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS,
55
+ eager_install: bool = True,
56
+ log_files: Optional[List[str]] = None,
57
+ ):
58
+ super().__init__()
59
+ if not isinstance(setup_timeout_seconds, int):
60
+ raise TypeError(
61
+ "setup_timeout_seconds must be of type int, "
62
+ f"got: {type(setup_timeout_seconds)}"
63
+ )
64
+ elif setup_timeout_seconds <= 0 and setup_timeout_seconds != -1:
65
+ raise ValueError(
66
+ "setup_timeout_seconds must be greater than zero "
67
+ f"or equals to -1, got: {setup_timeout_seconds}"
68
+ )
69
+ self["setup_timeout_seconds"] = setup_timeout_seconds
70
+
71
+ if not isinstance(eager_install, bool):
72
+ raise TypeError(
73
+ f"eager_install must be a boolean. got {type(eager_install)}"
74
+ )
75
+ self["eager_install"] = eager_install
76
+
77
+ if log_files is not None:
78
+ if not isinstance(log_files, list):
79
+ raise TypeError(
80
+ "log_files must be a list of strings or None, got "
81
+ f"{log_files} with type {type(log_files)}."
82
+ )
83
+ for file_name in log_files:
84
+ if not isinstance(file_name, str):
85
+ raise TypeError("Each item in log_files must be a string.")
86
+ else:
87
+ log_files = self._default_config["log_files"]
88
+
89
+ self["log_files"] = log_files
90
+
91
+ @staticmethod
92
+ def parse_and_validate_runtime_env_config(
93
+ config: Union[Dict, "RuntimeEnvConfig"]
94
+ ) -> "RuntimeEnvConfig":
95
+ if isinstance(config, RuntimeEnvConfig):
96
+ return config
97
+ elif isinstance(config, Dict):
98
+ unknown_fields = set(config.keys()) - RuntimeEnvConfig.known_fields
99
+ if len(unknown_fields):
100
+ logger.warning(
101
+ "The following unknown entries in the runtime_env_config "
102
+ f"dictionary will be ignored: {unknown_fields}."
103
+ )
104
+ config_dict = dict()
105
+ for field in RuntimeEnvConfig.known_fields:
106
+ if field in config:
107
+ config_dict[field] = config[field]
108
+ return RuntimeEnvConfig(**config_dict)
109
+ else:
110
+ raise TypeError(
111
+ "runtime_env['config'] must be of type dict or RuntimeEnvConfig, "
112
+ f"got: {type(config)}"
113
+ )
114
+
115
+ @classmethod
116
+ def default_config(cls):
117
+ return RuntimeEnvConfig(**cls._default_config)
118
+
119
+ def build_proto_runtime_env_config(self) -> ProtoRuntimeEnvConfig:
120
+ runtime_env_config = ProtoRuntimeEnvConfig()
121
+ runtime_env_config.setup_timeout_seconds = self["setup_timeout_seconds"]
122
+ runtime_env_config.eager_install = self["eager_install"]
123
+ if self["log_files"] is not None:
124
+ runtime_env_config.log_files.extend(self["log_files"])
125
+ return runtime_env_config
126
+
127
+ @classmethod
128
+ def from_proto(cls, runtime_env_config: ProtoRuntimeEnvConfig):
129
+ setup_timeout_seconds = runtime_env_config.setup_timeout_seconds
130
+ # Cause python class RuntimeEnvConfig has validate to avoid
131
+ # setup_timeout_seconds equals zero, so setup_timeout_seconds
132
+ # on RuntimeEnvConfig is zero means other Language(except python)
133
+ # dosn't assign value to setup_timeout_seconds. So runtime_env_agent
134
+ # assign the default value to setup_timeout_seconds.
135
+ if setup_timeout_seconds == 0:
136
+ setup_timeout_seconds = cls._default_config["setup_timeout_seconds"]
137
+ return cls(
138
+ setup_timeout_seconds=setup_timeout_seconds,
139
+ eager_install=runtime_env_config.eager_install,
140
+ log_files=list(runtime_env_config.log_files),
141
+ )
142
+
143
+ def to_dict(self) -> Dict:
144
+ return dict(deepcopy(self))
145
+
146
+
147
+ # Due to circular reference, field config can only be assigned a value here
148
+ OPTION_TO_VALIDATION_FN[
149
+ "config"
150
+ ] = RuntimeEnvConfig.parse_and_validate_runtime_env_config
151
+
152
+
153
+ @PublicAPI
154
+ class RuntimeEnv(dict):
155
+ """This class is used to define a runtime environment for a job, task,
156
+ or actor.
157
+
158
+ See :ref:`runtime-environments` for detailed documentation.
159
+
160
+ This class can be used interchangeably with an unstructured dictionary
161
+ in the relevant API calls.
162
+
163
+ Can specify a runtime environment whole job, whether running a script
164
+ directly on the cluster, using Ray Job submission, or using Ray Client:
165
+
166
+ .. code-block:: python
167
+
168
+ from ray.runtime_env import RuntimeEnv
169
+ # Starting a single-node local Ray cluster
170
+ ray.init(runtime_env=RuntimeEnv(...))
171
+
172
+ .. code-block:: python
173
+
174
+ from ray.runtime_env import RuntimeEnv
175
+ # Connecting to remote cluster using Ray Client
176
+ ray.init("ray://123.456.7.89:10001", runtime_env=RuntimeEnv(...))
177
+
178
+ Can specify different runtime environments per-actor or per-task using
179
+ ``.options()`` or the ``@ray.remote`` decorator:
180
+
181
+ .. code-block:: python
182
+
183
+ from ray.runtime_env import RuntimeEnv
184
+ # Invoke a remote task that runs in a specified runtime environment.
185
+ f.options(runtime_env=RuntimeEnv(...)).remote()
186
+
187
+ # Instantiate an actor that runs in a specified runtime environment.
188
+ actor = SomeClass.options(runtime_env=RuntimeEnv(...)).remote()
189
+
190
+ # Specify a runtime environment in the task definition. Future invocations via
191
+ # `g.remote()` use this runtime environment unless overridden by using
192
+ # `.options()` as above.
193
+ @ray.remote(runtime_env=RuntimeEnv(...))
194
+ def g():
195
+ pass
196
+
197
+ # Specify a runtime environment in the actor definition. Future instantiations
198
+ # via `MyClass.remote()` use this runtime environment unless overridden by
199
+ # using `.options()` as above.
200
+ @ray.remote(runtime_env=RuntimeEnv(...))
201
+ class MyClass:
202
+ pass
203
+
204
+ Here are some examples of RuntimeEnv initialization:
205
+
206
+ .. code-block:: python
207
+
208
+ # Example for using conda
209
+ RuntimeEnv(conda={
210
+ "channels": ["defaults"], "dependencies": ["codecov"]})
211
+ RuntimeEnv(conda="pytorch_p36") # Found on DLAMIs
212
+
213
+ # Example for using container
214
+ RuntimeEnv(
215
+ container={"image": "anyscale/ray-ml:nightly-py38-cpu",
216
+ "run_options": ["--cap-drop SYS_ADMIN","--log-level=debug"]})
217
+
218
+ # Example for set env_vars
219
+ RuntimeEnv(env_vars={"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"})
220
+
221
+ # Example for set pip
222
+ RuntimeEnv(
223
+ pip={"packages":["tensorflow", "requests"], "pip_check": False,
224
+ "pip_version": "==22.0.2;python_version=='3.8.11'"})
225
+
226
+ # Example for using image_uri
227
+ RuntimeEnv(
228
+ image_uri="rayproject/ray:2.39.0-py312-cu123")
229
+
230
+ Args:
231
+ py_modules: List of URIs (either in the GCS or external
232
+ storage), each of which is a zip file that Ray unpacks and
233
+ inserts into the PYTHONPATH of the workers.
234
+ working_dir: URI (either in the GCS or external storage) of a zip
235
+ file that Ray unpacks in the directory of each task/actor.
236
+ pip: Either a list of pip packages, a string
237
+ containing the path to a pip requirements.txt file, or a Python
238
+ dictionary that has three fields: 1) ``packages`` (required, List[str]): a
239
+ list of pip packages, 2) ``pip_check`` (optional, bool): whether enable
240
+ pip check at the end of pip install, defaults to False.
241
+ 3) ``pip_version`` (optional, str): the version of pip, Ray prepends
242
+ the package name "pip" in front of the ``pip_version`` to form the final
243
+ requirement string, the syntax of a requirement specifier is defined in
244
+ full in PEP 508.
245
+ uv: Either a list of pip packages, or a Python dictionary that has one field:
246
+ 1) ``packages`` (required, List[str]).
247
+ conda: Either the conda YAML config, the name of a
248
+ local conda env (e.g., "pytorch_p36"), or the path to a conda
249
+ environment.yaml file.
250
+ Ray automatically injects the dependency into the conda
251
+ env to ensure compatibility with the cluster Ray. Ray may automatically
252
+ mangle the conda name to avoid conflicts between runtime envs.
253
+ This field can't be specified at the same time as the 'pip' field.
254
+ To use pip with conda, specify your pip dependencies within
255
+ the conda YAML config:
256
+ https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#create-env-file-manually
257
+ container: Require a given (Docker) container image,
258
+ The Ray worker process runs in a container with this image.
259
+ This parameter only works alone, or with the ``config`` or
260
+ ``env_vars`` parameters.
261
+ The `run_options` list spec is here:
262
+ https://docs.docker.com/engine/reference/run/
263
+ env_vars: Environment variables to set.
264
+ worker_process_setup_hook: (Experimental) The setup hook that's
265
+ called after workers start and before Tasks and Actors are scheduled.
266
+ A module name (string type) or callable (function) can be passed.
267
+ When a module name is passed, Ray worker should be able to access the
268
+ module name. When a callable is passed, callable should be serializable.
269
+ When a runtime env is specified by job submission API,
270
+ only a module name (string) is allowed.
271
+ nsight: Dictionary mapping nsight profile option name to it's value.
272
+ config: config for runtime environment. Either
273
+ a dict or a RuntimeEnvConfig. Field: (1) setup_timeout_seconds, the
274
+ timeout of runtime environment creation, timeout is in seconds.
275
+ image_uri: URI to a container image. The Ray worker process runs
276
+ in a container with this image. This parameter only works alone,
277
+ or with the ``config`` or ``env_vars`` parameters.
278
+ """
279
+
280
+ known_fields: Set[str] = {
281
+ "py_modules",
282
+ "java_jars",
283
+ "working_dir",
284
+ "conda",
285
+ "pip",
286
+ "uv",
287
+ "container",
288
+ "excludes",
289
+ "env_vars",
290
+ "_ray_release",
291
+ "_ray_commit",
292
+ "_inject_current_ray",
293
+ "config",
294
+ # TODO(SongGuyang): We add this because the test
295
+ # `test_experimental_package_github` set a `docker`
296
+ # field which is not supported. We should remove it
297
+ # with the test.
298
+ "docker",
299
+ "worker_process_setup_hook",
300
+ "_nsight",
301
+ "mpi",
302
+ "image_uri",
303
+ }
304
+
305
+ extensions_fields: Set[str] = {
306
+ "_ray_release",
307
+ "_ray_commit",
308
+ "_inject_current_ray",
309
+ }
310
+
311
+ def __init__(
312
+ self,
313
+ *,
314
+ py_modules: Optional[List[str]] = None,
315
+ working_dir: Optional[str] = None,
316
+ pip: Optional[List[str]] = None,
317
+ conda: Optional[Union[Dict[str, str], str]] = None,
318
+ container: Optional[Dict[str, str]] = None,
319
+ env_vars: Optional[Dict[str, str]] = None,
320
+ worker_process_setup_hook: Optional[Union[Callable, str]] = None,
321
+ nsight: Optional[Union[str, Dict[str, str]]] = None,
322
+ config: Optional[Union[Dict, RuntimeEnvConfig]] = None,
323
+ _validate: bool = True,
324
+ mpi: Optional[Dict] = None,
325
+ image_uri: Optional[str] = None,
326
+ uv: Optional[List[str]] = None,
327
+ **kwargs,
328
+ ):
329
+ super().__init__()
330
+
331
+ runtime_env = kwargs
332
+ if py_modules is not None:
333
+ runtime_env["py_modules"] = py_modules
334
+ if working_dir is not None:
335
+ runtime_env["working_dir"] = working_dir
336
+ if pip is not None:
337
+ runtime_env["pip"] = pip
338
+ if uv is not None:
339
+ runtime_env["uv"] = uv
340
+ if conda is not None:
341
+ runtime_env["conda"] = conda
342
+ if nsight is not None:
343
+ runtime_env["_nsight"] = nsight
344
+ if container is not None:
345
+ runtime_env["container"] = container
346
+ if env_vars is not None:
347
+ runtime_env["env_vars"] = env_vars
348
+ if config is not None:
349
+ runtime_env["config"] = config
350
+ if worker_process_setup_hook is not None:
351
+ runtime_env["worker_process_setup_hook"] = worker_process_setup_hook
352
+ if mpi is not None:
353
+ runtime_env["mpi"] = mpi
354
+ if image_uri is not None:
355
+ runtime_env["image_uri"] = image_uri
356
+ if runtime_env.get("java_jars"):
357
+ runtime_env["java_jars"] = runtime_env.get("java_jars")
358
+
359
+ self.update(runtime_env)
360
+
361
+ # Blindly trust that the runtime_env has already been validated.
362
+ # This is dangerous and should only be used internally (e.g., on the
363
+ # deserialization codepath.
364
+ if not _validate:
365
+ return
366
+
367
+ if (self.get("conda") is not None) + (self.get("pip") is not None) + (
368
+ self.get("uv") is not None
369
+ ) > 1:
370
+ raise ValueError(
371
+ "The 'pip' field, 'uv' field, and 'conda' field of "
372
+ "runtime_env cannot be specified at the same time.\n"
373
+ f"specified pip field: {self.get('pip')}\n"
374
+ f"specified conda field: {self.get('conda')}\n"
375
+ f"specified uv field: {self.get('uv')}\n"
376
+ "To use pip with conda, please only set the 'conda'"
377
+ "field, and specify your pip dependencies within the conda YAML "
378
+ "config dict: see https://conda.io/projects/conda/en/latest/"
379
+ "user-guide/tasks/manage-environments.html"
380
+ "#create-env-file-manually"
381
+ )
382
+
383
+ if self.get("container"):
384
+ invalid_keys = set(runtime_env.keys()) - {"container", "config", "env_vars"}
385
+ if len(invalid_keys):
386
+ raise ValueError(
387
+ "The 'container' field currently cannot be used "
388
+ "together with other fields of runtime_env. "
389
+ f"Specified fields: {invalid_keys}"
390
+ )
391
+
392
+ logger.warning(
393
+ "The `container` runtime environment field is DEPRECATED and will be "
394
+ "removed after July 31, 2025. Use `image_uri` instead. See "
395
+ "https://docs.ray.io/en/latest/serve/advanced-guides/multi-app-container.html." # noqa
396
+ )
397
+
398
+ if self.get("image_uri"):
399
+ image_uri_plugin_cls = get_image_uri_plugin_cls()
400
+ invalid_keys = (
401
+ set(runtime_env.keys()) - image_uri_plugin_cls.get_compatible_keys()
402
+ )
403
+ if len(invalid_keys):
404
+ raise ValueError(
405
+ "The 'image_uri' field currently cannot be used "
406
+ "together with other fields of runtime_env. "
407
+ f"Specified fields: {invalid_keys}"
408
+ )
409
+
410
+ for option, validate_fn in OPTION_TO_VALIDATION_FN.items():
411
+ option_val = self.get(option)
412
+ if option_val is not None:
413
+ del self[option]
414
+ self[option] = option_val
415
+
416
+ if "_ray_commit" not in self:
417
+ if self.get("pip") or self.get("conda"):
418
+ self["_ray_commit"] = ray.__commit__
419
+
420
+ # Used for testing wheels that have not yet been merged into master.
421
+ # If this is set to True, then we do not inject Ray into the conda
422
+ # or pip dependencies.
423
+ if "_inject_current_ray" not in self:
424
+ if "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
425
+ self["_inject_current_ray"] = True
426
+
427
+ # NOTE(architkulkarni): This allows worker caching code in C++ to check
428
+ # if a runtime env is empty without deserializing it. This is a catch-
429
+ # all; for validated inputs we won't set the key if the value is None.
430
+ if all(val is None for val in self.values()):
431
+ self.clear()
432
+
433
+ def __setitem__(self, key: str, value: Any) -> None:
434
+ if is_dataclass(value):
435
+ jsonable_type = asdict(value)
436
+ else:
437
+ jsonable_type = value
438
+ RuntimeEnvPluginSchemaManager.validate(key, jsonable_type)
439
+ res_value = jsonable_type
440
+ if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
441
+ res_value = OPTION_TO_VALIDATION_FN[key](jsonable_type)
442
+ if res_value is None:
443
+ return
444
+ return super().__setitem__(key, res_value)
445
+
446
+ def set(self, name: str, value: Any) -> None:
447
+ self.__setitem__(name, value)
448
+
449
+ def get(self, name, default=None, data_class=None):
450
+ if name not in self:
451
+ return default
452
+ if not data_class:
453
+ return self.__getitem__(name)
454
+ else:
455
+ return from_dict(data_class=data_class, data=self.__getitem__(name))
456
+
457
+ @classmethod
458
+ def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
459
+ return cls(_validate=False, **json.loads(serialized_runtime_env))
460
+
461
+ def serialize(self) -> str:
462
+ # To ensure the accuracy of Proto, `__setitem__` can only guarantee the
463
+ # accuracy of a certain field, not the overall accuracy
464
+ runtime_env = type(self)(_validate=True, **self)
465
+ return json.dumps(
466
+ runtime_env,
467
+ sort_keys=True,
468
+ )
469
+
470
+ def to_dict(self) -> Dict:
471
+ runtime_env_dict = dict(deepcopy(self))
472
+
473
+ # Replace strongly-typed RuntimeEnvConfig with a dict to allow the returned
474
+ # dict to work properly as a field in a dataclass. Details in issue #26986
475
+ if runtime_env_dict.get("config"):
476
+ runtime_env_dict["config"] = runtime_env_dict["config"].to_dict()
477
+
478
+ return runtime_env_dict
479
+
480
+ def has_working_dir(self) -> bool:
481
+ return self.get("working_dir") is not None
482
+
483
+ def working_dir_uri(self) -> Optional[str]:
484
+ return self.get("working_dir")
485
+
486
+ def py_modules_uris(self) -> List[str]:
487
+ if "py_modules" in self:
488
+ return list(self["py_modules"])
489
+ return []
490
+
491
+ def conda_uri(self) -> Optional[str]:
492
+ if "conda" in self:
493
+ return get_conda_uri(self)
494
+ return None
495
+
496
+ def pip_uri(self) -> Optional[str]:
497
+ if "pip" in self:
498
+ return get_pip_uri(self)
499
+ return None
500
+
501
+ def uv_uri(self) -> Optional[str]:
502
+ if "uv" in self:
503
+ return get_uv_uri(self)
504
+ return None
505
+
506
+ def plugin_uris(self) -> List[str]:
507
+ """Not implemented yet, always return a empty list"""
508
+ return []
509
+
510
+ def working_dir(self) -> str:
511
+ return self.get("working_dir", "")
512
+
513
+ def py_modules(self) -> List[str]:
514
+ if "py_modules" in self:
515
+ return list(self["py_modules"])
516
+ return []
517
+
518
+ def java_jars(self) -> List[str]:
519
+ if "java_jars" in self:
520
+ return list(self["java_jars"])
521
+ return []
522
+
523
+ def mpi(self) -> Optional[Union[str, Dict[str, str]]]:
524
+ return self.get("mpi", None)
525
+
526
+ def nsight(self) -> Optional[Union[str, Dict[str, str]]]:
527
+ return self.get("_nsight", None)
528
+
529
+ def env_vars(self) -> Dict:
530
+ return self.get("env_vars", {})
531
+
532
+ def has_conda(self) -> str:
533
+ if self.get("conda"):
534
+ return True
535
+ return False
536
+
537
+ def conda_env_name(self) -> str:
538
+ if not self.has_conda() or not isinstance(self["conda"], str):
539
+ return None
540
+ return self["conda"]
541
+
542
+ def conda_config(self) -> str:
543
+ if not self.has_conda() or not isinstance(self["conda"], dict):
544
+ return None
545
+ return json.dumps(self["conda"], sort_keys=True)
546
+
547
+ def has_pip(self) -> bool:
548
+ if self.get("pip"):
549
+ return True
550
+ return False
551
+
552
+ def has_uv(self) -> bool:
553
+ if self.get("uv"):
554
+ return True
555
+ return False
556
+
557
+ def virtualenv_name(self) -> Optional[str]:
558
+ if not self.has_pip() or not isinstance(self["pip"], str):
559
+ return None
560
+ return self["pip"]
561
+
562
+ def pip_config(self) -> Dict:
563
+ if not self.has_pip() or isinstance(self["pip"], str):
564
+ return {}
565
+ # Parse and validate field pip on method `__setitem__`
566
+ self["pip"] = self["pip"]
567
+ return self["pip"]
568
+
569
+ def uv_config(self) -> Dict:
570
+ if not self.has_uv() or isinstance(self["uv"], str):
571
+ return {}
572
+ # Parse and validate field pip on method `__setitem__`
573
+ self["uv"] = self["uv"]
574
+ return self["uv"]
575
+
576
+ def get_extension(self, key) -> Optional[str]:
577
+ if key not in RuntimeEnv.extensions_fields:
578
+ raise ValueError(
579
+ f"Extension key must be one of {RuntimeEnv.extensions_fields}, "
580
+ f"got: {key}"
581
+ )
582
+ return self.get(key)
583
+
584
+ def has_py_container(self) -> bool:
585
+ if self.get("container"):
586
+ return True
587
+ return False
588
+
589
+ def py_container_image(self) -> Optional[str]:
590
+ if not self.has_py_container():
591
+ return None
592
+ return self["container"].get("image", "")
593
+
594
+ def py_container_worker_path(self) -> Optional[str]:
595
+ if not self.has_py_container():
596
+ return None
597
+ return self["container"].get("worker_path", "")
598
+
599
+ def py_container_run_options(self) -> List:
600
+ if not self.has_py_container():
601
+ return None
602
+ return self["container"].get("run_options", [])
603
+
604
+ def image_uri(self) -> Optional[str]:
605
+ return self.get("image_uri")
606
+
607
+ def plugins(self) -> List[Tuple[str, Any]]:
608
+ result = list()
609
+ for key, value in self.items():
610
+ if key not in self.known_fields:
611
+ result.append((key, value))
612
+ return result
613
+
614
+
615
+ def _merge_runtime_env(
616
+ parent: Optional[RuntimeEnv],
617
+ child: Optional[RuntimeEnv],
618
+ override: bool = False,
619
+ ) -> Optional[RuntimeEnv]:
620
+ """Merge the parent and child runtime environments.
621
+
622
+ If override = True, the child's runtime env overrides the parent's
623
+ runtime env in the event of a conflict.
624
+
625
+ Merging happens per key (i.e., "conda", "pip", ...), but
626
+ "env_vars" are merged per env var key.
627
+
628
+ It returns None if Ray fails to merge runtime environments because
629
+ of a conflict and `override = False`.
630
+
631
+ Args:
632
+ parent: Parent runtime env.
633
+ child: Child runtime env.
634
+ override: If True, the child's runtime env overrides
635
+ conflicting fields.
636
+ Returns:
637
+ The merged runtime env's if Ray successfully merges them.
638
+ None if the runtime env's conflict. Empty dict if
639
+ parent and child are both None.
640
+ """
641
+ if parent is None:
642
+ parent = {}
643
+ if child is None:
644
+ child = {}
645
+
646
+ parent = deepcopy(parent)
647
+ child = deepcopy(child)
648
+ parent_env_vars = parent.pop("env_vars", {})
649
+ child_env_vars = child.pop("env_vars", {})
650
+
651
+ if not override:
652
+ if set(parent.keys()).intersection(set(child.keys())):
653
+ return None
654
+ if set(parent_env_vars.keys()).intersection(set(child_env_vars.keys())): # noqa
655
+ return None
656
+
657
+ parent.update(child)
658
+ parent_env_vars.update(child_env_vars)
659
+ if parent_env_vars:
660
+ parent["env_vars"] = parent_env_vars
661
+
662
+ return parent
.venv/lib/python3.11/site-packages/ray/widgets/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from ray.widgets.render import Template
2
+ from ray.widgets.util import make_table_html_repr
3
+
4
+ __all__ = ["Template", "make_table_html_repr"]
.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (362 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/render.cpython-311.pyc ADDED
Binary file (2.69 kB). View file
 
.venv/lib/python3.11/site-packages/ray/widgets/__pycache__/util.cpython-311.pyc ADDED
Binary file (9.29 kB). View file
 
.venv/lib/python3.11/site-packages/ray/widgets/render.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import List
3
+
4
+ from ray.util.annotations import DeveloperAPI
5
+
6
+
7
+ @DeveloperAPI
8
+ class Template:
9
+ """Class which provides basic HTML templating."""
10
+
11
+ def __init__(self, file: str):
12
+ with open(pathlib.Path(__file__).parent / "templates" / file, "r") as f:
13
+ self.template = f.read()
14
+
15
+ def render(self, **kwargs) -> str:
16
+ """Render an HTML template with the given data.
17
+
18
+ This is done by replacing instances of `{{ key }}` with `value`
19
+ from the keyword arguments.
20
+
21
+ Returns:
22
+ HTML template with the keys of the kwargs replaced with corresponding
23
+ values.
24
+ """
25
+ rendered = self.template
26
+ for key, value in kwargs.items():
27
+ if isinstance(value, List):
28
+ value = "".join(value)
29
+ rendered = rendered.replace("{{ " + key + " }}", value if value else "")
30
+ return rendered
31
+
32
+ @staticmethod
33
+ def list_templates() -> List[pathlib.Path]:
34
+ """List the available HTML templates.
35
+
36
+ Returns:
37
+ A list of files with .html.j2 extensions inside ../templates/
38
+ """
39
+ return (pathlib.Path(__file__).parent / "templates").glob("*.html.j2")
.venv/lib/python3.11/site-packages/ray/widgets/templates/context.html.j2 ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <div class="lm-Widget p-Widget lm-Panel p-Panel jp-Cell-outputWrapper">
2
+ <div style="margin-left: 50px;display: flex;flex-direction: row;align-items: center">
3
+ {{ context_logo }}
4
+ {{ context_table }}
5
+ </div>
6
+ </div>
.venv/lib/python3.11/site-packages/ray/widgets/templates/context_dashrow.html.j2 ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <tr>
2
+ <td style="text-align: left"><b>Dashboard:</b></td>
3
+ <td style="text-align: left"><b><a href="{{ dashboard_url }}" target="_blank">{{ dashboard_url }}</a></b></td>
4
+ </tr>
.venv/lib/python3.11/site-packages/ray/widgets/templates/context_logo.html.j2 ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div class="jp-RenderedHTMLCommon" style="display: flex; flex-direction: row;">
2
+ <svg viewBox="0 0 567 224" fill="none" xmlns="http://www.w3.org/2000/svg" style="height: 3em;">
3
+ <g clip-path="url(#clip0_4338_178347)">
4
+ <path d="M341.29 165.561H355.29L330.13 129.051C345.63 123.991 354.21 112.051 354.21 94.2307C354.21 71.3707 338.72 58.1807 311.88 58.1807H271V165.561H283.27V131.661H311.8C314.25 131.661 316.71 131.501 319.01 131.351L341.25 165.561H341.29ZM283.29 119.851V70.0007H311.82C331.3 70.0007 342.34 78.2907 342.34 94.5507C342.34 111.271 331.34 119.861 311.82 119.861L283.29 119.851ZM451.4 138.411L463.4 165.561H476.74L428.74 58.1807H416L367.83 165.561H380.83L392.83 138.411H451.4ZM446.19 126.601H398L422 72.1407L446.24 126.601H446.19ZM526.11 128.741L566.91 58.1807H554.35L519.99 114.181L485.17 58.1807H472.44L514.01 129.181V165.541H526.13V128.741H526.11Z" fill="var(--jp-ui-font-color0)"/>
5
+ <path d="M82.35 104.44C84.0187 97.8827 87.8248 92.0678 93.1671 87.9146C98.5094 83.7614 105.083 81.5067 111.85 81.5067C118.617 81.5067 125.191 83.7614 130.533 87.9146C135.875 92.0678 139.681 97.8827 141.35 104.44H163.75C164.476 101.562 165.622 98.8057 167.15 96.2605L127.45 56.5605C121.071 60.3522 113.526 61.6823 106.235 60.3005C98.9443 58.9187 92.4094 54.9203 87.8602 49.0574C83.3109 43.1946 81.0609 35.8714 81.5332 28.4656C82.0056 21.0599 85.1679 14.0819 90.4252 8.8446C95.6824 3.60726 102.672 0.471508 110.08 0.0272655C117.487 -0.416977 124.802 1.86091 130.647 6.4324C136.493 11.0039 140.467 17.5539 141.821 24.8501C143.175 32.1463 141.816 39.6859 138 46.0505L177.69 85.7505C182.31 82.9877 187.58 81.4995 192.962 81.4375C198.345 81.3755 203.648 82.742 208.33 85.3976C213.012 88.0532 216.907 91.9029 219.616 96.5544C222.326 101.206 223.753 106.492 223.753 111.875C223.753 117.258 222.326 122.545 219.616 127.197C216.907 131.848 213.012 135.698 208.33 138.353C203.648 141.009 198.345 142.375 192.962 142.313C187.58 142.251 182.31 140.763 177.69 138L138 177.7C141.808 184.071 143.155 191.614 141.79 198.91C140.424 206.205 136.44 212.75 130.585 217.313C124.731 221.875 117.412 224.141 110.004 223.683C102.596 223.226 95.6103 220.077 90.3621 214.828C85.1139 209.58 81.9647 202.595 81.5072 195.187C81.0497 187.779 83.3154 180.459 87.878 174.605C92.4405 168.751 98.9853 164.766 106.281 163.401C113.576 162.035 121.119 163.383 127.49 167.19L167.19 127.49C165.664 124.941 164.518 122.182 163.79 119.3H141.39C139.721 125.858 135.915 131.673 130.573 135.826C125.231 139.98 118.657 142.234 111.89 142.234C105.123 142.234 98.5494 139.98 93.2071 135.826C87.8648 131.673 84.0587 125.858 82.39 119.3H60C58.1878 126.495 53.8086 132.78 47.6863 136.971C41.5641 141.163 34.1211 142.972 26.7579 142.059C19.3947 141.146 12.6191 137.574 7.70605 132.014C2.79302 126.454 0.0813599 119.29 0.0813599 111.87C0.0813599 104.451 2.79302 97.2871 7.70605 91.7272C12.6191 86.1673 19.3947 82.5947 26.7579 81.6817C34.1211 80.7686 41.5641 82.5781 47.6863 86.7696C53.8086 90.9611 58.1878 97.2456 60 104.44H82.35ZM100.86 204.32C103.407 206.868 106.759 208.453 110.345 208.806C113.93 209.159 117.527 208.258 120.522 206.256C123.517 204.254 125.725 201.276 126.771 197.828C127.816 194.38 127.633 190.677 126.253 187.349C124.874 184.021 122.383 181.274 119.205 179.577C116.027 177.88 112.359 177.337 108.826 178.042C105.293 178.746 102.113 180.654 99.8291 183.44C97.5451 186.226 96.2979 189.718 96.3 193.32C96.2985 195.364 96.7006 197.388 97.4831 199.275C98.2656 201.163 99.4132 202.877 100.86 204.32ZM204.32 122.88C206.868 120.333 208.453 116.981 208.806 113.396C209.159 109.811 208.258 106.214 206.256 103.219C204.254 100.223 201.275 98.0151 197.827 96.97C194.38 95.9249 190.676 96.1077 187.348 97.4873C184.02 98.8669 181.274 101.358 179.577 104.536C177.879 107.714 177.337 111.382 178.041 114.915C178.746 118.448 180.653 121.627 183.439 123.911C186.226 126.195 189.717 127.443 193.32 127.44C195.364 127.443 197.388 127.042 199.275 126.259C201.163 125.476 202.878 124.328 204.32 122.88ZM122.88 19.4205C120.333 16.8729 116.981 15.2876 113.395 14.9347C109.81 14.5817 106.213 15.483 103.218 17.4849C100.223 19.4868 98.0146 22.4654 96.9696 25.9131C95.9245 29.3608 96.1073 33.0642 97.4869 36.3922C98.8665 39.7202 101.358 42.4668 104.535 44.1639C107.713 45.861 111.381 46.4036 114.914 45.6992C118.447 44.9949 121.627 43.0871 123.911 40.301C126.195 37.515 127.442 34.0231 127.44 30.4205C127.44 28.3772 127.038 26.3539 126.255 24.4664C125.473 22.5788 124.326 20.8642 122.88 19.4205ZM19.42 100.86C16.8725 103.408 15.2872 106.76 14.9342 110.345C14.5813 113.93 15.4826 117.527 17.4844 120.522C19.4863 123.518 22.4649 125.726 25.9127 126.771C29.3604 127.816 33.0638 127.633 36.3918 126.254C39.7198 124.874 42.4664 122.383 44.1635 119.205C45.8606 116.027 46.4032 112.359 45.6988 108.826C44.9944 105.293 43.0866 102.114 40.3006 99.8296C37.5145 97.5455 34.0227 96.2983 30.42 96.3005C26.2938 96.3018 22.337 97.9421 19.42 100.86ZM100.86 100.86C98.3125 103.408 96.7272 106.76 96.3742 110.345C96.0213 113.93 96.9226 117.527 98.9244 120.522C100.926 123.518 103.905 125.726 107.353 126.771C110.8 127.816 114.504 127.633 117.832 126.254C121.16 124.874 123.906 122.383 125.604 119.205C127.301 116.027 127.843 112.359 127.139 108.826C126.434 105.293 124.527 102.114 121.741 99.8296C118.955 97.5455 115.463 96.2983 111.86 96.3005C109.817 96.299 107.793 96.701 105.905 97.4835C104.018 98.2661 102.303 99.4136 100.86 100.86Z" fill="#00AEEF"/>
6
+ </g>
7
+ <defs>
8
+ <clipPath id="clip0_4338_178347">
9
+ <rect width="566.93" height="223.75" fill="white"/>
10
+ </clipPath>
11
+ </defs>
12
+ </svg>
13
+ </div>
.venv/lib/python3.11/site-packages/ray/widgets/templates/context_table.html.j2 ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <table class="jp-RenderedHTMLCommon" style="border-collapse: collapse;color: var(--jp-ui-font-color1);font-size: var(--jp-ui-font-size1);">
2
+ <tr>
3
+ <td style="text-align: left"><b>Python version:</b></td>
4
+ <td style="text-align: left"><b>{{ python_version }}</b></td>
5
+ </tr>
6
+ <tr>
7
+ <td style="text-align: left"><b>Ray version:</b></td>
8
+ <td style="text-align: left"><b>{{ ray_version }}</b></td>
9
+ </tr>
10
+ {{ dashboard_row }}
11
+ </table>
.venv/lib/python3.11/site-packages/ray/widgets/templates/divider.html.j2 ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ <div class="vDivider"></div>
2
+ <style>
3
+ .vDivider {
4
+ border-left-width: var(--jp-border-width);
5
+ border-left-color: var(--jp-border-color0);
6
+ border-left-style: solid;
7
+ margin: 0.5em 1em 0.5em 1em;
8
+ }
9
+ </style>
.venv/lib/python3.11/site-packages/ray/widgets/templates/rendered_html_common.html.j2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ <div class='jp-RenderedHTMLCommon'>
2
+ {{ content }}
3
+ </div>