Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/mistral_common/data/tekken_240911.json +3 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/aggregate.py +411 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/arrow_block.py +649 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/batcher.py +325 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/block_batching.py +60 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/interfaces.py +47 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/block_builder.py +39 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/block_list.py +98 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/compute.py +143 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/delegating_block_builder.py +76 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/equalize.py +142 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/logging.py +208 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/memory_tracing.py +147 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/null_aggregate.py +276 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/numpy_support.py +233 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/output_buffer.py +109 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/pandas_block.py +728 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/plan.py +602 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/progress_bar.py +217 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/remote_fn.py +80 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/row.py +42 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/size_estimator.py +92 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/split.py +297 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/stats.py +1495 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/table_block.py +310 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/torch_iterable_dataset.py +10 -0
- .venv/lib/python3.11/site-packages/ray/data/_internal/util.py +1262 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__init__.py +67 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasink.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasource.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_based_datasource.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_datasink.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_meta_provider.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/filename_provider.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/parquet_meta_provider.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/partitioning.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/path_util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/file_datasink.py +266 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/partitioning.py +456 -0
- .venv/lib/python3.11/site-packages/ray/data/datasource/path_util.py +206 -0
- .venv/lib/python3.11/site-packages/ray/data/extensions/__init__.py +45 -0
- .venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/object_extension.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/tensor_extension.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/data/extensions/object_extension.py +10 -0
- .venv/lib/python3.11/site-packages/ray/data/extensions/tensor_extension.py +15 -0
- .venv/lib/python3.11/site-packages/ray/data/preprocessors/__init__.py +50 -0
.gitattributes
CHANGED
|
@@ -151,3 +151,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 151 |
.venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 152 |
.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 153 |
.venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 151 |
.venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 152 |
.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 153 |
.venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 154 |
+
.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240911.json filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240911.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:386b1f98fba69b38c3de512a4eb602dc69a95dae0e54e6ce048ea3e29a2627a8
|
| 3 |
+
size 19280967
|
.venv/lib/python3.11/site-packages/ray/data/_internal/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/data/_internal/aggregate.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.null_aggregate import (
|
| 5 |
+
_null_wrap_accumulate_block,
|
| 6 |
+
_null_wrap_accumulate_row,
|
| 7 |
+
_null_wrap_finalize,
|
| 8 |
+
_null_wrap_init,
|
| 9 |
+
_null_wrap_merge,
|
| 10 |
+
)
|
| 11 |
+
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
|
| 12 |
+
from ray.data.aggregate import AggregateFn
|
| 13 |
+
from ray.data.block import AggType, Block, BlockAccessor
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
import pyarrow as pa
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _AggregateOnKeyBase(AggregateFn):
|
| 20 |
+
def _set_key_fn(self, on: str):
|
| 21 |
+
self._key_fn = on
|
| 22 |
+
|
| 23 |
+
def _validate(self, schema: Optional[Union[type, "pa.lib.Schema"]]) -> None:
|
| 24 |
+
SortKey(self._key_fn).validate_schema(schema)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Count(AggregateFn):
|
| 28 |
+
"""Defines count aggregation."""
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__(
|
| 32 |
+
init=lambda k: 0,
|
| 33 |
+
accumulate_block=(
|
| 34 |
+
lambda a, block: a + BlockAccessor.for_block(block).num_rows()
|
| 35 |
+
),
|
| 36 |
+
merge=lambda a1, a2: a1 + a2,
|
| 37 |
+
name="count()",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Sum(_AggregateOnKeyBase):
|
| 42 |
+
"""Defines sum aggregation."""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
on: Optional[str] = None,
|
| 47 |
+
ignore_nulls: bool = True,
|
| 48 |
+
alias_name: Optional[str] = None,
|
| 49 |
+
):
|
| 50 |
+
self._set_key_fn(on)
|
| 51 |
+
if alias_name:
|
| 52 |
+
self._rs_name = alias_name
|
| 53 |
+
else:
|
| 54 |
+
self._rs_name = f"sum({str(on)})"
|
| 55 |
+
|
| 56 |
+
null_merge = _null_wrap_merge(ignore_nulls, lambda a1, a2: a1 + a2)
|
| 57 |
+
|
| 58 |
+
super().__init__(
|
| 59 |
+
init=_null_wrap_init(lambda k: 0),
|
| 60 |
+
merge=null_merge,
|
| 61 |
+
accumulate_block=_null_wrap_accumulate_block(
|
| 62 |
+
ignore_nulls,
|
| 63 |
+
lambda block: BlockAccessor.for_block(block).sum(on, ignore_nulls),
|
| 64 |
+
null_merge,
|
| 65 |
+
),
|
| 66 |
+
finalize=_null_wrap_finalize(lambda a: a),
|
| 67 |
+
name=(self._rs_name),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Min(_AggregateOnKeyBase):
|
| 72 |
+
"""Defines min aggregation."""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
on: Optional[str] = None,
|
| 77 |
+
ignore_nulls: bool = True,
|
| 78 |
+
alias_name: Optional[str] = None,
|
| 79 |
+
):
|
| 80 |
+
self._set_key_fn(on)
|
| 81 |
+
if alias_name:
|
| 82 |
+
self._rs_name = alias_name
|
| 83 |
+
else:
|
| 84 |
+
self._rs_name = f"min({str(on)})"
|
| 85 |
+
|
| 86 |
+
null_merge = _null_wrap_merge(ignore_nulls, min)
|
| 87 |
+
|
| 88 |
+
super().__init__(
|
| 89 |
+
init=_null_wrap_init(lambda k: float("inf")),
|
| 90 |
+
merge=null_merge,
|
| 91 |
+
accumulate_block=_null_wrap_accumulate_block(
|
| 92 |
+
ignore_nulls,
|
| 93 |
+
lambda block: BlockAccessor.for_block(block).min(on, ignore_nulls),
|
| 94 |
+
null_merge,
|
| 95 |
+
),
|
| 96 |
+
finalize=_null_wrap_finalize(lambda a: a),
|
| 97 |
+
name=(self._rs_name),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Max(_AggregateOnKeyBase):
|
| 102 |
+
"""Defines max aggregation."""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
on: Optional[str] = None,
|
| 107 |
+
ignore_nulls: bool = True,
|
| 108 |
+
alias_name: Optional[str] = None,
|
| 109 |
+
):
|
| 110 |
+
self._set_key_fn(on)
|
| 111 |
+
if alias_name:
|
| 112 |
+
self._rs_name = alias_name
|
| 113 |
+
else:
|
| 114 |
+
self._rs_name = f"max({str(on)})"
|
| 115 |
+
|
| 116 |
+
null_merge = _null_wrap_merge(ignore_nulls, max)
|
| 117 |
+
|
| 118 |
+
super().__init__(
|
| 119 |
+
init=_null_wrap_init(lambda k: float("-inf")),
|
| 120 |
+
merge=null_merge,
|
| 121 |
+
accumulate_block=_null_wrap_accumulate_block(
|
| 122 |
+
ignore_nulls,
|
| 123 |
+
lambda block: BlockAccessor.for_block(block).max(on, ignore_nulls),
|
| 124 |
+
null_merge,
|
| 125 |
+
),
|
| 126 |
+
finalize=_null_wrap_finalize(lambda a: a),
|
| 127 |
+
name=(self._rs_name),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Mean(_AggregateOnKeyBase):
|
| 132 |
+
"""Defines mean aggregation."""
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
on: Optional[str] = None,
|
| 137 |
+
ignore_nulls: bool = True,
|
| 138 |
+
alias_name: Optional[str] = None,
|
| 139 |
+
):
|
| 140 |
+
self._set_key_fn(on)
|
| 141 |
+
if alias_name:
|
| 142 |
+
self._rs_name = alias_name
|
| 143 |
+
else:
|
| 144 |
+
self._rs_name = f"mean({str(on)})"
|
| 145 |
+
|
| 146 |
+
null_merge = _null_wrap_merge(
|
| 147 |
+
ignore_nulls, lambda a1, a2: [a1[0] + a2[0], a1[1] + a2[1]]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def vectorized_mean(block: Block) -> AggType:
|
| 151 |
+
block_acc = BlockAccessor.for_block(block)
|
| 152 |
+
count = block_acc.count(on)
|
| 153 |
+
if count == 0 or count is None:
|
| 154 |
+
# Empty or all null.
|
| 155 |
+
return None
|
| 156 |
+
sum_ = block_acc.sum(on, ignore_nulls)
|
| 157 |
+
if sum_ is None:
|
| 158 |
+
# ignore_nulls=False and at least one null.
|
| 159 |
+
return None
|
| 160 |
+
return [sum_, count]
|
| 161 |
+
|
| 162 |
+
super().__init__(
|
| 163 |
+
init=_null_wrap_init(lambda k: [0, 0]),
|
| 164 |
+
merge=null_merge,
|
| 165 |
+
accumulate_block=_null_wrap_accumulate_block(
|
| 166 |
+
ignore_nulls,
|
| 167 |
+
vectorized_mean,
|
| 168 |
+
null_merge,
|
| 169 |
+
),
|
| 170 |
+
finalize=_null_wrap_finalize(lambda a: a[0] / a[1]),
|
| 171 |
+
name=(self._rs_name),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class Std(_AggregateOnKeyBase):
|
| 176 |
+
"""Defines standard deviation aggregation.
|
| 177 |
+
|
| 178 |
+
Uses Welford's online method for an accumulator-style computation of the
|
| 179 |
+
standard deviation. This method was chosen due to its numerical
|
| 180 |
+
stability, and it being computable in a single pass.
|
| 181 |
+
This may give different (but more accurate) results than NumPy, Pandas,
|
| 182 |
+
and sklearn, which use a less numerically stable two-pass algorithm.
|
| 183 |
+
See
|
| 184 |
+
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
on: Optional[str] = None,
|
| 190 |
+
ddof: int = 1,
|
| 191 |
+
ignore_nulls: bool = True,
|
| 192 |
+
alias_name: Optional[str] = None,
|
| 193 |
+
):
|
| 194 |
+
self._set_key_fn(on)
|
| 195 |
+
if alias_name:
|
| 196 |
+
self._rs_name = alias_name
|
| 197 |
+
else:
|
| 198 |
+
self._rs_name = f"std({str(on)})"
|
| 199 |
+
|
| 200 |
+
def merge(a: List[float], b: List[float]):
|
| 201 |
+
# Merges two accumulations into one.
|
| 202 |
+
# See
|
| 203 |
+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
| 204 |
+
M2_a, mean_a, count_a = a
|
| 205 |
+
M2_b, mean_b, count_b = b
|
| 206 |
+
delta = mean_b - mean_a
|
| 207 |
+
count = count_a + count_b
|
| 208 |
+
# NOTE: We use this mean calculation since it's more numerically
|
| 209 |
+
# stable than mean_a + delta * count_b / count, which actually
|
| 210 |
+
# deviates from Pandas in the ~15th decimal place and causes our
|
| 211 |
+
# exact comparison tests to fail.
|
| 212 |
+
mean = (mean_a * count_a + mean_b * count_b) / count
|
| 213 |
+
# Update the sum of squared differences.
|
| 214 |
+
M2 = M2_a + M2_b + (delta**2) * count_a * count_b / count
|
| 215 |
+
return [M2, mean, count]
|
| 216 |
+
|
| 217 |
+
null_merge = _null_wrap_merge(ignore_nulls, merge)
|
| 218 |
+
|
| 219 |
+
def vectorized_std(block: Block) -> AggType:
|
| 220 |
+
block_acc = BlockAccessor.for_block(block)
|
| 221 |
+
count = block_acc.count(on)
|
| 222 |
+
if count == 0 or count is None:
|
| 223 |
+
# Empty or all null.
|
| 224 |
+
return None
|
| 225 |
+
sum_ = block_acc.sum(on, ignore_nulls)
|
| 226 |
+
if sum_ is None:
|
| 227 |
+
# ignore_nulls=False and at least one null.
|
| 228 |
+
return None
|
| 229 |
+
mean = sum_ / count
|
| 230 |
+
M2 = block_acc.sum_of_squared_diffs_from_mean(on, ignore_nulls, mean)
|
| 231 |
+
return [M2, mean, count]
|
| 232 |
+
|
| 233 |
+
def finalize(a: List[float]):
|
| 234 |
+
# Compute the final standard deviation from the accumulated
|
| 235 |
+
# sum of squared differences from current mean and the count.
|
| 236 |
+
M2, mean, count = a
|
| 237 |
+
if count < 2:
|
| 238 |
+
return 0.0
|
| 239 |
+
return math.sqrt(M2 / (count - ddof))
|
| 240 |
+
|
| 241 |
+
super().__init__(
|
| 242 |
+
init=_null_wrap_init(lambda k: [0, 0, 0]),
|
| 243 |
+
merge=null_merge,
|
| 244 |
+
accumulate_block=_null_wrap_accumulate_block(
|
| 245 |
+
ignore_nulls,
|
| 246 |
+
vectorized_std,
|
| 247 |
+
null_merge,
|
| 248 |
+
),
|
| 249 |
+
finalize=_null_wrap_finalize(finalize),
|
| 250 |
+
name=(self._rs_name),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class AbsMax(_AggregateOnKeyBase):
|
| 255 |
+
"""Defines absolute max aggregation."""
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
on: Optional[str] = None,
|
| 260 |
+
ignore_nulls: bool = True,
|
| 261 |
+
alias_name: Optional[str] = None,
|
| 262 |
+
):
|
| 263 |
+
self._set_key_fn(on)
|
| 264 |
+
on_fn = _to_on_fn(on)
|
| 265 |
+
if alias_name:
|
| 266 |
+
self._rs_name = alias_name
|
| 267 |
+
else:
|
| 268 |
+
self._rs_name = f"abs_max({str(on)})"
|
| 269 |
+
|
| 270 |
+
super().__init__(
|
| 271 |
+
init=_null_wrap_init(lambda k: 0),
|
| 272 |
+
merge=_null_wrap_merge(ignore_nulls, max),
|
| 273 |
+
accumulate_row=_null_wrap_accumulate_row(
|
| 274 |
+
ignore_nulls, on_fn, lambda a, r: max(a, abs(r))
|
| 275 |
+
),
|
| 276 |
+
finalize=_null_wrap_finalize(lambda a: a),
|
| 277 |
+
name=(self._rs_name),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _to_on_fn(on: Optional[str]):
|
| 282 |
+
if on is None:
|
| 283 |
+
return lambda r: r
|
| 284 |
+
elif isinstance(on, str):
|
| 285 |
+
return lambda r: r[on]
|
| 286 |
+
else:
|
| 287 |
+
return on
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class Quantile(_AggregateOnKeyBase):
|
| 291 |
+
"""Defines Quantile aggregation."""
|
| 292 |
+
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
on: Optional[str] = None,
|
| 296 |
+
q: float = 0.5,
|
| 297 |
+
ignore_nulls: bool = True,
|
| 298 |
+
alias_name: Optional[str] = None,
|
| 299 |
+
):
|
| 300 |
+
self._set_key_fn(on)
|
| 301 |
+
self._q = q
|
| 302 |
+
if alias_name:
|
| 303 |
+
self._rs_name = alias_name
|
| 304 |
+
else:
|
| 305 |
+
self._rs_name = f"quantile({str(on)})"
|
| 306 |
+
|
| 307 |
+
def merge(a: List[int], b: List[int]):
|
| 308 |
+
if isinstance(a, List) and isinstance(b, List):
|
| 309 |
+
a.extend(b)
|
| 310 |
+
return a
|
| 311 |
+
if isinstance(a, List) and (not isinstance(b, List)):
|
| 312 |
+
if b is not None and b != "":
|
| 313 |
+
a.append(b)
|
| 314 |
+
return a
|
| 315 |
+
if isinstance(b, List) and (not isinstance(a, List)):
|
| 316 |
+
if a is not None and a != "":
|
| 317 |
+
b.append(a)
|
| 318 |
+
return b
|
| 319 |
+
|
| 320 |
+
ls = []
|
| 321 |
+
if a is not None and a != "":
|
| 322 |
+
ls.append(a)
|
| 323 |
+
if b is not None and b != "":
|
| 324 |
+
ls.append(b)
|
| 325 |
+
return ls
|
| 326 |
+
|
| 327 |
+
null_merge = _null_wrap_merge(ignore_nulls, merge)
|
| 328 |
+
|
| 329 |
+
def block_row_ls(block: Block) -> AggType:
|
| 330 |
+
block_acc = BlockAccessor.for_block(block)
|
| 331 |
+
ls = []
|
| 332 |
+
for row in block_acc.iter_rows(public_row_format=False):
|
| 333 |
+
ls.append(row.get(on))
|
| 334 |
+
return ls
|
| 335 |
+
|
| 336 |
+
import math
|
| 337 |
+
|
| 338 |
+
def percentile(input_values, key: Optional[Callable[[Any], Any]] = None):
|
| 339 |
+
if not input_values:
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
if key is None:
|
| 343 |
+
key = lambda x: x # noqa: E731
|
| 344 |
+
|
| 345 |
+
input_values = sorted(input_values)
|
| 346 |
+
k = (len(input_values) - 1) * self._q
|
| 347 |
+
f = math.floor(k)
|
| 348 |
+
c = math.ceil(k)
|
| 349 |
+
if f == c:
|
| 350 |
+
return key(input_values[int(k)])
|
| 351 |
+
d0 = key(input_values[int(f)]) * (c - k)
|
| 352 |
+
d1 = key(input_values[int(c)]) * (k - f)
|
| 353 |
+
return round(d0 + d1, 5)
|
| 354 |
+
|
| 355 |
+
super().__init__(
|
| 356 |
+
init=_null_wrap_init(lambda k: [0]),
|
| 357 |
+
merge=null_merge,
|
| 358 |
+
accumulate_block=_null_wrap_accumulate_block(
|
| 359 |
+
ignore_nulls,
|
| 360 |
+
block_row_ls,
|
| 361 |
+
null_merge,
|
| 362 |
+
),
|
| 363 |
+
finalize=_null_wrap_finalize(percentile),
|
| 364 |
+
name=(self._rs_name),
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class Unique(_AggregateOnKeyBase):
|
| 369 |
+
"""Defines unique aggregation."""
|
| 370 |
+
|
| 371 |
+
def __init__(
|
| 372 |
+
self,
|
| 373 |
+
on: Optional[str] = None,
|
| 374 |
+
alias_name: Optional[str] = None,
|
| 375 |
+
):
|
| 376 |
+
self._set_key_fn(on)
|
| 377 |
+
if alias_name:
|
| 378 |
+
self._rs_name = alias_name
|
| 379 |
+
else:
|
| 380 |
+
self._rs_name = f"unique({str(on)})"
|
| 381 |
+
|
| 382 |
+
def to_set(x):
|
| 383 |
+
if isinstance(x, set):
|
| 384 |
+
return x
|
| 385 |
+
elif isinstance(x, list):
|
| 386 |
+
return set(x)
|
| 387 |
+
else:
|
| 388 |
+
return {x}
|
| 389 |
+
|
| 390 |
+
def block_row_unique(block: Block) -> AggType:
|
| 391 |
+
import pyarrow.compute as pac
|
| 392 |
+
|
| 393 |
+
col = BlockAccessor.for_block(block).to_arrow().column(on)
|
| 394 |
+
return pac.unique(col).to_pylist()
|
| 395 |
+
|
| 396 |
+
def merge(a, b):
|
| 397 |
+
return to_set(a) | to_set(b)
|
| 398 |
+
|
| 399 |
+
null_merge = _null_wrap_merge(False, merge)
|
| 400 |
+
|
| 401 |
+
super().__init__(
|
| 402 |
+
init=_null_wrap_init(lambda x: set()),
|
| 403 |
+
merge=null_merge,
|
| 404 |
+
accumulate_block=_null_wrap_accumulate_block(
|
| 405 |
+
False,
|
| 406 |
+
block_row_unique,
|
| 407 |
+
null_merge,
|
| 408 |
+
),
|
| 409 |
+
name=(self._rs_name),
|
| 410 |
+
finalize=_null_wrap_finalize(lambda x: x),
|
| 411 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_block.py
ADDED
|
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import heapq
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
from typing import (
|
| 6 |
+
TYPE_CHECKING,
|
| 7 |
+
Any,
|
| 8 |
+
Callable,
|
| 9 |
+
Dict,
|
| 10 |
+
Iterator,
|
| 11 |
+
List,
|
| 12 |
+
Optional,
|
| 13 |
+
Sequence,
|
| 14 |
+
Tuple,
|
| 15 |
+
TypeVar,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ray._private.utils import _get_pyarrow_version
|
| 22 |
+
from ray.air.constants import TENSOR_COLUMN_NAME
|
| 23 |
+
from ray.air.util.tensor_extensions.arrow import (
|
| 24 |
+
convert_to_pyarrow_array,
|
| 25 |
+
pyarrow_table_from_pydict,
|
| 26 |
+
)
|
| 27 |
+
from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow
|
| 28 |
+
from ray.data._internal.numpy_support import convert_to_numpy
|
| 29 |
+
from ray.data._internal.row import TableRow
|
| 30 |
+
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
|
| 31 |
+
from ray.data._internal.util import NULL_SENTINEL, find_partitions, keys_equal
|
| 32 |
+
from ray.data.block import (
|
| 33 |
+
Block,
|
| 34 |
+
BlockAccessor,
|
| 35 |
+
BlockExecStats,
|
| 36 |
+
BlockMetadata,
|
| 37 |
+
BlockType,
|
| 38 |
+
KeyType,
|
| 39 |
+
U,
|
| 40 |
+
)
|
| 41 |
+
from ray.data.context import DataContext
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import pyarrow
|
| 45 |
+
except ImportError:
|
| 46 |
+
pyarrow = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if TYPE_CHECKING:
|
| 50 |
+
import pandas
|
| 51 |
+
|
| 52 |
+
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
|
| 53 |
+
from ray.data.aggregate import AggregateFn
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
T = TypeVar("T")
|
| 57 |
+
logger = logging.getLogger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# We offload some transformations to polars for performance.
|
| 61 |
+
def get_sort_transform(context: DataContext) -> Callable:
|
| 62 |
+
if context.use_polars:
|
| 63 |
+
return transform_polars.sort
|
| 64 |
+
else:
|
| 65 |
+
return transform_pyarrow.sort
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_concat_and_sort_transform(context: DataContext) -> Callable:
|
| 69 |
+
if context.use_polars:
|
| 70 |
+
return transform_polars.concat_and_sort
|
| 71 |
+
else:
|
| 72 |
+
return transform_pyarrow.concat_and_sort
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ArrowRow(TableRow):
|
| 76 |
+
"""
|
| 77 |
+
Row of a tabular Dataset backed by a Arrow Table block.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, key: Union[str, List[str]]) -> Any:
|
| 81 |
+
from ray.data.extensions import get_arrow_extension_tensor_types
|
| 82 |
+
|
| 83 |
+
tensor_arrow_extension_types = get_arrow_extension_tensor_types()
|
| 84 |
+
|
| 85 |
+
def get_item(keys: List[str]) -> Any:
|
| 86 |
+
schema = self._row.schema
|
| 87 |
+
if isinstance(schema.field(keys[0]).type, tensor_arrow_extension_types):
|
| 88 |
+
# Build a tensor row.
|
| 89 |
+
return tuple(
|
| 90 |
+
[
|
| 91 |
+
ArrowBlockAccessor._build_tensor_row(self._row, col_name=key)
|
| 92 |
+
for key in keys
|
| 93 |
+
]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
table = self._row.select(keys)
|
| 97 |
+
if len(table) == 0:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
items = [col[0] for col in table.columns]
|
| 101 |
+
try:
|
| 102 |
+
# Try to interpret this as a pyarrow.Scalar value.
|
| 103 |
+
return tuple([item.as_py() for item in items])
|
| 104 |
+
|
| 105 |
+
except AttributeError:
|
| 106 |
+
# Assume that this row is an element of an extension array, and
|
| 107 |
+
# that it is bypassing pyarrow's scalar model for Arrow < 8.0.0.
|
| 108 |
+
return items
|
| 109 |
+
|
| 110 |
+
is_single_item = isinstance(key, str)
|
| 111 |
+
keys = [key] if is_single_item else key
|
| 112 |
+
|
| 113 |
+
items = get_item(keys)
|
| 114 |
+
|
| 115 |
+
if items is None:
|
| 116 |
+
return None
|
| 117 |
+
elif is_single_item:
|
| 118 |
+
return items[0]
|
| 119 |
+
else:
|
| 120 |
+
return items
|
| 121 |
+
|
| 122 |
+
def __iter__(self) -> Iterator:
|
| 123 |
+
for k in self._row.column_names:
|
| 124 |
+
yield k
|
| 125 |
+
|
| 126 |
+
def __len__(self):
|
| 127 |
+
return self._row.num_columns
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ArrowBlockBuilder(TableBlockBuilder):
|
| 131 |
+
def __init__(self):
|
| 132 |
+
if pyarrow is None:
|
| 133 |
+
raise ImportError("Run `pip install pyarrow` for Arrow support")
|
| 134 |
+
super().__init__((pyarrow.Table, bytes))
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:
|
| 138 |
+
pa_cols: Dict[str, pyarrow.Array] = dict()
|
| 139 |
+
|
| 140 |
+
for col_name, col_vals in columns.items():
|
| 141 |
+
np_col_vals = convert_to_numpy(col_vals)
|
| 142 |
+
|
| 143 |
+
pa_cols[col_name] = convert_to_pyarrow_array(np_col_vals, col_name)
|
| 144 |
+
|
| 145 |
+
return pyarrow_table_from_pydict(pa_cols)
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def _concat_tables(tables: List[Block]) -> Block:
|
| 149 |
+
return transform_pyarrow.concat(tables)
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def _concat_would_copy() -> bool:
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def _empty_table() -> "pyarrow.Table":
|
| 157 |
+
return pyarrow_table_from_pydict({})
|
| 158 |
+
|
| 159 |
+
def block_type(self) -> BlockType:
|
| 160 |
+
return BlockType.ARROW
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ArrowBlockAccessor(TableBlockAccessor):
|
| 164 |
+
ROW_TYPE = ArrowRow
|
| 165 |
+
|
| 166 |
+
def __init__(self, table: "pyarrow.Table"):
|
| 167 |
+
if pyarrow is None:
|
| 168 |
+
raise ImportError("Run `pip install pyarrow` for Arrow support")
|
| 169 |
+
super().__init__(table)
|
| 170 |
+
|
| 171 |
+
def column_names(self) -> List[str]:
|
| 172 |
+
return self._table.column_names
|
| 173 |
+
|
| 174 |
+
def append_column(self, name: str, data: Any) -> Block:
|
| 175 |
+
assert name not in self._table.column_names
|
| 176 |
+
|
| 177 |
+
if any(isinstance(item, np.ndarray) for item in data):
|
| 178 |
+
raise NotImplementedError(
|
| 179 |
+
f"`{self.__class__.__name__}.append_column()` doesn't support "
|
| 180 |
+
"array-like data."
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return self._table.append_column(name, [data])
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":
|
| 187 |
+
reader = pyarrow.ipc.open_stream(data)
|
| 188 |
+
return cls(reader.read_all())
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def _build_tensor_row(
|
| 192 |
+
row: ArrowRow, col_name: str = TENSOR_COLUMN_NAME
|
| 193 |
+
) -> np.ndarray:
|
| 194 |
+
from packaging.version import parse as parse_version
|
| 195 |
+
|
| 196 |
+
element = row[col_name][0]
|
| 197 |
+
# TODO(Clark): Reduce this to np.asarray(element) once we only support Arrow
|
| 198 |
+
# 9.0.0+.
|
| 199 |
+
pyarrow_version = _get_pyarrow_version()
|
| 200 |
+
if pyarrow_version is not None:
|
| 201 |
+
pyarrow_version = parse_version(pyarrow_version)
|
| 202 |
+
if pyarrow_version is None or pyarrow_version >= parse_version("8.0.0"):
|
| 203 |
+
assert isinstance(element, pyarrow.ExtensionScalar)
|
| 204 |
+
if pyarrow_version is None or pyarrow_version >= parse_version("9.0.0"):
|
| 205 |
+
# For Arrow 9.0.0+, accessing an element in a chunked tensor array
|
| 206 |
+
# produces an ArrowTensorScalar, which we convert to an ndarray using
|
| 207 |
+
# .as_py().
|
| 208 |
+
element = element.as_py()
|
| 209 |
+
else:
|
| 210 |
+
# For Arrow 8.*, accessing an element in a chunked tensor array produces
|
| 211 |
+
# an ExtensionScalar, which we convert to an ndarray using our custom
|
| 212 |
+
# method.
|
| 213 |
+
element = element.type._extension_scalar_to_ndarray(element)
|
| 214 |
+
# For Arrow < 8.0.0, accessing an element in a chunked tensor array produces an
|
| 215 |
+
# ndarray, which we return directly.
|
| 216 |
+
assert isinstance(element, np.ndarray), type(element)
|
| 217 |
+
return element
|
| 218 |
+
|
| 219 |
+
def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
|
| 220 |
+
view = self._table.slice(start, end - start)
|
| 221 |
+
if copy:
|
| 222 |
+
view = transform_pyarrow.combine_chunks(view)
|
| 223 |
+
return view
|
| 224 |
+
|
| 225 |
+
def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
|
| 226 |
+
# TODO(swang): Creating this np.array index can add a lot of memory
|
| 227 |
+
# pressure when there are a large number of small rows. Investigate
|
| 228 |
+
# random shuffling in place to reduce memory pressure.
|
| 229 |
+
# See https://github.com/ray-project/ray/issues/42146.
|
| 230 |
+
random = np.random.RandomState(random_seed)
|
| 231 |
+
return self.take(random.permutation(self.num_rows()))
|
| 232 |
+
|
| 233 |
+
def schema(self) -> "pyarrow.lib.Schema":
|
| 234 |
+
return self._table.schema
|
| 235 |
+
|
| 236 |
+
def to_pandas(self) -> "pandas.DataFrame":
|
| 237 |
+
from ray.air.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays
|
| 238 |
+
|
| 239 |
+
df = self._table.to_pandas()
|
| 240 |
+
ctx = DataContext.get_current()
|
| 241 |
+
if ctx.enable_tensor_extension_casting:
|
| 242 |
+
df = _cast_tensor_columns_to_ndarrays(df)
|
| 243 |
+
return df
|
| 244 |
+
|
| 245 |
+
def to_numpy(
|
| 246 |
+
self, columns: Optional[Union[str, List[str]]] = None
|
| 247 |
+
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
| 248 |
+
if columns is None:
|
| 249 |
+
columns = self._table.column_names
|
| 250 |
+
should_be_single_ndarray = False
|
| 251 |
+
elif isinstance(columns, list):
|
| 252 |
+
should_be_single_ndarray = False
|
| 253 |
+
else:
|
| 254 |
+
columns = [columns]
|
| 255 |
+
should_be_single_ndarray = True
|
| 256 |
+
|
| 257 |
+
column_names_set = set(self._table.column_names)
|
| 258 |
+
for column in columns:
|
| 259 |
+
if column not in column_names_set:
|
| 260 |
+
raise ValueError(
|
| 261 |
+
f"Cannot find column {column}, available columns: "
|
| 262 |
+
f"{column_names_set}"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
column_values_ndarrays = []
|
| 266 |
+
|
| 267 |
+
for col_name in columns:
|
| 268 |
+
col = self._table[col_name]
|
| 269 |
+
|
| 270 |
+
# Combine columnar values arrays to make these contiguous
|
| 271 |
+
# (making them compatible with numpy format)
|
| 272 |
+
combined_array = transform_pyarrow.combine_chunked_array(col)
|
| 273 |
+
|
| 274 |
+
column_values_ndarrays.append(
|
| 275 |
+
transform_pyarrow.to_numpy(combined_array, zero_copy_only=False)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if should_be_single_ndarray:
|
| 279 |
+
assert len(columns) == 1
|
| 280 |
+
return column_values_ndarrays[0]
|
| 281 |
+
else:
|
| 282 |
+
return dict(zip(columns, column_values_ndarrays))
|
| 283 |
+
|
| 284 |
+
def to_arrow(self) -> "pyarrow.Table":
|
| 285 |
+
return self._table
|
| 286 |
+
|
| 287 |
+
def num_rows(self) -> int:
|
| 288 |
+
# Arrow may represent an empty table via an N > 0 row, 0-column table, e.g. when
|
| 289 |
+
# slicing an empty table, so we return 0 if num_columns == 0.
|
| 290 |
+
return self._table.num_rows if self._table.num_columns > 0 else 0
|
| 291 |
+
|
| 292 |
+
def size_bytes(self) -> int:
|
| 293 |
+
return self._table.nbytes
|
| 294 |
+
|
| 295 |
+
def _zip(self, acc: BlockAccessor) -> "Block":
|
| 296 |
+
r = self.to_arrow()
|
| 297 |
+
s = acc.to_arrow()
|
| 298 |
+
for col_name in s.column_names:
|
| 299 |
+
col = s.column(col_name)
|
| 300 |
+
# Ensure the column names are unique after zip.
|
| 301 |
+
if col_name in r.column_names:
|
| 302 |
+
i = 1
|
| 303 |
+
new_name = col_name
|
| 304 |
+
while new_name in r.column_names:
|
| 305 |
+
new_name = "{}_{}".format(col_name, i)
|
| 306 |
+
i += 1
|
| 307 |
+
col_name = new_name
|
| 308 |
+
r = r.append_column(col_name, col)
|
| 309 |
+
return r
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def builder() -> ArrowBlockBuilder:
|
| 313 |
+
return ArrowBlockBuilder()
|
| 314 |
+
|
| 315 |
+
@staticmethod
|
| 316 |
+
def _empty_table() -> "pyarrow.Table":
|
| 317 |
+
return ArrowBlockBuilder._empty_table()
|
| 318 |
+
|
| 319 |
+
def take(
|
| 320 |
+
self,
|
| 321 |
+
indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
|
| 322 |
+
) -> "pyarrow.Table":
|
| 323 |
+
"""Select rows from the underlying table.
|
| 324 |
+
|
| 325 |
+
This method is an alternative to pyarrow.Table.take(), which breaks for
|
| 326 |
+
extension arrays.
|
| 327 |
+
"""
|
| 328 |
+
return transform_pyarrow.take_table(self._table, indices)
|
| 329 |
+
|
| 330 |
+
def select(self, columns: List[str]) -> "pyarrow.Table":
|
| 331 |
+
if not all(isinstance(col, str) for col in columns):
|
| 332 |
+
raise ValueError(
|
| 333 |
+
"Columns must be a list of column name strings when aggregating on "
|
| 334 |
+
f"Arrow blocks, but got: {columns}."
|
| 335 |
+
)
|
| 336 |
+
return self._table.select(columns)
|
| 337 |
+
|
| 338 |
+
def rename_columns(self, columns_rename: Dict[str, str]) -> "pyarrow.Table":
|
| 339 |
+
return self._table.rename_columns(columns_rename)
|
| 340 |
+
|
| 341 |
+
def _sample(self, n_samples: int, sort_key: "SortKey") -> "pyarrow.Table":
|
| 342 |
+
indices = random.sample(range(self._table.num_rows), n_samples)
|
| 343 |
+
table = self._table.select(sort_key.get_columns())
|
| 344 |
+
return transform_pyarrow.take_table(table, indices)
|
| 345 |
+
|
| 346 |
+
def count(self, on: str) -> Optional[U]:
|
| 347 |
+
"""Count the number of non-null values in the provided column."""
|
| 348 |
+
import pyarrow.compute as pac
|
| 349 |
+
|
| 350 |
+
if not isinstance(on, str):
|
| 351 |
+
raise ValueError(
|
| 352 |
+
"on must be a string when aggregating on Arrow blocks, but got:"
|
| 353 |
+
f"{type(on)}."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if self.num_rows() == 0:
|
| 357 |
+
return None
|
| 358 |
+
|
| 359 |
+
col = self._table[on]
|
| 360 |
+
return pac.count(col).as_py()
|
| 361 |
+
|
| 362 |
+
def _apply_arrow_compute(
|
| 363 |
+
self, compute_fn: Callable, on: str, ignore_nulls: bool
|
| 364 |
+
) -> Optional[U]:
|
| 365 |
+
"""Helper providing null handling around applying an aggregation to a column."""
|
| 366 |
+
import pyarrow as pa
|
| 367 |
+
|
| 368 |
+
if not isinstance(on, str):
|
| 369 |
+
raise ValueError(
|
| 370 |
+
"on must be a string when aggregating on Arrow blocks, but got:"
|
| 371 |
+
f"{type(on)}."
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if self.num_rows() == 0:
|
| 375 |
+
return None
|
| 376 |
+
|
| 377 |
+
col = self._table[on]
|
| 378 |
+
if pa.types.is_null(col.type):
|
| 379 |
+
return None
|
| 380 |
+
else:
|
| 381 |
+
return compute_fn(col, skip_nulls=ignore_nulls).as_py()
|
| 382 |
+
|
| 383 |
+
def sum(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 384 |
+
import pyarrow.compute as pac
|
| 385 |
+
|
| 386 |
+
return self._apply_arrow_compute(pac.sum, on, ignore_nulls)
|
| 387 |
+
|
| 388 |
+
def min(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 389 |
+
import pyarrow.compute as pac
|
| 390 |
+
|
| 391 |
+
return self._apply_arrow_compute(pac.min, on, ignore_nulls)
|
| 392 |
+
|
| 393 |
+
def max(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 394 |
+
import pyarrow.compute as pac
|
| 395 |
+
|
| 396 |
+
return self._apply_arrow_compute(pac.max, on, ignore_nulls)
|
| 397 |
+
|
| 398 |
+
def mean(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 399 |
+
import pyarrow.compute as pac
|
| 400 |
+
|
| 401 |
+
return self._apply_arrow_compute(pac.mean, on, ignore_nulls)
|
| 402 |
+
|
| 403 |
+
def sum_of_squared_diffs_from_mean(
|
| 404 |
+
self,
|
| 405 |
+
on: str,
|
| 406 |
+
ignore_nulls: bool,
|
| 407 |
+
mean: Optional[U] = None,
|
| 408 |
+
) -> Optional[U]:
|
| 409 |
+
import pyarrow.compute as pac
|
| 410 |
+
|
| 411 |
+
if mean is None:
|
| 412 |
+
# If precomputed mean not given, we compute it ourselves.
|
| 413 |
+
mean = self.mean(on, ignore_nulls)
|
| 414 |
+
if mean is None:
|
| 415 |
+
return None
|
| 416 |
+
return self._apply_arrow_compute(
|
| 417 |
+
lambda col, skip_nulls: pac.sum(
|
| 418 |
+
pac.power(pac.subtract(col, mean), 2),
|
| 419 |
+
skip_nulls=skip_nulls,
|
| 420 |
+
),
|
| 421 |
+
on,
|
| 422 |
+
ignore_nulls,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
def sort_and_partition(
|
| 426 |
+
self, boundaries: List[T], sort_key: "SortKey"
|
| 427 |
+
) -> List["Block"]:
|
| 428 |
+
if self._table.num_rows == 0:
|
| 429 |
+
# If the pyarrow table is empty we may not have schema
|
| 430 |
+
# so calling sort_indices() will raise an error.
|
| 431 |
+
return [self._empty_table() for _ in range(len(boundaries) + 1)]
|
| 432 |
+
|
| 433 |
+
context = DataContext.get_current()
|
| 434 |
+
sort = get_sort_transform(context)
|
| 435 |
+
|
| 436 |
+
table = sort(self._table, sort_key)
|
| 437 |
+
if len(boundaries) == 0:
|
| 438 |
+
return [table]
|
| 439 |
+
return find_partitions(table, boundaries, sort_key)
|
| 440 |
+
|
| 441 |
+
def combine(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block:
|
| 442 |
+
"""Combine rows with the same key into an accumulator.
|
| 443 |
+
|
| 444 |
+
This assumes the block is already sorted by key in ascending order.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
sort_key: A column name or list of column names.
|
| 448 |
+
If this is ``None``, place all rows in a single group.
|
| 449 |
+
|
| 450 |
+
aggs: The aggregations to do.
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
A sorted block of [k, v_1, ..., v_n] columns where k is the groupby
|
| 454 |
+
key and v_i is the partially combined accumulator for the ith given
|
| 455 |
+
aggregation.
|
| 456 |
+
If key is None then the k column is omitted.
|
| 457 |
+
"""
|
| 458 |
+
keys: List[str] = sort_key.get_columns()
|
| 459 |
+
|
| 460 |
+
def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
|
| 461 |
+
"""Creates an iterator over zero-copy group views."""
|
| 462 |
+
if not keys:
|
| 463 |
+
# Global aggregation consists of a single "group", so we short-circuit.
|
| 464 |
+
yield tuple(), self.to_block()
|
| 465 |
+
return
|
| 466 |
+
|
| 467 |
+
start = end = 0
|
| 468 |
+
iter = self.iter_rows(public_row_format=False)
|
| 469 |
+
next_row = None
|
| 470 |
+
while True:
|
| 471 |
+
try:
|
| 472 |
+
if next_row is None:
|
| 473 |
+
next_row = next(iter)
|
| 474 |
+
next_keys = next_row[keys]
|
| 475 |
+
while keys_equal(next_row[keys], next_keys):
|
| 476 |
+
end += 1
|
| 477 |
+
try:
|
| 478 |
+
next_row = next(iter)
|
| 479 |
+
except StopIteration:
|
| 480 |
+
next_row = None
|
| 481 |
+
break
|
| 482 |
+
yield next_keys, self.slice(start, end)
|
| 483 |
+
start = end
|
| 484 |
+
except StopIteration:
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
builder = ArrowBlockBuilder()
|
| 488 |
+
for group_keys, group_view in iter_groups():
|
| 489 |
+
# Aggregate.
|
| 490 |
+
init_vals = group_keys
|
| 491 |
+
if len(group_keys) == 1:
|
| 492 |
+
init_vals = group_keys[0]
|
| 493 |
+
|
| 494 |
+
accumulators = [agg.init(init_vals) for agg in aggs]
|
| 495 |
+
for i in range(len(aggs)):
|
| 496 |
+
accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)
|
| 497 |
+
|
| 498 |
+
# Build the row.
|
| 499 |
+
row = {}
|
| 500 |
+
if keys:
|
| 501 |
+
for k, gk in zip(keys, group_keys):
|
| 502 |
+
row[k] = gk
|
| 503 |
+
|
| 504 |
+
count = collections.defaultdict(int)
|
| 505 |
+
for agg, accumulator in zip(aggs, accumulators):
|
| 506 |
+
name = agg.name
|
| 507 |
+
# Check for conflicts with existing aggregation name.
|
| 508 |
+
if count[name] > 0:
|
| 509 |
+
name = self._munge_conflict(name, count[name])
|
| 510 |
+
count[name] += 1
|
| 511 |
+
row[name] = accumulator
|
| 512 |
+
|
| 513 |
+
builder.add(row)
|
| 514 |
+
|
| 515 |
+
return builder.build()
|
| 516 |
+
|
| 517 |
+
@staticmethod
|
| 518 |
+
def merge_sorted_blocks(
|
| 519 |
+
blocks: List[Block], sort_key: "SortKey"
|
| 520 |
+
) -> Tuple[Block, BlockMetadata]:
|
| 521 |
+
stats = BlockExecStats.builder()
|
| 522 |
+
blocks = [b for b in blocks if b.num_rows > 0]
|
| 523 |
+
if len(blocks) == 0:
|
| 524 |
+
ret = ArrowBlockAccessor._empty_table()
|
| 525 |
+
else:
|
| 526 |
+
# Handle blocks of different types.
|
| 527 |
+
blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")
|
| 528 |
+
concat_and_sort = get_concat_and_sort_transform(DataContext.get_current())
|
| 529 |
+
ret = concat_and_sort(blocks, sort_key)
|
| 530 |
+
return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())
|
| 531 |
+
|
| 532 |
+
@staticmethod
|
| 533 |
+
def aggregate_combined_blocks(
|
| 534 |
+
blocks: List[Block],
|
| 535 |
+
sort_key: "SortKey",
|
| 536 |
+
aggs: Tuple["AggregateFn"],
|
| 537 |
+
finalize: bool,
|
| 538 |
+
) -> Tuple[Block, BlockMetadata]:
|
| 539 |
+
"""Aggregate sorted, partially combined blocks with the same key range.
|
| 540 |
+
|
| 541 |
+
This assumes blocks are already sorted by key in ascending order,
|
| 542 |
+
so we can do merge sort to get all the rows with the same key.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
blocks: A list of partially combined and sorted blocks.
|
| 546 |
+
sort_key: The column name of key or None for global aggregation.
|
| 547 |
+
aggs: The aggregations to do.
|
| 548 |
+
finalize: Whether to finalize the aggregation. This is used as an
|
| 549 |
+
optimization for cases where we repeatedly combine partially
|
| 550 |
+
aggregated groups.
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
A block of [k, v_1, ..., v_n] columns and its metadata where k is
|
| 554 |
+
the groupby key and v_i is the corresponding aggregation result for
|
| 555 |
+
the ith given aggregation.
|
| 556 |
+
If key is None then the k column is omitted.
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
stats = BlockExecStats.builder()
|
| 560 |
+
keys = sort_key.get_columns()
|
| 561 |
+
|
| 562 |
+
def key_fn(r):
|
| 563 |
+
if keys:
|
| 564 |
+
return tuple(r[keys])
|
| 565 |
+
else:
|
| 566 |
+
return (0,)
|
| 567 |
+
|
| 568 |
+
# Replace Nones with NULL_SENTINEL to ensure safe sorting.
|
| 569 |
+
def key_fn_with_null_sentinel(r):
|
| 570 |
+
values = key_fn(r)
|
| 571 |
+
return [NULL_SENTINEL if v is None else v for v in values]
|
| 572 |
+
|
| 573 |
+
# Handle blocks of different types.
|
| 574 |
+
blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")
|
| 575 |
+
|
| 576 |
+
iter = heapq.merge(
|
| 577 |
+
*[
|
| 578 |
+
ArrowBlockAccessor(block).iter_rows(public_row_format=False)
|
| 579 |
+
for block in blocks
|
| 580 |
+
],
|
| 581 |
+
key=key_fn_with_null_sentinel,
|
| 582 |
+
)
|
| 583 |
+
next_row = None
|
| 584 |
+
builder = ArrowBlockBuilder()
|
| 585 |
+
while True:
|
| 586 |
+
try:
|
| 587 |
+
if next_row is None:
|
| 588 |
+
next_row = next(iter)
|
| 589 |
+
next_keys = key_fn(next_row)
|
| 590 |
+
next_key_columns = keys
|
| 591 |
+
|
| 592 |
+
def gen():
|
| 593 |
+
nonlocal iter
|
| 594 |
+
nonlocal next_row
|
| 595 |
+
while keys_equal(key_fn(next_row), next_keys):
|
| 596 |
+
yield next_row
|
| 597 |
+
try:
|
| 598 |
+
next_row = next(iter)
|
| 599 |
+
except StopIteration:
|
| 600 |
+
next_row = None
|
| 601 |
+
break
|
| 602 |
+
|
| 603 |
+
# Merge.
|
| 604 |
+
first = True
|
| 605 |
+
accumulators = [None] * len(aggs)
|
| 606 |
+
resolved_agg_names = [None] * len(aggs)
|
| 607 |
+
for r in gen():
|
| 608 |
+
if first:
|
| 609 |
+
count = collections.defaultdict(int)
|
| 610 |
+
for i in range(len(aggs)):
|
| 611 |
+
name = aggs[i].name
|
| 612 |
+
# Check for conflicts with existing aggregation
|
| 613 |
+
# name.
|
| 614 |
+
if count[name] > 0:
|
| 615 |
+
name = ArrowBlockAccessor._munge_conflict(
|
| 616 |
+
name, count[name]
|
| 617 |
+
)
|
| 618 |
+
count[name] += 1
|
| 619 |
+
resolved_agg_names[i] = name
|
| 620 |
+
accumulators[i] = r[name]
|
| 621 |
+
first = False
|
| 622 |
+
else:
|
| 623 |
+
for i in range(len(aggs)):
|
| 624 |
+
accumulators[i] = aggs[i].merge(
|
| 625 |
+
accumulators[i], r[resolved_agg_names[i]]
|
| 626 |
+
)
|
| 627 |
+
# Build the row.
|
| 628 |
+
row = {}
|
| 629 |
+
if keys:
|
| 630 |
+
for col_name, next_key in zip(next_key_columns, next_keys):
|
| 631 |
+
row[col_name] = next_key
|
| 632 |
+
|
| 633 |
+
for agg, agg_name, accumulator in zip(
|
| 634 |
+
aggs, resolved_agg_names, accumulators
|
| 635 |
+
):
|
| 636 |
+
if finalize:
|
| 637 |
+
row[agg_name] = agg.finalize(accumulator)
|
| 638 |
+
else:
|
| 639 |
+
row[agg_name] = accumulator
|
| 640 |
+
|
| 641 |
+
builder.add(row)
|
| 642 |
+
except StopIteration:
|
| 643 |
+
break
|
| 644 |
+
|
| 645 |
+
ret = builder.build()
|
| 646 |
+
return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())
|
| 647 |
+
|
| 648 |
+
def block_type(self) -> BlockType:
|
| 649 |
+
return BlockType.ARROW
|
.venv/lib/python3.11/site-packages/ray/data/_internal/batcher.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.arrow_block import ArrowBlockAccessor
|
| 4 |
+
from ray.data._internal.arrow_ops import transform_pyarrow
|
| 5 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 6 |
+
from ray.data.block import Block, BlockAccessor
|
| 7 |
+
|
| 8 |
+
# pyarrow.Table.slice is slow when the table has many chunks
|
| 9 |
+
# so we combine chunks into a single one to make slice faster
|
| 10 |
+
# with the cost of an extra copy.
|
| 11 |
+
# See https://github.com/ray-project/ray/issues/31108 for more details.
|
| 12 |
+
# TODO(jjyao): remove this once
|
| 13 |
+
# https://github.com/apache/arrow/issues/35126 is resolved.
|
| 14 |
+
MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS = 10
|
| 15 |
+
|
| 16 |
+
# Delay compaction until the shuffle buffer has reached this ratio over the min
|
| 17 |
+
# shuffle buffer size. Setting this to 1 minimizes memory usage, at the cost of
|
| 18 |
+
# frequent compactions. Setting this to higher values increases memory usage but
|
| 19 |
+
# reduces compaction frequency.
|
| 20 |
+
SHUFFLE_BUFFER_COMPACTION_RATIO = 1.5
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BatcherInterface:
|
| 24 |
+
def add(self, block: Block):
|
| 25 |
+
"""Add a block to the block buffer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
block: Block to add to the block buffer.
|
| 29 |
+
"""
|
| 30 |
+
raise NotImplementedError()
|
| 31 |
+
|
| 32 |
+
def done_adding(self) -> bool:
|
| 33 |
+
"""Indicate to the batcher that no more blocks will be added to the buffer."""
|
| 34 |
+
raise NotImplementedError()
|
| 35 |
+
|
| 36 |
+
def has_batch(self) -> bool:
|
| 37 |
+
"""Whether this Batcher has any full batches."""
|
| 38 |
+
raise NotImplementedError()
|
| 39 |
+
|
| 40 |
+
def has_any(self) -> bool:
|
| 41 |
+
"""Whether this Batcher has any data."""
|
| 42 |
+
raise NotImplementedError()
|
| 43 |
+
|
| 44 |
+
def next_batch(self) -> Block:
|
| 45 |
+
"""Get the next batch from the block buffer.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
A batch represented as a Block.
|
| 49 |
+
"""
|
| 50 |
+
raise NotImplementedError()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Batcher(BatcherInterface):
|
| 54 |
+
"""Chunks blocks into batches."""
|
| 55 |
+
|
| 56 |
+
# Implementation Note: When there are multiple batches per block, this batcher will
|
| 57 |
+
# slice off and return each batch and add the remaining block back to the buffer
|
| 58 |
+
# instead of optimally slicing and returning all batches from the block at once.
|
| 59 |
+
# This will result in extra (and nested) block slicing. However, since slices are
|
| 60 |
+
# zero-copy views, we sacrifice what should be a small performance hit for better
|
| 61 |
+
# readability.
|
| 62 |
+
|
| 63 |
+
def __init__(self, batch_size: Optional[int], ensure_copy: bool = False):
|
| 64 |
+
"""
|
| 65 |
+
Construct a batcher that yields batches of batch_sizes rows.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
batch_size: The size of batches to yield.
|
| 69 |
+
ensure_copy: Whether batches are always copied from the underlying base
|
| 70 |
+
blocks (not zero-copy views).
|
| 71 |
+
"""
|
| 72 |
+
self._batch_size = batch_size
|
| 73 |
+
self._buffer = []
|
| 74 |
+
self._buffer_size = 0
|
| 75 |
+
self._done_adding = False
|
| 76 |
+
self._ensure_copy = ensure_copy
|
| 77 |
+
|
| 78 |
+
def add(self, block: Block):
|
| 79 |
+
"""Add a block to the block buffer.
|
| 80 |
+
|
| 81 |
+
Note empty block is not added to buffer.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
block: Block to add to the block buffer.
|
| 85 |
+
"""
|
| 86 |
+
if BlockAccessor.for_block(block).num_rows() > 0:
|
| 87 |
+
self._buffer.append(block)
|
| 88 |
+
self._buffer_size += BlockAccessor.for_block(block).num_rows()
|
| 89 |
+
|
| 90 |
+
def done_adding(self) -> bool:
|
| 91 |
+
"""Indicate to the batcher that no more blocks will be added to the batcher."""
|
| 92 |
+
self._done_adding = True
|
| 93 |
+
|
| 94 |
+
def has_batch(self) -> bool:
|
| 95 |
+
"""Whether this Batcher has any full batches."""
|
| 96 |
+
return self.has_any() and (
|
| 97 |
+
self._batch_size is None or self._buffer_size >= self._batch_size
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def has_any(self) -> bool:
|
| 101 |
+
"""Whether this Batcher has any data."""
|
| 102 |
+
return self._buffer_size > 0
|
| 103 |
+
|
| 104 |
+
def next_batch(self) -> Block:
|
| 105 |
+
"""Get the next batch from the block buffer.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
A batch represented as a Block.
|
| 109 |
+
"""
|
| 110 |
+
assert self.has_batch() or (self._done_adding and self.has_any())
|
| 111 |
+
needs_copy = self._ensure_copy
|
| 112 |
+
# If no batch size, short-circuit.
|
| 113 |
+
if self._batch_size is None:
|
| 114 |
+
assert len(self._buffer) == 1
|
| 115 |
+
block = self._buffer[0]
|
| 116 |
+
if needs_copy:
|
| 117 |
+
# Copy block if needing to ensure fresh batch copy.
|
| 118 |
+
block = BlockAccessor.for_block(block)
|
| 119 |
+
block = block.slice(0, block.num_rows(), copy=True)
|
| 120 |
+
self._buffer = []
|
| 121 |
+
self._buffer_size = 0
|
| 122 |
+
return block
|
| 123 |
+
output = DelegatingBlockBuilder()
|
| 124 |
+
leftover = []
|
| 125 |
+
needed = self._batch_size
|
| 126 |
+
for block in self._buffer:
|
| 127 |
+
accessor = BlockAccessor.for_block(block)
|
| 128 |
+
if needed <= 0:
|
| 129 |
+
# We already have a full batch, so add this block to
|
| 130 |
+
# the leftovers.
|
| 131 |
+
leftover.append(block)
|
| 132 |
+
elif accessor.num_rows() <= needed:
|
| 133 |
+
output.add_block(accessor.to_block())
|
| 134 |
+
needed -= accessor.num_rows()
|
| 135 |
+
else:
|
| 136 |
+
if (
|
| 137 |
+
isinstance(accessor, ArrowBlockAccessor)
|
| 138 |
+
and block.num_columns > 0
|
| 139 |
+
and block.column(0).num_chunks
|
| 140 |
+
>= MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS
|
| 141 |
+
):
|
| 142 |
+
accessor = BlockAccessor.for_block(
|
| 143 |
+
transform_pyarrow.combine_chunks(block)
|
| 144 |
+
)
|
| 145 |
+
# We only need part of the block to fill out a batch.
|
| 146 |
+
output.add_block(accessor.slice(0, needed, copy=False))
|
| 147 |
+
# Add the rest of the block to the leftovers.
|
| 148 |
+
leftover.append(accessor.slice(needed, accessor.num_rows(), copy=False))
|
| 149 |
+
needed = 0
|
| 150 |
+
|
| 151 |
+
# Move the leftovers into the block buffer so they're the first
|
| 152 |
+
# blocks consumed on the next batch extraction.
|
| 153 |
+
self._buffer = leftover
|
| 154 |
+
self._buffer_size -= self._batch_size
|
| 155 |
+
needs_copy = needs_copy and not output.will_build_yield_copy()
|
| 156 |
+
batch = output.build()
|
| 157 |
+
if needs_copy:
|
| 158 |
+
# Need to ensure that the batch is a fresh copy.
|
| 159 |
+
batch = BlockAccessor.for_block(batch)
|
| 160 |
+
batch = batch.slice(0, batch.num_rows(), copy=True)
|
| 161 |
+
return batch
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class ShufflingBatcher(BatcherInterface):
|
| 165 |
+
"""Chunks blocks into shuffled batches, using a local in-memory shuffle buffer."""
|
| 166 |
+
|
| 167 |
+
# Implementation Note:
|
| 168 |
+
#
|
| 169 |
+
# This shuffling batcher lazily builds a shuffle buffer from added blocks, and once
|
| 170 |
+
# a batch is requested via .next_batch(), it concatenates the blocks into a concrete
|
| 171 |
+
# shuffle buffer and randomly shuffles the entire buffer.
|
| 172 |
+
#
|
| 173 |
+
# Adding of more blocks can be intermixed with retrieving batches, but it should be
|
| 174 |
+
# noted that we can end up performing two expensive operations on each retrieval:
|
| 175 |
+
# 1. Build added blocks into a concrete shuffle buffer.
|
| 176 |
+
# 2. Shuffling the entire buffer.
|
| 177 |
+
# To amortize the overhead of this process, we only shuffle the blocks after a
|
| 178 |
+
# delay designated by SHUFFLE_BUFFER_COMPACTION_RATIO.
|
| 179 |
+
#
|
| 180 |
+
# Similarly, adding blocks is very cheap. Each added block will be appended to a
|
| 181 |
+
# list, with concatenation of the underlying data delayed until the next batch
|
| 182 |
+
# compaction.
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
batch_size: Optional[int],
|
| 187 |
+
shuffle_buffer_min_size: int,
|
| 188 |
+
shuffle_seed: Optional[int] = None,
|
| 189 |
+
):
|
| 190 |
+
"""Constructs a random-shuffling block batcher.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
batch_size: Record batch size.
|
| 194 |
+
shuffle_buffer_min_size: Minimum number of rows that must be in the local
|
| 195 |
+
in-memory shuffle buffer in order to yield a batch. When there are no
|
| 196 |
+
more rows to be added to the buffer, the number of rows in the buffer
|
| 197 |
+
*will* decrease below this value while yielding the remaining batches,
|
| 198 |
+
and the final batch may have less than ``batch_size`` rows. Increasing
|
| 199 |
+
this will improve the randomness of the shuffle but may increase the
|
| 200 |
+
latency to the first batch.
|
| 201 |
+
shuffle_seed: The seed to use for the local random shuffle.
|
| 202 |
+
"""
|
| 203 |
+
if batch_size is None:
|
| 204 |
+
raise ValueError("Must specify a batch_size if using a local shuffle.")
|
| 205 |
+
self._batch_size = batch_size
|
| 206 |
+
self._shuffle_seed = shuffle_seed
|
| 207 |
+
if shuffle_buffer_min_size < batch_size:
|
| 208 |
+
# Round it up internally to `batch_size` since our algorithm requires it.
|
| 209 |
+
# This is harmless since it only offers extra randomization.
|
| 210 |
+
shuffle_buffer_min_size = batch_size
|
| 211 |
+
self._buffer_min_size = shuffle_buffer_min_size
|
| 212 |
+
self._builder = DelegatingBlockBuilder()
|
| 213 |
+
self._shuffle_buffer: Block = None
|
| 214 |
+
self._batch_head = 0
|
| 215 |
+
self._done_adding = False
|
| 216 |
+
|
| 217 |
+
def add(self, block: Block):
|
| 218 |
+
"""Add a block to the shuffle buffer.
|
| 219 |
+
|
| 220 |
+
Note empty block is not added to buffer.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
block: Block to add to the shuffle buffer.
|
| 224 |
+
"""
|
| 225 |
+
if BlockAccessor.for_block(block).num_rows() > 0:
|
| 226 |
+
self._builder.add_block(block)
|
| 227 |
+
|
| 228 |
+
def done_adding(self) -> bool:
|
| 229 |
+
"""Indicate to the batcher that no more blocks will be added to the batcher.
|
| 230 |
+
|
| 231 |
+
No more blocks should be added to the batcher after calling this.
|
| 232 |
+
"""
|
| 233 |
+
self._done_adding = True
|
| 234 |
+
|
| 235 |
+
def has_any(self) -> bool:
|
| 236 |
+
"""Whether this batcher has any data."""
|
| 237 |
+
return self._buffer_size() > 0
|
| 238 |
+
|
| 239 |
+
def has_batch(self) -> bool:
|
| 240 |
+
"""Whether this batcher has any batches."""
|
| 241 |
+
buffer_size = self._buffer_size()
|
| 242 |
+
|
| 243 |
+
if not self._done_adding:
|
| 244 |
+
# Delay pulling of batches until the buffer is large enough in order to
|
| 245 |
+
# amortize compaction overhead.
|
| 246 |
+
return self._materialized_buffer_size() >= self._buffer_min_size or (
|
| 247 |
+
buffer_size - self._batch_size
|
| 248 |
+
>= self._buffer_min_size * SHUFFLE_BUFFER_COMPACTION_RATIO
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
return buffer_size >= self._batch_size
|
| 252 |
+
|
| 253 |
+
def _buffer_size(self) -> int:
|
| 254 |
+
"""Return shuffle buffer size."""
|
| 255 |
+
buffer_size = self._builder.num_rows()
|
| 256 |
+
buffer_size += self._materialized_buffer_size()
|
| 257 |
+
return buffer_size
|
| 258 |
+
|
| 259 |
+
def _materialized_buffer_size(self) -> int:
|
| 260 |
+
"""Return materialized (compacted portion of) shuffle buffer size."""
|
| 261 |
+
if self._shuffle_buffer is None:
|
| 262 |
+
return 0
|
| 263 |
+
# The size of the concrete (materialized) shuffle buffer, adjusting
|
| 264 |
+
# for the batch head position, which also serves as a counter of the number
|
| 265 |
+
# of already-yielded rows from the current concrete shuffle buffer.
|
| 266 |
+
return max(
|
| 267 |
+
0,
|
| 268 |
+
BlockAccessor.for_block(self._shuffle_buffer).num_rows() - self._batch_head,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def next_batch(self) -> Block:
|
| 272 |
+
"""Get the next shuffled batch from the shuffle buffer.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
A batch represented as a Block.
|
| 276 |
+
"""
|
| 277 |
+
assert self.has_batch() or (self._done_adding and self.has_any())
|
| 278 |
+
# Add rows in the builder to the shuffle buffer. Note that we delay compaction
|
| 279 |
+
# as much as possible to amortize the concatenation overhead. Compaction is
|
| 280 |
+
# only necessary when the materialized buffer size falls below the min size.
|
| 281 |
+
if self._builder.num_rows() > 0 and (
|
| 282 |
+
self._done_adding
|
| 283 |
+
or self._materialized_buffer_size() <= self._buffer_min_size
|
| 284 |
+
):
|
| 285 |
+
if self._shuffle_buffer is not None:
|
| 286 |
+
if self._batch_head > 0:
|
| 287 |
+
# Compact the materialized shuffle buffer.
|
| 288 |
+
block = BlockAccessor.for_block(self._shuffle_buffer)
|
| 289 |
+
self._shuffle_buffer = block.slice(
|
| 290 |
+
self._batch_head, block.num_rows()
|
| 291 |
+
)
|
| 292 |
+
# Add the unyielded rows from the existing shuffle buffer.
|
| 293 |
+
self._builder.add_block(self._shuffle_buffer)
|
| 294 |
+
# Build the new shuffle buffer.
|
| 295 |
+
self._shuffle_buffer = self._builder.build()
|
| 296 |
+
self._shuffle_buffer = BlockAccessor.for_block(
|
| 297 |
+
self._shuffle_buffer
|
| 298 |
+
).random_shuffle(self._shuffle_seed)
|
| 299 |
+
if self._shuffle_seed is not None:
|
| 300 |
+
self._shuffle_seed += 1
|
| 301 |
+
if (
|
| 302 |
+
isinstance(
|
| 303 |
+
BlockAccessor.for_block(self._shuffle_buffer), ArrowBlockAccessor
|
| 304 |
+
)
|
| 305 |
+
and self._shuffle_buffer.num_columns > 0
|
| 306 |
+
and self._shuffle_buffer.column(0).num_chunks
|
| 307 |
+
>= MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS
|
| 308 |
+
):
|
| 309 |
+
self._shuffle_buffer = transform_pyarrow.combine_chunks(
|
| 310 |
+
self._shuffle_buffer
|
| 311 |
+
)
|
| 312 |
+
# Reset the builder.
|
| 313 |
+
self._builder = DelegatingBlockBuilder()
|
| 314 |
+
self._batch_head = 0
|
| 315 |
+
|
| 316 |
+
assert self._shuffle_buffer is not None
|
| 317 |
+
buffer_size = BlockAccessor.for_block(self._shuffle_buffer).num_rows()
|
| 318 |
+
# Truncate the batch to the buffer size, if necessary.
|
| 319 |
+
batch_size = min(self._batch_size, buffer_size)
|
| 320 |
+
slice_start = self._batch_head
|
| 321 |
+
self._batch_head += batch_size
|
| 322 |
+
# Yield the shuffled batch.
|
| 323 |
+
return BlockAccessor.for_block(self._shuffle_buffer).slice(
|
| 324 |
+
slice_start, self._batch_head
|
| 325 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/block_batching.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import nullcontext
|
| 2 |
+
from typing import Callable, Iterator, Optional, TypeVar
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.block_batching.util import (
|
| 5 |
+
blocks_to_batches,
|
| 6 |
+
collate,
|
| 7 |
+
extract_data_from_batch,
|
| 8 |
+
format_batches,
|
| 9 |
+
)
|
| 10 |
+
from ray.data._internal.stats import DatasetStats
|
| 11 |
+
from ray.data.block import Block, DataBatch
|
| 12 |
+
|
| 13 |
+
T = TypeVar("T")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def batch_blocks(
|
| 17 |
+
blocks: Iterator[Block],
|
| 18 |
+
*,
|
| 19 |
+
stats: Optional[DatasetStats] = None,
|
| 20 |
+
batch_size: Optional[int] = None,
|
| 21 |
+
batch_format: str = "default",
|
| 22 |
+
drop_last: bool = False,
|
| 23 |
+
collate_fn: Optional[Callable[[DataBatch], DataBatch]] = None,
|
| 24 |
+
shuffle_buffer_min_size: Optional[int] = None,
|
| 25 |
+
shuffle_seed: Optional[int] = None,
|
| 26 |
+
ensure_copy: bool = False,
|
| 27 |
+
) -> Iterator[DataBatch]:
|
| 28 |
+
"""Create formatted batches of data from 1 or more blocks.
|
| 29 |
+
|
| 30 |
+
This function takes in an iterator of already fetched blocks. Consequently, this
|
| 31 |
+
function doesn't support block prefetching.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]:
|
| 35 |
+
batch_iter = format_batches(
|
| 36 |
+
blocks_to_batches(
|
| 37 |
+
block_iter=base_iterator,
|
| 38 |
+
stats=stats,
|
| 39 |
+
batch_size=batch_size,
|
| 40 |
+
drop_last=drop_last,
|
| 41 |
+
shuffle_buffer_min_size=shuffle_buffer_min_size,
|
| 42 |
+
shuffle_seed=shuffle_seed,
|
| 43 |
+
ensure_copy=ensure_copy,
|
| 44 |
+
),
|
| 45 |
+
batch_format=batch_format,
|
| 46 |
+
stats=stats,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if collate_fn is not None:
|
| 50 |
+
batch_iter = collate(batch_iter, collate_fn=collate_fn, stats=stats)
|
| 51 |
+
|
| 52 |
+
batch_iter = extract_data_from_batch(batch_iter)
|
| 53 |
+
yield from batch_iter
|
| 54 |
+
|
| 55 |
+
batch_iter = _iterator_fn(blocks)
|
| 56 |
+
|
| 57 |
+
for formatted_batch in batch_iter:
|
| 58 |
+
user_timer = stats.iter_user_s.timer() if stats else nullcontext()
|
| 59 |
+
with user_timer:
|
| 60 |
+
yield formatted_batch
|
.venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/interfaces.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, List
|
| 4 |
+
|
| 5 |
+
from ray.data.block import Block, DataBatch
|
| 6 |
+
from ray.types import ObjectRef
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class Batch:
|
| 11 |
+
"""A batch of data with a corresponding index.
|
| 12 |
+
|
| 13 |
+
Attributes:
|
| 14 |
+
batch_idx: The global index of this batch so that downstream operations can
|
| 15 |
+
maintain ordering.
|
| 16 |
+
data: The batch of data.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
batch_idx: int
|
| 20 |
+
data: DataBatch
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CollatedBatch(Batch):
|
| 24 |
+
"""A batch of collated data with a corresponding index.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
batch_idx: The global index of this batch so that downstream operations can
|
| 28 |
+
maintain ordering.
|
| 29 |
+
data: The batch of data which is the output of a user provided collate_fn
|
| 30 |
+
Therefore, the type of this data can be Any.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
batch_idx: int
|
| 34 |
+
data: Any
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BlockPrefetcher(metaclass=abc.ABCMeta):
|
| 38 |
+
"""Interface for prefetching blocks."""
|
| 39 |
+
|
| 40 |
+
@abc.abstractmethod
|
| 41 |
+
def prefetch_blocks(self, blocks: List[ObjectRef[Block]]):
|
| 42 |
+
"""Prefetch the provided blocks to this node."""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def stop(self):
|
| 46 |
+
"""Stop prefetching and release resources."""
|
| 47 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/data/_internal/block_builder.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Generic
|
| 2 |
+
|
| 3 |
+
from ray.data.block import Block, BlockAccessor, BlockType, T
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BlockBuilder(Generic[T]):
|
| 7 |
+
"""A builder class for blocks."""
|
| 8 |
+
|
| 9 |
+
@staticmethod
|
| 10 |
+
def for_block(block: Block) -> "BlockBuilder":
|
| 11 |
+
return BlockAccessor.for_block(block).builder()
|
| 12 |
+
|
| 13 |
+
def add(self, item: T) -> None:
|
| 14 |
+
"""Append a single row to the block being built."""
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
def add_block(self, block: Block) -> None:
|
| 18 |
+
"""Append an entire block to the block being built."""
|
| 19 |
+
raise NotImplementedError
|
| 20 |
+
|
| 21 |
+
def will_build_yield_copy(self) -> bool:
|
| 22 |
+
"""Whether building this block will yield a new block copy."""
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
|
| 25 |
+
def build(self) -> Block:
|
| 26 |
+
"""Build the block."""
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
def num_rows(self) -> int:
|
| 30 |
+
"""Return the number of rows added in the block."""
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
def get_estimated_memory_usage(self) -> int:
|
| 34 |
+
"""Return the estimated memory usage so far in bytes."""
|
| 35 |
+
raise NotImplementedError
|
| 36 |
+
|
| 37 |
+
def block_type(self) -> BlockType:
|
| 38 |
+
"""Return the block type."""
|
| 39 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/ray/data/_internal/block_list.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterator, List, Tuple
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.memory_tracing import trace_allocation
|
| 4 |
+
from ray.data.block import Block, BlockMetadata
|
| 5 |
+
from ray.types import ObjectRef
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BlockList:
|
| 9 |
+
"""A list of blocks that may be computed or pending computation.
|
| 10 |
+
|
| 11 |
+
All blocks are known ahead of time
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
blocks: List[ObjectRef[Block]],
|
| 17 |
+
metadata: List[BlockMetadata],
|
| 18 |
+
*,
|
| 19 |
+
owned_by_consumer: bool,
|
| 20 |
+
):
|
| 21 |
+
assert len(blocks) == len(metadata), (blocks, metadata)
|
| 22 |
+
for b in blocks:
|
| 23 |
+
trace_allocation(b, "BlockList.__init__")
|
| 24 |
+
self._blocks: List[ObjectRef[Block]] = blocks
|
| 25 |
+
self._num_blocks = len(self._blocks)
|
| 26 |
+
self._metadata: List[BlockMetadata] = metadata
|
| 27 |
+
# Whether the block list is owned by consuming APIs, and if so it can be
|
| 28 |
+
# eagerly deleted after read by the consumer.
|
| 29 |
+
self._owned_by_consumer = owned_by_consumer
|
| 30 |
+
# This field can be set to indicate the number of estimated output blocks,
|
| 31 |
+
# since each read task may produce multiple output blocks after splitting.
|
| 32 |
+
self._estimated_num_blocks = None
|
| 33 |
+
|
| 34 |
+
def __repr__(self):
|
| 35 |
+
return f"BlockList(owned_by_consumer={self._owned_by_consumer})"
|
| 36 |
+
|
| 37 |
+
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
|
| 38 |
+
"""Get the metadata for all blocks."""
|
| 39 |
+
return self._metadata.copy()
|
| 40 |
+
|
| 41 |
+
def copy(self) -> "BlockList":
|
| 42 |
+
"""Perform a shallow copy of this BlockList."""
|
| 43 |
+
return BlockList(
|
| 44 |
+
self._blocks, self._metadata, owned_by_consumer=self._owned_by_consumer
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def clear(self) -> None:
|
| 48 |
+
"""Erase references to the tasks tracked by the BlockList."""
|
| 49 |
+
self._blocks = None
|
| 50 |
+
|
| 51 |
+
def is_cleared(self) -> bool:
|
| 52 |
+
"""Whether this BlockList has been cleared."""
|
| 53 |
+
return self._blocks is None
|
| 54 |
+
|
| 55 |
+
def _check_if_cleared(self) -> None:
|
| 56 |
+
"""Raise an error if this BlockList has been previously cleared."""
|
| 57 |
+
if self.is_cleared():
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"This Dataset's blocks have been moved, which means that you "
|
| 60 |
+
"can no longer use this Dataset."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def get_blocks(self) -> List[ObjectRef[Block]]:
|
| 64 |
+
"""Get list of the blocks of this block list.
|
| 65 |
+
|
| 66 |
+
This blocks on the execution of the tasks generating block outputs.
|
| 67 |
+
The length of this iterator is not known until execution.
|
| 68 |
+
"""
|
| 69 |
+
self._check_if_cleared()
|
| 70 |
+
return list(self._blocks)
|
| 71 |
+
|
| 72 |
+
def get_blocks_with_metadata(self) -> List[Tuple[ObjectRef[Block], BlockMetadata]]:
|
| 73 |
+
"""Bulk version of iter_blocks_with_metadata().
|
| 74 |
+
|
| 75 |
+
Prefer calling this instead of the iter form for performance if you
|
| 76 |
+
don't need lazy evaluation.
|
| 77 |
+
"""
|
| 78 |
+
self.get_blocks()
|
| 79 |
+
return list(self.iter_blocks_with_metadata())
|
| 80 |
+
|
| 81 |
+
def iter_blocks_with_metadata(
|
| 82 |
+
self,
|
| 83 |
+
) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
|
| 84 |
+
"""Iterate over the blocks along with their runtime metadata.
|
| 85 |
+
|
| 86 |
+
This blocks on the execution of the tasks generating block outputs.
|
| 87 |
+
The length of this iterator is not known until execution.
|
| 88 |
+
"""
|
| 89 |
+
self._check_if_cleared()
|
| 90 |
+
return zip(self._blocks, self._metadata)
|
| 91 |
+
|
| 92 |
+
def initial_num_blocks(self) -> int:
|
| 93 |
+
"""Returns the number of blocks of this BlockList."""
|
| 94 |
+
return self._num_blocks
|
| 95 |
+
|
| 96 |
+
def estimated_num_blocks(self) -> int:
|
| 97 |
+
"""Estimate of number of output blocks, without triggering actual execution."""
|
| 98 |
+
return self._estimated_num_blocks or self._num_blocks
|
.venv/lib/python3.11/site-packages/ray/data/_internal/compute.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Callable, Iterable, Optional, TypeVar, Union
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.execution.interfaces import TaskContext
|
| 5 |
+
from ray.data.block import Block, UserDefinedFunction
|
| 6 |
+
from ray.util.annotations import DeveloperAPI
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
T = TypeVar("T")
|
| 11 |
+
U = TypeVar("U")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Block transform function applied by task and actor pools.
|
| 15 |
+
BlockTransform = Union[
|
| 16 |
+
# TODO(Clark): Once Ray only supports Python 3.8+, use protocol to constrain block
|
| 17 |
+
# transform type.
|
| 18 |
+
# Callable[[Block, ...], Iterable[Block]]
|
| 19 |
+
# Callable[[Block, UserDefinedFunction, ...], Iterable[Block]],
|
| 20 |
+
Callable[[Iterable[Block], TaskContext], Iterable[Block]],
|
| 21 |
+
Callable[[Iterable[Block], TaskContext, UserDefinedFunction], Iterable[Block]],
|
| 22 |
+
Callable[..., Iterable[Block]],
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@DeveloperAPI
|
| 27 |
+
class ComputeStrategy:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@DeveloperAPI
|
| 32 |
+
class TaskPoolStrategy(ComputeStrategy):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
size: Optional[int] = None,
|
| 36 |
+
):
|
| 37 |
+
"""Construct TaskPoolStrategy for a Dataset transform.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
size: Specify the maximum size of the task pool.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
if size is not None and size < 1:
|
| 44 |
+
raise ValueError("`size` must be >= 1", size)
|
| 45 |
+
self.size = size
|
| 46 |
+
|
| 47 |
+
def __eq__(self, other: Any) -> bool:
|
| 48 |
+
return (isinstance(other, TaskPoolStrategy) and self.size == other.size) or (
|
| 49 |
+
other == "tasks" and self.size is None
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ActorPoolStrategy(ComputeStrategy):
|
| 54 |
+
"""Specify the compute strategy for a Dataset transform.
|
| 55 |
+
|
| 56 |
+
ActorPoolStrategy specifies that an autoscaling pool of actors should be used
|
| 57 |
+
for a given Dataset transform. This is useful for stateful setup of callable
|
| 58 |
+
classes.
|
| 59 |
+
|
| 60 |
+
For a fixed-sized pool of size ``n``, specify ``compute=ActorPoolStrategy(size=n)``.
|
| 61 |
+
To autoscale from ``m`` to ``n`` actors, specify
|
| 62 |
+
``ActorPoolStrategy(min_size=m, max_size=n)``.
|
| 63 |
+
|
| 64 |
+
To increase opportunities for pipelining task dependency prefetching with
|
| 65 |
+
computation and avoiding actor startup delays, set max_tasks_in_flight_per_actor
|
| 66 |
+
to 2 or greater; to try to decrease the delay due to queueing of tasks on the worker
|
| 67 |
+
actors, set max_tasks_in_flight_per_actor to 1.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
*,
|
| 73 |
+
size: Optional[int] = None,
|
| 74 |
+
min_size: Optional[int] = None,
|
| 75 |
+
max_size: Optional[int] = None,
|
| 76 |
+
max_tasks_in_flight_per_actor: Optional[int] = None,
|
| 77 |
+
):
|
| 78 |
+
"""Construct ActorPoolStrategy for a Dataset transform.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
size: Specify a fixed size actor pool of this size. It is an error to
|
| 82 |
+
specify both `size` and `min_size` or `max_size`.
|
| 83 |
+
min_size: The minimize size of the actor pool.
|
| 84 |
+
max_size: The maximum size of the actor pool.
|
| 85 |
+
max_tasks_in_flight_per_actor: The maximum number of tasks to concurrently
|
| 86 |
+
send to a single actor worker. Increasing this will increase
|
| 87 |
+
opportunities for pipelining task dependency prefetching with
|
| 88 |
+
computation and avoiding actor startup delays, but will also increase
|
| 89 |
+
queueing delay.
|
| 90 |
+
"""
|
| 91 |
+
if size is not None:
|
| 92 |
+
if size < 1:
|
| 93 |
+
raise ValueError("size must be >= 1", size)
|
| 94 |
+
if max_size is not None or min_size is not None:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
"min_size and max_size cannot be set at the same time as `size`"
|
| 97 |
+
)
|
| 98 |
+
min_size = size
|
| 99 |
+
max_size = size
|
| 100 |
+
if min_size is not None and min_size < 1:
|
| 101 |
+
raise ValueError("min_size must be >= 1", min_size)
|
| 102 |
+
if max_size is not None:
|
| 103 |
+
if min_size is None:
|
| 104 |
+
min_size = 1 # Legacy default.
|
| 105 |
+
if min_size > max_size:
|
| 106 |
+
raise ValueError("min_size must be <= max_size", min_size, max_size)
|
| 107 |
+
if (
|
| 108 |
+
max_tasks_in_flight_per_actor is not None
|
| 109 |
+
and max_tasks_in_flight_per_actor < 1
|
| 110 |
+
):
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"max_tasks_in_flight_per_actor must be >= 1, got: ",
|
| 113 |
+
max_tasks_in_flight_per_actor,
|
| 114 |
+
)
|
| 115 |
+
self.min_size = min_size or 1
|
| 116 |
+
self.max_size = max_size or float("inf")
|
| 117 |
+
self.max_tasks_in_flight_per_actor = max_tasks_in_flight_per_actor
|
| 118 |
+
self.num_workers = 0
|
| 119 |
+
self.ready_to_total_workers_ratio = 0.8
|
| 120 |
+
|
| 121 |
+
def __eq__(self, other: Any) -> bool:
|
| 122 |
+
return isinstance(other, ActorPoolStrategy) and (
|
| 123 |
+
self.min_size == other.min_size
|
| 124 |
+
and self.max_size == other.max_size
|
| 125 |
+
and self.max_tasks_in_flight_per_actor
|
| 126 |
+
== other.max_tasks_in_flight_per_actor
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def get_compute(compute_spec: Union[str, ComputeStrategy]) -> ComputeStrategy:
|
| 131 |
+
if not isinstance(compute_spec, (TaskPoolStrategy, ActorPoolStrategy)):
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"In Ray 2.5, the compute spec must be either "
|
| 134 |
+
f"TaskPoolStrategy or ActorPoolStrategy, was: {compute_spec}."
|
| 135 |
+
)
|
| 136 |
+
elif not compute_spec or compute_spec == "tasks":
|
| 137 |
+
return TaskPoolStrategy()
|
| 138 |
+
elif compute_spec == "actors":
|
| 139 |
+
return ActorPoolStrategy()
|
| 140 |
+
elif isinstance(compute_spec, ComputeStrategy):
|
| 141 |
+
return compute_spec
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError("compute must be one of [`tasks`, `actors`, ComputeStrategy]")
|
.venv/lib/python3.11/site-packages/ray/data/_internal/delegating_block_builder.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
from typing import Any, Mapping, Optional
|
| 3 |
+
|
| 4 |
+
from ray.data._internal.arrow_block import ArrowBlockBuilder
|
| 5 |
+
from ray.data._internal.block_builder import BlockBuilder
|
| 6 |
+
from ray.data.block import Block, BlockAccessor, BlockType, DataBatch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DelegatingBlockBuilder(BlockBuilder):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self._builder = None
|
| 12 |
+
self._empty_block = None
|
| 13 |
+
|
| 14 |
+
@property
|
| 15 |
+
def _inferred_block_type(self) -> Optional[BlockType]:
|
| 16 |
+
"""The block type inferred from the first item added to the builder."""
|
| 17 |
+
if self._builder is not None:
|
| 18 |
+
return self._builder.block_type()
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
def add(self, item: Mapping[str, Any]) -> None:
|
| 22 |
+
assert isinstance(item, collections.abc.Mapping), item
|
| 23 |
+
|
| 24 |
+
if self._builder is None:
|
| 25 |
+
self._builder = ArrowBlockBuilder()
|
| 26 |
+
|
| 27 |
+
self._builder.add(item)
|
| 28 |
+
|
| 29 |
+
def add_batch(self, batch: DataBatch):
|
| 30 |
+
"""Add a user-facing data batch to the builder.
|
| 31 |
+
|
| 32 |
+
This data batch will be converted to an internal block and then added to the
|
| 33 |
+
underlying builder.
|
| 34 |
+
"""
|
| 35 |
+
block = BlockAccessor.batch_to_block(batch, self._inferred_block_type)
|
| 36 |
+
return self.add_block(block)
|
| 37 |
+
|
| 38 |
+
def add_block(self, block: Block):
|
| 39 |
+
accessor = BlockAccessor.for_block(block)
|
| 40 |
+
if accessor.num_rows() == 0:
|
| 41 |
+
# Don't infer types of empty lists. Store the block and use it if no
|
| 42 |
+
# other data is added. https://github.com/ray-project/ray/issues/20290
|
| 43 |
+
self._empty_block = block
|
| 44 |
+
return
|
| 45 |
+
if self._builder is None:
|
| 46 |
+
self._builder = accessor.builder()
|
| 47 |
+
else:
|
| 48 |
+
block_type = accessor.block_type()
|
| 49 |
+
assert block_type == self._inferred_block_type, (
|
| 50 |
+
block_type,
|
| 51 |
+
self._inferred_block_type,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self._builder.add_block(accessor.to_block())
|
| 55 |
+
|
| 56 |
+
def will_build_yield_copy(self) -> bool:
|
| 57 |
+
if self._builder is None:
|
| 58 |
+
return True
|
| 59 |
+
return self._builder.will_build_yield_copy()
|
| 60 |
+
|
| 61 |
+
def build(self) -> Block:
|
| 62 |
+
if self._builder is None:
|
| 63 |
+
if self._empty_block is not None:
|
| 64 |
+
self._builder = BlockAccessor.for_block(self._empty_block).builder()
|
| 65 |
+
self._builder.add_block(self._empty_block)
|
| 66 |
+
else:
|
| 67 |
+
self._builder = ArrowBlockBuilder()
|
| 68 |
+
return self._builder.build()
|
| 69 |
+
|
| 70 |
+
def num_rows(self) -> int:
|
| 71 |
+
return self._builder.num_rows() if self._builder is not None else 0
|
| 72 |
+
|
| 73 |
+
def get_estimated_memory_usage(self) -> int:
|
| 74 |
+
if self._builder is None:
|
| 75 |
+
return 0
|
| 76 |
+
return self._builder.get_estimated_memory_usage()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/equalize.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.execution.interfaces import RefBundle
|
| 4 |
+
from ray.data._internal.split import _calculate_blocks_rows, _split_at_indices
|
| 5 |
+
from ray.data.block import Block, BlockMetadata, BlockPartition
|
| 6 |
+
from ray.types import ObjectRef
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _equalize(
|
| 10 |
+
per_split_bundles: List[RefBundle],
|
| 11 |
+
owned_by_consumer: bool,
|
| 12 |
+
) -> List[RefBundle]:
|
| 13 |
+
"""Equalize split ref bundles into equal number of rows.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
per_split_bundles: ref bundles to equalize.
|
| 17 |
+
Returns:
|
| 18 |
+
the equalized ref bundles.
|
| 19 |
+
"""
|
| 20 |
+
if len(per_split_bundles) == 0:
|
| 21 |
+
return per_split_bundles
|
| 22 |
+
per_split_blocks_with_metadata = [bundle.blocks for bundle in per_split_bundles]
|
| 23 |
+
per_split_num_rows: List[List[int]] = [
|
| 24 |
+
_calculate_blocks_rows(split) for split in per_split_blocks_with_metadata
|
| 25 |
+
]
|
| 26 |
+
total_rows = sum([sum(blocks_rows) for blocks_rows in per_split_num_rows])
|
| 27 |
+
target_split_size = total_rows // len(per_split_blocks_with_metadata)
|
| 28 |
+
|
| 29 |
+
# phase 1: shave the current splits by dropping blocks (into leftovers)
|
| 30 |
+
# and calculate num rows needed to the meet target.
|
| 31 |
+
shaved_splits, per_split_needed_rows, leftovers = _shave_all_splits(
|
| 32 |
+
per_split_blocks_with_metadata, per_split_num_rows, target_split_size
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# validate invariants
|
| 36 |
+
for shaved_split, split_needed_row in zip(shaved_splits, per_split_needed_rows):
|
| 37 |
+
num_shaved_rows = sum([meta.num_rows for _, meta in shaved_split])
|
| 38 |
+
assert num_shaved_rows <= target_split_size
|
| 39 |
+
assert num_shaved_rows + split_needed_row == target_split_size
|
| 40 |
+
|
| 41 |
+
# phase 2: based on the num rows needed for each shaved split, split the leftovers
|
| 42 |
+
# in the shape that exactly matches the rows needed.
|
| 43 |
+
leftover_bundle = RefBundle(leftovers, owns_blocks=owned_by_consumer)
|
| 44 |
+
leftover_splits = _split_leftovers(leftover_bundle, per_split_needed_rows)
|
| 45 |
+
|
| 46 |
+
# phase 3: merge the shaved_splits and leftoever splits and return.
|
| 47 |
+
for i, leftover_split in enumerate(leftover_splits):
|
| 48 |
+
shaved_splits[i].extend(leftover_split)
|
| 49 |
+
|
| 50 |
+
# validate invariants.
|
| 51 |
+
num_shaved_rows = sum([meta.num_rows for _, meta in shaved_splits[i]])
|
| 52 |
+
assert num_shaved_rows == target_split_size
|
| 53 |
+
|
| 54 |
+
# Compose the result back to RefBundle
|
| 55 |
+
equalized_ref_bundles: List[RefBundle] = []
|
| 56 |
+
for split in shaved_splits:
|
| 57 |
+
equalized_ref_bundles.append(RefBundle(split, owns_blocks=owned_by_consumer))
|
| 58 |
+
return equalized_ref_bundles
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _shave_one_split(
|
| 62 |
+
split: BlockPartition, num_rows_per_block: List[int], target_size: int
|
| 63 |
+
) -> Tuple[BlockPartition, int, BlockPartition]:
|
| 64 |
+
"""Shave a block list to the target size.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
split: the block list to shave.
|
| 68 |
+
num_rows_per_block: num rows for each block in the list.
|
| 69 |
+
target_size: the upper bound target size of the shaved list.
|
| 70 |
+
Returns:
|
| 71 |
+
A tuple of:
|
| 72 |
+
- shaved block list.
|
| 73 |
+
- num of rows needed for the block list to meet the target size.
|
| 74 |
+
- leftover blocks.
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
# iterates through the blocks from the input list and
|
| 78 |
+
shaved = []
|
| 79 |
+
leftovers = []
|
| 80 |
+
shaved_rows = 0
|
| 81 |
+
for block_with_meta, block_rows in zip(split, num_rows_per_block):
|
| 82 |
+
if block_rows + shaved_rows <= target_size:
|
| 83 |
+
shaved.append(block_with_meta)
|
| 84 |
+
shaved_rows += block_rows
|
| 85 |
+
else:
|
| 86 |
+
leftovers.append(block_with_meta)
|
| 87 |
+
num_rows_needed = target_size - shaved_rows
|
| 88 |
+
return shaved, num_rows_needed, leftovers
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _shave_all_splits(
|
| 92 |
+
input_splits: List[BlockPartition],
|
| 93 |
+
per_split_num_rows: List[List[int]],
|
| 94 |
+
target_size: int,
|
| 95 |
+
) -> Tuple[List[BlockPartition], List[int], BlockPartition]:
|
| 96 |
+
"""Shave all block list to the target size.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
input_splits: all block list to shave.
|
| 100 |
+
input_splits: num rows (per block) for each block list.
|
| 101 |
+
target_size: the upper bound target size of the shaved lists.
|
| 102 |
+
Returns:
|
| 103 |
+
A tuple of:
|
| 104 |
+
- all shaved block list.
|
| 105 |
+
- num of rows needed for the block list to meet the target size.
|
| 106 |
+
- leftover blocks.
|
| 107 |
+
"""
|
| 108 |
+
shaved_splits = []
|
| 109 |
+
per_split_needed_rows = []
|
| 110 |
+
leftovers = []
|
| 111 |
+
|
| 112 |
+
for split, num_rows_per_block in zip(input_splits, per_split_num_rows):
|
| 113 |
+
shaved, num_rows_needed, _leftovers = _shave_one_split(
|
| 114 |
+
split, num_rows_per_block, target_size
|
| 115 |
+
)
|
| 116 |
+
shaved_splits.append(shaved)
|
| 117 |
+
per_split_needed_rows.append(num_rows_needed)
|
| 118 |
+
leftovers.extend(_leftovers)
|
| 119 |
+
|
| 120 |
+
return shaved_splits, per_split_needed_rows, leftovers
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _split_leftovers(
|
| 124 |
+
leftovers: RefBundle, per_split_needed_rows: List[int]
|
| 125 |
+
) -> List[BlockPartition]:
|
| 126 |
+
"""Split leftover blocks by the num of rows needed."""
|
| 127 |
+
num_splits = len(per_split_needed_rows)
|
| 128 |
+
split_indices = []
|
| 129 |
+
prev = 0
|
| 130 |
+
for i, num_rows_needed in enumerate(per_split_needed_rows):
|
| 131 |
+
split_indices.append(prev + num_rows_needed)
|
| 132 |
+
prev = split_indices[i]
|
| 133 |
+
split_result: Tuple[
|
| 134 |
+
List[List[ObjectRef[Block]]], List[List[BlockMetadata]]
|
| 135 |
+
] = _split_at_indices(
|
| 136 |
+
leftovers.blocks,
|
| 137 |
+
split_indices,
|
| 138 |
+
leftovers.owns_blocks,
|
| 139 |
+
)
|
| 140 |
+
return [list(zip(block_refs, meta)) for block_refs, meta in zip(*split_result)][
|
| 141 |
+
:num_splits
|
| 142 |
+
]
|
.venv/lib/python3.11/site-packages/ray/data/_internal/logging.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import logging.config
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
import ray
|
| 9 |
+
|
| 10 |
+
DEFAULT_CONFIG = {
|
| 11 |
+
"version": 1,
|
| 12 |
+
"disable_existing_loggers": False,
|
| 13 |
+
"formatters": {
|
| 14 |
+
"ray": {
|
| 15 |
+
"format": "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s" # noqa: E501
|
| 16 |
+
},
|
| 17 |
+
"ray_json": {"class": "ray._private.ray_logging.formatters.JSONFormatter"},
|
| 18 |
+
},
|
| 19 |
+
"filters": {
|
| 20 |
+
"console_filter": {"()": "ray.data._internal.logging.HiddenRecordFilter"},
|
| 21 |
+
"core_context_filter": {
|
| 22 |
+
"()": "ray._private.ray_logging.filters.CoreContextFilter"
|
| 23 |
+
},
|
| 24 |
+
},
|
| 25 |
+
"handlers": {
|
| 26 |
+
"file": {
|
| 27 |
+
"class": "ray.data._internal.logging.SessionFileHandler",
|
| 28 |
+
"formatter": "ray",
|
| 29 |
+
"filename": "ray-data.log",
|
| 30 |
+
},
|
| 31 |
+
"file_json": {
|
| 32 |
+
"class": "ray.data._internal.logging.SessionFileHandler",
|
| 33 |
+
"formatter": "ray_json",
|
| 34 |
+
"filename": "ray-data.log",
|
| 35 |
+
"filters": ["core_context_filter"],
|
| 36 |
+
},
|
| 37 |
+
"console": {
|
| 38 |
+
"class": "ray._private.log.PlainRayHandler",
|
| 39 |
+
"formatter": "ray",
|
| 40 |
+
"level": "INFO",
|
| 41 |
+
"filters": ["console_filter"],
|
| 42 |
+
},
|
| 43 |
+
},
|
| 44 |
+
"loggers": {
|
| 45 |
+
"ray.data": {
|
| 46 |
+
"level": "DEBUG",
|
| 47 |
+
"handlers": ["file", "console"],
|
| 48 |
+
"propagate": False,
|
| 49 |
+
},
|
| 50 |
+
"ray.air.util.tensor_extensions": {
|
| 51 |
+
"level": "DEBUG",
|
| 52 |
+
"handlers": ["file", "console"],
|
| 53 |
+
"propagate": False,
|
| 54 |
+
},
|
| 55 |
+
},
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Dictionary of substitutions to be performed when using JSON mode. Handlers with names
|
| 59 |
+
# corresponding to keys will be replaced by those corresponding to values.
|
| 60 |
+
RAY_DATA_LOG_HANDLER_JSON_SUBSTITUTIONS = {"file": "file_json"}
|
| 61 |
+
|
| 62 |
+
# Env. variable to specify the encoding of the file logs when using the default config.
|
| 63 |
+
RAY_DATA_LOG_ENCODING_ENV_VAR_NAME = "RAY_DATA_LOG_ENCODING"
|
| 64 |
+
|
| 65 |
+
# Env. variable to specify the logging config path use defaults if not set
|
| 66 |
+
RAY_DATA_LOGGING_CONFIG_ENV_VAR_NAME = "RAY_DATA_LOGGING_CONFIG"
|
| 67 |
+
|
| 68 |
+
# To facilitate debugging, Ray Data writes debug logs to a file. However, if Ray Data
|
| 69 |
+
# logs every scheduler loop, logging might impact performance. So, we add a "TRACE"
|
| 70 |
+
# level where logs aren't written by default.
|
| 71 |
+
#
|
| 72 |
+
# Use the following code to log a message at the "TRACE" level:
|
| 73 |
+
# ```
|
| 74 |
+
# logger.log(logging.getLevelName("TRACE"), "Your message here.")
|
| 75 |
+
# ````
|
| 76 |
+
logging.addLevelName(logging.DEBUG - 1, "TRACE")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class HiddenRecordFilter:
|
| 80 |
+
"""Filters out log records with the "hide" attribute set to True.
|
| 81 |
+
|
| 82 |
+
This filter allows you to override default logging behavior. For example, if errors
|
| 83 |
+
are printed by default, and you don't want to print a specific error, you can set
|
| 84 |
+
the "hide" attribute to avoid printing the message.
|
| 85 |
+
|
| 86 |
+
.. testcode::
|
| 87 |
+
|
| 88 |
+
import logging
|
| 89 |
+
logger = logging.getLogger("ray.data.spam")
|
| 90 |
+
|
| 91 |
+
# This warning won't be printed to the console.
|
| 92 |
+
logger.warning("ham", extra={"hide": True})
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def filter(self, record):
|
| 96 |
+
return not getattr(record, "hide", False)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class SessionFileHandler(logging.Handler):
|
| 100 |
+
"""A handler that writes to a log file in the Ray session directory.
|
| 101 |
+
|
| 102 |
+
The Ray session directory isn't available until Ray is initialized, so this handler
|
| 103 |
+
lazily creates the file handler when you emit a log record.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
filename: The name of the log file. The file is created in the 'logs' directory
|
| 107 |
+
of the Ray session directory.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, filename: str):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self._filename = filename
|
| 113 |
+
self._handler = None
|
| 114 |
+
self._formatter = None
|
| 115 |
+
self._path = None
|
| 116 |
+
|
| 117 |
+
def emit(self, record):
|
| 118 |
+
if self._handler is None:
|
| 119 |
+
self._try_create_handler()
|
| 120 |
+
if self._handler is not None:
|
| 121 |
+
self._handler.emit(record)
|
| 122 |
+
|
| 123 |
+
def setFormatter(self, fmt: logging.Formatter) -> None:
|
| 124 |
+
if self._handler is not None:
|
| 125 |
+
self._handler.setFormatter(fmt)
|
| 126 |
+
self._formatter = fmt
|
| 127 |
+
|
| 128 |
+
def _try_create_handler(self):
|
| 129 |
+
assert self._handler is None
|
| 130 |
+
|
| 131 |
+
log_directory = get_log_directory()
|
| 132 |
+
if log_directory is None:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
os.makedirs(log_directory, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
self._path = os.path.join(log_directory, self._filename)
|
| 138 |
+
self._handler = logging.FileHandler(self._path)
|
| 139 |
+
if self._formatter is not None:
|
| 140 |
+
self._handler.setFormatter(self._formatter)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def configure_logging() -> None:
|
| 144 |
+
"""Configure the Python logger named 'ray.data'.
|
| 145 |
+
|
| 146 |
+
This function loads the configration YAML specified by "RAY_DATA_LOGGING_CONFIG"
|
| 147 |
+
environment variable. If the variable isn't set, this function loads the default
|
| 148 |
+
"logging.yaml" file that is adjacent to this module.
|
| 149 |
+
|
| 150 |
+
If "RAY_DATA_LOG_ENCODING" is specified as "JSON" we will enable JSON logging mode
|
| 151 |
+
if using the default logging config.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def _load_logging_config(config_path: str):
|
| 155 |
+
with open(config_path) as file:
|
| 156 |
+
config = yaml.safe_load(file)
|
| 157 |
+
return config
|
| 158 |
+
|
| 159 |
+
# Dynamically load env vars
|
| 160 |
+
config_path = os.environ.get(RAY_DATA_LOGGING_CONFIG_ENV_VAR_NAME)
|
| 161 |
+
log_encoding = os.environ.get(RAY_DATA_LOG_ENCODING_ENV_VAR_NAME)
|
| 162 |
+
|
| 163 |
+
if config_path is not None:
|
| 164 |
+
config = _load_logging_config(config_path)
|
| 165 |
+
else:
|
| 166 |
+
config = DEFAULT_CONFIG
|
| 167 |
+
if log_encoding is not None and log_encoding.upper() == "JSON":
|
| 168 |
+
for logger in config["loggers"].values():
|
| 169 |
+
for (
|
| 170 |
+
old_handler_name,
|
| 171 |
+
new_handler_name,
|
| 172 |
+
) in RAY_DATA_LOG_HANDLER_JSON_SUBSTITUTIONS.items():
|
| 173 |
+
logger["handlers"].remove(old_handler_name)
|
| 174 |
+
logger["handlers"].append(new_handler_name)
|
| 175 |
+
|
| 176 |
+
logging.config.dictConfig(config)
|
| 177 |
+
|
| 178 |
+
# After configuring logger, warn if RAY_DATA_LOGGING_CONFIG is used with
|
| 179 |
+
# RAY_DATA_LOG_ENCODING, because they are not both supported together.
|
| 180 |
+
if config_path is not None and log_encoding is not None:
|
| 181 |
+
logger = logging.getLogger(__name__)
|
| 182 |
+
logger.warning(
|
| 183 |
+
"Using `RAY_DATA_LOG_ENCODING` is not supported with "
|
| 184 |
+
+ "`RAY_DATA_LOGGING_CONFIG`"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def reset_logging() -> None:
|
| 189 |
+
"""Reset the logger named 'ray.data' to its initial state.
|
| 190 |
+
|
| 191 |
+
Used for testing.
|
| 192 |
+
"""
|
| 193 |
+
logger = logging.getLogger("ray.data")
|
| 194 |
+
logger.handlers.clear()
|
| 195 |
+
logger.setLevel(logging.NOTSET)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_log_directory() -> Optional[str]:
|
| 199 |
+
"""Return the directory where Ray Data writes log files.
|
| 200 |
+
|
| 201 |
+
If Ray isn't initialized, this function returns ``None``.
|
| 202 |
+
"""
|
| 203 |
+
global_node = ray._private.worker._global_node
|
| 204 |
+
if global_node is None:
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
session_dir = global_node.get_session_dir_path()
|
| 208 |
+
return os.path.join(session_dir, "logs", "ray-data")
|
.venv/lib/python3.11/site-packages/ray/data/_internal/memory_tracing.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility for debugging object store memory eager deletion in Datasets.
|
| 2 |
+
|
| 3 |
+
NOTE: the performance overhead of tracing object allocation is fairly substantial.
|
| 4 |
+
This is meant to use in unit test for debugging. Please do not enable in production,
|
| 5 |
+
without performance optimization.
|
| 6 |
+
|
| 7 |
+
Enable with RAY_DATA_TRACE_ALLOCATIONS=1.
|
| 8 |
+
|
| 9 |
+
Basic usage is to call `trace_allocation` each time a new object is created, and call
|
| 10 |
+
`trace_deallocation` when an object should be disposed of. When the workload is
|
| 11 |
+
complete, call `leak_report` to view possibly leaked objects.
|
| 12 |
+
|
| 13 |
+
Note that so called "leaked" objects will be reclaimed eventually by reference counting
|
| 14 |
+
in Ray. This is just to debug the eager deletion protocol which is more efficient.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from io import StringIO
|
| 18 |
+
from typing import Dict, List
|
| 19 |
+
|
| 20 |
+
import ray
|
| 21 |
+
from ray.data.context import DataContext
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def trace_allocation(ref: ray.ObjectRef, loc: str) -> None:
|
| 25 |
+
"""Record that an object has been created.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
ref: The object created.
|
| 29 |
+
loc: A human-readable string identifying the call site.
|
| 30 |
+
"""
|
| 31 |
+
ctx = DataContext.get_current()
|
| 32 |
+
if ctx.trace_allocations:
|
| 33 |
+
tracer = _get_mem_actor()
|
| 34 |
+
# TODO: it would be nice to determine loc automatically based on the stack.
|
| 35 |
+
ray.get(tracer.trace_alloc.remote([ref], loc))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def trace_deallocation(ref: ray.ObjectRef, loc: str, free: bool = True) -> None:
|
| 39 |
+
"""Record that an object has been deleted (and delete if free=True).
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
ref: The object we no longer need.
|
| 43 |
+
loc: A human-readable string identifying the call site.
|
| 44 |
+
free: Whether to eagerly destroy the object instead of waiting for Ray
|
| 45 |
+
reference counting to kick in.
|
| 46 |
+
"""
|
| 47 |
+
if free:
|
| 48 |
+
ray._private.internal_api.free(ref, local_only=False)
|
| 49 |
+
ctx = DataContext.get_current()
|
| 50 |
+
if ctx.trace_allocations:
|
| 51 |
+
tracer = _get_mem_actor()
|
| 52 |
+
ray.get(tracer.trace_dealloc.remote([ref], loc, free))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def leak_report() -> str:
|
| 56 |
+
tracer = _get_mem_actor()
|
| 57 |
+
return ray.get(tracer.leak_report.remote())
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@ray.remote(num_cpus=0)
|
| 61 |
+
class _MemActor:
|
| 62 |
+
def __init__(self):
|
| 63 |
+
self.allocated: Dict[ray.ObjectRef, dict] = {}
|
| 64 |
+
self.deallocated: Dict[ray.ObjectRef, dict] = {}
|
| 65 |
+
self.skip_dealloc: Dict[ray.ObjectRef, str] = {}
|
| 66 |
+
self.peak_mem = 0
|
| 67 |
+
self.cur_mem = 0
|
| 68 |
+
|
| 69 |
+
def trace_alloc(self, ref: List[ray.ObjectRef], loc: str):
|
| 70 |
+
ref = ref[0] # Avoid Ray materializing the ref.
|
| 71 |
+
if ref not in self.allocated:
|
| 72 |
+
meta = ray.experimental.get_object_locations([ref])
|
| 73 |
+
size_bytes = meta.get("object_size", 0)
|
| 74 |
+
if not size_bytes:
|
| 75 |
+
size_bytes = -1
|
| 76 |
+
from ray import cloudpickle as pickle
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
obj = ray.get(ref, timeout=5.0)
|
| 80 |
+
size_bytes = len(pickle.dumps(obj))
|
| 81 |
+
except Exception:
|
| 82 |
+
print("[mem_tracing] ERROR getting size")
|
| 83 |
+
size_bytes = -1
|
| 84 |
+
print(f"[mem_tracing] Allocated {size_bytes} bytes at {loc}: {ref}")
|
| 85 |
+
entry = {
|
| 86 |
+
"size_bytes": size_bytes,
|
| 87 |
+
"loc": loc,
|
| 88 |
+
}
|
| 89 |
+
self.allocated[ref] = entry
|
| 90 |
+
self.cur_mem += size_bytes
|
| 91 |
+
self.peak_mem = max(self.cur_mem, self.peak_mem)
|
| 92 |
+
|
| 93 |
+
def trace_dealloc(self, ref: List[ray.ObjectRef], loc: str, freed: bool):
|
| 94 |
+
ref = ref[0] # Avoid Ray materializing the ref.
|
| 95 |
+
size_bytes = self.allocated.get(ref, {}).get("size_bytes", 0)
|
| 96 |
+
if freed:
|
| 97 |
+
print(f"[mem_tracing] Freed {size_bytes} bytes at {loc}: {ref}")
|
| 98 |
+
if ref in self.allocated:
|
| 99 |
+
self.cur_mem -= size_bytes
|
| 100 |
+
self.deallocated[ref] = self.allocated.pop(ref)
|
| 101 |
+
self.deallocated[ref]["dealloc_loc"] = loc
|
| 102 |
+
if ref in self.deallocated:
|
| 103 |
+
# This object reference is already deallocated.
|
| 104 |
+
pass
|
| 105 |
+
else:
|
| 106 |
+
print(f"[mem_tracing] WARNING: allocation of {ref} was not traced!")
|
| 107 |
+
else:
|
| 108 |
+
print(f"[mem_tracing] Skipped freeing {size_bytes} bytes at {loc}: {ref}")
|
| 109 |
+
self.skip_dealloc[ref] = loc
|
| 110 |
+
|
| 111 |
+
def leak_report(self) -> str:
|
| 112 |
+
output = StringIO()
|
| 113 |
+
output.write("[mem_tracing] ===== Leaked objects =====\n")
|
| 114 |
+
for ref in self.allocated:
|
| 115 |
+
size_bytes = self.allocated[ref].get("size_bytes")
|
| 116 |
+
loc = self.allocated[ref].get("loc")
|
| 117 |
+
if ref in self.skip_dealloc:
|
| 118 |
+
dealloc_loc = self.skip_dealloc[ref]
|
| 119 |
+
output.write(
|
| 120 |
+
f"[mem_tracing] Leaked object, created at {loc}, size "
|
| 121 |
+
f"{size_bytes}, skipped dealloc at {dealloc_loc}: {ref}\n"
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
output.write(
|
| 125 |
+
f"[mem_tracing] Leaked object, created at {loc}, "
|
| 126 |
+
f"size {size_bytes}: {ref}\n"
|
| 127 |
+
)
|
| 128 |
+
output.write("[mem_tracing] ===== End leaked objects =====\n")
|
| 129 |
+
output.write("[mem_tracing] ===== Freed objects =====\n")
|
| 130 |
+
for ref in self.deallocated:
|
| 131 |
+
size_bytes = self.deallocated[ref].get("size_bytes")
|
| 132 |
+
loc = self.deallocated[ref].get("loc")
|
| 133 |
+
dealloc_loc = self.deallocated[ref].get("dealloc_loc")
|
| 134 |
+
output.write(
|
| 135 |
+
f"[mem_tracing] Freed object from {loc} at {dealloc_loc}, "
|
| 136 |
+
f"size {size_bytes}: {ref}\n"
|
| 137 |
+
)
|
| 138 |
+
output.write("[mem_tracing] ===== End freed objects =====\n")
|
| 139 |
+
output.write(f"[mem_tracing] Peak size bytes {self.peak_mem}\n")
|
| 140 |
+
output.write(f"[mem_tracing] Current size bytes {self.cur_mem}\n")
|
| 141 |
+
return output.getvalue()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _get_mem_actor():
|
| 145 |
+
return _MemActor.options(
|
| 146 |
+
name="mem_tracing_actor", get_if_exists=True, lifetime="detached"
|
| 147 |
+
).remote()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/null_aggregate.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import ModuleType
|
| 2 |
+
from typing import Any, Callable, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ray.data.block import AggType, Block, KeyType, T, U
|
| 7 |
+
|
| 8 |
+
WrappedAggType = Tuple[AggType, int]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# This module contains aggregation helpers for handling nulls.
|
| 12 |
+
# The null handling policy is:
|
| 13 |
+
# 1. Mix of values and nulls - ignore_nulls=True: Ignore the nulls, return
|
| 14 |
+
# aggregation of non-null values.
|
| 15 |
+
# 2. Mix of values and nulls - ignore_nulls=False: Return None.
|
| 16 |
+
# 3. All nulls: Return None.
|
| 17 |
+
# 4. Empty dataset: Return None.
|
| 18 |
+
#
|
| 19 |
+
# This is accomplished by checking rows for null values and by propagating nulls
|
| 20 |
+
# if found AND if we're not ignoring them. If not ignoring nulls, in order to delineate
|
| 21 |
+
# between found null rows and an empty block accumulation when merging (the latter of
|
| 22 |
+
# which we want to propagate; the former of which we do not), we attach a boolean flag
|
| 23 |
+
# indicating whether or not an accumulation contains valid data to intermediate block
|
| 24 |
+
# accumulations via _wrap_acc() and _unwrap_acc(). This allows us to properly merge
|
| 25 |
+
# intermediate block accumulations under a streaming constraint.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _wrap_acc(a: AggType, has_data: bool) -> WrappedAggType:
|
| 29 |
+
"""
|
| 30 |
+
Wrap accumulation with a numeric boolean flag indicating whether or not
|
| 31 |
+
this accumulation contains real data; if it doesn't, we consider it to be
|
| 32 |
+
empty.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
a: The accumulation value.
|
| 36 |
+
has_data: Whether the accumulation contains real data.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
An AggType list with the last element being a numeric boolean flag indicating
|
| 40 |
+
whether or not this accumulation contains real data. If the input a has length
|
| 41 |
+
n, the returned AggType has length n + 1.
|
| 42 |
+
"""
|
| 43 |
+
if not isinstance(a, list):
|
| 44 |
+
a = [a]
|
| 45 |
+
return a + [1 if has_data else 0]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _unwrap_acc(a: WrappedAggType) -> Tuple[AggType, bool]:
|
| 49 |
+
"""
|
| 50 |
+
Unwrap the accumulation, which we assume has been wrapped (via _wrap_acc) with a
|
| 51 |
+
numeric boolean flag indicating whether or not this accumulation contains real data.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
a: The wrapped accumulation value that we wish to unwrap.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
A tuple containing the unwrapped accumulation value and a boolean indicating
|
| 58 |
+
whether the accumulation contains real data.
|
| 59 |
+
"""
|
| 60 |
+
has_data = a[-1] == 1
|
| 61 |
+
a = a[:-1]
|
| 62 |
+
if len(a) == 1:
|
| 63 |
+
a = a[0]
|
| 64 |
+
return a, has_data
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _null_wrap_init(
|
| 68 |
+
init: Callable[[KeyType], AggType]
|
| 69 |
+
) -> Callable[[KeyType], WrappedAggType]:
|
| 70 |
+
"""
|
| 71 |
+
Wraps an accumulation initializer with null handling.
|
| 72 |
+
|
| 73 |
+
The returned initializer function adds on a has_data field that the accumulator
|
| 74 |
+
uses to track whether an aggregation is empty.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
init: The core init function to wrap.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
A new accumulation initializer function that can handle nulls.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def _init(k: KeyType) -> AggType:
|
| 84 |
+
a = init(k)
|
| 85 |
+
# Initializing accumulation, so indicate that the accumulation doesn't represent
|
| 86 |
+
# real data yet.
|
| 87 |
+
return _wrap_acc(a, has_data=False)
|
| 88 |
+
|
| 89 |
+
return _init
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _null_wrap_merge(
|
| 93 |
+
ignore_nulls: bool,
|
| 94 |
+
merge: Callable[[AggType, AggType], AggType],
|
| 95 |
+
) -> Callable[[WrappedAggType, WrappedAggType], WrappedAggType]:
|
| 96 |
+
"""
|
| 97 |
+
Wrap merge function with null handling.
|
| 98 |
+
|
| 99 |
+
The returned merge function expects a1 and a2 to be either None or of the form:
|
| 100 |
+
a = [acc_data_1, ..., acc_data_2, has_data].
|
| 101 |
+
|
| 102 |
+
This merges two accumulations subject to the following null rules:
|
| 103 |
+
1. If a1 is empty and a2 is empty, return empty accumulation.
|
| 104 |
+
2. If a1 (a2) is empty and a2 (a1) is None, return None.
|
| 105 |
+
3. If a1 (a2) is empty and a2 (a1) is non-None, return a2 (a1).
|
| 106 |
+
4. If a1 (a2) is None, return a2 (a1) if ignoring nulls, None otherwise.
|
| 107 |
+
5. If a1 and a2 are both non-null, return merge(a1, a2).
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
ignore_nulls: Whether nulls should be ignored or cause a None result.
|
| 111 |
+
merge: The core merge function to wrap.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
A new merge function that handles nulls.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def _merge(a1: WrappedAggType, a2: WrappedAggType) -> WrappedAggType:
|
| 118 |
+
if a1 is None:
|
| 119 |
+
# If we're ignoring nulls, propagate a2; otherwise, propagate None.
|
| 120 |
+
return a2 if ignore_nulls else None
|
| 121 |
+
unwrapped_a1, a1_has_data = _unwrap_acc(a1)
|
| 122 |
+
if not a1_has_data:
|
| 123 |
+
# If a1 is empty, propagate a2.
|
| 124 |
+
# No matter whether a2 is a real value, empty, or None,
|
| 125 |
+
# propagating each of these is correct if a1 is empty.
|
| 126 |
+
return a2
|
| 127 |
+
if a2 is None:
|
| 128 |
+
# If we're ignoring nulls, propagate a1; otherwise, propagate None.
|
| 129 |
+
return a1 if ignore_nulls else None
|
| 130 |
+
unwrapped_a2, a2_has_data = _unwrap_acc(a2)
|
| 131 |
+
if not a2_has_data:
|
| 132 |
+
# If a2 is empty, propagate a1.
|
| 133 |
+
return a1
|
| 134 |
+
a = merge(unwrapped_a1, unwrapped_a2)
|
| 135 |
+
return _wrap_acc(a, has_data=True)
|
| 136 |
+
|
| 137 |
+
return _merge
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _null_wrap_accumulate_row(
|
| 141 |
+
ignore_nulls: bool,
|
| 142 |
+
on_fn: Callable[[T], T],
|
| 143 |
+
accum: Callable[[AggType, T], AggType],
|
| 144 |
+
) -> Callable[[WrappedAggType, T], WrappedAggType]:
|
| 145 |
+
"""
|
| 146 |
+
Wrap accumulator function with null handling.
|
| 147 |
+
|
| 148 |
+
The returned accumulate function expects a to be either None or of the form:
|
| 149 |
+
a = [acc_data_1, ..., acc_data_n, has_data].
|
| 150 |
+
|
| 151 |
+
This performs an accumulation subject to the following null rules:
|
| 152 |
+
1. If r is null and ignore_nulls=False, return None.
|
| 153 |
+
2. If r is null and ignore_nulls=True, return a.
|
| 154 |
+
3. If r is non-null and a is None, return None.
|
| 155 |
+
4. If r is non-null and a is non-None, return accum(a[:-1], r).
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
ignore_nulls: Whether nulls should be ignored or cause a None result.
|
| 159 |
+
on_fn: Function selecting a subset of the row to apply the aggregation.
|
| 160 |
+
accum: The core accumulator function to wrap.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
A new accumulator function that handles nulls.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def _accum(a: WrappedAggType, r: T) -> WrappedAggType:
|
| 167 |
+
r = on_fn(r)
|
| 168 |
+
if _is_null(r):
|
| 169 |
+
if ignore_nulls:
|
| 170 |
+
# Ignoring nulls, return the current accumulation, ignoring r.
|
| 171 |
+
return a
|
| 172 |
+
else:
|
| 173 |
+
# Not ignoring nulls, so propagate the null.
|
| 174 |
+
return None
|
| 175 |
+
else:
|
| 176 |
+
if a is None:
|
| 177 |
+
# Accumulation is None so (1) a previous row must have been null, and
|
| 178 |
+
# (2) we must be propagating nulls, so continue to pragate this null.
|
| 179 |
+
return None
|
| 180 |
+
else:
|
| 181 |
+
# Row is non-null and accumulation is non-null, so we now apply the core
|
| 182 |
+
# accumulation.
|
| 183 |
+
a, _ = _unwrap_acc(a)
|
| 184 |
+
a = accum(a, r)
|
| 185 |
+
return _wrap_acc(a, has_data=True)
|
| 186 |
+
|
| 187 |
+
return _accum
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _null_wrap_accumulate_block(
|
| 191 |
+
ignore_nulls: bool,
|
| 192 |
+
accum_block: Callable[[Block], AggType],
|
| 193 |
+
null_merge: Callable[[WrappedAggType, WrappedAggType], WrappedAggType],
|
| 194 |
+
) -> Callable[[WrappedAggType, Block], WrappedAggType]:
|
| 195 |
+
"""
|
| 196 |
+
Wrap vectorized aggregate function with null handling.
|
| 197 |
+
|
| 198 |
+
This performs a block accumulation subject to the following null rules:
|
| 199 |
+
1. If any row is null and ignore_nulls=False, return None.
|
| 200 |
+
2. If at least one row is not null and ignore_nulls=True, return the block
|
| 201 |
+
accumulation.
|
| 202 |
+
3. If all rows are null and ignore_nulls=True, return the base accumulation.
|
| 203 |
+
4. If all rows non-null, return the block accumulation.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
ignore_nulls: Whether nulls should be ignored or cause a None result.
|
| 207 |
+
accum_block: The core vectorized aggregate function to wrap.
|
| 208 |
+
null_merge: A null-handling merge, as returned from _null_wrap_merge().
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
A new vectorized aggregate function that handles nulls.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def _accum_block_null(a: WrappedAggType, block: Block) -> WrappedAggType:
|
| 215 |
+
ret = accum_block(block)
|
| 216 |
+
if ret is not None:
|
| 217 |
+
ret = _wrap_acc(ret, has_data=True)
|
| 218 |
+
elif ignore_nulls:
|
| 219 |
+
# This can happen if we're ignoring nulls but the entire block only consists
|
| 220 |
+
# of nulls. We treat the block as if it were empty in this case.
|
| 221 |
+
ret = a
|
| 222 |
+
return null_merge(a, ret)
|
| 223 |
+
|
| 224 |
+
return _accum_block_null
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _null_wrap_finalize(
|
| 228 |
+
finalize: Callable[[AggType], AggType]
|
| 229 |
+
) -> Callable[[WrappedAggType], U]:
|
| 230 |
+
"""
|
| 231 |
+
Wrap finalizer with null handling.
|
| 232 |
+
|
| 233 |
+
If the accumulation is empty or None, the returned finalizer returns None.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
finalize: The core finalizing function to wrap.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
A new finalizing function that handles nulls.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def _finalize(a: AggType) -> U:
|
| 243 |
+
if a is None:
|
| 244 |
+
return None
|
| 245 |
+
a, has_data = _unwrap_acc(a)
|
| 246 |
+
if not has_data:
|
| 247 |
+
return None
|
| 248 |
+
return finalize(a)
|
| 249 |
+
|
| 250 |
+
return _finalize
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
LazyModule = Union[None, bool, ModuleType]
|
| 254 |
+
_pandas: LazyModule = None
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _lazy_import_pandas() -> LazyModule:
|
| 258 |
+
global _pandas
|
| 259 |
+
if _pandas is None:
|
| 260 |
+
try:
|
| 261 |
+
import pandas as _pandas
|
| 262 |
+
except ModuleNotFoundError:
|
| 263 |
+
# If module is not found, set _pandas to False so we won't
|
| 264 |
+
# keep trying to import it on every _lazy_import_pandas() call.
|
| 265 |
+
_pandas = False
|
| 266 |
+
return _pandas
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _is_null(r: Any):
|
| 270 |
+
pd = _lazy_import_pandas()
|
| 271 |
+
if pd:
|
| 272 |
+
return pd.isnull(r)
|
| 273 |
+
try:
|
| 274 |
+
return np.isnan(r)
|
| 275 |
+
except TypeError:
|
| 276 |
+
return r is None
|
.venv/lib/python3.11/site-packages/ray/data/_internal/numpy_support.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Any, Dict, List, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
|
| 9 |
+
from ray.data._internal.util import _truncated_repr
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def is_array_like(value: Any) -> bool:
|
| 15 |
+
"""Checks whether objects are array-like, excluding numpy scalars."""
|
| 16 |
+
|
| 17 |
+
return hasattr(value, "__array__") and hasattr(value, "__len__")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def is_valid_udf_return(udf_return_col: Any) -> bool:
|
| 21 |
+
"""Check whether a UDF column is valid.
|
| 22 |
+
|
| 23 |
+
Valid columns must either be a list of elements, or an array-like object.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
return isinstance(udf_return_col, list) or is_array_like(udf_return_col)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_nested_list(udf_return_col: List[Any]) -> bool:
|
| 30 |
+
for e in udf_return_col:
|
| 31 |
+
if isinstance(e, list):
|
| 32 |
+
return True
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def validate_numpy_batch(batch: Union[Dict[str, np.ndarray], Dict[str, list]]) -> None:
|
| 37 |
+
if not isinstance(batch, collections.abc.Mapping) or any(
|
| 38 |
+
not is_valid_udf_return(col) for col in batch.values()
|
| 39 |
+
):
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"Batch must be an ndarray or dictionary of ndarrays when converting "
|
| 42 |
+
f"a numpy batch to a block, got: {type(batch)} "
|
| 43 |
+
f"({_truncated_repr(batch)})"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _detect_highest_datetime_precision(datetime_list: List[datetime]) -> str:
|
| 48 |
+
"""Detect the highest precision for a list of datetime objects.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
datetime_list: List of datetime objects.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
A string representing the highest precision among the datetime objects
|
| 55 |
+
('D', 's', 'ms', 'us', 'ns').
|
| 56 |
+
"""
|
| 57 |
+
# Define precision hierarchy
|
| 58 |
+
precision_hierarchy = ["D", "s", "ms", "us", "ns"]
|
| 59 |
+
highest_precision_index = 0 # Start with the lowest precision ("D")
|
| 60 |
+
|
| 61 |
+
for dt in datetime_list:
|
| 62 |
+
# Safely get the nanosecond value using getattr for backward compatibility
|
| 63 |
+
nanosecond = getattr(dt, "nanosecond", 0)
|
| 64 |
+
if nanosecond != 0:
|
| 65 |
+
current_precision = "ns"
|
| 66 |
+
elif dt.microsecond != 0:
|
| 67 |
+
# Check if the microsecond precision is exactly millisecond
|
| 68 |
+
if dt.microsecond % 1000 == 0:
|
| 69 |
+
current_precision = "ms"
|
| 70 |
+
else:
|
| 71 |
+
current_precision = "us"
|
| 72 |
+
elif dt.second != 0 or dt.minute != 0 or dt.hour != 0:
|
| 73 |
+
# pyarrow does not support h or m, use s for those cases to
|
| 74 |
+
current_precision = "s"
|
| 75 |
+
else:
|
| 76 |
+
current_precision = "D"
|
| 77 |
+
|
| 78 |
+
# Update highest_precision_index based on the hierarchy
|
| 79 |
+
current_index = precision_hierarchy.index(current_precision)
|
| 80 |
+
highest_precision_index = max(highest_precision_index, current_index)
|
| 81 |
+
|
| 82 |
+
# Stop early if highest possible precision is reached
|
| 83 |
+
if highest_precision_index == len(precision_hierarchy) - 1:
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
return precision_hierarchy[highest_precision_index]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _convert_to_datetime64(dt: datetime, precision: str) -> np.datetime64:
|
| 90 |
+
"""
|
| 91 |
+
Converts a datetime object to a numpy datetime64 object with the specified
|
| 92 |
+
precision.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
dt: A datetime object to be converted.
|
| 96 |
+
precision: The desired precision for the datetime64 conversion. Possible
|
| 97 |
+
values are 'D', 's', 'ms', 'us', 'ns'.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
np.datetime64: A numpy datetime64 object with the specified precision.
|
| 101 |
+
"""
|
| 102 |
+
if precision == "ns":
|
| 103 |
+
# Calculate nanoseconds from microsecond and nanosecond
|
| 104 |
+
microseconds_as_ns = dt.microsecond * 1000
|
| 105 |
+
# Use getattr for backward compatibility where nanosecond attribute may not
|
| 106 |
+
# exist
|
| 107 |
+
nanoseconds = getattr(dt, "nanosecond", 0)
|
| 108 |
+
total_nanoseconds = microseconds_as_ns + nanoseconds
|
| 109 |
+
# Create datetime64 from base datetime with microsecond precision
|
| 110 |
+
base_dt = np.datetime64(dt, "us")
|
| 111 |
+
# Add remaining nanoseconds as timedelta
|
| 112 |
+
return base_dt + np.timedelta64(total_nanoseconds - microseconds_as_ns, "ns")
|
| 113 |
+
else:
|
| 114 |
+
return np.datetime64(dt).astype(f"datetime64[{precision}]")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _convert_datetime_list_to_array(datetime_list: List[datetime]) -> np.ndarray:
|
| 118 |
+
"""Convert a list of datetime objects to a NumPy array of datetime64 with proper
|
| 119 |
+
precision.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
datetime_list (List[datetime]): A list of `datetime` objects to be converted.
|
| 123 |
+
Each `datetime` object represents a specific point in time.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
np.ndarray: A NumPy array containing the `datetime64` values of the datetime
|
| 127 |
+
objects from the input list, with the appropriate precision (e.g., nanoseconds,
|
| 128 |
+
microseconds, milliseconds, etc.).
|
| 129 |
+
"""
|
| 130 |
+
# Detect the highest precision for the datetime objects
|
| 131 |
+
precision = _detect_highest_datetime_precision(datetime_list)
|
| 132 |
+
|
| 133 |
+
# Convert each datetime to the corresponding numpy datetime64 with the appropriate
|
| 134 |
+
# precision
|
| 135 |
+
return np.array([_convert_to_datetime64(dt, precision) for dt in datetime_list])
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def convert_to_numpy(column_values: Any) -> np.ndarray:
|
| 139 |
+
"""Convert UDF columns (output of map_batches) to numpy, if possible.
|
| 140 |
+
|
| 141 |
+
This includes lists of scalars, objects supporting the array protocol, and lists
|
| 142 |
+
of objects supporting the array protocol, such as `[1, 2, 3]`, `Tensor([1, 2, 3])`,
|
| 143 |
+
and `[array(1), array(2), array(3)]`.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
The input as an np.ndarray if possible, otherwise the original input.
|
| 147 |
+
|
| 148 |
+
Raises:
|
| 149 |
+
ValueError if an input was array-like but we failed to convert it to an array.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
if isinstance(column_values, np.ndarray):
|
| 153 |
+
# No copy/conversion needed, just keep it verbatim.
|
| 154 |
+
return column_values
|
| 155 |
+
|
| 156 |
+
elif isinstance(column_values, list):
|
| 157 |
+
if len(column_values) == 1 and isinstance(column_values[0], np.ndarray):
|
| 158 |
+
# Optimization to avoid conversion overhead from list to np.array.
|
| 159 |
+
return np.expand_dims(column_values[0], axis=0)
|
| 160 |
+
|
| 161 |
+
if all(isinstance(elem, datetime) for elem in column_values):
|
| 162 |
+
return _convert_datetime_list_to_array(column_values)
|
| 163 |
+
|
| 164 |
+
# Try to convert list values into an numpy array via
|
| 165 |
+
# np.array(), so users don't need to manually cast.
|
| 166 |
+
# NOTE: we don't cast generic iterables, since types like
|
| 167 |
+
# `str` are also Iterable.
|
| 168 |
+
try:
|
| 169 |
+
# Convert array-like objects (like torch.Tensor) to `np.ndarray`s
|
| 170 |
+
if all(is_array_like(e) for e in column_values):
|
| 171 |
+
# Use np.asarray() instead of np.array() to avoid copying if possible.
|
| 172 |
+
column_values = [np.asarray(e) for e in column_values]
|
| 173 |
+
|
| 174 |
+
shapes = set()
|
| 175 |
+
has_object = False
|
| 176 |
+
for e in column_values:
|
| 177 |
+
if isinstance(e, np.ndarray):
|
| 178 |
+
shapes.add((e.dtype, e.shape))
|
| 179 |
+
elif isinstance(e, bytes):
|
| 180 |
+
# Don't convert variable length binary data to Numpy arrays as it
|
| 181 |
+
# treats zero encoding as termination by default.
|
| 182 |
+
# Per recommendation from
|
| 183 |
+
# https://github.com/apache/arrow/issues/26470,
|
| 184 |
+
# we use object dtype.
|
| 185 |
+
# https://github.com/ray-project/ray/issues/35586#issuecomment-1558148261
|
| 186 |
+
has_object = True
|
| 187 |
+
elif not np.isscalar(e):
|
| 188 |
+
has_object = True
|
| 189 |
+
|
| 190 |
+
# When column values are
|
| 191 |
+
# - Arrays of heterogeneous shapes
|
| 192 |
+
# - Byte-strings (viewed as arrays of heterogeneous shapes)
|
| 193 |
+
# - Non-scalar objects (tuples, lists, arbitrary object types)
|
| 194 |
+
#
|
| 195 |
+
# Custom "ragged ndarray" is created, represented as an array of
|
| 196 |
+
# references (ie ndarray with dtype=object)
|
| 197 |
+
if has_object or len(shapes) > 1:
|
| 198 |
+
# This util works around some limitations of np.array(dtype=object).
|
| 199 |
+
return create_ragged_ndarray(column_values)
|
| 200 |
+
else:
|
| 201 |
+
return np.array(column_values)
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error(
|
| 205 |
+
f"Failed to convert column values to numpy array: "
|
| 206 |
+
f"{_truncated_repr(column_values)}",
|
| 207 |
+
exc_info=e,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"Failed to convert column values to numpy array: "
|
| 212 |
+
f"({_truncated_repr(column_values)}): {e}."
|
| 213 |
+
) from e
|
| 214 |
+
|
| 215 |
+
elif is_array_like(column_values):
|
| 216 |
+
# Converts other array-like objects such as torch.Tensor.
|
| 217 |
+
try:
|
| 218 |
+
# Use np.asarray() instead of np.array() to avoid copying if possible.
|
| 219 |
+
return np.asarray(column_values)
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.error(
|
| 222 |
+
f"Failed to convert column values to numpy array: "
|
| 223 |
+
f"{_truncated_repr(column_values)}",
|
| 224 |
+
exc_info=e,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"Failed to convert column values to numpy array: "
|
| 229 |
+
f"({_truncated_repr(column_values)}): {e}."
|
| 230 |
+
) from e
|
| 231 |
+
|
| 232 |
+
else:
|
| 233 |
+
return column_values
|
.venv/lib/python3.11/site-packages/ray/data/_internal/output_buffer.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 4 |
+
from ray.data.block import Block, BlockAccessor, DataBatch
|
| 5 |
+
from ray.data.context import MAX_SAFE_BLOCK_SIZE_FACTOR
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BlockOutputBuffer:
|
| 9 |
+
"""Generates output blocks of a given size given a stream of inputs.
|
| 10 |
+
|
| 11 |
+
This class is used to turn a stream of items / blocks of arbitrary size
|
| 12 |
+
into a stream of blocks of ``target_max_block_size``. The caller should
|
| 13 |
+
check ``has_next()`` after each ``add()`` call, and call ``next()`` to get
|
| 14 |
+
the next block when ``has_next()`` returns True.
|
| 15 |
+
|
| 16 |
+
When all items have been added, the caller must call ``finalize()`` and
|
| 17 |
+
then check ``has_next()`` one last time.
|
| 18 |
+
|
| 19 |
+
Examples:
|
| 20 |
+
>>> from ray.data._internal.output_buffer import BlockOutputBuffer
|
| 21 |
+
>>> udf = ... # doctest: +SKIP
|
| 22 |
+
>>> generator = ... # doctest: +SKIP
|
| 23 |
+
>>> # Yield a stream of output blocks.
|
| 24 |
+
>>> output = BlockOutputBuffer(udf, 500 * 1024 * 1024) # doctest: +SKIP
|
| 25 |
+
>>> for item in generator(): # doctest: +SKIP
|
| 26 |
+
... output.add(item) # doctest: +SKIP
|
| 27 |
+
... if output.has_next(): # doctest: +SKIP
|
| 28 |
+
... yield output.next() # doctest: +SKIP
|
| 29 |
+
>>> output.finalize() # doctest: +SKIP
|
| 30 |
+
>>> if output.has_next() # doctest: +SKIP
|
| 31 |
+
... yield output.next() # doctest: +SKIP
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, target_max_block_size: int):
|
| 35 |
+
self._target_max_block_size = target_max_block_size
|
| 36 |
+
self._buffer = DelegatingBlockBuilder()
|
| 37 |
+
self._returned_at_least_one_block = False
|
| 38 |
+
self._finalized = False
|
| 39 |
+
|
| 40 |
+
def add(self, item: Any) -> None:
|
| 41 |
+
"""Add a single item to this output buffer."""
|
| 42 |
+
assert not self._finalized
|
| 43 |
+
self._buffer.add(item)
|
| 44 |
+
|
| 45 |
+
def add_batch(self, batch: DataBatch) -> None:
|
| 46 |
+
"""Add a data batch to this output buffer."""
|
| 47 |
+
assert not self._finalized
|
| 48 |
+
self._buffer.add_batch(batch)
|
| 49 |
+
|
| 50 |
+
def add_block(self, block: Block) -> None:
|
| 51 |
+
"""Add a data block to this output buffer."""
|
| 52 |
+
assert not self._finalized
|
| 53 |
+
self._buffer.add_block(block)
|
| 54 |
+
|
| 55 |
+
def finalize(self) -> None:
|
| 56 |
+
"""Must be called once all items have been added."""
|
| 57 |
+
assert not self._finalized
|
| 58 |
+
self._finalized = True
|
| 59 |
+
|
| 60 |
+
def has_next(self) -> bool:
|
| 61 |
+
"""Returns true when a complete output block is produced."""
|
| 62 |
+
if self._finalized:
|
| 63 |
+
return not self._returned_at_least_one_block or self._buffer.num_rows() > 0
|
| 64 |
+
else:
|
| 65 |
+
return (
|
| 66 |
+
self._buffer.get_estimated_memory_usage() > self._target_max_block_size
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def next(self) -> Block:
|
| 70 |
+
"""Returns the next complete output block."""
|
| 71 |
+
assert self.has_next()
|
| 72 |
+
|
| 73 |
+
block_to_yield = self._buffer.build()
|
| 74 |
+
block_remainder = None
|
| 75 |
+
block = BlockAccessor.for_block(block_to_yield)
|
| 76 |
+
if (
|
| 77 |
+
block.size_bytes()
|
| 78 |
+
>= MAX_SAFE_BLOCK_SIZE_FACTOR * self._target_max_block_size
|
| 79 |
+
):
|
| 80 |
+
# Slice a block to respect the target max block size. We only do
|
| 81 |
+
# this if we are more than 50% above the target block size, because
|
| 82 |
+
# this ensures that the last block produced will be at least half
|
| 83 |
+
# the block size.
|
| 84 |
+
num_bytes_per_row = block.size_bytes() // block.num_rows()
|
| 85 |
+
target_num_rows = max(1, self._target_max_block_size // num_bytes_per_row)
|
| 86 |
+
|
| 87 |
+
if target_num_rows < block.num_rows():
|
| 88 |
+
# NOTE: We're maintaining following protocol of slicing underlying block
|
| 89 |
+
# into appropriately sized ones:
|
| 90 |
+
#
|
| 91 |
+
# - (Finalized) Target blocks sliced from the original one
|
| 92 |
+
# and are *copied* to avoid referencing original blocks
|
| 93 |
+
# - Temporary remainder of the block should *NOT* be copied
|
| 94 |
+
# such as to avoid repeatedly copying the remainder bytes
|
| 95 |
+
# of the block, resulting in O(M * N) total bytes being
|
| 96 |
+
# copied, where N is the total number of bytes in the original
|
| 97 |
+
# block and M is the number of blocks that will be produced by
|
| 98 |
+
# this iterator
|
| 99 |
+
block_to_yield = block.slice(0, target_num_rows, copy=True)
|
| 100 |
+
block_remainder = block.slice(
|
| 101 |
+
target_num_rows, block.num_rows(), copy=False
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self._buffer = DelegatingBlockBuilder()
|
| 105 |
+
if block_remainder is not None:
|
| 106 |
+
self._buffer.add_block(block_remainder)
|
| 107 |
+
|
| 108 |
+
self._returned_at_least_one_block = True
|
| 109 |
+
return block_to_yield
|
.venv/lib/python3.11/site-packages/ray/data/_internal/pandas_block.py
ADDED
|
@@ -0,0 +1,728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import heapq
|
| 3 |
+
import logging
|
| 4 |
+
import sys
|
| 5 |
+
from typing import (
|
| 6 |
+
TYPE_CHECKING,
|
| 7 |
+
Any,
|
| 8 |
+
Callable,
|
| 9 |
+
Dict,
|
| 10 |
+
Iterator,
|
| 11 |
+
List,
|
| 12 |
+
Optional,
|
| 13 |
+
Sequence,
|
| 14 |
+
Tuple,
|
| 15 |
+
TypeVar,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ray.air.constants import TENSOR_COLUMN_NAME
|
| 22 |
+
from ray.air.util.tensor_extensions.utils import _is_ndarray_tensor
|
| 23 |
+
from ray.data._internal.numpy_support import convert_to_numpy, validate_numpy_batch
|
| 24 |
+
from ray.data._internal.row import TableRow
|
| 25 |
+
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
|
| 26 |
+
from ray.data._internal.util import find_partitions, keys_equal
|
| 27 |
+
from ray.data.block import (
|
| 28 |
+
Block,
|
| 29 |
+
BlockAccessor,
|
| 30 |
+
BlockExecStats,
|
| 31 |
+
BlockMetadata,
|
| 32 |
+
BlockType,
|
| 33 |
+
KeyType,
|
| 34 |
+
U,
|
| 35 |
+
)
|
| 36 |
+
from ray.data.context import DataContext
|
| 37 |
+
|
| 38 |
+
if TYPE_CHECKING:
|
| 39 |
+
import pandas
|
| 40 |
+
import pyarrow
|
| 41 |
+
|
| 42 |
+
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
|
| 43 |
+
from ray.data.aggregate import AggregateFn
|
| 44 |
+
|
| 45 |
+
T = TypeVar("T")
|
| 46 |
+
# Max number of samples used to estimate the Pandas block size.
|
| 47 |
+
_PANDAS_SIZE_BYTES_MAX_SAMPLE_COUNT = 50
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
_pandas = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def lazy_import_pandas():
|
| 55 |
+
global _pandas
|
| 56 |
+
if _pandas is None:
|
| 57 |
+
import pandas
|
| 58 |
+
|
| 59 |
+
_pandas = pandas
|
| 60 |
+
return _pandas
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PandasRow(TableRow):
|
| 64 |
+
"""
|
| 65 |
+
Row of a tabular Dataset backed by a Pandas DataFrame block.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, key: Union[str, List[str]]) -> Any:
|
| 69 |
+
from ray.data.extensions import TensorArrayElement
|
| 70 |
+
|
| 71 |
+
pd = lazy_import_pandas()
|
| 72 |
+
|
| 73 |
+
def get_item(keys: List[str]) -> Any:
|
| 74 |
+
col = self._row[keys]
|
| 75 |
+
if len(col) == 0:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
items = col.iloc[0]
|
| 79 |
+
if isinstance(items.iloc[0], TensorArrayElement):
|
| 80 |
+
# Getting an item in a Pandas tensor column may return
|
| 81 |
+
# a TensorArrayElement, which we have to convert to an ndarray.
|
| 82 |
+
return pd.Series(item.to_numpy() for item in items)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Try to interpret this as a numpy-type value.
|
| 86 |
+
# See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
|
| 87 |
+
return pd.Series(item.as_py() for item in items)
|
| 88 |
+
|
| 89 |
+
except (AttributeError, ValueError):
|
| 90 |
+
# Fallback to the original form.
|
| 91 |
+
return items
|
| 92 |
+
|
| 93 |
+
is_single_item = isinstance(key, str)
|
| 94 |
+
keys = [key] if is_single_item else key
|
| 95 |
+
|
| 96 |
+
items = get_item(keys)
|
| 97 |
+
|
| 98 |
+
if items is None:
|
| 99 |
+
return None
|
| 100 |
+
elif is_single_item:
|
| 101 |
+
return items.iloc[0]
|
| 102 |
+
else:
|
| 103 |
+
return items
|
| 104 |
+
|
| 105 |
+
def __iter__(self) -> Iterator:
|
| 106 |
+
for k in self._row.columns:
|
| 107 |
+
yield k
|
| 108 |
+
|
| 109 |
+
def __len__(self):
|
| 110 |
+
return self._row.shape[1]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class PandasBlockBuilder(TableBlockBuilder):
|
| 114 |
+
def __init__(self):
|
| 115 |
+
pandas = lazy_import_pandas()
|
| 116 |
+
super().__init__(pandas.DataFrame)
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def _table_from_pydict(columns: Dict[str, List[Any]]) -> "pandas.DataFrame":
|
| 120 |
+
pandas = lazy_import_pandas()
|
| 121 |
+
|
| 122 |
+
pd_columns: Dict[str, Any] = {}
|
| 123 |
+
|
| 124 |
+
for col_name, col_vals in columns.items():
|
| 125 |
+
np_col_vals = convert_to_numpy(col_vals)
|
| 126 |
+
|
| 127 |
+
if col_name == TENSOR_COLUMN_NAME or _is_ndarray_tensor(np_col_vals):
|
| 128 |
+
from ray.data.extensions.tensor_extension import TensorArray
|
| 129 |
+
|
| 130 |
+
pd_columns[col_name] = TensorArray(np_col_vals)
|
| 131 |
+
else:
|
| 132 |
+
pd_columns[col_name] = np_col_vals
|
| 133 |
+
|
| 134 |
+
return pandas.DataFrame(pd_columns)
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def _concat_tables(tables: List["pandas.DataFrame"]) -> "pandas.DataFrame":
|
| 138 |
+
pandas = lazy_import_pandas()
|
| 139 |
+
from ray.air.util.data_batch_conversion import (
|
| 140 |
+
_cast_ndarray_columns_to_tensor_extension,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if len(tables) > 1:
|
| 144 |
+
df = pandas.concat(tables, ignore_index=True)
|
| 145 |
+
df.reset_index(drop=True, inplace=True)
|
| 146 |
+
else:
|
| 147 |
+
df = tables[0]
|
| 148 |
+
ctx = DataContext.get_current()
|
| 149 |
+
if ctx.enable_tensor_extension_casting:
|
| 150 |
+
df = _cast_ndarray_columns_to_tensor_extension(df)
|
| 151 |
+
return df
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def _concat_would_copy() -> bool:
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def _empty_table() -> "pandas.DataFrame":
|
| 159 |
+
pandas = lazy_import_pandas()
|
| 160 |
+
return pandas.DataFrame()
|
| 161 |
+
|
| 162 |
+
def block_type(self) -> BlockType:
|
| 163 |
+
return BlockType.PANDAS
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# This is to be compatible with pyarrow.lib.schema
|
| 167 |
+
# TODO (kfstorm): We need a format-independent way to represent schema.
|
| 168 |
+
PandasBlockSchema = collections.namedtuple("PandasBlockSchema", ["names", "types"])
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class PandasBlockAccessor(TableBlockAccessor):
|
| 172 |
+
ROW_TYPE = PandasRow
|
| 173 |
+
|
| 174 |
+
def __init__(self, table: "pandas.DataFrame"):
|
| 175 |
+
super().__init__(table)
|
| 176 |
+
|
| 177 |
+
def column_names(self) -> List[str]:
|
| 178 |
+
return self._table.columns.tolist()
|
| 179 |
+
|
| 180 |
+
def append_column(self, name: str, data: Any) -> Block:
|
| 181 |
+
assert name not in self._table.columns
|
| 182 |
+
|
| 183 |
+
if any(isinstance(item, np.ndarray) for item in data):
|
| 184 |
+
raise NotImplementedError(
|
| 185 |
+
f"`{self.__class__.__name__}.append_column()` doesn't support "
|
| 186 |
+
"array-like data."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
table = self._table.copy()
|
| 190 |
+
table[name] = data
|
| 191 |
+
return table
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def _build_tensor_row(row: PandasRow) -> np.ndarray:
|
| 195 |
+
from ray.data.extensions import TensorArrayElement
|
| 196 |
+
|
| 197 |
+
tensor = row[TENSOR_COLUMN_NAME].iloc[0]
|
| 198 |
+
if isinstance(tensor, TensorArrayElement):
|
| 199 |
+
# Getting an item in a Pandas tensor column may return a TensorArrayElement,
|
| 200 |
+
# which we have to convert to an ndarray.
|
| 201 |
+
tensor = tensor.to_numpy()
|
| 202 |
+
return tensor
|
| 203 |
+
|
| 204 |
+
def slice(self, start: int, end: int, copy: bool = False) -> "pandas.DataFrame":
|
| 205 |
+
view = self._table[start:end]
|
| 206 |
+
view.reset_index(drop=True, inplace=True)
|
| 207 |
+
if copy:
|
| 208 |
+
view = view.copy(deep=True)
|
| 209 |
+
return view
|
| 210 |
+
|
| 211 |
+
def take(self, indices: List[int]) -> "pandas.DataFrame":
|
| 212 |
+
table = self._table.take(indices)
|
| 213 |
+
table.reset_index(drop=True, inplace=True)
|
| 214 |
+
return table
|
| 215 |
+
|
| 216 |
+
def select(self, columns: List[str]) -> "pandas.DataFrame":
|
| 217 |
+
if not all(isinstance(col, str) for col in columns):
|
| 218 |
+
raise ValueError(
|
| 219 |
+
"Columns must be a list of column name strings when aggregating on "
|
| 220 |
+
f"Pandas blocks, but got: {columns}."
|
| 221 |
+
)
|
| 222 |
+
return self._table[columns]
|
| 223 |
+
|
| 224 |
+
def rename_columns(self, columns_rename: Dict[str, str]) -> "pandas.DataFrame":
|
| 225 |
+
return self._table.rename(columns=columns_rename, inplace=False, copy=False)
|
| 226 |
+
|
| 227 |
+
def random_shuffle(self, random_seed: Optional[int]) -> "pandas.DataFrame":
|
| 228 |
+
table = self._table.sample(frac=1, random_state=random_seed)
|
| 229 |
+
table.reset_index(drop=True, inplace=True)
|
| 230 |
+
return table
|
| 231 |
+
|
| 232 |
+
def schema(self) -> PandasBlockSchema:
|
| 233 |
+
dtypes = self._table.dtypes
|
| 234 |
+
schema = PandasBlockSchema(
|
| 235 |
+
names=dtypes.index.tolist(), types=dtypes.values.tolist()
|
| 236 |
+
)
|
| 237 |
+
# Column names with non-str types of a pandas DataFrame is not
|
| 238 |
+
# supported by Ray Dataset.
|
| 239 |
+
if any(not isinstance(name, str) for name in schema.names):
|
| 240 |
+
raise ValueError(
|
| 241 |
+
"A Pandas DataFrame with column names of non-str types"
|
| 242 |
+
" is not supported by Ray Dataset. Column names of this"
|
| 243 |
+
f" DataFrame: {schema.names!r}."
|
| 244 |
+
)
|
| 245 |
+
return schema
|
| 246 |
+
|
| 247 |
+
def to_pandas(self) -> "pandas.DataFrame":
|
| 248 |
+
from ray.air.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays
|
| 249 |
+
|
| 250 |
+
ctx = DataContext.get_current()
|
| 251 |
+
table = self._table
|
| 252 |
+
if ctx.enable_tensor_extension_casting:
|
| 253 |
+
table = _cast_tensor_columns_to_ndarrays(table)
|
| 254 |
+
return table
|
| 255 |
+
|
| 256 |
+
def to_numpy(
|
| 257 |
+
self, columns: Optional[Union[str, List[str]]] = None
|
| 258 |
+
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
| 259 |
+
if columns is None:
|
| 260 |
+
columns = self._table.columns.tolist()
|
| 261 |
+
should_be_single_ndarray = False
|
| 262 |
+
elif isinstance(columns, list):
|
| 263 |
+
should_be_single_ndarray = False
|
| 264 |
+
else:
|
| 265 |
+
columns = [columns]
|
| 266 |
+
should_be_single_ndarray = True
|
| 267 |
+
|
| 268 |
+
column_names_set = set(self._table.columns)
|
| 269 |
+
for column in columns:
|
| 270 |
+
if column not in column_names_set:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
f"Cannot find column {column}, available columns: "
|
| 273 |
+
f"{self._table.columns.tolist()}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
arrays = []
|
| 277 |
+
for column in columns:
|
| 278 |
+
arrays.append(self._table[column].to_numpy())
|
| 279 |
+
|
| 280 |
+
if should_be_single_ndarray:
|
| 281 |
+
arrays = arrays[0]
|
| 282 |
+
else:
|
| 283 |
+
arrays = dict(zip(columns, arrays))
|
| 284 |
+
return arrays
|
| 285 |
+
|
| 286 |
+
def to_arrow(self) -> "pyarrow.Table":
|
| 287 |
+
import pyarrow
|
| 288 |
+
|
| 289 |
+
# Set `preserve_index=False` so that Arrow doesn't add a '__index_level_0__'
|
| 290 |
+
# column to the resulting table.
|
| 291 |
+
return pyarrow.Table.from_pandas(self._table, preserve_index=False)
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def numpy_to_block(
|
| 295 |
+
batch: Union[Dict[str, np.ndarray], Dict[str, list]],
|
| 296 |
+
) -> "pandas.DataFrame":
|
| 297 |
+
validate_numpy_batch(batch)
|
| 298 |
+
|
| 299 |
+
block = PandasBlockBuilder._table_from_pydict(batch)
|
| 300 |
+
return block
|
| 301 |
+
|
| 302 |
+
def num_rows(self) -> int:
|
| 303 |
+
return self._table.shape[0]
|
| 304 |
+
|
| 305 |
+
def size_bytes(self) -> int:
|
| 306 |
+
from pandas.api.types import is_object_dtype
|
| 307 |
+
|
| 308 |
+
from ray.air.util.tensor_extensions.pandas import TensorArray
|
| 309 |
+
from ray.data.extensions import TensorArrayElement, TensorDtype
|
| 310 |
+
|
| 311 |
+
pd = lazy_import_pandas()
|
| 312 |
+
|
| 313 |
+
def get_deep_size(obj):
|
| 314 |
+
"""Calculates the memory size of objects,
|
| 315 |
+
including nested objects using an iterative approach."""
|
| 316 |
+
seen = set()
|
| 317 |
+
total_size = 0
|
| 318 |
+
objects = collections.deque([obj])
|
| 319 |
+
while objects:
|
| 320 |
+
current = objects.pop()
|
| 321 |
+
|
| 322 |
+
# Skip interning-eligible immutable objects
|
| 323 |
+
if isinstance(current, (str, bytes, int, float)):
|
| 324 |
+
size = sys.getsizeof(current)
|
| 325 |
+
total_size += size
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
# Check if the object has been seen before
|
| 329 |
+
# i.e. a = np.ndarray([1,2,3]), b = [a,a]
|
| 330 |
+
# The patten above will have only one memory copy
|
| 331 |
+
if id(current) in seen:
|
| 332 |
+
continue
|
| 333 |
+
seen.add(id(current))
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
size = sys.getsizeof(current)
|
| 337 |
+
except TypeError:
|
| 338 |
+
size = 0
|
| 339 |
+
total_size += size
|
| 340 |
+
|
| 341 |
+
# Handle specific cases
|
| 342 |
+
if isinstance(current, np.ndarray):
|
| 343 |
+
total_size += current.nbytes - size # Avoid double counting
|
| 344 |
+
elif isinstance(current, pd.DataFrame):
|
| 345 |
+
total_size += (
|
| 346 |
+
current.memory_usage(index=True, deep=True).sum() - size
|
| 347 |
+
)
|
| 348 |
+
elif isinstance(current, (list, tuple, set)):
|
| 349 |
+
objects.extend(current)
|
| 350 |
+
elif isinstance(current, dict):
|
| 351 |
+
objects.extend(current.keys())
|
| 352 |
+
objects.extend(current.values())
|
| 353 |
+
elif isinstance(current, TensorArrayElement):
|
| 354 |
+
objects.extend(current.to_numpy())
|
| 355 |
+
return total_size
|
| 356 |
+
|
| 357 |
+
# Get initial memory usage including deep introspection
|
| 358 |
+
memory_usage = self._table.memory_usage(index=True, deep=True)
|
| 359 |
+
|
| 360 |
+
# TensorDtype for ray.air.util.tensor_extensions.pandas.TensorDtype
|
| 361 |
+
object_need_check = (TensorDtype,)
|
| 362 |
+
max_sample_count = _PANDAS_SIZE_BYTES_MAX_SAMPLE_COUNT
|
| 363 |
+
|
| 364 |
+
# Handle object columns separately
|
| 365 |
+
for column in self._table.columns:
|
| 366 |
+
# Check pandas object dtype and the extension dtype
|
| 367 |
+
if is_object_dtype(self._table[column].dtype) or isinstance(
|
| 368 |
+
self._table[column].dtype, object_need_check
|
| 369 |
+
):
|
| 370 |
+
total_size = len(self._table[column])
|
| 371 |
+
|
| 372 |
+
# Determine the sample size based on max_sample_count
|
| 373 |
+
sample_size = min(total_size, max_sample_count)
|
| 374 |
+
# Following codes can also handel case that sample_size == total_size
|
| 375 |
+
sampled_data = self._table[column].sample(n=sample_size).values
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
if isinstance(sampled_data, TensorArray) and np.issubdtype(
|
| 379 |
+
sampled_data[0].numpy_dtype, np.number
|
| 380 |
+
):
|
| 381 |
+
column_memory_sample = sampled_data.nbytes
|
| 382 |
+
else:
|
| 383 |
+
vectorized_size_calc = np.vectorize(lambda x: get_deep_size(x))
|
| 384 |
+
column_memory_sample = np.sum(
|
| 385 |
+
vectorized_size_calc(sampled_data)
|
| 386 |
+
)
|
| 387 |
+
# Scale back to the full column size if we sampled
|
| 388 |
+
column_memory = column_memory_sample * (total_size / sample_size)
|
| 389 |
+
memory_usage[column] = int(column_memory)
|
| 390 |
+
except Exception as e:
|
| 391 |
+
# Handle or log the exception as needed
|
| 392 |
+
logger.warning(f"Error calculating size for column '{column}': {e}")
|
| 393 |
+
|
| 394 |
+
# Sum up total memory usage
|
| 395 |
+
total_memory_usage = memory_usage.sum()
|
| 396 |
+
|
| 397 |
+
return int(total_memory_usage)
|
| 398 |
+
|
| 399 |
+
def _zip(self, acc: BlockAccessor) -> "pandas.DataFrame":
|
| 400 |
+
r = self.to_pandas().copy(deep=False)
|
| 401 |
+
s = acc.to_pandas()
|
| 402 |
+
for col_name in s.columns:
|
| 403 |
+
col = s[col_name]
|
| 404 |
+
column_names = list(r.columns)
|
| 405 |
+
# Ensure the column names are unique after zip.
|
| 406 |
+
if col_name in column_names:
|
| 407 |
+
i = 1
|
| 408 |
+
new_name = col_name
|
| 409 |
+
while new_name in column_names:
|
| 410 |
+
new_name = "{}_{}".format(col_name, i)
|
| 411 |
+
i += 1
|
| 412 |
+
col_name = new_name
|
| 413 |
+
r[col_name] = col
|
| 414 |
+
return r
|
| 415 |
+
|
| 416 |
+
@staticmethod
|
| 417 |
+
def builder() -> PandasBlockBuilder:
|
| 418 |
+
return PandasBlockBuilder()
|
| 419 |
+
|
| 420 |
+
@staticmethod
|
| 421 |
+
def _empty_table() -> "pandas.DataFrame":
|
| 422 |
+
return PandasBlockBuilder._empty_table()
|
| 423 |
+
|
| 424 |
+
def _sample(self, n_samples: int, sort_key: "SortKey") -> "pandas.DataFrame":
|
| 425 |
+
return self._table[sort_key.get_columns()].sample(n_samples, ignore_index=True)
|
| 426 |
+
|
| 427 |
+
def _apply_agg(
|
| 428 |
+
self, agg_fn: Callable[["pandas.Series", bool], U], on: str
|
| 429 |
+
) -> Optional[U]:
|
| 430 |
+
"""Helper providing null handling around applying an aggregation to a column."""
|
| 431 |
+
pd = lazy_import_pandas()
|
| 432 |
+
if on is not None and not isinstance(on, str):
|
| 433 |
+
raise ValueError(
|
| 434 |
+
"on must be a string or None when aggregating on Pandas blocks, but "
|
| 435 |
+
f"got: {type(on)}."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if self.num_rows() == 0:
|
| 439 |
+
return None
|
| 440 |
+
|
| 441 |
+
col = self._table[on]
|
| 442 |
+
try:
|
| 443 |
+
val = agg_fn(col)
|
| 444 |
+
except TypeError as e:
|
| 445 |
+
# Converting an all-null column in an Arrow Table to a Pandas DataFrame
|
| 446 |
+
# column will result in an all-None column of object type, which will raise
|
| 447 |
+
# a type error when attempting to do most binary operations. We explicitly
|
| 448 |
+
# check for this type failure here so we can properly propagate a null.
|
| 449 |
+
if np.issubdtype(col.dtype, np.object_) and col.isnull().all():
|
| 450 |
+
return None
|
| 451 |
+
raise e from None
|
| 452 |
+
if pd.isnull(val):
|
| 453 |
+
return None
|
| 454 |
+
return val
|
| 455 |
+
|
| 456 |
+
def count(self, on: str) -> Optional[U]:
|
| 457 |
+
return self._apply_agg(lambda col: col.count(), on)
|
| 458 |
+
|
| 459 |
+
def sum(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 460 |
+
pd = lazy_import_pandas()
|
| 461 |
+
if on is not None and not isinstance(on, str):
|
| 462 |
+
raise ValueError(
|
| 463 |
+
"on must be a string or None when aggregating on Pandas blocks, but "
|
| 464 |
+
f"got: {type(on)}."
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
if self.num_rows() == 0:
|
| 468 |
+
return None
|
| 469 |
+
|
| 470 |
+
col = self._table[on]
|
| 471 |
+
if col.isnull().all():
|
| 472 |
+
# Short-circuit on an all-null column, returning None. This is required for
|
| 473 |
+
# sum() since it will otherwise return 0 when summing on an all-null column,
|
| 474 |
+
# which is not what we want.
|
| 475 |
+
return None
|
| 476 |
+
val = col.sum(skipna=ignore_nulls)
|
| 477 |
+
if pd.isnull(val):
|
| 478 |
+
return None
|
| 479 |
+
return val
|
| 480 |
+
|
| 481 |
+
def min(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 482 |
+
return self._apply_agg(lambda col: col.min(skipna=ignore_nulls), on)
|
| 483 |
+
|
| 484 |
+
def max(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 485 |
+
return self._apply_agg(lambda col: col.max(skipna=ignore_nulls), on)
|
| 486 |
+
|
| 487 |
+
def mean(self, on: str, ignore_nulls: bool) -> Optional[U]:
|
| 488 |
+
return self._apply_agg(lambda col: col.mean(skipna=ignore_nulls), on)
|
| 489 |
+
|
| 490 |
+
def sum_of_squared_diffs_from_mean(
|
| 491 |
+
self,
|
| 492 |
+
on: str,
|
| 493 |
+
ignore_nulls: bool,
|
| 494 |
+
mean: Optional[U] = None,
|
| 495 |
+
) -> Optional[U]:
|
| 496 |
+
if mean is None:
|
| 497 |
+
mean = self.mean(on, ignore_nulls)
|
| 498 |
+
return self._apply_agg(
|
| 499 |
+
lambda col: ((col - mean) ** 2).sum(skipna=ignore_nulls),
|
| 500 |
+
on,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def sort_and_partition(
|
| 504 |
+
self, boundaries: List[T], sort_key: "SortKey"
|
| 505 |
+
) -> List[Block]:
|
| 506 |
+
if self._table.shape[0] == 0:
|
| 507 |
+
# If the pyarrow table is empty we may not have schema
|
| 508 |
+
# so calling sort_indices() will raise an error.
|
| 509 |
+
return [self._empty_table() for _ in range(len(boundaries) + 1)]
|
| 510 |
+
|
| 511 |
+
columns, ascending = sort_key.to_pandas_sort_args()
|
| 512 |
+
table = self._table.sort_values(by=columns, ascending=ascending)
|
| 513 |
+
if len(boundaries) == 0:
|
| 514 |
+
return [table]
|
| 515 |
+
|
| 516 |
+
return find_partitions(table, boundaries, sort_key)
|
| 517 |
+
|
| 518 |
+
# TODO (srinathk) Needs to handle None types correctly.
|
| 519 |
+
def combine(
|
| 520 |
+
self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]
|
| 521 |
+
) -> "pandas.DataFrame":
|
| 522 |
+
"""Combine rows with the same key into an accumulator.
|
| 523 |
+
|
| 524 |
+
This assumes the block is already sorted by key in ascending order.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
sort_key: A SortKey object which holds column names/keys.
|
| 528 |
+
If this is ``None``, place all rows in a single group.
|
| 529 |
+
|
| 530 |
+
aggs: The aggregations to do.
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
A sorted block of [k, v_1, ..., v_n] columns where k is the groupby
|
| 534 |
+
key and v_i is the partially combined accumulator for the ith given
|
| 535 |
+
aggregation.
|
| 536 |
+
If key is None then the k column is omitted.
|
| 537 |
+
"""
|
| 538 |
+
keys: List[str] = sort_key.get_columns()
|
| 539 |
+
pd = lazy_import_pandas()
|
| 540 |
+
|
| 541 |
+
def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
|
| 542 |
+
"""Creates an iterator over zero-copy group views."""
|
| 543 |
+
if not keys:
|
| 544 |
+
# Global aggregation consists of a single "group", so we short-circuit.
|
| 545 |
+
yield tuple(), self.to_block()
|
| 546 |
+
return
|
| 547 |
+
|
| 548 |
+
start = end = 0
|
| 549 |
+
iter = self.iter_rows(public_row_format=False)
|
| 550 |
+
next_row = None
|
| 551 |
+
while True:
|
| 552 |
+
try:
|
| 553 |
+
if next_row is None:
|
| 554 |
+
next_row = next(iter)
|
| 555 |
+
next_keys = next_row[keys]
|
| 556 |
+
while keys_equal(next_row[keys], next_keys):
|
| 557 |
+
end += 1
|
| 558 |
+
try:
|
| 559 |
+
next_row = next(iter)
|
| 560 |
+
except StopIteration:
|
| 561 |
+
next_row = None
|
| 562 |
+
break
|
| 563 |
+
if isinstance(next_keys, pd.Series):
|
| 564 |
+
next_keys = next_keys.values
|
| 565 |
+
yield next_keys, self.slice(start, end, copy=False)
|
| 566 |
+
start = end
|
| 567 |
+
except StopIteration:
|
| 568 |
+
break
|
| 569 |
+
|
| 570 |
+
builder = PandasBlockBuilder()
|
| 571 |
+
for group_keys, group_view in iter_groups():
|
| 572 |
+
# Aggregate.
|
| 573 |
+
init_vals = group_keys
|
| 574 |
+
if len(group_keys) == 1:
|
| 575 |
+
init_vals = group_keys[0]
|
| 576 |
+
accumulators = [agg.init(init_vals) for agg in aggs]
|
| 577 |
+
for i in range(len(aggs)):
|
| 578 |
+
accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)
|
| 579 |
+
|
| 580 |
+
# Build the row.
|
| 581 |
+
row = {}
|
| 582 |
+
if keys:
|
| 583 |
+
for k, gk in zip(keys, group_keys):
|
| 584 |
+
row[k] = gk
|
| 585 |
+
|
| 586 |
+
count = collections.defaultdict(int)
|
| 587 |
+
for agg, accumulator in zip(aggs, accumulators):
|
| 588 |
+
name = agg.name
|
| 589 |
+
# Check for conflicts with existing aggregation name.
|
| 590 |
+
if count[name] > 0:
|
| 591 |
+
name = self._munge_conflict(name, count[name])
|
| 592 |
+
count[name] += 1
|
| 593 |
+
row[name] = accumulator
|
| 594 |
+
|
| 595 |
+
builder.add(row)
|
| 596 |
+
|
| 597 |
+
return builder.build()
|
| 598 |
+
|
| 599 |
+
@staticmethod
|
| 600 |
+
def merge_sorted_blocks(
|
| 601 |
+
blocks: List[Block], sort_key: "SortKey"
|
| 602 |
+
) -> Tuple["pandas.DataFrame", BlockMetadata]:
|
| 603 |
+
pd = lazy_import_pandas()
|
| 604 |
+
stats = BlockExecStats.builder()
|
| 605 |
+
blocks = [b for b in blocks if b.shape[0] > 0]
|
| 606 |
+
if len(blocks) == 0:
|
| 607 |
+
ret = PandasBlockAccessor._empty_table()
|
| 608 |
+
else:
|
| 609 |
+
# Handle blocks of different types.
|
| 610 |
+
blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas")
|
| 611 |
+
ret = pd.concat(blocks, ignore_index=True)
|
| 612 |
+
columns, ascending = sort_key.to_pandas_sort_args()
|
| 613 |
+
ret = ret.sort_values(by=columns, ascending=ascending)
|
| 614 |
+
return ret, PandasBlockAccessor(ret).get_metadata(exec_stats=stats.build())
|
| 615 |
+
|
| 616 |
+
@staticmethod
|
| 617 |
+
def aggregate_combined_blocks(
|
| 618 |
+
blocks: List["pandas.DataFrame"],
|
| 619 |
+
sort_key: "SortKey",
|
| 620 |
+
aggs: Tuple["AggregateFn"],
|
| 621 |
+
finalize: bool,
|
| 622 |
+
) -> Tuple["pandas.DataFrame", BlockMetadata]:
|
| 623 |
+
"""Aggregate sorted, partially combined blocks with the same key range.
|
| 624 |
+
|
| 625 |
+
This assumes blocks are already sorted by key in ascending order,
|
| 626 |
+
so we can do merge sort to get all the rows with the same key.
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
blocks: A list of partially combined and sorted blocks.
|
| 630 |
+
sort_key: The column name of key or None for global aggregation.
|
| 631 |
+
aggs: The aggregations to do.
|
| 632 |
+
finalize: Whether to finalize the aggregation. This is used as an
|
| 633 |
+
optimization for cases where we repeatedly combine partially
|
| 634 |
+
aggregated groups.
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
A block of [k, v_1, ..., v_n] columns and its metadata where k is
|
| 638 |
+
the groupby key and v_i is the corresponding aggregation result for
|
| 639 |
+
the ith given aggregation.
|
| 640 |
+
If key is None then the k column is omitted.
|
| 641 |
+
"""
|
| 642 |
+
|
| 643 |
+
stats = BlockExecStats.builder()
|
| 644 |
+
keys = sort_key.get_columns()
|
| 645 |
+
|
| 646 |
+
def key_fn(r):
|
| 647 |
+
if keys:
|
| 648 |
+
return tuple(r[keys])
|
| 649 |
+
else:
|
| 650 |
+
return (0,)
|
| 651 |
+
|
| 652 |
+
# Handle blocks of different types.
|
| 653 |
+
blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas")
|
| 654 |
+
|
| 655 |
+
iter = heapq.merge(
|
| 656 |
+
*[
|
| 657 |
+
PandasBlockAccessor(block).iter_rows(public_row_format=False)
|
| 658 |
+
for block in blocks
|
| 659 |
+
],
|
| 660 |
+
key=key_fn,
|
| 661 |
+
)
|
| 662 |
+
next_row = None
|
| 663 |
+
builder = PandasBlockBuilder()
|
| 664 |
+
while True:
|
| 665 |
+
try:
|
| 666 |
+
if next_row is None:
|
| 667 |
+
next_row = next(iter)
|
| 668 |
+
next_keys = key_fn(next_row)
|
| 669 |
+
next_key_columns = keys
|
| 670 |
+
|
| 671 |
+
def gen():
|
| 672 |
+
nonlocal iter
|
| 673 |
+
nonlocal next_row
|
| 674 |
+
while keys_equal(key_fn(next_row), next_keys):
|
| 675 |
+
yield next_row
|
| 676 |
+
try:
|
| 677 |
+
next_row = next(iter)
|
| 678 |
+
except StopIteration:
|
| 679 |
+
next_row = None
|
| 680 |
+
break
|
| 681 |
+
|
| 682 |
+
# Merge.
|
| 683 |
+
first = True
|
| 684 |
+
accumulators = [None] * len(aggs)
|
| 685 |
+
resolved_agg_names = [None] * len(aggs)
|
| 686 |
+
for r in gen():
|
| 687 |
+
if first:
|
| 688 |
+
count = collections.defaultdict(int)
|
| 689 |
+
for i in range(len(aggs)):
|
| 690 |
+
name = aggs[i].name
|
| 691 |
+
# Check for conflicts with existing aggregation
|
| 692 |
+
# name.
|
| 693 |
+
if count[name] > 0:
|
| 694 |
+
name = PandasBlockAccessor._munge_conflict(
|
| 695 |
+
name, count[name]
|
| 696 |
+
)
|
| 697 |
+
count[name] += 1
|
| 698 |
+
resolved_agg_names[i] = name
|
| 699 |
+
accumulators[i] = r[name]
|
| 700 |
+
first = False
|
| 701 |
+
else:
|
| 702 |
+
for i in range(len(aggs)):
|
| 703 |
+
accumulators[i] = aggs[i].merge(
|
| 704 |
+
accumulators[i], r[resolved_agg_names[i]]
|
| 705 |
+
)
|
| 706 |
+
# Build the row.
|
| 707 |
+
row = {}
|
| 708 |
+
if keys:
|
| 709 |
+
for col_name, next_key in zip(next_key_columns, next_keys):
|
| 710 |
+
row[col_name] = next_key
|
| 711 |
+
|
| 712 |
+
for agg, agg_name, accumulator in zip(
|
| 713 |
+
aggs, resolved_agg_names, accumulators
|
| 714 |
+
):
|
| 715 |
+
if finalize:
|
| 716 |
+
row[agg_name] = agg.finalize(accumulator)
|
| 717 |
+
else:
|
| 718 |
+
row[agg_name] = accumulator
|
| 719 |
+
|
| 720 |
+
builder.add(row)
|
| 721 |
+
except StopIteration:
|
| 722 |
+
break
|
| 723 |
+
|
| 724 |
+
ret = builder.build()
|
| 725 |
+
return ret, PandasBlockAccessor(ret).get_metadata(exec_stats=stats.build())
|
| 726 |
+
|
| 727 |
+
def block_type(self) -> BlockType:
|
| 728 |
+
return BlockType.PANDAS
|
.venv/lib/python3.11/site-packages/ray/data/_internal/plan.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union
|
| 5 |
+
|
| 6 |
+
import pyarrow
|
| 7 |
+
|
| 8 |
+
import ray
|
| 9 |
+
from ray._private.internal_api import get_memory_info_reply, get_state_from_address
|
| 10 |
+
from ray.data._internal.execution.interfaces import RefBundle
|
| 11 |
+
from ray.data._internal.logical.interfaces.logical_operator import LogicalOperator
|
| 12 |
+
from ray.data._internal.logical.interfaces.logical_plan import LogicalPlan
|
| 13 |
+
from ray.data._internal.logical.operators.from_operators import AbstractFrom
|
| 14 |
+
from ray.data._internal.logical.operators.input_data_operator import InputData
|
| 15 |
+
from ray.data._internal.logical.operators.read_operator import Read
|
| 16 |
+
from ray.data._internal.stats import DatasetStats
|
| 17 |
+
from ray.data._internal.util import create_dataset_tag, unify_block_metadata_schema
|
| 18 |
+
from ray.data.block import BlockMetadata
|
| 19 |
+
from ray.data.context import DataContext
|
| 20 |
+
from ray.data.exceptions import omit_traceback_stdout
|
| 21 |
+
from ray.util.debug import log_once
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
|
| 25 |
+
from ray.data._internal.execution.interfaces import Executor
|
| 26 |
+
from ray.data.dataset import Dataset
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Scheduling strategy can be inherited from prev operator if not specified.
|
| 30 |
+
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ExecutionPlan:
|
| 37 |
+
"""A lazy execution plan for a Dataset.
|
| 38 |
+
|
| 39 |
+
This lazy execution plan builds up a chain of ``List[RefBundle]`` -->
|
| 40 |
+
``List[RefBundle]`` operators. Prior to execution, we apply a set of logical
|
| 41 |
+
plan optimizations, such as operator fusion, in order to reduce Ray task
|
| 42 |
+
overhead and data copies.
|
| 43 |
+
|
| 44 |
+
Internally, the execution plan holds a snapshot of a computed list of
|
| 45 |
+
blocks and their associated metadata under ``self._snapshot_bundle``,
|
| 46 |
+
where this snapshot is the cached output of executing the operator chain."""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
stats: DatasetStats,
|
| 51 |
+
*,
|
| 52 |
+
data_context: Optional[DataContext] = None,
|
| 53 |
+
):
|
| 54 |
+
"""Create a plan with no transformation operators.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
stats: Stats for the base blocks.
|
| 58 |
+
data_context: :class:`~ray.data.context.DataContext`
|
| 59 |
+
object to use for execution.
|
| 60 |
+
"""
|
| 61 |
+
self._in_stats = stats
|
| 62 |
+
# A computed snapshot of some prefix of operators and their corresponding
|
| 63 |
+
# output blocks and stats.
|
| 64 |
+
self._snapshot_operator: Optional[LogicalOperator] = None
|
| 65 |
+
self._snapshot_stats = None
|
| 66 |
+
self._snapshot_bundle = None
|
| 67 |
+
# Snapshot of only metadata corresponding to the final operator's
|
| 68 |
+
# output bundles, used as the source of truth for the Dataset's schema
|
| 69 |
+
# and count. This is calculated and cached when the plan is executed as an
|
| 70 |
+
# iterator (`execute_to_iterator()`), and avoids caching
|
| 71 |
+
# all of the output blocks in memory like in `self.snapshot_bundle`.
|
| 72 |
+
# TODO(scottjlee): To keep the caching logic consistent, update `execute()`
|
| 73 |
+
# to also store the metadata in `_snapshot_metadata` instead of
|
| 74 |
+
# `_snapshot_bundle`. For example, we could store the blocks in
|
| 75 |
+
# `self._snapshot_blocks` and the metadata in `self._snapshot_metadata`.
|
| 76 |
+
self._snapshot_metadata: Optional[BlockMetadata] = None
|
| 77 |
+
|
| 78 |
+
# Cached schema.
|
| 79 |
+
self._schema = None
|
| 80 |
+
# Set when a Dataset is constructed with this plan
|
| 81 |
+
self._dataset_uuid = None
|
| 82 |
+
|
| 83 |
+
self._dataset_name = None
|
| 84 |
+
|
| 85 |
+
self._has_started_execution = False
|
| 86 |
+
|
| 87 |
+
if data_context is None:
|
| 88 |
+
# Snapshot the current context, so that the config of Datasets is always
|
| 89 |
+
# determined by the config at the time it was created.
|
| 90 |
+
self._context = copy.deepcopy(DataContext.get_current())
|
| 91 |
+
else:
|
| 92 |
+
self._context = data_context
|
| 93 |
+
|
| 94 |
+
def __repr__(self) -> str:
|
| 95 |
+
return (
|
| 96 |
+
f"ExecutionPlan("
|
| 97 |
+
f"dataset_uuid={self._dataset_uuid}, "
|
| 98 |
+
f"snapshot_operator={self._snapshot_operator}"
|
| 99 |
+
f")"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def get_plan_as_string(self, dataset_cls: Type["Dataset"]) -> str:
|
| 103 |
+
"""Create a cosmetic string representation of this execution plan.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
The string representation of this execution plan.
|
| 107 |
+
"""
|
| 108 |
+
# NOTE: this is used for Dataset.__repr__ to give a user-facing string
|
| 109 |
+
# representation. Ideally ExecutionPlan.__repr__ should be replaced with this
|
| 110 |
+
# method as well.
|
| 111 |
+
|
| 112 |
+
from ray.data.dataset import MaterializedDataset
|
| 113 |
+
|
| 114 |
+
# Do not force execution for schema, as this method is expected to be very
|
| 115 |
+
# cheap.
|
| 116 |
+
plan_str = ""
|
| 117 |
+
plan_max_depth = 0
|
| 118 |
+
if not self.has_computed_output():
|
| 119 |
+
|
| 120 |
+
def generate_logical_plan_string(
|
| 121 |
+
op: LogicalOperator,
|
| 122 |
+
curr_str: str = "",
|
| 123 |
+
depth: int = 0,
|
| 124 |
+
):
|
| 125 |
+
"""Traverse (DFS) the LogicalPlan DAG and
|
| 126 |
+
return a string representation of the operators."""
|
| 127 |
+
if isinstance(op, (Read, InputData, AbstractFrom)):
|
| 128 |
+
return curr_str, depth
|
| 129 |
+
|
| 130 |
+
curr_max_depth = depth
|
| 131 |
+
op_name = op.name
|
| 132 |
+
if depth == 0:
|
| 133 |
+
curr_str += f"{op_name}\n"
|
| 134 |
+
else:
|
| 135 |
+
trailing_space = " " * ((depth - 1) * 3)
|
| 136 |
+
curr_str += f"{trailing_space}+- {op_name}\n"
|
| 137 |
+
|
| 138 |
+
for input in op.input_dependencies:
|
| 139 |
+
curr_str, input_max_depth = generate_logical_plan_string(
|
| 140 |
+
input, curr_str, depth + 1
|
| 141 |
+
)
|
| 142 |
+
curr_max_depth = max(curr_max_depth, input_max_depth)
|
| 143 |
+
return curr_str, curr_max_depth
|
| 144 |
+
|
| 145 |
+
# generate_logical_plan_string(self._logical_plan.dag)
|
| 146 |
+
plan_str, plan_max_depth = generate_logical_plan_string(
|
| 147 |
+
self._logical_plan.dag
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if self._snapshot_bundle is not None:
|
| 151 |
+
# This plan has executed some but not all operators.
|
| 152 |
+
schema = unify_block_metadata_schema(self._snapshot_bundle.metadata)
|
| 153 |
+
count = self._snapshot_bundle.num_rows()
|
| 154 |
+
elif self._snapshot_metadata is not None:
|
| 155 |
+
schema = self._snapshot_metadata.schema
|
| 156 |
+
count = self._snapshot_metadata.num_rows
|
| 157 |
+
else:
|
| 158 |
+
# This plan hasn't executed any operators.
|
| 159 |
+
sources = self._logical_plan.sources()
|
| 160 |
+
# TODO(@bveeramani): Handle schemas for n-ary operators like `Union`.
|
| 161 |
+
if len(sources) > 1:
|
| 162 |
+
# Multiple sources, cannot determine schema.
|
| 163 |
+
schema = None
|
| 164 |
+
count = None
|
| 165 |
+
else:
|
| 166 |
+
assert len(sources) == 1
|
| 167 |
+
plan = ExecutionPlan(DatasetStats(metadata={}, parent=None))
|
| 168 |
+
plan.link_logical_plan(LogicalPlan(sources[0], plan._context))
|
| 169 |
+
schema = plan.schema()
|
| 170 |
+
count = plan.meta_count()
|
| 171 |
+
else:
|
| 172 |
+
# Get schema of output blocks.
|
| 173 |
+
schema = self.schema(fetch_if_missing=False)
|
| 174 |
+
count = self._snapshot_bundle.num_rows()
|
| 175 |
+
|
| 176 |
+
if schema is None:
|
| 177 |
+
schema_str = "Unknown schema"
|
| 178 |
+
elif isinstance(schema, type):
|
| 179 |
+
schema_str = str(schema)
|
| 180 |
+
else:
|
| 181 |
+
schema_str = []
|
| 182 |
+
for n, t in zip(schema.names, schema.types):
|
| 183 |
+
if hasattr(t, "__name__"):
|
| 184 |
+
t = t.__name__
|
| 185 |
+
schema_str.append(f"{n}: {t}")
|
| 186 |
+
schema_str = ", ".join(schema_str)
|
| 187 |
+
schema_str = "{" + schema_str + "}"
|
| 188 |
+
|
| 189 |
+
if count is None:
|
| 190 |
+
count = "?"
|
| 191 |
+
|
| 192 |
+
num_blocks = None
|
| 193 |
+
if dataset_cls == MaterializedDataset:
|
| 194 |
+
num_blocks = self.initial_num_blocks()
|
| 195 |
+
assert num_blocks is not None
|
| 196 |
+
|
| 197 |
+
name_str = (
|
| 198 |
+
"name={}, ".format(self._dataset_name)
|
| 199 |
+
if self._dataset_name is not None
|
| 200 |
+
else ""
|
| 201 |
+
)
|
| 202 |
+
num_blocks_str = f"num_blocks={num_blocks}, " if num_blocks else ""
|
| 203 |
+
|
| 204 |
+
dataset_str = "{}({}{}num_rows={}, schema={})".format(
|
| 205 |
+
dataset_cls.__name__,
|
| 206 |
+
name_str,
|
| 207 |
+
num_blocks_str,
|
| 208 |
+
count,
|
| 209 |
+
schema_str,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# If the resulting string representation fits in one line, use it directly.
|
| 213 |
+
SCHEMA_LINE_CHAR_LIMIT = 80
|
| 214 |
+
MIN_FIELD_LENGTH = 10
|
| 215 |
+
INDENT_STR = " " * 3
|
| 216 |
+
trailing_space = INDENT_STR * plan_max_depth
|
| 217 |
+
|
| 218 |
+
if len(dataset_str) > SCHEMA_LINE_CHAR_LIMIT:
|
| 219 |
+
# If the resulting string representation exceeds the line char limit,
|
| 220 |
+
# first try breaking up each `Dataset` parameter into its own line
|
| 221 |
+
# and check if each line fits within the line limit. We check the
|
| 222 |
+
# `schema` param's length, since this is likely the longest string.
|
| 223 |
+
schema_str_on_new_line = f"{trailing_space}{INDENT_STR}schema={schema_str}"
|
| 224 |
+
if len(schema_str_on_new_line) > SCHEMA_LINE_CHAR_LIMIT:
|
| 225 |
+
# If the schema cannot fit on a single line, break up each field
|
| 226 |
+
# into its own line.
|
| 227 |
+
schema_str = []
|
| 228 |
+
for n, t in zip(schema.names, schema.types):
|
| 229 |
+
if hasattr(t, "__name__"):
|
| 230 |
+
t = t.__name__
|
| 231 |
+
col_str = f"{trailing_space}{INDENT_STR * 2}{n}: {t}"
|
| 232 |
+
# If the field line exceeds the char limit, abbreviate
|
| 233 |
+
# the field name to fit while maintaining the full type
|
| 234 |
+
if len(col_str) > SCHEMA_LINE_CHAR_LIMIT:
|
| 235 |
+
shortened_suffix = f"...: {str(t)}"
|
| 236 |
+
# Show at least 10 characters of the field name, even if
|
| 237 |
+
# we have already hit the line limit with the type.
|
| 238 |
+
chars_left_for_col_name = max(
|
| 239 |
+
SCHEMA_LINE_CHAR_LIMIT - len(shortened_suffix),
|
| 240 |
+
MIN_FIELD_LENGTH,
|
| 241 |
+
)
|
| 242 |
+
col_str = (
|
| 243 |
+
f"{col_str[:chars_left_for_col_name]}{shortened_suffix}"
|
| 244 |
+
)
|
| 245 |
+
schema_str.append(col_str)
|
| 246 |
+
schema_str = ",\n".join(schema_str)
|
| 247 |
+
schema_str = (
|
| 248 |
+
"{\n" + schema_str + f"\n{trailing_space}{INDENT_STR}" + "}"
|
| 249 |
+
)
|
| 250 |
+
name_str = (
|
| 251 |
+
f"\n{trailing_space}{INDENT_STR}name={self._dataset_name},"
|
| 252 |
+
if self._dataset_name is not None
|
| 253 |
+
else ""
|
| 254 |
+
)
|
| 255 |
+
num_blocks_str = (
|
| 256 |
+
f"\n{trailing_space}{INDENT_STR}num_blocks={num_blocks},"
|
| 257 |
+
if num_blocks
|
| 258 |
+
else ""
|
| 259 |
+
)
|
| 260 |
+
dataset_str = (
|
| 261 |
+
f"{dataset_cls.__name__}("
|
| 262 |
+
f"{name_str}"
|
| 263 |
+
f"{num_blocks_str}"
|
| 264 |
+
f"\n{trailing_space}{INDENT_STR}num_rows={count},"
|
| 265 |
+
f"\n{trailing_space}{INDENT_STR}schema={schema_str}"
|
| 266 |
+
f"\n{trailing_space})"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if plan_max_depth == 0:
|
| 270 |
+
plan_str += dataset_str
|
| 271 |
+
else:
|
| 272 |
+
plan_str += f"{INDENT_STR * (plan_max_depth - 1)}+- {dataset_str}"
|
| 273 |
+
return plan_str
|
| 274 |
+
|
| 275 |
+
def link_logical_plan(self, logical_plan: "LogicalPlan"):
|
| 276 |
+
"""Link the logical plan into this execution plan.
|
| 277 |
+
|
| 278 |
+
This is used for triggering execution for optimizer code path in this legacy
|
| 279 |
+
execution plan.
|
| 280 |
+
"""
|
| 281 |
+
self._logical_plan = logical_plan
|
| 282 |
+
self._logical_plan._context = self._context
|
| 283 |
+
|
| 284 |
+
def copy(self) -> "ExecutionPlan":
|
| 285 |
+
"""Create a shallow copy of this execution plan.
|
| 286 |
+
|
| 287 |
+
This copy can be executed without mutating the original, but clearing the copy
|
| 288 |
+
will also clear the original.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
A shallow copy of this execution plan.
|
| 292 |
+
"""
|
| 293 |
+
plan_copy = ExecutionPlan(
|
| 294 |
+
self._in_stats,
|
| 295 |
+
data_context=self._context,
|
| 296 |
+
)
|
| 297 |
+
if self._snapshot_bundle is not None:
|
| 298 |
+
# Copy over the existing snapshot.
|
| 299 |
+
plan_copy._snapshot_bundle = self._snapshot_bundle
|
| 300 |
+
plan_copy._snapshot_operator = self._snapshot_operator
|
| 301 |
+
plan_copy._snapshot_stats = self._snapshot_stats
|
| 302 |
+
plan_copy._dataset_name = self._dataset_name
|
| 303 |
+
return plan_copy
|
| 304 |
+
|
| 305 |
+
def deep_copy(self) -> "ExecutionPlan":
|
| 306 |
+
"""Create a deep copy of this execution plan.
|
| 307 |
+
|
| 308 |
+
This copy can be executed AND cleared without mutating the original.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
A deep copy of this execution plan.
|
| 312 |
+
"""
|
| 313 |
+
plan_copy = ExecutionPlan(copy.copy(self._in_stats))
|
| 314 |
+
if self._snapshot_bundle:
|
| 315 |
+
# Copy over the existing snapshot.
|
| 316 |
+
plan_copy._snapshot_bundle = copy.copy(self._snapshot_bundle)
|
| 317 |
+
plan_copy._snapshot_operator = copy.copy(self._snapshot_operator)
|
| 318 |
+
plan_copy._snapshot_stats = copy.copy(self._snapshot_stats)
|
| 319 |
+
plan_copy._dataset_name = self._dataset_name
|
| 320 |
+
return plan_copy
|
| 321 |
+
|
| 322 |
+
def initial_num_blocks(self) -> Optional[int]:
|
| 323 |
+
"""Get the estimated number of blocks from the logical plan
|
| 324 |
+
after applying execution plan optimizations, but prior to
|
| 325 |
+
fully executing the dataset."""
|
| 326 |
+
return self._logical_plan.dag.estimated_num_outputs()
|
| 327 |
+
|
| 328 |
+
def schema(
|
| 329 |
+
self, fetch_if_missing: bool = False
|
| 330 |
+
) -> Union[type, "pyarrow.lib.Schema"]:
|
| 331 |
+
"""Get the schema after applying all execution plan optimizations,
|
| 332 |
+
but prior to fully executing the dataset
|
| 333 |
+
(unless `fetch_if_missing` is set to True).
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
fetch_if_missing: Whether to execute the plan to fetch the schema.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
The schema of the output dataset.
|
| 340 |
+
"""
|
| 341 |
+
if self._schema is not None:
|
| 342 |
+
return self._schema
|
| 343 |
+
|
| 344 |
+
schema = None
|
| 345 |
+
if self.has_computed_output():
|
| 346 |
+
schema = unify_block_metadata_schema(self._snapshot_bundle.metadata)
|
| 347 |
+
elif self._logical_plan.dag.aggregate_output_metadata().schema is not None:
|
| 348 |
+
schema = self._logical_plan.dag.aggregate_output_metadata().schema
|
| 349 |
+
elif fetch_if_missing:
|
| 350 |
+
iter_ref_bundles, _, _ = self.execute_to_iterator()
|
| 351 |
+
for ref_bundle in iter_ref_bundles:
|
| 352 |
+
for metadata in ref_bundle.metadata:
|
| 353 |
+
if metadata.schema is not None and (
|
| 354 |
+
metadata.num_rows is None or metadata.num_rows > 0
|
| 355 |
+
):
|
| 356 |
+
schema = metadata.schema
|
| 357 |
+
break
|
| 358 |
+
elif self.is_read_only():
|
| 359 |
+
# For consistency with the previous implementation, we fetch the schema if
|
| 360 |
+
# the plan is read-only even if `fetch_if_missing` is False.
|
| 361 |
+
iter_ref_bundles, _, _ = self.execute_to_iterator()
|
| 362 |
+
try:
|
| 363 |
+
ref_bundle = next(iter(iter_ref_bundles))
|
| 364 |
+
for metadata in ref_bundle.metadata:
|
| 365 |
+
if metadata.schema is not None:
|
| 366 |
+
schema = metadata.schema
|
| 367 |
+
break
|
| 368 |
+
except StopIteration: # Empty dataset.
|
| 369 |
+
schema = None
|
| 370 |
+
|
| 371 |
+
self._schema = schema
|
| 372 |
+
return self._schema
|
| 373 |
+
|
| 374 |
+
def cache_schema(self, schema: Union[type, "pyarrow.lib.Schema"]):
|
| 375 |
+
self._schema = schema
|
| 376 |
+
|
| 377 |
+
def input_files(self) -> Optional[List[str]]:
|
| 378 |
+
"""Get the input files of the dataset, if available."""
|
| 379 |
+
return self._logical_plan.dag.aggregate_output_metadata().input_files
|
| 380 |
+
|
| 381 |
+
def meta_count(self) -> Optional[int]:
|
| 382 |
+
"""Get the number of rows after applying all plan optimizations, if possible.
|
| 383 |
+
|
| 384 |
+
This method will never trigger any computation.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
The number of records of the result Dataset, or None.
|
| 388 |
+
"""
|
| 389 |
+
if self.has_computed_output():
|
| 390 |
+
num_rows = sum(m.num_rows for m in self._snapshot_bundle.metadata)
|
| 391 |
+
elif self._logical_plan.dag.aggregate_output_metadata().num_rows is not None:
|
| 392 |
+
num_rows = self._logical_plan.dag.aggregate_output_metadata().num_rows
|
| 393 |
+
else:
|
| 394 |
+
num_rows = None
|
| 395 |
+
return num_rows
|
| 396 |
+
|
| 397 |
+
@omit_traceback_stdout
|
| 398 |
+
def execute_to_iterator(
|
| 399 |
+
self,
|
| 400 |
+
) -> Tuple[Iterator[RefBundle], DatasetStats, Optional["Executor"]]:
|
| 401 |
+
"""Execute this plan, returning an iterator.
|
| 402 |
+
|
| 403 |
+
This will use streaming execution to generate outputs.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
Tuple of iterator over output RefBundles, DatasetStats, and the executor.
|
| 407 |
+
"""
|
| 408 |
+
self._has_started_execution = True
|
| 409 |
+
|
| 410 |
+
# Always used the saved context for execution.
|
| 411 |
+
ctx = self._context
|
| 412 |
+
|
| 413 |
+
if self.has_computed_output():
|
| 414 |
+
bundle = self.execute()
|
| 415 |
+
return iter([bundle]), self._snapshot_stats, None
|
| 416 |
+
|
| 417 |
+
from ray.data._internal.execution.legacy_compat import (
|
| 418 |
+
execute_to_legacy_bundle_iterator,
|
| 419 |
+
)
|
| 420 |
+
from ray.data._internal.execution.streaming_executor import StreamingExecutor
|
| 421 |
+
|
| 422 |
+
metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
|
| 423 |
+
executor = StreamingExecutor(ctx, metrics_tag)
|
| 424 |
+
bundle_iter = execute_to_legacy_bundle_iterator(executor, self)
|
| 425 |
+
# Since the generator doesn't run any code until we try to fetch the first
|
| 426 |
+
# value, force execution of one bundle before we call get_stats().
|
| 427 |
+
gen = iter(bundle_iter)
|
| 428 |
+
try:
|
| 429 |
+
bundle_iter = itertools.chain([next(gen)], gen)
|
| 430 |
+
except StopIteration:
|
| 431 |
+
pass
|
| 432 |
+
self._snapshot_stats = executor.get_stats()
|
| 433 |
+
return bundle_iter, self._snapshot_stats, executor
|
| 434 |
+
|
| 435 |
+
@omit_traceback_stdout
|
| 436 |
+
def execute(
|
| 437 |
+
self,
|
| 438 |
+
preserve_order: bool = False,
|
| 439 |
+
) -> RefBundle:
|
| 440 |
+
"""Execute this plan.
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
preserve_order: Whether to preserve order in execution.
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
The blocks of the output dataset.
|
| 447 |
+
"""
|
| 448 |
+
self._has_started_execution = True
|
| 449 |
+
|
| 450 |
+
# Always used the saved context for execution.
|
| 451 |
+
context = self._context
|
| 452 |
+
|
| 453 |
+
if not ray.available_resources().get("CPU"):
|
| 454 |
+
if log_once("cpu_warning"):
|
| 455 |
+
logger.warning(
|
| 456 |
+
"Warning: The Ray cluster currently does not have "
|
| 457 |
+
"any available CPUs. The Dataset job will hang unless more CPUs "
|
| 458 |
+
"are freed up. A common reason is that cluster resources are "
|
| 459 |
+
"used by Actors or Tune trials; see the following link "
|
| 460 |
+
"for more details: "
|
| 461 |
+
"https://docs.ray.io/en/latest/data/data-internals.html#ray-data-and-tune" # noqa: E501
|
| 462 |
+
)
|
| 463 |
+
if not self.has_computed_output():
|
| 464 |
+
from ray.data._internal.execution.legacy_compat import (
|
| 465 |
+
_get_initial_stats_from_plan,
|
| 466 |
+
execute_to_legacy_block_list,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
if self._logical_plan.dag.output_data() is not None:
|
| 470 |
+
# If the data is already materialized (e.g., `from_pandas`), we can
|
| 471 |
+
# skip execution and directly return the output data. This avoids
|
| 472 |
+
# recording unnecessary metrics for an empty plan execution.
|
| 473 |
+
stats = _get_initial_stats_from_plan(self)
|
| 474 |
+
|
| 475 |
+
# TODO(@bveeramani): Make `ExecutionPlan.execute()` return
|
| 476 |
+
# `List[RefBundle]` instead of `RefBundle`. Among other reasons, it'd
|
| 477 |
+
# allow us to remove the unwrapping logic below.
|
| 478 |
+
output_bundles = self._logical_plan.dag.output_data()
|
| 479 |
+
owns_blocks = all(bundle.owns_blocks for bundle in output_bundles)
|
| 480 |
+
bundle = RefBundle(
|
| 481 |
+
[
|
| 482 |
+
(block, metadata)
|
| 483 |
+
for bundle in output_bundles
|
| 484 |
+
for block, metadata in bundle.blocks
|
| 485 |
+
],
|
| 486 |
+
owns_blocks=owns_blocks,
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
from ray.data._internal.execution.streaming_executor import (
|
| 490 |
+
StreamingExecutor,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
|
| 494 |
+
executor = StreamingExecutor(
|
| 495 |
+
context,
|
| 496 |
+
metrics_tag,
|
| 497 |
+
)
|
| 498 |
+
blocks = execute_to_legacy_block_list(
|
| 499 |
+
executor,
|
| 500 |
+
self,
|
| 501 |
+
dataset_uuid=self._dataset_uuid,
|
| 502 |
+
preserve_order=preserve_order,
|
| 503 |
+
)
|
| 504 |
+
bundle = RefBundle(
|
| 505 |
+
tuple(blocks.iter_blocks_with_metadata()),
|
| 506 |
+
owns_blocks=blocks._owned_by_consumer,
|
| 507 |
+
)
|
| 508 |
+
stats = executor.get_stats()
|
| 509 |
+
stats_summary_string = stats.to_summary().to_string(
|
| 510 |
+
include_parent=False
|
| 511 |
+
)
|
| 512 |
+
if context.enable_auto_log_stats:
|
| 513 |
+
logger.info(stats_summary_string)
|
| 514 |
+
|
| 515 |
+
# Retrieve memory-related stats from ray.
|
| 516 |
+
try:
|
| 517 |
+
reply = get_memory_info_reply(
|
| 518 |
+
get_state_from_address(ray.get_runtime_context().gcs_address)
|
| 519 |
+
)
|
| 520 |
+
if reply.store_stats.spill_time_total_s > 0:
|
| 521 |
+
stats.global_bytes_spilled = int(
|
| 522 |
+
reply.store_stats.spilled_bytes_total
|
| 523 |
+
)
|
| 524 |
+
if reply.store_stats.restore_time_total_s > 0:
|
| 525 |
+
stats.global_bytes_restored = int(
|
| 526 |
+
reply.store_stats.restored_bytes_total
|
| 527 |
+
)
|
| 528 |
+
except Exception as e:
|
| 529 |
+
logger.debug(
|
| 530 |
+
"Skipping recording memory spilled and restored statistics due to "
|
| 531 |
+
f"exception: {e}"
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
stats.dataset_bytes_spilled = 0
|
| 535 |
+
|
| 536 |
+
def collect_stats(cur_stats):
|
| 537 |
+
stats.dataset_bytes_spilled += cur_stats.extra_metrics.get(
|
| 538 |
+
"obj_store_mem_spilled", 0
|
| 539 |
+
)
|
| 540 |
+
for parent in cur_stats.parents:
|
| 541 |
+
collect_stats(parent)
|
| 542 |
+
|
| 543 |
+
collect_stats(stats)
|
| 544 |
+
|
| 545 |
+
# Set the snapshot to the output of the final operator.
|
| 546 |
+
self._snapshot_bundle = bundle
|
| 547 |
+
self._snapshot_operator = self._logical_plan.dag
|
| 548 |
+
self._snapshot_stats = stats
|
| 549 |
+
self._snapshot_stats.dataset_uuid = self._dataset_uuid
|
| 550 |
+
|
| 551 |
+
return self._snapshot_bundle
|
| 552 |
+
|
| 553 |
+
@property
|
| 554 |
+
def has_started_execution(self) -> bool:
|
| 555 |
+
"""Return ``True`` if this plan has been partially or fully executed."""
|
| 556 |
+
return self._has_started_execution
|
| 557 |
+
|
| 558 |
+
def clear_snapshot(self) -> None:
|
| 559 |
+
"""Clear the snapshot kept in the plan to the beginning state."""
|
| 560 |
+
self._snapshot_bundle = None
|
| 561 |
+
self._snapshot_operator = None
|
| 562 |
+
self._snapshot_stats = None
|
| 563 |
+
|
| 564 |
+
def stats(self) -> DatasetStats:
|
| 565 |
+
"""Return stats for this plan.
|
| 566 |
+
|
| 567 |
+
If the plan isn't executed, an empty stats object will be returned.
|
| 568 |
+
"""
|
| 569 |
+
if not self._snapshot_stats:
|
| 570 |
+
return DatasetStats(metadata={}, parent=None)
|
| 571 |
+
return self._snapshot_stats
|
| 572 |
+
|
| 573 |
+
def has_lazy_input(self) -> bool:
|
| 574 |
+
"""Return whether this plan has lazy input blocks."""
|
| 575 |
+
return all(isinstance(op, Read) for op in self._logical_plan.sources())
|
| 576 |
+
|
| 577 |
+
def is_read_only(self, root_op: Optional[LogicalOperator] = None) -> bool:
|
| 578 |
+
"""Return whether the LogicalPlan corresponding to `root_op`
|
| 579 |
+
contains only a Read op. By default, the last operator of
|
| 580 |
+
the LogicalPlan is used."""
|
| 581 |
+
if root_op is None:
|
| 582 |
+
root_op = self._logical_plan.dag
|
| 583 |
+
return isinstance(root_op, Read) and len(root_op.input_dependencies) == 0
|
| 584 |
+
|
| 585 |
+
def has_computed_output(self) -> bool:
|
| 586 |
+
"""Whether this plan has a computed snapshot for the final operator, i.e. for
|
| 587 |
+
the output of this plan.
|
| 588 |
+
"""
|
| 589 |
+
return (
|
| 590 |
+
self._snapshot_bundle is not None
|
| 591 |
+
and self._snapshot_operator == self._logical_plan.dag
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
def require_preserve_order(self) -> bool:
|
| 595 |
+
"""Whether this plan requires to preserve order."""
|
| 596 |
+
from ray.data._internal.logical.operators.all_to_all_operator import Sort
|
| 597 |
+
from ray.data._internal.logical.operators.n_ary_operator import Zip
|
| 598 |
+
|
| 599 |
+
for op in self._logical_plan.dag.post_order_iter():
|
| 600 |
+
if isinstance(op, (Zip, Sort)):
|
| 601 |
+
return True
|
| 602 |
+
return False
|
.venv/lib/python3.11/site-packages/ray/data/_internal/progress_bar.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import threading
|
| 3 |
+
from typing import Any, List, Optional
|
| 4 |
+
|
| 5 |
+
import ray
|
| 6 |
+
from ray.experimental import tqdm_ray
|
| 7 |
+
from ray.types import ObjectRef
|
| 8 |
+
from ray.util.debug import log_once
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import tqdm
|
| 14 |
+
|
| 15 |
+
needs_warning = False
|
| 16 |
+
except ImportError:
|
| 17 |
+
tqdm = None
|
| 18 |
+
needs_warning = True
|
| 19 |
+
|
| 20 |
+
# Used a signal to cancel execution.
|
| 21 |
+
_canceled_threads = set()
|
| 22 |
+
_canceled_threads_lock = threading.Lock()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def extract_num_rows(result: Any) -> int:
|
| 26 |
+
"""Extract the number of rows from a result object.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
result: The result object from which to extract the number of rows.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
The number of rows, defaulting to 1 if it cannot be determined.
|
| 33 |
+
"""
|
| 34 |
+
if hasattr(result, "num_rows"):
|
| 35 |
+
return result.num_rows
|
| 36 |
+
elif hasattr(result, "__len__"):
|
| 37 |
+
# For output is DataFrame,i.e. sort_sample
|
| 38 |
+
return len(result)
|
| 39 |
+
else:
|
| 40 |
+
return 1
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ProgressBar:
|
| 44 |
+
"""Thin wrapper around tqdm to handle soft imports.
|
| 45 |
+
|
| 46 |
+
If `total` is `None` known (for example, it is unknown
|
| 47 |
+
because no tasks have finished yet), doesn't display the full
|
| 48 |
+
progress bar. Still displays basic progress stats from tqdm."""
|
| 49 |
+
|
| 50 |
+
# If the name/description of the progress bar exceeds this length,
|
| 51 |
+
# it will be truncated.
|
| 52 |
+
MAX_NAME_LENGTH = 100
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
name: str,
|
| 57 |
+
total: Optional[int],
|
| 58 |
+
unit: str,
|
| 59 |
+
position: int = 0,
|
| 60 |
+
enabled: Optional[bool] = None,
|
| 61 |
+
):
|
| 62 |
+
self._desc = self._truncate_name(name)
|
| 63 |
+
self._progress = 0
|
| 64 |
+
# Prepend a space to the unit for better formatting.
|
| 65 |
+
if unit[0] != " ":
|
| 66 |
+
unit = " " + unit
|
| 67 |
+
|
| 68 |
+
if enabled is None:
|
| 69 |
+
from ray.data import DataContext
|
| 70 |
+
|
| 71 |
+
enabled = DataContext.get_current().enable_progress_bars
|
| 72 |
+
if not enabled:
|
| 73 |
+
self._bar = None
|
| 74 |
+
elif tqdm:
|
| 75 |
+
ctx = ray.data.context.DataContext.get_current()
|
| 76 |
+
if ctx.use_ray_tqdm:
|
| 77 |
+
self._bar = tqdm_ray.tqdm(total=total, unit=unit, position=position)
|
| 78 |
+
else:
|
| 79 |
+
self._bar = tqdm.tqdm(
|
| 80 |
+
total=total or 0,
|
| 81 |
+
position=position,
|
| 82 |
+
dynamic_ncols=True,
|
| 83 |
+
unit=unit,
|
| 84 |
+
unit_scale=True,
|
| 85 |
+
)
|
| 86 |
+
self._bar.set_description(self._desc)
|
| 87 |
+
else:
|
| 88 |
+
global needs_warning
|
| 89 |
+
if needs_warning:
|
| 90 |
+
print("[dataset]: Run `pip install tqdm` to enable progress reporting.")
|
| 91 |
+
needs_warning = False
|
| 92 |
+
self._bar = None
|
| 93 |
+
|
| 94 |
+
def _truncate_name(self, name: str) -> str:
|
| 95 |
+
ctx = ray.data.context.DataContext.get_current()
|
| 96 |
+
if (
|
| 97 |
+
not ctx.enable_progress_bar_name_truncation
|
| 98 |
+
or len(name) <= self.MAX_NAME_LENGTH
|
| 99 |
+
):
|
| 100 |
+
return name
|
| 101 |
+
|
| 102 |
+
op_names = name.split("->")
|
| 103 |
+
if len(op_names) == 1:
|
| 104 |
+
return op_names[0]
|
| 105 |
+
|
| 106 |
+
# Include as many operators as possible without approximately
|
| 107 |
+
# exceeding `MAX_NAME_LENGTH`. Always include the first and
|
| 108 |
+
# last operator names soit is easy to identify the DAG.
|
| 109 |
+
truncated_op_names = [op_names[0]]
|
| 110 |
+
for op_name in op_names[1:-1]:
|
| 111 |
+
if (
|
| 112 |
+
len("->".join(truncated_op_names))
|
| 113 |
+
+ len("->")
|
| 114 |
+
+ len(op_name)
|
| 115 |
+
+ len("->")
|
| 116 |
+
+ len(op_names[-1])
|
| 117 |
+
) > self.MAX_NAME_LENGTH:
|
| 118 |
+
truncated_op_names.append("...")
|
| 119 |
+
if log_once("ray_data_truncate_operator_name"):
|
| 120 |
+
logger.warning(
|
| 121 |
+
f"Truncating long operator name to {self.MAX_NAME_LENGTH} "
|
| 122 |
+
"characters. To disable this behavior, set "
|
| 123 |
+
"`ray.data.DataContext.get_current()."
|
| 124 |
+
"DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`."
|
| 125 |
+
)
|
| 126 |
+
break
|
| 127 |
+
truncated_op_names.append(op_name)
|
| 128 |
+
truncated_op_names.append(op_names[-1])
|
| 129 |
+
return "->".join(truncated_op_names)
|
| 130 |
+
|
| 131 |
+
def block_until_complete(self, remaining: List[ObjectRef]) -> None:
|
| 132 |
+
t = threading.current_thread()
|
| 133 |
+
while remaining:
|
| 134 |
+
done, remaining = ray.wait(
|
| 135 |
+
remaining, num_returns=len(remaining), fetch_local=False, timeout=0.1
|
| 136 |
+
)
|
| 137 |
+
total_rows_processed = 0
|
| 138 |
+
for _, result in zip(done, ray.get(done)):
|
| 139 |
+
num_rows = extract_num_rows(result)
|
| 140 |
+
total_rows_processed += num_rows
|
| 141 |
+
self.update(total_rows_processed)
|
| 142 |
+
|
| 143 |
+
with _canceled_threads_lock:
|
| 144 |
+
if t in _canceled_threads:
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]:
|
| 148 |
+
ref_to_result = {}
|
| 149 |
+
remaining = refs
|
| 150 |
+
t = threading.current_thread()
|
| 151 |
+
# Triggering fetch_local redundantly for the same object is slower.
|
| 152 |
+
# We only need to trigger the fetch_local once for each object,
|
| 153 |
+
# raylet will persist these fetch requests even after ray.wait returns.
|
| 154 |
+
# See https://github.com/ray-project/ray/issues/30375.
|
| 155 |
+
fetch_local = True
|
| 156 |
+
while remaining:
|
| 157 |
+
done, remaining = ray.wait(
|
| 158 |
+
remaining,
|
| 159 |
+
num_returns=len(remaining),
|
| 160 |
+
fetch_local=fetch_local,
|
| 161 |
+
timeout=0.1,
|
| 162 |
+
)
|
| 163 |
+
if fetch_local:
|
| 164 |
+
fetch_local = False
|
| 165 |
+
total_rows_processed = 0
|
| 166 |
+
for ref, result in zip(done, ray.get(done)):
|
| 167 |
+
ref_to_result[ref] = result
|
| 168 |
+
num_rows = extract_num_rows(result)
|
| 169 |
+
total_rows_processed += num_rows
|
| 170 |
+
self.update(total_rows_processed)
|
| 171 |
+
|
| 172 |
+
with _canceled_threads_lock:
|
| 173 |
+
if t in _canceled_threads:
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
return [ref_to_result[ref] for ref in refs]
|
| 177 |
+
|
| 178 |
+
def set_description(self, name: str) -> None:
|
| 179 |
+
name = self._truncate_name(name)
|
| 180 |
+
if self._bar and name != self._desc:
|
| 181 |
+
self._desc = name
|
| 182 |
+
self._bar.set_description(self._desc)
|
| 183 |
+
|
| 184 |
+
def get_description(self) -> str:
|
| 185 |
+
return self._desc
|
| 186 |
+
|
| 187 |
+
def refresh(self):
|
| 188 |
+
if self._bar:
|
| 189 |
+
self._bar.refresh()
|
| 190 |
+
|
| 191 |
+
def update(self, i: int = 0, total: Optional[int] = None) -> None:
|
| 192 |
+
if self._bar and (i != 0 or self._bar.total != total):
|
| 193 |
+
self._progress += i
|
| 194 |
+
if total is not None:
|
| 195 |
+
self._bar.total = total
|
| 196 |
+
if self._bar.total is not None and self._progress > self._bar.total:
|
| 197 |
+
# If the progress goes over 100%, update the total.
|
| 198 |
+
self._bar.total = self._progress
|
| 199 |
+
self._bar.update(i)
|
| 200 |
+
|
| 201 |
+
def close(self):
|
| 202 |
+
if self._bar:
|
| 203 |
+
if self._bar.total is not None and self._progress != self._bar.total:
|
| 204 |
+
# If the progress is not complete, update the total.
|
| 205 |
+
self._bar.total = self._progress
|
| 206 |
+
self._bar.refresh()
|
| 207 |
+
self._bar.close()
|
| 208 |
+
self._bar = None
|
| 209 |
+
|
| 210 |
+
def __del__(self):
|
| 211 |
+
self.close()
|
| 212 |
+
|
| 213 |
+
def __getstate__(self):
|
| 214 |
+
return {}
|
| 215 |
+
|
| 216 |
+
def __setstate__(self, state):
|
| 217 |
+
self._bar = None # Progress bar is disabled on remote nodes.
|
.venv/lib/python3.11/site-packages/ray/data/_internal/remote_fn.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Hashable, List
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
|
| 5 |
+
CACHED_FUNCTIONS = {}
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def cached_remote_fn(fn: Any, **ray_remote_args) -> Any:
|
| 9 |
+
"""Lazily defines a ray.remote function.
|
| 10 |
+
|
| 11 |
+
This is used in Datasets to avoid circular import issues with ray.remote.
|
| 12 |
+
(ray imports ray.data in order to allow ``ray.data.read_foo()`` to work,
|
| 13 |
+
which means ray.remote cannot be used top-level in ray.data).
|
| 14 |
+
|
| 15 |
+
NOTE: Dynamic arguments should not be passed in directly,
|
| 16 |
+
and should be set with ``options`` instead:
|
| 17 |
+
``cached_remote_fn(fn, **static_args).options(**dynamic_args)``.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# NOTE: Hash of the passed in arguments guarantees that we're caching
|
| 21 |
+
# complete instantiation of the Ray's remote method
|
| 22 |
+
#
|
| 23 |
+
# To compute the hash of passed in arguments and make sure it's deterministic
|
| 24 |
+
# - Sort all KV-pairs by the keys
|
| 25 |
+
# - Convert sorted list into tuple
|
| 26 |
+
# - Compute hash of the resulting tuple
|
| 27 |
+
hashable_args = _make_hashable(ray_remote_args)
|
| 28 |
+
args_hash = hash(hashable_args)
|
| 29 |
+
|
| 30 |
+
if (fn, args_hash) not in CACHED_FUNCTIONS:
|
| 31 |
+
default_ray_remote_args = {
|
| 32 |
+
# Use the default scheduling strategy for all tasks so that we will
|
| 33 |
+
# not inherit a placement group from the caller, if there is one.
|
| 34 |
+
# The caller of this function may override the scheduling strategy
|
| 35 |
+
# as needed.
|
| 36 |
+
"scheduling_strategy": "DEFAULT",
|
| 37 |
+
"max_retries": -1,
|
| 38 |
+
}
|
| 39 |
+
ray_remote_args = {**default_ray_remote_args, **ray_remote_args}
|
| 40 |
+
_add_system_error_to_retry_exceptions(ray_remote_args)
|
| 41 |
+
|
| 42 |
+
CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn)
|
| 43 |
+
|
| 44 |
+
return CACHED_FUNCTIONS[(fn, args_hash)]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _make_hashable(obj):
|
| 48 |
+
if isinstance(obj, (List, tuple)):
|
| 49 |
+
return tuple([_make_hashable(o) for o in obj])
|
| 50 |
+
elif isinstance(obj, Dict):
|
| 51 |
+
converted = [(_make_hashable(k), _make_hashable(v)) for k, v in obj.items()]
|
| 52 |
+
return tuple(sorted(converted, key=lambda t: t[0]))
|
| 53 |
+
elif isinstance(obj, Hashable):
|
| 54 |
+
return obj
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"Type {type(obj)} is not hashable")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _add_system_error_to_retry_exceptions(ray_remote_args) -> None:
|
| 60 |
+
"""Modify the remote args so that Ray retries `RaySystemError`s.
|
| 61 |
+
|
| 62 |
+
Ray typically automatically retries system errors. However, in some cases, Ray won't
|
| 63 |
+
retry system errors if they're raised from task code. To ensure that Ray Data is
|
| 64 |
+
fault tolerant to those errors, we need to add `RaySystemError` to the
|
| 65 |
+
`retry_exceptions` list.
|
| 66 |
+
|
| 67 |
+
TODO: Fix this in Ray Core. See https://github.com/ray-project/ray/pull/45079.
|
| 68 |
+
"""
|
| 69 |
+
retry_exceptions = ray_remote_args.get("retry_exceptions", False)
|
| 70 |
+
assert isinstance(retry_exceptions, (list, bool))
|
| 71 |
+
|
| 72 |
+
if (
|
| 73 |
+
isinstance(retry_exceptions, list)
|
| 74 |
+
and ray.exceptions.RaySystemError not in retry_exceptions
|
| 75 |
+
):
|
| 76 |
+
retry_exceptions.append(ray.exceptions.RaySystemError)
|
| 77 |
+
elif not retry_exceptions:
|
| 78 |
+
retry_exceptions = [ray.exceptions.RaySystemError]
|
| 79 |
+
|
| 80 |
+
ray_remote_args["retry_exceptions"] = retry_exceptions
|
.venv/lib/python3.11/site-packages/ray/data/_internal/row.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Mapping
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TableRow(Mapping):
|
| 6 |
+
"""
|
| 7 |
+
A dict-like row of a tabular ``Dataset``.
|
| 8 |
+
|
| 9 |
+
This implements the dictionary mapping interface, but provides more
|
| 10 |
+
efficient access with less data copying than converting Arrow Tables
|
| 11 |
+
or Pandas DataFrames into per-row dicts. This class must be subclassed,
|
| 12 |
+
with subclasses implementing ``__getitem__``, ``__iter__``, and ``__len__``.
|
| 13 |
+
|
| 14 |
+
Concrete subclasses include ``ray.data._internal.arrow_block.ArrowRow`` and
|
| 15 |
+
``ray.data._internal.pandas_block.PandasRow``.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, row: Any):
|
| 19 |
+
"""
|
| 20 |
+
Construct a ``TableRow`` (internal API).
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
row: The tabular row that backs this row mapping.
|
| 24 |
+
"""
|
| 25 |
+
self._row = row
|
| 26 |
+
|
| 27 |
+
def as_pydict(self) -> dict:
|
| 28 |
+
"""
|
| 29 |
+
Convert to a normal Python dict. This will create a new copy of the row."""
|
| 30 |
+
return dict(self.items())
|
| 31 |
+
|
| 32 |
+
def __str__(self):
|
| 33 |
+
return str(self.as_pydict())
|
| 34 |
+
|
| 35 |
+
def __repr__(self):
|
| 36 |
+
return str(self)
|
| 37 |
+
|
| 38 |
+
def _repr_pretty_(self, p, cycle):
|
| 39 |
+
from IPython.lib.pretty import _dict_pprinter_factory
|
| 40 |
+
|
| 41 |
+
pprinter = _dict_pprinter_factory("{", "}")
|
| 42 |
+
return pprinter(self, p, cycle)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/size_estimator.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
from ray import cloudpickle
|
| 5 |
+
|
| 6 |
+
_ray_initialized = False
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SizeEstimator:
|
| 10 |
+
"""Efficiently estimates the Ray serialized size of a stream of items.
|
| 11 |
+
|
| 12 |
+
For efficiency, this only samples a fraction of the added items for real
|
| 13 |
+
Ray-serialization.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self._running_mean = RunningMean()
|
| 18 |
+
self._count = 0
|
| 19 |
+
|
| 20 |
+
def add(self, item: Any) -> None:
|
| 21 |
+
self._count += 1
|
| 22 |
+
if self._count <= 10:
|
| 23 |
+
self._running_mean.add(self._real_size(item), weight=1)
|
| 24 |
+
elif self._count <= 100:
|
| 25 |
+
if self._count % 10 == 0:
|
| 26 |
+
self._running_mean.add(self._real_size(item), weight=10)
|
| 27 |
+
elif self._count % 100 == 0:
|
| 28 |
+
self._running_mean.add(self._real_size(item), weight=100)
|
| 29 |
+
|
| 30 |
+
def add_block(self, block: List[Any]) -> None:
|
| 31 |
+
if self._count < 10:
|
| 32 |
+
for i in range(min(10 - self._count, len(block))):
|
| 33 |
+
self._running_mean.add(self._real_size(block[i]), weight=1)
|
| 34 |
+
if self._count < 100:
|
| 35 |
+
for i in range(
|
| 36 |
+
10 - (self._count % 10), min(100 - self._count, len(block)), 10
|
| 37 |
+
):
|
| 38 |
+
self._running_mean.add(self._real_size(block[i]), weight=10)
|
| 39 |
+
if (len(block) + (self._count % 100)) // 100 > 1:
|
| 40 |
+
for i in range(100 - (self._count % 100), len(block), 100):
|
| 41 |
+
self._running_mean.add(self._real_size(block[i]), weight=100)
|
| 42 |
+
self._count += len(block)
|
| 43 |
+
|
| 44 |
+
def size_bytes(self) -> int:
|
| 45 |
+
return int(self._running_mean.mean * self._count)
|
| 46 |
+
|
| 47 |
+
def _real_size(self, item: Any) -> int:
|
| 48 |
+
is_client = ray.util.client.ray.is_connected()
|
| 49 |
+
# In client mode, fallback to using Ray cloudpickle instead of the
|
| 50 |
+
# real serializer.
|
| 51 |
+
if is_client:
|
| 52 |
+
return len(cloudpickle.dumps(item))
|
| 53 |
+
|
| 54 |
+
# We're using an internal Ray API, and have to ensure it's
|
| 55 |
+
# initialized # by calling a public API.
|
| 56 |
+
global _ray_initialized
|
| 57 |
+
if not _ray_initialized:
|
| 58 |
+
_ray_initialized = True
|
| 59 |
+
ray.put(None)
|
| 60 |
+
return (
|
| 61 |
+
ray._private.worker.global_worker.get_serialization_context()
|
| 62 |
+
.serialize(item)
|
| 63 |
+
.total_bytes
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Adapted from the RLlib MeanStdFilter.
|
| 68 |
+
class RunningMean:
|
| 69 |
+
def __init__(self):
|
| 70 |
+
self._weight = 0
|
| 71 |
+
self._mean = 0
|
| 72 |
+
|
| 73 |
+
def add(self, x: int, weight: int = 1) -> None:
|
| 74 |
+
if weight == 0:
|
| 75 |
+
return
|
| 76 |
+
n1 = self._weight
|
| 77 |
+
n2 = weight
|
| 78 |
+
n = n1 + n2
|
| 79 |
+
M = (n1 * self._mean + n2 * x) / n
|
| 80 |
+
self._weight = n
|
| 81 |
+
self._mean = M
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def n(self) -> int:
|
| 85 |
+
return self._weight
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def mean(self) -> float:
|
| 89 |
+
return self._mean
|
| 90 |
+
|
| 91 |
+
def __repr__(self):
|
| 92 |
+
return "(n={}, mean={})".format(self.n, self.mean)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/split.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Iterable, List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import ray
|
| 6 |
+
from ray.data._internal.memory_tracing import trace_deallocation
|
| 7 |
+
from ray.data._internal.remote_fn import cached_remote_fn
|
| 8 |
+
from ray.data.block import (
|
| 9 |
+
Block,
|
| 10 |
+
BlockAccessor,
|
| 11 |
+
BlockExecStats,
|
| 12 |
+
BlockMetadata,
|
| 13 |
+
BlockPartition,
|
| 14 |
+
)
|
| 15 |
+
from ray.types import ObjectRef
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _calculate_blocks_rows(
|
| 21 |
+
blocks_with_metadata: BlockPartition,
|
| 22 |
+
) -> List[int]:
|
| 23 |
+
"""Calculate the number of rows for a list of blocks with metadata."""
|
| 24 |
+
get_num_rows = cached_remote_fn(_get_num_rows)
|
| 25 |
+
block_rows = []
|
| 26 |
+
for block, metadata in blocks_with_metadata:
|
| 27 |
+
if metadata.num_rows is None:
|
| 28 |
+
# Need to fetch number of rows.
|
| 29 |
+
num_rows = ray.get(get_num_rows.remote(block))
|
| 30 |
+
metadata.num_rows = num_rows
|
| 31 |
+
else:
|
| 32 |
+
num_rows = metadata.num_rows
|
| 33 |
+
block_rows.append(num_rows)
|
| 34 |
+
return block_rows
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _generate_valid_indices(
|
| 38 |
+
num_rows_per_block: List[int],
|
| 39 |
+
split_indices: List[int],
|
| 40 |
+
) -> List[int]:
|
| 41 |
+
"""Generate valid split indices by apply min(index, total_num_rows)
|
| 42 |
+
to every index."""
|
| 43 |
+
total_rows = sum(num_rows_per_block)
|
| 44 |
+
return [min(index, total_rows) for index in split_indices]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _generate_per_block_split_indices(
|
| 48 |
+
num_rows_per_block: List[int],
|
| 49 |
+
split_indices: List[int],
|
| 50 |
+
) -> List[List[int]]:
|
| 51 |
+
"""Given num rows per block and valid split indices, generate per block split indices.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
num_rows_per_block: num of rows per block.
|
| 55 |
+
split_indices: The (global) indices at which to split the blocks.
|
| 56 |
+
Returns:
|
| 57 |
+
Per block split indices indicates each input block's split point(s).
|
| 58 |
+
"""
|
| 59 |
+
# for each split index, we iterate though the currnet input block
|
| 60 |
+
# to see if the index falls into this block. if the index
|
| 61 |
+
# falls into this block, we push it back to the current block's
|
| 62 |
+
# split indices. Otherwise, we move on to the next block.
|
| 63 |
+
per_block_split_indices = []
|
| 64 |
+
current_input_block_id = 0
|
| 65 |
+
current_block_split_indices = []
|
| 66 |
+
current_block_global_offset = 0
|
| 67 |
+
current_index_id = 0
|
| 68 |
+
|
| 69 |
+
while current_index_id < len(split_indices):
|
| 70 |
+
split_index = split_indices[current_index_id]
|
| 71 |
+
current_block_row = num_rows_per_block[current_input_block_id]
|
| 72 |
+
if split_index - current_block_global_offset <= current_block_row:
|
| 73 |
+
current_block_split_indices.append(
|
| 74 |
+
split_index - current_block_global_offset
|
| 75 |
+
)
|
| 76 |
+
current_index_id += 1
|
| 77 |
+
continue
|
| 78 |
+
per_block_split_indices.append(current_block_split_indices)
|
| 79 |
+
current_block_split_indices = []
|
| 80 |
+
current_block_global_offset += num_rows_per_block[current_input_block_id]
|
| 81 |
+
current_input_block_id += 1
|
| 82 |
+
|
| 83 |
+
# we might finished all the indices but there are still blocks left, also
|
| 84 |
+
# current_block_split_indices might not be added yet.
|
| 85 |
+
while len(per_block_split_indices) < len(num_rows_per_block):
|
| 86 |
+
per_block_split_indices.append(current_block_split_indices)
|
| 87 |
+
current_block_split_indices = []
|
| 88 |
+
return per_block_split_indices
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _split_single_block(
|
| 92 |
+
block_id: int,
|
| 93 |
+
block: Block,
|
| 94 |
+
meta: BlockMetadata,
|
| 95 |
+
split_indices: List[int],
|
| 96 |
+
) -> Tuple[Union[Tuple[int, List[BlockMetadata]], Block], ...]:
|
| 97 |
+
"""Split the provided block at the given indices.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
block_id: the id of this block in the block list.
|
| 101 |
+
block: block to be split.
|
| 102 |
+
meta: metadata of the block, we expect meta.num is valid.
|
| 103 |
+
split_indices: the indices where the block should be split.
|
| 104 |
+
Returns:
|
| 105 |
+
returns block_id, split blocks metadata, and a list of blocks
|
| 106 |
+
in the following form. We return blocks in this way
|
| 107 |
+
so that the owner of blocks could be the caller(driver)
|
| 108 |
+
instead of worker itself.
|
| 109 |
+
Tuple(block_id, split_blocks_meta), block0, block1 ...
|
| 110 |
+
"""
|
| 111 |
+
split_meta = []
|
| 112 |
+
split_blocks = []
|
| 113 |
+
block_accessor = BlockAccessor.for_block(block)
|
| 114 |
+
prev_index = 0
|
| 115 |
+
# append one more entry at the last so we don't
|
| 116 |
+
# need handle empty edge case.
|
| 117 |
+
split_indices.append(meta.num_rows)
|
| 118 |
+
for index in split_indices:
|
| 119 |
+
logger.debug(f"slicing block {prev_index}:{index}")
|
| 120 |
+
stats = BlockExecStats.builder()
|
| 121 |
+
split_block = block_accessor.slice(prev_index, index)
|
| 122 |
+
accessor = BlockAccessor.for_block(split_block)
|
| 123 |
+
_meta = BlockMetadata(
|
| 124 |
+
num_rows=accessor.num_rows(),
|
| 125 |
+
size_bytes=accessor.size_bytes(),
|
| 126 |
+
schema=meta.schema,
|
| 127 |
+
input_files=meta.input_files,
|
| 128 |
+
exec_stats=stats.build(),
|
| 129 |
+
)
|
| 130 |
+
split_meta.append(_meta)
|
| 131 |
+
split_blocks.append(split_block)
|
| 132 |
+
prev_index = index
|
| 133 |
+
results = [(block_id, split_meta)]
|
| 134 |
+
results.extend(split_blocks)
|
| 135 |
+
return tuple(results)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]:
|
| 139 |
+
"""drop split indices that creates empty block split. This could happen when there
|
| 140 |
+
are duplicated indices, or index equal to 0 (start of the block) or num_block_rows
|
| 141 |
+
(end of the block).
|
| 142 |
+
"""
|
| 143 |
+
prev_index = -1
|
| 144 |
+
optimized_indices = []
|
| 145 |
+
for index in block_split_indices:
|
| 146 |
+
if index == 0 or index == num_rows:
|
| 147 |
+
continue
|
| 148 |
+
if index == prev_index:
|
| 149 |
+
continue
|
| 150 |
+
optimized_indices.append(index)
|
| 151 |
+
prev_index = index
|
| 152 |
+
return optimized_indices
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _split_all_blocks(
|
| 156 |
+
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
|
| 157 |
+
per_block_split_indices: List[List[int]],
|
| 158 |
+
owned_by_consumer: bool,
|
| 159 |
+
) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]:
|
| 160 |
+
"""Split all the input blocks based on the split indices"""
|
| 161 |
+
split_single_block = cached_remote_fn(_split_single_block)
|
| 162 |
+
|
| 163 |
+
all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)
|
| 164 |
+
|
| 165 |
+
per_block_split_metadata_futures = []
|
| 166 |
+
per_block_split_block_refs = []
|
| 167 |
+
|
| 168 |
+
# tracking splitted blocks for gc.
|
| 169 |
+
blocks_splitted = []
|
| 170 |
+
for block_id, block_split_indices in enumerate(per_block_split_indices):
|
| 171 |
+
(block_ref, meta) = blocks_with_metadata[block_id]
|
| 172 |
+
block_row = meta.num_rows
|
| 173 |
+
block_split_indices = _drop_empty_block_split(block_split_indices, block_row)
|
| 174 |
+
if len(block_split_indices) == 0:
|
| 175 |
+
# optimization: if no split is needed, we just need to add it to the
|
| 176 |
+
# result
|
| 177 |
+
all_blocks_split_results[block_id] = [(block_ref, meta)]
|
| 178 |
+
else:
|
| 179 |
+
# otherwise call split remote function.
|
| 180 |
+
object_refs = split_single_block.options(
|
| 181 |
+
scheduling_strategy="SPREAD", num_returns=2 + len(block_split_indices)
|
| 182 |
+
).remote(
|
| 183 |
+
block_id,
|
| 184 |
+
block_ref,
|
| 185 |
+
meta,
|
| 186 |
+
block_split_indices,
|
| 187 |
+
)
|
| 188 |
+
per_block_split_metadata_futures.append(object_refs[0])
|
| 189 |
+
per_block_split_block_refs.append(object_refs[1:])
|
| 190 |
+
|
| 191 |
+
blocks_splitted.append(block_ref)
|
| 192 |
+
|
| 193 |
+
if per_block_split_metadata_futures:
|
| 194 |
+
# only get metadata.
|
| 195 |
+
per_block_split_metadata = ray.get(per_block_split_metadata_futures)
|
| 196 |
+
for (block_id, meta), block_refs in zip(
|
| 197 |
+
per_block_split_metadata, per_block_split_block_refs
|
| 198 |
+
):
|
| 199 |
+
assert len(meta) == len(block_refs)
|
| 200 |
+
all_blocks_split_results[block_id] = zip(block_refs, meta)
|
| 201 |
+
|
| 202 |
+
# We make a copy for the blocks that have been splitted, so the input blocks
|
| 203 |
+
# can be cleared if they are owned by consumer (consumer-owned blocks will
|
| 204 |
+
# only be consumed by the owner).
|
| 205 |
+
if owned_by_consumer:
|
| 206 |
+
for b in blocks_splitted:
|
| 207 |
+
trace_deallocation(b, "split._split_all_blocks")
|
| 208 |
+
else:
|
| 209 |
+
for b in blocks_splitted:
|
| 210 |
+
trace_deallocation(b, "split._split_all_blocks", free=False)
|
| 211 |
+
|
| 212 |
+
return itertools.chain.from_iterable(all_blocks_split_results)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _generate_global_split_results(
|
| 216 |
+
all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]],
|
| 217 |
+
global_split_sizes: List[int],
|
| 218 |
+
) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
|
| 219 |
+
"""Reassemble per block's split result into final split result."""
|
| 220 |
+
result_blocks = []
|
| 221 |
+
result_metas = []
|
| 222 |
+
|
| 223 |
+
current_blocks = []
|
| 224 |
+
current_meta = []
|
| 225 |
+
current_split_size = 0
|
| 226 |
+
current_split_id = 0
|
| 227 |
+
|
| 228 |
+
while current_split_id < len(global_split_sizes):
|
| 229 |
+
if current_split_size >= global_split_sizes[current_split_id]:
|
| 230 |
+
assert current_split_size == global_split_sizes[current_split_id]
|
| 231 |
+
result_blocks.append(current_blocks)
|
| 232 |
+
result_metas.append(current_meta)
|
| 233 |
+
|
| 234 |
+
current_blocks = []
|
| 235 |
+
current_meta = []
|
| 236 |
+
current_split_size = 0
|
| 237 |
+
current_split_id += 1
|
| 238 |
+
else:
|
| 239 |
+
(block_ref, meta) = next(all_blocks_split_results)
|
| 240 |
+
current_blocks.append(block_ref)
|
| 241 |
+
current_meta.append(meta)
|
| 242 |
+
current_split_size += meta.num_rows
|
| 243 |
+
|
| 244 |
+
return result_blocks, result_metas
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _split_at_indices(
|
| 248 |
+
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
|
| 249 |
+
indices: List[int],
|
| 250 |
+
owned_by_consumer: bool = True,
|
| 251 |
+
block_rows: List[int] = None,
|
| 252 |
+
) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
|
| 253 |
+
"""Split blocks at the provided indices.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
blocks_with_metadata: Block futures to split, including the associated metadata.
|
| 257 |
+
indices: The (global) indices at which to split the blocks.
|
| 258 |
+
owned_by_consumer: Whether the provided blocks are owned by the consumer.
|
| 259 |
+
block_rows: The number of rows for each block, in case it has already been
|
| 260 |
+
computed.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
The block split futures and their metadata. If an index split is empty, the
|
| 264 |
+
corresponding block split will be empty .
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
# We implement the split in 3 phases.
|
| 268 |
+
# phase 1: calculate the per block split indices.
|
| 269 |
+
blocks_with_metadata = list(blocks_with_metadata)
|
| 270 |
+
if len(blocks_with_metadata) == 0:
|
| 271 |
+
return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1))
|
| 272 |
+
if block_rows is None:
|
| 273 |
+
block_rows = _calculate_blocks_rows(blocks_with_metadata)
|
| 274 |
+
valid_indices = _generate_valid_indices(block_rows, indices)
|
| 275 |
+
per_block_split_indices: List[List[int]] = _generate_per_block_split_indices(
|
| 276 |
+
block_rows, valid_indices
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# phase 2: split each block based on the indices from previous step.
|
| 280 |
+
all_blocks_split_results: Iterable[
|
| 281 |
+
Tuple[ObjectRef[Block], BlockMetadata]
|
| 282 |
+
] = _split_all_blocks(
|
| 283 |
+
blocks_with_metadata, per_block_split_indices, owned_by_consumer
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# phase 3: generate the final split.
|
| 287 |
+
|
| 288 |
+
# first calculate the size for each split.
|
| 289 |
+
helper = [0] + valid_indices + [sum(block_rows)]
|
| 290 |
+
split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))]
|
| 291 |
+
|
| 292 |
+
return _generate_global_split_results(all_blocks_split_results, split_sizes)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _get_num_rows(block: Block) -> int:
|
| 296 |
+
"""Get the number of rows contained in the provided block."""
|
| 297 |
+
return BlockAccessor.for_block(block).num_rows()
|
.venv/lib/python3.11/site-packages/ray/data/_internal/stats.py
ADDED
|
@@ -0,0 +1,1495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import threading
|
| 4 |
+
import time
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
| 8 |
+
from uuid import uuid4
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import ray
|
| 13 |
+
from ray.actor import ActorHandle
|
| 14 |
+
from ray.data._internal.block_list import BlockList
|
| 15 |
+
from ray.data._internal.execution.interfaces.op_runtime_metrics import (
|
| 16 |
+
MetricsGroup,
|
| 17 |
+
OpRuntimeMetrics,
|
| 18 |
+
)
|
| 19 |
+
from ray.data._internal.util import capfirst
|
| 20 |
+
from ray.data.block import BlockMetadata
|
| 21 |
+
from ray.data.context import DataContext
|
| 22 |
+
from ray.util.annotations import DeveloperAPI
|
| 23 |
+
from ray.util.metrics import Gauge
|
| 24 |
+
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
STATS_ACTOR_NAME = "datasets_stats_actor"
|
| 29 |
+
STATS_ACTOR_NAMESPACE = "_dataset_stats_actor"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
StatsDict = Dict[str, List[BlockMetadata]]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fmt(seconds: float) -> str:
|
| 36 |
+
if seconds > 1:
|
| 37 |
+
return str(round(seconds, 2)) + "s"
|
| 38 |
+
elif seconds > 0.001:
|
| 39 |
+
return str(round(seconds * 1000, 2)) + "ms"
|
| 40 |
+
else:
|
| 41 |
+
return str(round(seconds * 1000 * 1000, 2)) + "us"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def leveled_indent(lvl: int = 0, spaces_per_indent: int = 3) -> str:
|
| 45 |
+
"""Returns a string of spaces which contains `level` indents,
|
| 46 |
+
each indent containing `spaces_per_indent` spaces. For example:
|
| 47 |
+
>>> leveled_indent(2, 3)
|
| 48 |
+
' '
|
| 49 |
+
"""
|
| 50 |
+
return (" " * spaces_per_indent) * lvl
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Timer:
|
| 54 |
+
"""Helper class for tracking accumulated time (in seconds)."""
|
| 55 |
+
|
| 56 |
+
def __init__(self):
|
| 57 |
+
self._value: float = 0
|
| 58 |
+
self._min: float = float("inf")
|
| 59 |
+
self._max: float = 0
|
| 60 |
+
self._total_count: float = 0
|
| 61 |
+
|
| 62 |
+
@contextmanager
|
| 63 |
+
def timer(self) -> None:
|
| 64 |
+
time_start = time.perf_counter()
|
| 65 |
+
try:
|
| 66 |
+
yield
|
| 67 |
+
finally:
|
| 68 |
+
self.add(time.perf_counter() - time_start)
|
| 69 |
+
|
| 70 |
+
def add(self, value: float) -> None:
|
| 71 |
+
self._value += value
|
| 72 |
+
if value < self._min:
|
| 73 |
+
self._min = value
|
| 74 |
+
if value > self._max:
|
| 75 |
+
self._max = value
|
| 76 |
+
self._total_count += 1
|
| 77 |
+
|
| 78 |
+
def get(self) -> float:
|
| 79 |
+
return self._value
|
| 80 |
+
|
| 81 |
+
def min(self) -> float:
|
| 82 |
+
return self._min
|
| 83 |
+
|
| 84 |
+
def max(self) -> float:
|
| 85 |
+
return self._max
|
| 86 |
+
|
| 87 |
+
def avg(self) -> float:
|
| 88 |
+
return self._value / self._total_count if self._total_count else float("inf")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class _DatasetStatsBuilder:
|
| 92 |
+
"""Helper class for building dataset stats.
|
| 93 |
+
|
| 94 |
+
When this class is created, we record the start time. When build() is
|
| 95 |
+
called with the final blocks of the new dataset, the time delta is
|
| 96 |
+
saved as part of the stats."""
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
operator_name: str,
|
| 101 |
+
parent: "DatasetStats",
|
| 102 |
+
override_start_time: Optional[float],
|
| 103 |
+
):
|
| 104 |
+
self.operator_name = operator_name
|
| 105 |
+
self.parent = parent
|
| 106 |
+
self.start_time = override_start_time or time.perf_counter()
|
| 107 |
+
|
| 108 |
+
def build_multioperator(self, metadata: StatsDict) -> "DatasetStats":
|
| 109 |
+
op_metadata = {}
|
| 110 |
+
for i, (k, v) in enumerate(metadata.items()):
|
| 111 |
+
capped_k = capfirst(k)
|
| 112 |
+
if len(metadata) > 1:
|
| 113 |
+
if i == 0:
|
| 114 |
+
op_metadata[self.operator_name + capped_k] = v
|
| 115 |
+
else:
|
| 116 |
+
op_metadata[self.operator_name.split("->")[-1] + capped_k] = v
|
| 117 |
+
else:
|
| 118 |
+
op_metadata[self.operator_name] = v
|
| 119 |
+
stats = DatasetStats(
|
| 120 |
+
metadata=op_metadata,
|
| 121 |
+
parent=self.parent,
|
| 122 |
+
base_name=self.operator_name,
|
| 123 |
+
)
|
| 124 |
+
stats.time_total_s = time.perf_counter() - self.start_time
|
| 125 |
+
return stats
|
| 126 |
+
|
| 127 |
+
def build(self, final_blocks: BlockList) -> "DatasetStats":
|
| 128 |
+
stats = DatasetStats(
|
| 129 |
+
metadata={self.operator_name: final_blocks.get_metadata()},
|
| 130 |
+
parent=self.parent,
|
| 131 |
+
)
|
| 132 |
+
stats.time_total_s = time.perf_counter() - self.start_time
|
| 133 |
+
return stats
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@ray.remote(num_cpus=0)
|
| 137 |
+
class _StatsActor:
|
| 138 |
+
"""Actor holding stats for blocks created by LazyBlockList.
|
| 139 |
+
|
| 140 |
+
This actor is shared across all datasets created in the same cluster.
|
| 141 |
+
In order to cap memory usage, we set a max number of stats to keep
|
| 142 |
+
in the actor. When this limit is exceeded, the stats will be garbage
|
| 143 |
+
collected in FIFO order.
|
| 144 |
+
|
| 145 |
+
TODO(ekl) we should consider refactoring LazyBlockList so stats can be
|
| 146 |
+
extracted without using an out-of-band actor."""
|
| 147 |
+
|
| 148 |
+
def __init__(self, max_stats=1000):
|
| 149 |
+
# Mapping from uuid -> (task_id -> list of blocks statistics).
|
| 150 |
+
self.metadata = collections.defaultdict(dict)
|
| 151 |
+
self.last_time = {}
|
| 152 |
+
self.start_time = {}
|
| 153 |
+
self.max_stats = max_stats
|
| 154 |
+
self.fifo_queue = []
|
| 155 |
+
|
| 156 |
+
# Assign dataset uuids with a global counter.
|
| 157 |
+
self.next_dataset_id = 0
|
| 158 |
+
# Dataset metadata to be queried directly by DashboardHead api.
|
| 159 |
+
self.datasets: Dict[str, Any] = {}
|
| 160 |
+
|
| 161 |
+
# Ray Data dashboard metrics
|
| 162 |
+
# Everything is a gauge because we need to reset all of
|
| 163 |
+
# a dataset's metrics to 0 after each finishes execution.
|
| 164 |
+
op_tags_keys = ("dataset", "operator")
|
| 165 |
+
|
| 166 |
+
# TODO(scottjlee): move these overvie metrics as fields in a
|
| 167 |
+
# separate dataclass, similar to OpRuntimeMetrics.
|
| 168 |
+
self.spilled_bytes = Gauge(
|
| 169 |
+
"data_spilled_bytes",
|
| 170 |
+
description="""Bytes spilled by dataset operators.
|
| 171 |
+
DataContext.enable_get_object_locations_for_metrics
|
| 172 |
+
must be set to True to report this metric""",
|
| 173 |
+
tag_keys=op_tags_keys,
|
| 174 |
+
)
|
| 175 |
+
self.allocated_bytes = Gauge(
|
| 176 |
+
"data_allocated_bytes",
|
| 177 |
+
description="Bytes allocated by dataset operators",
|
| 178 |
+
tag_keys=op_tags_keys,
|
| 179 |
+
)
|
| 180 |
+
self.freed_bytes = Gauge(
|
| 181 |
+
"data_freed_bytes",
|
| 182 |
+
description="Bytes freed by dataset operators",
|
| 183 |
+
tag_keys=op_tags_keys,
|
| 184 |
+
)
|
| 185 |
+
self.current_bytes = Gauge(
|
| 186 |
+
"data_current_bytes",
|
| 187 |
+
description="Bytes currently in memory store used by dataset operators",
|
| 188 |
+
tag_keys=op_tags_keys,
|
| 189 |
+
)
|
| 190 |
+
self.cpu_usage_cores = Gauge(
|
| 191 |
+
"data_cpu_usage_cores",
|
| 192 |
+
description="CPUs allocated to dataset operators",
|
| 193 |
+
tag_keys=op_tags_keys,
|
| 194 |
+
)
|
| 195 |
+
self.gpu_usage_cores = Gauge(
|
| 196 |
+
"data_gpu_usage_cores",
|
| 197 |
+
description="GPUs allocated to dataset operators",
|
| 198 |
+
tag_keys=op_tags_keys,
|
| 199 |
+
)
|
| 200 |
+
self.output_bytes = Gauge(
|
| 201 |
+
"data_output_bytes",
|
| 202 |
+
description="Bytes outputted by dataset operators",
|
| 203 |
+
tag_keys=op_tags_keys,
|
| 204 |
+
)
|
| 205 |
+
self.output_rows = Gauge(
|
| 206 |
+
"data_output_rows",
|
| 207 |
+
description="Rows outputted by dataset operators",
|
| 208 |
+
tag_keys=op_tags_keys,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# === Metrics from OpRuntimeMetrics ===
|
| 212 |
+
# Inputs-related metrics
|
| 213 |
+
self.execution_metrics_inputs = (
|
| 214 |
+
self._create_prometheus_metrics_for_execution_metrics(
|
| 215 |
+
metrics_group=MetricsGroup.INPUTS,
|
| 216 |
+
tag_keys=op_tags_keys,
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Outputs-related metrics
|
| 221 |
+
self.execution_metrics_outputs = (
|
| 222 |
+
self._create_prometheus_metrics_for_execution_metrics(
|
| 223 |
+
metrics_group=MetricsGroup.OUTPUTS,
|
| 224 |
+
tag_keys=op_tags_keys,
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Task-related metrics
|
| 229 |
+
self.execution_metrics_tasks = (
|
| 230 |
+
self._create_prometheus_metrics_for_execution_metrics(
|
| 231 |
+
metrics_group=MetricsGroup.TASKS,
|
| 232 |
+
tag_keys=op_tags_keys,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Object store memory-related metrics
|
| 237 |
+
self.execution_metrics_obj_store_memory = (
|
| 238 |
+
self._create_prometheus_metrics_for_execution_metrics(
|
| 239 |
+
metrics_group=MetricsGroup.OBJECT_STORE_MEMORY,
|
| 240 |
+
tag_keys=op_tags_keys,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Miscellaneous metrics
|
| 245 |
+
self.execution_metrics_misc = (
|
| 246 |
+
self._create_prometheus_metrics_for_execution_metrics(
|
| 247 |
+
metrics_group=MetricsGroup.MISC,
|
| 248 |
+
tag_keys=op_tags_keys,
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
iter_tag_keys = ("dataset",)
|
| 253 |
+
self.iter_total_blocked_s = Gauge(
|
| 254 |
+
"data_iter_total_blocked_seconds",
|
| 255 |
+
description="Seconds user thread is blocked by iter_batches()",
|
| 256 |
+
tag_keys=iter_tag_keys,
|
| 257 |
+
)
|
| 258 |
+
self.iter_user_s = Gauge(
|
| 259 |
+
"data_iter_user_seconds",
|
| 260 |
+
description="Seconds spent in user code",
|
| 261 |
+
tag_keys=iter_tag_keys,
|
| 262 |
+
)
|
| 263 |
+
self.iter_initialize_s = Gauge(
|
| 264 |
+
"data_iter_initialize_seconds",
|
| 265 |
+
description="Seconds spent in iterator initialization code",
|
| 266 |
+
tag_keys=iter_tag_keys,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def _create_prometheus_metrics_for_execution_metrics(
|
| 270 |
+
self, metrics_group: MetricsGroup, tag_keys: Tuple[str, ...]
|
| 271 |
+
) -> Dict[str, Gauge]:
|
| 272 |
+
metrics = {}
|
| 273 |
+
for metric in OpRuntimeMetrics.get_metrics():
|
| 274 |
+
if not metric.metrics_group == metrics_group:
|
| 275 |
+
continue
|
| 276 |
+
metric_name = f"data_{metric.name}"
|
| 277 |
+
metric_description = metric.description
|
| 278 |
+
metrics[metric.name] = Gauge(
|
| 279 |
+
metric_name,
|
| 280 |
+
description=metric_description,
|
| 281 |
+
tag_keys=tag_keys,
|
| 282 |
+
)
|
| 283 |
+
return metrics
|
| 284 |
+
|
| 285 |
+
def record_start(self, stats_uuid):
|
| 286 |
+
self.start_time[stats_uuid] = time.perf_counter()
|
| 287 |
+
self.fifo_queue.append(stats_uuid)
|
| 288 |
+
# Purge the oldest stats if the limit is exceeded.
|
| 289 |
+
if len(self.fifo_queue) > self.max_stats:
|
| 290 |
+
uuid = self.fifo_queue.pop(0)
|
| 291 |
+
if uuid in self.start_time:
|
| 292 |
+
del self.start_time[uuid]
|
| 293 |
+
if uuid in self.last_time:
|
| 294 |
+
del self.last_time[uuid]
|
| 295 |
+
if uuid in self.metadata:
|
| 296 |
+
del self.metadata[uuid]
|
| 297 |
+
|
| 298 |
+
def record_task(
|
| 299 |
+
self, stats_uuid: str, task_idx: int, blocks_metadata: List[BlockMetadata]
|
| 300 |
+
):
|
| 301 |
+
# Null out the schema to keep the stats size small.
|
| 302 |
+
# TODO(chengsu): ideally schema should be null out on caller side.
|
| 303 |
+
for metadata in blocks_metadata:
|
| 304 |
+
metadata.schema = None
|
| 305 |
+
if stats_uuid in self.start_time:
|
| 306 |
+
self.metadata[stats_uuid][task_idx] = blocks_metadata
|
| 307 |
+
self.last_time[stats_uuid] = time.perf_counter()
|
| 308 |
+
|
| 309 |
+
def get(self, stats_uuid):
|
| 310 |
+
if stats_uuid not in self.metadata:
|
| 311 |
+
return {}, 0.0
|
| 312 |
+
return (
|
| 313 |
+
self.metadata[stats_uuid],
|
| 314 |
+
self.last_time[stats_uuid] - self.start_time[stats_uuid],
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def _get_stats_dict_size(self):
|
| 318 |
+
return len(self.start_time), len(self.last_time), len(self.metadata)
|
| 319 |
+
|
| 320 |
+
def get_dataset_id(self):
|
| 321 |
+
dataset_id = str(self.next_dataset_id)
|
| 322 |
+
self.next_dataset_id += 1
|
| 323 |
+
return dataset_id
|
| 324 |
+
|
| 325 |
+
def update_metrics(self, execution_metrics, iteration_metrics):
|
| 326 |
+
for metrics in execution_metrics:
|
| 327 |
+
self.update_execution_metrics(*metrics)
|
| 328 |
+
for metrics in iteration_metrics:
|
| 329 |
+
self.update_iteration_metrics(*metrics)
|
| 330 |
+
|
| 331 |
+
def update_execution_metrics(
|
| 332 |
+
self,
|
| 333 |
+
dataset_tag: str,
|
| 334 |
+
op_metrics: List[Dict[str, Union[int, float]]],
|
| 335 |
+
operator_tags: List[str],
|
| 336 |
+
state: Dict[str, Any],
|
| 337 |
+
):
|
| 338 |
+
for stats, operator_tag in zip(op_metrics, operator_tags):
|
| 339 |
+
tags = self._create_tags(dataset_tag, operator_tag)
|
| 340 |
+
|
| 341 |
+
self.spilled_bytes.set(stats.get("obj_store_mem_spilled", 0), tags)
|
| 342 |
+
self.freed_bytes.set(stats.get("obj_store_mem_freed", 0), tags)
|
| 343 |
+
self.current_bytes.set(stats.get("obj_store_mem_used", 0), tags)
|
| 344 |
+
self.output_bytes.set(stats.get("bytes_task_outputs_generated", 0), tags)
|
| 345 |
+
self.output_rows.set(stats.get("rows_task_outputs_generated", 0), tags)
|
| 346 |
+
self.cpu_usage_cores.set(stats.get("cpu_usage", 0), tags)
|
| 347 |
+
self.gpu_usage_cores.set(stats.get("gpu_usage", 0), tags)
|
| 348 |
+
|
| 349 |
+
for field_name, prom_metric in self.execution_metrics_inputs.items():
|
| 350 |
+
prom_metric.set(stats.get(field_name, 0), tags)
|
| 351 |
+
|
| 352 |
+
for field_name, prom_metric in self.execution_metrics_outputs.items():
|
| 353 |
+
prom_metric.set(stats.get(field_name, 0), tags)
|
| 354 |
+
|
| 355 |
+
for field_name, prom_metric in self.execution_metrics_tasks.items():
|
| 356 |
+
prom_metric.set(stats.get(field_name, 0), tags)
|
| 357 |
+
|
| 358 |
+
for (
|
| 359 |
+
field_name,
|
| 360 |
+
prom_metric,
|
| 361 |
+
) in self.execution_metrics_obj_store_memory.items():
|
| 362 |
+
prom_metric.set(stats.get(field_name, 0), tags)
|
| 363 |
+
|
| 364 |
+
for field_name, prom_metric in self.execution_metrics_misc.items():
|
| 365 |
+
prom_metric.set(stats.get(field_name, 0), tags)
|
| 366 |
+
|
| 367 |
+
# This update is called from a dataset's executor,
|
| 368 |
+
# so all tags should contain the same dataset
|
| 369 |
+
self.update_dataset(dataset_tag, state)
|
| 370 |
+
|
| 371 |
+
def update_iteration_metrics(
|
| 372 |
+
self,
|
| 373 |
+
stats: "DatasetStats",
|
| 374 |
+
dataset_tag,
|
| 375 |
+
):
|
| 376 |
+
tags = self._create_tags(dataset_tag)
|
| 377 |
+
self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
|
| 378 |
+
self.iter_user_s.set(stats.iter_user_s.get(), tags)
|
| 379 |
+
self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)
|
| 380 |
+
|
| 381 |
+
def register_dataset(self, job_id: str, dataset_tag: str, operator_tags: List[str]):
|
| 382 |
+
self.datasets[dataset_tag] = {
|
| 383 |
+
"job_id": job_id,
|
| 384 |
+
"state": "RUNNING",
|
| 385 |
+
"progress": 0,
|
| 386 |
+
"total": 0,
|
| 387 |
+
"start_time": time.time(),
|
| 388 |
+
"end_time": None,
|
| 389 |
+
"operators": {
|
| 390 |
+
operator: {
|
| 391 |
+
"state": "RUNNING",
|
| 392 |
+
"progress": 0,
|
| 393 |
+
"total": 0,
|
| 394 |
+
}
|
| 395 |
+
for operator in operator_tags
|
| 396 |
+
},
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
def update_dataset(self, dataset_tag, state):
|
| 400 |
+
self.datasets[dataset_tag].update(state)
|
| 401 |
+
|
| 402 |
+
def get_datasets(self, job_id: Optional[str] = None):
|
| 403 |
+
if not job_id:
|
| 404 |
+
return self.datasets
|
| 405 |
+
return {k: v for k, v in self.datasets.items() if v["job_id"] == job_id}
|
| 406 |
+
|
| 407 |
+
def _create_tags(self, dataset_tag: str, operator_tag: Optional[str] = None):
|
| 408 |
+
tags = {"dataset": dataset_tag}
|
| 409 |
+
if operator_tag is not None:
|
| 410 |
+
tags["operator"] = operator_tag
|
| 411 |
+
return tags
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# Creating/getting an actor from multiple threads is not safe.
|
| 415 |
+
# https://github.com/ray-project/ray/issues/41324
|
| 416 |
+
_stats_actor_lock: threading.RLock = threading.RLock()
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _get_or_create_stats_actor():
|
| 420 |
+
ctx = DataContext.get_current()
|
| 421 |
+
scheduling_strategy = ctx.scheduling_strategy
|
| 422 |
+
if not ray.util.client.ray.is_connected():
|
| 423 |
+
# Pin the stats actor to the local node
|
| 424 |
+
# so it fate-shares with the driver.
|
| 425 |
+
scheduling_strategy = NodeAffinitySchedulingStrategy(
|
| 426 |
+
ray.get_runtime_context().get_node_id(),
|
| 427 |
+
soft=False,
|
| 428 |
+
)
|
| 429 |
+
with _stats_actor_lock:
|
| 430 |
+
return _StatsActor.options(
|
| 431 |
+
name=STATS_ACTOR_NAME,
|
| 432 |
+
namespace=STATS_ACTOR_NAMESPACE,
|
| 433 |
+
get_if_exists=True,
|
| 434 |
+
lifetime="detached",
|
| 435 |
+
scheduling_strategy=scheduling_strategy,
|
| 436 |
+
).remote()
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class _StatsManager:
|
| 440 |
+
"""A Class containing util functions that manage remote calls to _StatsActor.
|
| 441 |
+
|
| 442 |
+
This class collects stats from execution and iteration codepaths and keeps
|
| 443 |
+
track of the latest snapshot.
|
| 444 |
+
|
| 445 |
+
An instance of this class runs a single background thread that periodically
|
| 446 |
+
forwards the latest execution/iteration stats to the _StatsActor.
|
| 447 |
+
|
| 448 |
+
This thread will terminate itself after being inactive (meaning that there are
|
| 449 |
+
no active executors or iterators) for STATS_ACTOR_UPDATE_THREAD_INACTIVITY_LIMIT
|
| 450 |
+
iterations. After terminating, a new thread will start if more calls are made
|
| 451 |
+
to this class.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
# Interval for making remote calls to the _StatsActor.
|
| 455 |
+
STATS_ACTOR_UPDATE_INTERVAL_SECONDS = 5
|
| 456 |
+
|
| 457 |
+
# After this many iterations of inactivity,
|
| 458 |
+
# _StatsManager._update_thread will close itself.
|
| 459 |
+
UPDATE_THREAD_INACTIVITY_LIMIT = 5
|
| 460 |
+
|
| 461 |
+
def __init__(self):
|
| 462 |
+
# Lazily get stats actor handle to avoid circular import.
|
| 463 |
+
self._stats_actor_handle: Optional[ActorHandle] = None
|
| 464 |
+
self._stats_actor_cluster_id = None
|
| 465 |
+
|
| 466 |
+
# Last execution stats snapshots for all executing datasets
|
| 467 |
+
self._last_execution_stats = {}
|
| 468 |
+
# Last iteration stats snapshots for all running iterators
|
| 469 |
+
self._last_iteration_stats: Dict[
|
| 470 |
+
str, Tuple[Dict[str, str], "DatasetStats"]
|
| 471 |
+
] = {}
|
| 472 |
+
# Lock for updating stats snapshots
|
| 473 |
+
self._stats_lock: threading.Lock = threading.Lock()
|
| 474 |
+
|
| 475 |
+
# Background thread to make remote calls to _StatsActor
|
| 476 |
+
self._update_thread: Optional[threading.Thread] = None
|
| 477 |
+
self._update_thread_lock: threading.Lock = threading.Lock()
|
| 478 |
+
|
| 479 |
+
def _stats_actor(self, create_if_not_exists=True) -> Optional[ActorHandle]:
|
| 480 |
+
if ray._private.worker._global_node is None:
|
| 481 |
+
raise RuntimeError("Global node is not initialized.")
|
| 482 |
+
current_cluster_id = ray._private.worker._global_node.cluster_id
|
| 483 |
+
if (
|
| 484 |
+
self._stats_actor_handle is None
|
| 485 |
+
or self._stats_actor_cluster_id != current_cluster_id
|
| 486 |
+
):
|
| 487 |
+
if create_if_not_exists:
|
| 488 |
+
self._stats_actor_handle = _get_or_create_stats_actor()
|
| 489 |
+
else:
|
| 490 |
+
try:
|
| 491 |
+
self._stats_actor_handle = ray.get_actor(
|
| 492 |
+
name=STATS_ACTOR_NAME, namespace=STATS_ACTOR_NAMESPACE
|
| 493 |
+
)
|
| 494 |
+
except ValueError:
|
| 495 |
+
return None
|
| 496 |
+
self._stats_actor_cluster_id = current_cluster_id
|
| 497 |
+
return self._stats_actor_handle
|
| 498 |
+
|
| 499 |
+
def _start_thread_if_not_running(self):
|
| 500 |
+
# Start background update thread if not running.
|
| 501 |
+
with self._update_thread_lock:
|
| 502 |
+
if self._update_thread is None or not self._update_thread.is_alive():
|
| 503 |
+
|
| 504 |
+
def _run_update_loop():
|
| 505 |
+
iter_stats_inactivity = 0
|
| 506 |
+
while True:
|
| 507 |
+
if self._last_iteration_stats or self._last_execution_stats:
|
| 508 |
+
try:
|
| 509 |
+
# Do not create _StatsActor if it doesn't exist because
|
| 510 |
+
# this thread can be running even after the cluster is
|
| 511 |
+
# shutdown. Creating an actor will automatically start
|
| 512 |
+
# a new cluster.
|
| 513 |
+
stats_actor = self._stats_actor(
|
| 514 |
+
create_if_not_exists=False
|
| 515 |
+
)
|
| 516 |
+
if stats_actor is None:
|
| 517 |
+
continue
|
| 518 |
+
stats_actor.update_metrics.remote(
|
| 519 |
+
execution_metrics=list(
|
| 520 |
+
self._last_execution_stats.values()
|
| 521 |
+
),
|
| 522 |
+
iteration_metrics=list(
|
| 523 |
+
self._last_iteration_stats.values()
|
| 524 |
+
),
|
| 525 |
+
)
|
| 526 |
+
iter_stats_inactivity = 0
|
| 527 |
+
except Exception:
|
| 528 |
+
logger.debug(
|
| 529 |
+
"Error occurred during remote call to _StatsActor.",
|
| 530 |
+
exc_info=True,
|
| 531 |
+
)
|
| 532 |
+
return
|
| 533 |
+
else:
|
| 534 |
+
iter_stats_inactivity += 1
|
| 535 |
+
if (
|
| 536 |
+
iter_stats_inactivity
|
| 537 |
+
>= _StatsManager.UPDATE_THREAD_INACTIVITY_LIMIT
|
| 538 |
+
):
|
| 539 |
+
logger.debug(
|
| 540 |
+
"Terminating StatsManager thread due to inactivity."
|
| 541 |
+
)
|
| 542 |
+
return
|
| 543 |
+
time.sleep(StatsManager.STATS_ACTOR_UPDATE_INTERVAL_SECONDS)
|
| 544 |
+
|
| 545 |
+
self._update_thread = threading.Thread(
|
| 546 |
+
target=_run_update_loop, daemon=True
|
| 547 |
+
)
|
| 548 |
+
self._update_thread.start()
|
| 549 |
+
|
| 550 |
+
# Execution methods
|
| 551 |
+
|
| 552 |
+
def update_execution_metrics(
|
| 553 |
+
self,
|
| 554 |
+
dataset_tag: str,
|
| 555 |
+
op_metrics: List[OpRuntimeMetrics],
|
| 556 |
+
operator_tags: List[str],
|
| 557 |
+
state: Dict[str, Any],
|
| 558 |
+
force_update: bool = False,
|
| 559 |
+
):
|
| 560 |
+
op_metrics_dicts = [metric.as_dict() for metric in op_metrics]
|
| 561 |
+
args = (dataset_tag, op_metrics_dicts, operator_tags, state)
|
| 562 |
+
if force_update:
|
| 563 |
+
self._stats_actor().update_execution_metrics.remote(*args)
|
| 564 |
+
else:
|
| 565 |
+
with self._stats_lock:
|
| 566 |
+
self._last_execution_stats[dataset_tag] = args
|
| 567 |
+
self._start_thread_if_not_running()
|
| 568 |
+
|
| 569 |
+
def clear_last_execution_stats(self, dataset_tag: str):
|
| 570 |
+
# After dataset completes execution, remove cached execution stats.
|
| 571 |
+
# Marks the dataset as finished on job page's Ray Data Overview.
|
| 572 |
+
with self._stats_lock:
|
| 573 |
+
if dataset_tag in self._last_execution_stats:
|
| 574 |
+
del self._last_execution_stats[dataset_tag]
|
| 575 |
+
|
| 576 |
+
# Iteration methods
|
| 577 |
+
|
| 578 |
+
def update_iteration_metrics(self, stats: "DatasetStats", dataset_tag: str):
|
| 579 |
+
with self._stats_lock:
|
| 580 |
+
self._last_iteration_stats[dataset_tag] = (stats, dataset_tag)
|
| 581 |
+
self._start_thread_if_not_running()
|
| 582 |
+
|
| 583 |
+
def clear_iteration_metrics(self, dataset_tag: str):
|
| 584 |
+
# Delete the last iteration stats so that update thread will have
|
| 585 |
+
# a chance to terminate.
|
| 586 |
+
# Note we don't reset the actual metric values through the StatsActor
|
| 587 |
+
# since the value is essentially a counter value. See
|
| 588 |
+
# https://github.com/ray-project/ray/pull/48618 for more context.
|
| 589 |
+
with self._stats_lock:
|
| 590 |
+
if dataset_tag in self._last_iteration_stats:
|
| 591 |
+
del self._last_iteration_stats[dataset_tag]
|
| 592 |
+
|
| 593 |
+
# Other methods
|
| 594 |
+
|
| 595 |
+
def register_dataset_to_stats_actor(self, dataset_tag, operator_tags):
|
| 596 |
+
self._stats_actor().register_dataset.remote(
|
| 597 |
+
ray.get_runtime_context().get_job_id(),
|
| 598 |
+
dataset_tag,
|
| 599 |
+
operator_tags,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
def get_dataset_id_from_stats_actor(self) -> str:
|
| 603 |
+
try:
|
| 604 |
+
return ray.get(self._stats_actor().get_dataset_id.remote())
|
| 605 |
+
except Exception:
|
| 606 |
+
# Getting dataset id from _StatsActor may fail, in this case
|
| 607 |
+
# fall back to uuid4
|
| 608 |
+
return uuid4().hex
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
StatsManager = _StatsManager()
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class DatasetStats:
|
| 615 |
+
"""Holds the execution times for a given Dataset.
|
| 616 |
+
|
| 617 |
+
This object contains a reference to the parent Dataset's stats as well,
|
| 618 |
+
but not the Dataset object itself, to allow its blocks to be dropped from
|
| 619 |
+
memory."""
|
| 620 |
+
|
| 621 |
+
def __init__(
|
| 622 |
+
self,
|
| 623 |
+
*,
|
| 624 |
+
metadata: StatsDict,
|
| 625 |
+
parent: Union[Optional["DatasetStats"], List["DatasetStats"]],
|
| 626 |
+
needs_stats_actor: bool = False,
|
| 627 |
+
stats_uuid: str = None,
|
| 628 |
+
base_name: str = None,
|
| 629 |
+
):
|
| 630 |
+
"""Create dataset stats.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
metadata: Dict of operators used to create this Dataset from the
|
| 634 |
+
previous one. Typically one entry, e.g., {"map": [...]}.
|
| 635 |
+
parent: Reference to parent Dataset's stats, or a list of parents
|
| 636 |
+
if there are multiple.
|
| 637 |
+
needs_stats_actor: Whether this Dataset's stats needs a stats actor for
|
| 638 |
+
stats collection. This is currently only used for Datasets using a
|
| 639 |
+
lazy datasource (i.e. a LazyBlockList).
|
| 640 |
+
stats_uuid: The uuid for the stats, used to fetch the right stats
|
| 641 |
+
from the stats actor.
|
| 642 |
+
base_name: The name of the base operation for a multi-operator operation.
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
+
self.metadata: StatsDict = metadata
|
| 646 |
+
if parent is not None and not isinstance(parent, list):
|
| 647 |
+
parent = [parent]
|
| 648 |
+
self.parents: List["DatasetStats"] = parent or []
|
| 649 |
+
self.number: int = (
|
| 650 |
+
0 if not self.parents else max(p.number for p in self.parents) + 1
|
| 651 |
+
)
|
| 652 |
+
self.base_name = base_name
|
| 653 |
+
# TODO(ekl) deprecate and remove the notion of dataset UUID once we move
|
| 654 |
+
# fully to streaming execution.
|
| 655 |
+
self.dataset_uuid: str = "unknown_uuid"
|
| 656 |
+
self.time_total_s: float = 0
|
| 657 |
+
self.needs_stats_actor = needs_stats_actor
|
| 658 |
+
self.stats_uuid = stats_uuid
|
| 659 |
+
|
| 660 |
+
# Streaming executor stats
|
| 661 |
+
self.streaming_exec_schedule_s: Timer = Timer()
|
| 662 |
+
|
| 663 |
+
# Iteration stats, filled out if the user iterates over the dataset.
|
| 664 |
+
self.iter_wait_s: Timer = Timer()
|
| 665 |
+
self.iter_get_s: Timer = Timer()
|
| 666 |
+
self.iter_next_batch_s: Timer = Timer()
|
| 667 |
+
self.iter_format_batch_s: Timer = Timer()
|
| 668 |
+
self.iter_collate_batch_s: Timer = Timer()
|
| 669 |
+
self.iter_finalize_batch_s: Timer = Timer()
|
| 670 |
+
self.iter_total_blocked_s: Timer = Timer()
|
| 671 |
+
self.iter_user_s: Timer = Timer()
|
| 672 |
+
self.iter_initialize_s: Timer = Timer()
|
| 673 |
+
self.iter_total_s: Timer = Timer()
|
| 674 |
+
self.extra_metrics = {}
|
| 675 |
+
|
| 676 |
+
# Block fetch stats during iteration.
|
| 677 |
+
# These are stats about locations of blocks when the iterator is trying to
|
| 678 |
+
# consume them. The iteration performance will be affected depending on
|
| 679 |
+
# whether the block is in the local object store of the node where the
|
| 680 |
+
# iterator is running.
|
| 681 |
+
# This serves as an indicator of block prefetching effectiveness.
|
| 682 |
+
self.iter_blocks_local: int = 0
|
| 683 |
+
self.iter_blocks_remote: int = 0
|
| 684 |
+
self.iter_unknown_location: int = 0
|
| 685 |
+
|
| 686 |
+
# Memory usage stats
|
| 687 |
+
self.global_bytes_spilled: int = 0
|
| 688 |
+
self.global_bytes_restored: int = 0
|
| 689 |
+
self.dataset_bytes_spilled: int = 0
|
| 690 |
+
|
| 691 |
+
# Streaming split coordinator stats (dataset level)
|
| 692 |
+
self.streaming_split_coordinator_s: Timer = Timer()
|
| 693 |
+
|
| 694 |
+
@property
|
| 695 |
+
def stats_actor(self):
|
| 696 |
+
return _get_or_create_stats_actor()
|
| 697 |
+
|
| 698 |
+
def child_builder(
|
| 699 |
+
self, name: str, override_start_time: Optional[float] = None
|
| 700 |
+
) -> _DatasetStatsBuilder:
|
| 701 |
+
"""Start recording stats for an op of the given name (e.g., map)."""
|
| 702 |
+
return _DatasetStatsBuilder(name, self, override_start_time)
|
| 703 |
+
|
| 704 |
+
def to_summary(self) -> "DatasetStatsSummary":
|
| 705 |
+
"""Generate a `DatasetStatsSummary` object from the given `DatasetStats`
|
| 706 |
+
object, which can be used to generate a summary string."""
|
| 707 |
+
if self.needs_stats_actor:
|
| 708 |
+
ac = self.stats_actor
|
| 709 |
+
# TODO(chengsu): this is a super hack, clean it up.
|
| 710 |
+
stats_map, self.time_total_s = ray.get(ac.get.remote(self.stats_uuid))
|
| 711 |
+
# Only populate stats when stats from all read tasks are ready at
|
| 712 |
+
# stats actor.
|
| 713 |
+
if len(stats_map.items()) == len(self.metadata["Read"]):
|
| 714 |
+
self.metadata["Read"] = []
|
| 715 |
+
for _, blocks_metadata in sorted(stats_map.items()):
|
| 716 |
+
self.metadata["Read"] += blocks_metadata
|
| 717 |
+
|
| 718 |
+
operators_stats = []
|
| 719 |
+
is_sub_operator = len(self.metadata) > 1
|
| 720 |
+
for name, meta in self.metadata.items():
|
| 721 |
+
operators_stats.append(
|
| 722 |
+
OperatorStatsSummary.from_block_metadata(
|
| 723 |
+
name,
|
| 724 |
+
meta,
|
| 725 |
+
is_sub_operator=is_sub_operator,
|
| 726 |
+
)
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
iter_stats = IterStatsSummary(
|
| 730 |
+
self.iter_wait_s,
|
| 731 |
+
self.iter_get_s,
|
| 732 |
+
self.iter_next_batch_s,
|
| 733 |
+
self.iter_format_batch_s,
|
| 734 |
+
self.iter_collate_batch_s,
|
| 735 |
+
self.iter_finalize_batch_s,
|
| 736 |
+
self.iter_total_blocked_s,
|
| 737 |
+
self.iter_user_s,
|
| 738 |
+
self.iter_initialize_s,
|
| 739 |
+
self.iter_total_s,
|
| 740 |
+
self.streaming_split_coordinator_s,
|
| 741 |
+
self.iter_blocks_local,
|
| 742 |
+
self.iter_blocks_remote,
|
| 743 |
+
self.iter_unknown_location,
|
| 744 |
+
)
|
| 745 |
+
stats_summary_parents = []
|
| 746 |
+
if self.parents is not None:
|
| 747 |
+
stats_summary_parents = [p.to_summary() for p in self.parents]
|
| 748 |
+
streaming_exec_schedule_s = (
|
| 749 |
+
self.streaming_exec_schedule_s.get()
|
| 750 |
+
if self.streaming_exec_schedule_s
|
| 751 |
+
else 0
|
| 752 |
+
)
|
| 753 |
+
return DatasetStatsSummary(
|
| 754 |
+
operators_stats,
|
| 755 |
+
iter_stats,
|
| 756 |
+
stats_summary_parents,
|
| 757 |
+
self.number,
|
| 758 |
+
self.dataset_uuid,
|
| 759 |
+
self.time_total_s,
|
| 760 |
+
self.base_name,
|
| 761 |
+
self.extra_metrics,
|
| 762 |
+
self.global_bytes_spilled,
|
| 763 |
+
self.global_bytes_restored,
|
| 764 |
+
self.dataset_bytes_spilled,
|
| 765 |
+
streaming_exec_schedule_s,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
def runtime_metrics(self) -> str:
|
| 769 |
+
"""Generate a string representing the runtime metrics of a Dataset. This is
|
| 770 |
+
a high level summary of the time spent in Ray Data code broken down by operator.
|
| 771 |
+
It also includes the time spent in the scheduler. Times are shown as the total
|
| 772 |
+
time for each operator and percentages of time are shown as a fraction of the
|
| 773 |
+
total time for the whole dataset."""
|
| 774 |
+
return self.to_summary().runtime_metrics()
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
@DeveloperAPI
|
| 778 |
+
@dataclass
|
| 779 |
+
class DatasetStatsSummary:
|
| 780 |
+
operators_stats: List["OperatorStatsSummary"]
|
| 781 |
+
iter_stats: "IterStatsSummary"
|
| 782 |
+
parents: List["DatasetStatsSummary"]
|
| 783 |
+
number: int
|
| 784 |
+
dataset_uuid: str
|
| 785 |
+
time_total_s: float
|
| 786 |
+
base_name: str
|
| 787 |
+
extra_metrics: Dict[str, Any]
|
| 788 |
+
global_bytes_spilled: int
|
| 789 |
+
global_bytes_restored: int
|
| 790 |
+
dataset_bytes_spilled: int
|
| 791 |
+
streaming_exec_schedule_s: float
|
| 792 |
+
|
| 793 |
+
def to_string(
|
| 794 |
+
self,
|
| 795 |
+
already_printed: Optional[Set[str]] = None,
|
| 796 |
+
include_parent: bool = True,
|
| 797 |
+
add_global_stats=True,
|
| 798 |
+
) -> str:
|
| 799 |
+
"""Return a human-readable summary of this Dataset's stats.
|
| 800 |
+
|
| 801 |
+
Args:
|
| 802 |
+
already_printed: Set of operator IDs that have already had its stats printed
|
| 803 |
+
out.
|
| 804 |
+
include_parent: If true, also include parent stats summary; otherwise, only
|
| 805 |
+
log stats of the latest operator.
|
| 806 |
+
add_global_stats: If true, includes global stats to this summary.
|
| 807 |
+
Returns:
|
| 808 |
+
String with summary statistics for executing the Dataset.
|
| 809 |
+
"""
|
| 810 |
+
if already_printed is None:
|
| 811 |
+
already_printed = set()
|
| 812 |
+
|
| 813 |
+
out = ""
|
| 814 |
+
if self.parents and include_parent:
|
| 815 |
+
for p in self.parents:
|
| 816 |
+
parent_sum = p.to_string(already_printed, add_global_stats=False)
|
| 817 |
+
if parent_sum:
|
| 818 |
+
out += parent_sum
|
| 819 |
+
out += "\n"
|
| 820 |
+
operators_stats_summary = None
|
| 821 |
+
if len(self.operators_stats) == 1:
|
| 822 |
+
operators_stats_summary = self.operators_stats[0]
|
| 823 |
+
operator_name = operators_stats_summary.operator_name
|
| 824 |
+
operator_uuid = self.dataset_uuid + operator_name
|
| 825 |
+
out += "Operator {} {}: ".format(self.number, operator_name)
|
| 826 |
+
if operator_uuid in already_printed:
|
| 827 |
+
out += "[execution cached]\n"
|
| 828 |
+
else:
|
| 829 |
+
already_printed.add(operator_uuid)
|
| 830 |
+
out += str(operators_stats_summary)
|
| 831 |
+
elif len(self.operators_stats) > 1:
|
| 832 |
+
rounded_total = round(self.time_total_s, 2)
|
| 833 |
+
if rounded_total <= 0:
|
| 834 |
+
# Handle -0.0 case.
|
| 835 |
+
rounded_total = 0
|
| 836 |
+
out += "Operator {} {}: executed in {}s\n".format(
|
| 837 |
+
self.number, self.base_name, rounded_total
|
| 838 |
+
)
|
| 839 |
+
for n, operators_stats_summary in enumerate(self.operators_stats):
|
| 840 |
+
operator_name = operators_stats_summary.operator_name
|
| 841 |
+
operator_uuid = self.dataset_uuid + operator_name
|
| 842 |
+
out += "\n"
|
| 843 |
+
out += "\tSuboperator {} {}: ".format(n, operator_name)
|
| 844 |
+
if operator_uuid in already_printed:
|
| 845 |
+
out += "\t[execution cached]\n"
|
| 846 |
+
else:
|
| 847 |
+
already_printed.add(operator_uuid)
|
| 848 |
+
out += str(operators_stats_summary)
|
| 849 |
+
verbose_stats_logs = DataContext.get_current().verbose_stats_logs
|
| 850 |
+
if verbose_stats_logs and self.extra_metrics:
|
| 851 |
+
indent = (
|
| 852 |
+
"\t"
|
| 853 |
+
if operators_stats_summary and operators_stats_summary.is_sub_operator
|
| 854 |
+
else ""
|
| 855 |
+
)
|
| 856 |
+
out += indent
|
| 857 |
+
out += "* Extra metrics: " + str(self.extra_metrics) + "\n"
|
| 858 |
+
out += str(self.iter_stats)
|
| 859 |
+
|
| 860 |
+
if len(self.operators_stats) > 0 and add_global_stats:
|
| 861 |
+
mb_spilled = round(self.global_bytes_spilled / 1e6)
|
| 862 |
+
mb_restored = round(self.global_bytes_restored / 1e6)
|
| 863 |
+
if mb_spilled or mb_restored:
|
| 864 |
+
out += "\nCluster memory:\n"
|
| 865 |
+
out += "* Spilled to disk: {}MB\n".format(mb_spilled)
|
| 866 |
+
out += "* Restored from disk: {}MB\n".format(mb_restored)
|
| 867 |
+
|
| 868 |
+
dataset_mb_spilled = round(self.dataset_bytes_spilled / 1e6)
|
| 869 |
+
if dataset_mb_spilled:
|
| 870 |
+
out += "\nDataset memory:\n"
|
| 871 |
+
out += "* Spilled to disk: {}MB\n".format(dataset_mb_spilled)
|
| 872 |
+
|
| 873 |
+
# For throughput, we compute both an observed Ray Data dataset throughput
|
| 874 |
+
# and an estimated single node dataset throughput.
|
| 875 |
+
|
| 876 |
+
# The observed dataset throughput is computed by dividing the total number
|
| 877 |
+
# of rows produced by the total wall time of the dataset (i.e. from start to
|
| 878 |
+
# finish how long did the dataset take to be processed). With the recursive
|
| 879 |
+
# nature of the DatasetStatsSummary, we use get_total_wall_time to determine
|
| 880 |
+
# the total wall time (this finds the difference between the earliest start
|
| 881 |
+
# and latest end for any block in any operator).
|
| 882 |
+
|
| 883 |
+
# The estimated single node dataset throughput is computed by dividing the
|
| 884 |
+
# total number of rows produced the sum of the wall times across all blocks
|
| 885 |
+
# of all operators. This assumes that on a single node the work done would
|
| 886 |
+
# be equivalent, with no concurrency.
|
| 887 |
+
output_num_rows = self.operators_stats[-1].output_num_rows
|
| 888 |
+
total_num_out_rows = output_num_rows["sum"] if output_num_rows else 0
|
| 889 |
+
wall_time = self.get_total_wall_time()
|
| 890 |
+
total_time_all_blocks = self.get_total_time_all_blocks()
|
| 891 |
+
if total_num_out_rows and wall_time and total_time_all_blocks:
|
| 892 |
+
out += "\n"
|
| 893 |
+
out += "Dataset throughput:\n"
|
| 894 |
+
out += (
|
| 895 |
+
"\t* Ray Data throughput:"
|
| 896 |
+
f" {total_num_out_rows / wall_time} "
|
| 897 |
+
"rows/s\n"
|
| 898 |
+
)
|
| 899 |
+
out += (
|
| 900 |
+
"\t* Estimated single node throughput:"
|
| 901 |
+
f" {total_num_out_rows / total_time_all_blocks} "
|
| 902 |
+
"rows/s\n"
|
| 903 |
+
)
|
| 904 |
+
if verbose_stats_logs and add_global_stats:
|
| 905 |
+
out += "\n" + self.runtime_metrics()
|
| 906 |
+
|
| 907 |
+
return out
|
| 908 |
+
|
| 909 |
+
@staticmethod
|
| 910 |
+
def _collect_dataset_stats_summaries(
|
| 911 |
+
curr: "DatasetStatsSummary",
|
| 912 |
+
) -> List["DatasetStatsSummary"]:
|
| 913 |
+
summs = []
|
| 914 |
+
# TODO: Do operators ever have multiple parents? Do we need to deduplicate?
|
| 915 |
+
for p in curr.parents:
|
| 916 |
+
if p and p.parents:
|
| 917 |
+
summs.extend(DatasetStatsSummary._collect_dataset_stats_summaries(p))
|
| 918 |
+
return summs + [curr]
|
| 919 |
+
|
| 920 |
+
@staticmethod
|
| 921 |
+
def _find_start_and_end(summ: "DatasetStatsSummary") -> Tuple[float, float]:
|
| 922 |
+
earliest_start = min(ops.earliest_start_time for ops in summ.operators_stats)
|
| 923 |
+
latest_end = max(ops.latest_end_time for ops in summ.operators_stats)
|
| 924 |
+
return earliest_start, latest_end
|
| 925 |
+
|
| 926 |
+
def runtime_metrics(self) -> str:
|
| 927 |
+
total_wall_time = self.get_total_wall_time()
|
| 928 |
+
|
| 929 |
+
def fmt_line(name: str, time: float) -> str:
|
| 930 |
+
return f"* {name}: {fmt(time)} ({time / total_wall_time * 100:.3f}%)\n"
|
| 931 |
+
|
| 932 |
+
summaries = DatasetStatsSummary._collect_dataset_stats_summaries(self)
|
| 933 |
+
out = "Runtime Metrics:\n"
|
| 934 |
+
for summ in summaries:
|
| 935 |
+
if len(summ.operators_stats) > 0:
|
| 936 |
+
earliest_start, latest_end = DatasetStatsSummary._find_start_and_end(
|
| 937 |
+
summ
|
| 938 |
+
)
|
| 939 |
+
op_total_time = latest_end - earliest_start
|
| 940 |
+
out += fmt_line(summ.base_name, op_total_time)
|
| 941 |
+
out += fmt_line("Scheduling", self.streaming_exec_schedule_s)
|
| 942 |
+
out += fmt_line("Total", total_wall_time)
|
| 943 |
+
return out
|
| 944 |
+
|
| 945 |
+
def __repr__(self, level=0) -> str:
|
| 946 |
+
indent = leveled_indent(level)
|
| 947 |
+
operators_stats = "\n".join(
|
| 948 |
+
[ss.__repr__(level + 2) for ss in self.operators_stats]
|
| 949 |
+
)
|
| 950 |
+
parent_stats = "\n".join([ps.__repr__(level + 2) for ps in self.parents])
|
| 951 |
+
extra_metrics = "\n".join(
|
| 952 |
+
f"{leveled_indent(level + 2)}{k}: {v},"
|
| 953 |
+
for k, v in self.extra_metrics.items()
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
# Handle formatting case for empty outputs.
|
| 957 |
+
operators_stats = (
|
| 958 |
+
f"\n{operators_stats},\n{indent} " if operators_stats else ""
|
| 959 |
+
)
|
| 960 |
+
parent_stats = f"\n{parent_stats},\n{indent} " if parent_stats else ""
|
| 961 |
+
extra_metrics = f"\n{extra_metrics}\n{indent} " if extra_metrics else ""
|
| 962 |
+
return (
|
| 963 |
+
f"{indent}DatasetStatsSummary(\n"
|
| 964 |
+
f"{indent} dataset_uuid={self.dataset_uuid},\n"
|
| 965 |
+
f"{indent} base_name={self.base_name},\n"
|
| 966 |
+
f"{indent} number={self.number},\n"
|
| 967 |
+
f"{indent} extra_metrics={{{extra_metrics}}},\n"
|
| 968 |
+
f"{indent} operators_stats=[{operators_stats}],\n"
|
| 969 |
+
f"{indent} iter_stats={self.iter_stats.__repr__(level+1)},\n"
|
| 970 |
+
f"{indent} global_bytes_spilled={self.global_bytes_spilled / 1e6}MB,\n"
|
| 971 |
+
f"{indent} global_bytes_restored={self.global_bytes_restored / 1e6}MB,\n"
|
| 972 |
+
f"{indent} dataset_bytes_spilled={self.dataset_bytes_spilled / 1e6}MB,\n"
|
| 973 |
+
f"{indent} parents=[{parent_stats}],\n"
|
| 974 |
+
f"{indent})"
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
def get_total_wall_time(self) -> float:
|
| 978 |
+
"""Calculate the total wall time for the dataset, this is done by finding
|
| 979 |
+
the earliest start time and latest end time for any block in any operator.
|
| 980 |
+
The wall time is the difference of these two times.
|
| 981 |
+
"""
|
| 982 |
+
start_ends = [
|
| 983 |
+
DatasetStatsSummary._find_start_and_end(summ)
|
| 984 |
+
for summ in DatasetStatsSummary._collect_dataset_stats_summaries(self)
|
| 985 |
+
if len(summ.operators_stats) > 0
|
| 986 |
+
]
|
| 987 |
+
if len(start_ends) == 0:
|
| 988 |
+
return 0
|
| 989 |
+
else:
|
| 990 |
+
earliest_start = min(start_end[0] for start_end in start_ends)
|
| 991 |
+
latest_end = max(start_end[1] for start_end in start_ends)
|
| 992 |
+
return latest_end - earliest_start
|
| 993 |
+
|
| 994 |
+
def get_total_time_all_blocks(self) -> float:
|
| 995 |
+
"""Calculate the sum of the wall times across all blocks of all operators."""
|
| 996 |
+
summaries = DatasetStatsSummary._collect_dataset_stats_summaries(self)
|
| 997 |
+
return sum(
|
| 998 |
+
(
|
| 999 |
+
sum(
|
| 1000 |
+
ops.wall_time.get("sum", 0) if ops.wall_time else 0
|
| 1001 |
+
for ops in summ.operators_stats
|
| 1002 |
+
)
|
| 1003 |
+
)
|
| 1004 |
+
for summ in summaries
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
def get_total_cpu_time(self) -> float:
|
| 1008 |
+
parent_sum = sum(p.get_total_cpu_time() for p in self.parents)
|
| 1009 |
+
return parent_sum + sum(
|
| 1010 |
+
ss.cpu_time.get("sum", 0) for ss in self.operators_stats
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
def get_max_heap_memory(self) -> float:
|
| 1014 |
+
parent_memory = [p.get_max_heap_memory() for p in self.parents]
|
| 1015 |
+
parent_max = max(parent_memory) if parent_memory else 0
|
| 1016 |
+
if not self.operators_stats:
|
| 1017 |
+
return parent_max
|
| 1018 |
+
|
| 1019 |
+
return max(
|
| 1020 |
+
parent_max,
|
| 1021 |
+
*[ss.memory.get("max", 0) for ss in self.operators_stats],
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
@dataclass
|
| 1026 |
+
class OperatorStatsSummary:
|
| 1027 |
+
operator_name: str
|
| 1028 |
+
# Whether the operator associated with this OperatorStatsSummary object
|
| 1029 |
+
# is a suboperator
|
| 1030 |
+
is_sub_operator: bool
|
| 1031 |
+
# This is the total walltime of the entire operator, typically obtained from
|
| 1032 |
+
# `DatasetStats.time_total_s`. An important distinction is that this is the
|
| 1033 |
+
# overall runtime of the operator, pulled from the stats actor, whereas the
|
| 1034 |
+
# computed walltimes in `self.wall_time` are calculated on a operator level.
|
| 1035 |
+
time_total_s: float
|
| 1036 |
+
earliest_start_time: float
|
| 1037 |
+
latest_end_time: float
|
| 1038 |
+
# String summarizing high-level statistics from executing the operator
|
| 1039 |
+
block_execution_summary_str: str
|
| 1040 |
+
# The fields below are dicts with stats aggregated across blocks
|
| 1041 |
+
# processed in this operator. For example:
|
| 1042 |
+
# {"min": ..., "max": ..., "mean": ..., "sum": ...}
|
| 1043 |
+
wall_time: Optional[Dict[str, float]] = None
|
| 1044 |
+
cpu_time: Optional[Dict[str, float]] = None
|
| 1045 |
+
udf_time: Optional[Dict[str, float]] = None
|
| 1046 |
+
# memory: no "sum" stat
|
| 1047 |
+
memory: Optional[Dict[str, float]] = None
|
| 1048 |
+
output_num_rows: Optional[Dict[str, float]] = None
|
| 1049 |
+
output_size_bytes: Optional[Dict[str, float]] = None
|
| 1050 |
+
# node_count: "count" stat instead of "sum"
|
| 1051 |
+
node_count: Optional[Dict[str, float]] = None
|
| 1052 |
+
task_rows: Optional[Dict[str, float]] = None
|
| 1053 |
+
|
| 1054 |
+
@classmethod
|
| 1055 |
+
def from_block_metadata(
|
| 1056 |
+
cls,
|
| 1057 |
+
operator_name: str,
|
| 1058 |
+
block_metas: List[BlockMetadata],
|
| 1059 |
+
is_sub_operator: bool,
|
| 1060 |
+
) -> "OperatorStatsSummary":
|
| 1061 |
+
"""Calculate the stats for a operator from a given list of blocks,
|
| 1062 |
+
and generates a `OperatorStatsSummary` object with the results.
|
| 1063 |
+
|
| 1064 |
+
Args:
|
| 1065 |
+
block_metas: List of `BlockMetadata` to calculate stats of
|
| 1066 |
+
operator_name: Name of operator associated with `blocks`
|
| 1067 |
+
is_sub_operator: Whether this set of blocks belongs to a sub operator.
|
| 1068 |
+
Returns:
|
| 1069 |
+
A `OperatorStatsSummary` object initialized with the calculated statistics
|
| 1070 |
+
"""
|
| 1071 |
+
exec_stats = [m.exec_stats for m in block_metas if m.exec_stats is not None]
|
| 1072 |
+
rounded_total = 0
|
| 1073 |
+
time_total_s = 0
|
| 1074 |
+
earliest_start_time, latest_end_time = 0, 0
|
| 1075 |
+
|
| 1076 |
+
if exec_stats:
|
| 1077 |
+
# Calculate the total execution time of operator as
|
| 1078 |
+
# the difference between the latest end time and
|
| 1079 |
+
# the earliest start time of all blocks in the operator.
|
| 1080 |
+
earliest_start_time = min(s.start_time_s for s in exec_stats)
|
| 1081 |
+
latest_end_time = max(s.end_time_s for s in exec_stats)
|
| 1082 |
+
time_total_s = latest_end_time - earliest_start_time
|
| 1083 |
+
|
| 1084 |
+
if is_sub_operator:
|
| 1085 |
+
exec_summary_str = "{} blocks produced\n".format(len(exec_stats))
|
| 1086 |
+
else:
|
| 1087 |
+
if exec_stats:
|
| 1088 |
+
rounded_total = round(time_total_s, 2)
|
| 1089 |
+
if rounded_total <= 0:
|
| 1090 |
+
# Handle -0.0 case.
|
| 1091 |
+
rounded_total = 0
|
| 1092 |
+
exec_summary_str = "{} blocks produced in {}s".format(
|
| 1093 |
+
len(exec_stats), rounded_total
|
| 1094 |
+
)
|
| 1095 |
+
else:
|
| 1096 |
+
exec_summary_str = ""
|
| 1097 |
+
exec_summary_str += "\n"
|
| 1098 |
+
|
| 1099 |
+
task_rows = collections.defaultdict(int)
|
| 1100 |
+
for meta in block_metas:
|
| 1101 |
+
if meta.num_rows is not None and meta.exec_stats is not None:
|
| 1102 |
+
task_rows[meta.exec_stats.task_idx] += meta.num_rows
|
| 1103 |
+
task_rows_stats = None
|
| 1104 |
+
if len(task_rows) > 0:
|
| 1105 |
+
task_rows_stats = {
|
| 1106 |
+
"min": min(task_rows.values()),
|
| 1107 |
+
"max": max(task_rows.values()),
|
| 1108 |
+
"mean": int(np.mean(list(task_rows.values()))),
|
| 1109 |
+
"count": len(task_rows),
|
| 1110 |
+
}
|
| 1111 |
+
exec_summary_str = "{} tasks executed, {}".format(
|
| 1112 |
+
len(task_rows), exec_summary_str
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
wall_time_stats, cpu_stats, memory_stats, udf_stats = None, None, None, None
|
| 1116 |
+
if exec_stats:
|
| 1117 |
+
wall_time_stats = {
|
| 1118 |
+
"min": min([e.wall_time_s for e in exec_stats]),
|
| 1119 |
+
"max": max([e.wall_time_s for e in exec_stats]),
|
| 1120 |
+
"mean": np.mean([e.wall_time_s for e in exec_stats]),
|
| 1121 |
+
"sum": sum([e.wall_time_s for e in exec_stats]),
|
| 1122 |
+
}
|
| 1123 |
+
cpu_stats = {
|
| 1124 |
+
"min": min([e.cpu_time_s for e in exec_stats]),
|
| 1125 |
+
"max": max([e.cpu_time_s for e in exec_stats]),
|
| 1126 |
+
"mean": np.mean([e.cpu_time_s for e in exec_stats]),
|
| 1127 |
+
"sum": sum([e.cpu_time_s for e in exec_stats]),
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
memory_stats_mb = [
|
| 1131 |
+
round(e.max_rss_bytes / (1024 * 1024), 2) for e in exec_stats
|
| 1132 |
+
]
|
| 1133 |
+
memory_stats = {
|
| 1134 |
+
"min": min(memory_stats_mb),
|
| 1135 |
+
"max": max(memory_stats_mb),
|
| 1136 |
+
"mean": int(np.mean(memory_stats_mb)),
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
udf_stats = {
|
| 1140 |
+
"min": min([e.udf_time_s for e in exec_stats]),
|
| 1141 |
+
"max": max([e.udf_time_s for e in exec_stats]),
|
| 1142 |
+
"mean": np.mean([e.udf_time_s for e in exec_stats]),
|
| 1143 |
+
"sum": sum([e.udf_time_s for e in exec_stats]),
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
output_num_rows_stats = None
|
| 1147 |
+
output_num_rows = [m.num_rows for m in block_metas if m.num_rows is not None]
|
| 1148 |
+
if output_num_rows:
|
| 1149 |
+
output_num_rows_stats = {
|
| 1150 |
+
"min": min(output_num_rows),
|
| 1151 |
+
"max": max(output_num_rows),
|
| 1152 |
+
"mean": int(np.mean(output_num_rows)),
|
| 1153 |
+
"sum": sum(output_num_rows),
|
| 1154 |
+
}
|
| 1155 |
+
|
| 1156 |
+
output_size_bytes_stats = None
|
| 1157 |
+
output_size_bytes = [
|
| 1158 |
+
m.size_bytes for m in block_metas if m.size_bytes is not None
|
| 1159 |
+
]
|
| 1160 |
+
if output_size_bytes:
|
| 1161 |
+
output_size_bytes_stats = {
|
| 1162 |
+
"min": min(output_size_bytes),
|
| 1163 |
+
"max": max(output_size_bytes),
|
| 1164 |
+
"mean": int(np.mean(output_size_bytes)),
|
| 1165 |
+
"sum": sum(output_size_bytes),
|
| 1166 |
+
}
|
| 1167 |
+
|
| 1168 |
+
node_counts_stats = None
|
| 1169 |
+
if exec_stats:
|
| 1170 |
+
node_tasks = collections.defaultdict(set)
|
| 1171 |
+
for s in exec_stats:
|
| 1172 |
+
node_tasks[s.node_id].add(s.task_idx)
|
| 1173 |
+
|
| 1174 |
+
node_counts = {node: len(tasks) for node, tasks in node_tasks.items()}
|
| 1175 |
+
node_counts_stats = {
|
| 1176 |
+
"min": min(node_counts.values()),
|
| 1177 |
+
"max": max(node_counts.values()),
|
| 1178 |
+
"mean": int(np.mean(list(node_counts.values()))),
|
| 1179 |
+
"count": len(node_counts),
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
return OperatorStatsSummary(
|
| 1183 |
+
operator_name=operator_name,
|
| 1184 |
+
is_sub_operator=is_sub_operator,
|
| 1185 |
+
time_total_s=time_total_s,
|
| 1186 |
+
earliest_start_time=earliest_start_time,
|
| 1187 |
+
latest_end_time=latest_end_time,
|
| 1188 |
+
block_execution_summary_str=exec_summary_str,
|
| 1189 |
+
wall_time=wall_time_stats,
|
| 1190 |
+
cpu_time=cpu_stats,
|
| 1191 |
+
udf_time=udf_stats,
|
| 1192 |
+
memory=memory_stats,
|
| 1193 |
+
output_num_rows=output_num_rows_stats,
|
| 1194 |
+
output_size_bytes=output_size_bytes_stats,
|
| 1195 |
+
node_count=node_counts_stats,
|
| 1196 |
+
task_rows=task_rows_stats,
|
| 1197 |
+
)
|
| 1198 |
+
|
| 1199 |
+
def __str__(self) -> str:
|
| 1200 |
+
"""For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from
|
| 1201 |
+
`OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string
|
| 1202 |
+
that summarizes operator execution statistics.
|
| 1203 |
+
|
| 1204 |
+
Returns:
|
| 1205 |
+
String with summary statistics for executing the given operator.
|
| 1206 |
+
"""
|
| 1207 |
+
indent = "\t" if self.is_sub_operator else ""
|
| 1208 |
+
out = self.block_execution_summary_str
|
| 1209 |
+
|
| 1210 |
+
wall_time_stats = self.wall_time
|
| 1211 |
+
if wall_time_stats:
|
| 1212 |
+
out += indent
|
| 1213 |
+
out += "* Remote wall time: {} min, {} max, {} mean, {} total\n".format(
|
| 1214 |
+
fmt(wall_time_stats["min"]),
|
| 1215 |
+
fmt(wall_time_stats["max"]),
|
| 1216 |
+
fmt(wall_time_stats["mean"]),
|
| 1217 |
+
fmt(wall_time_stats["sum"]),
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
cpu_stats = self.cpu_time
|
| 1221 |
+
if cpu_stats:
|
| 1222 |
+
out += indent
|
| 1223 |
+
out += "* Remote cpu time: {} min, {} max, {} mean, {} total\n".format(
|
| 1224 |
+
fmt(cpu_stats["min"]),
|
| 1225 |
+
fmt(cpu_stats["max"]),
|
| 1226 |
+
fmt(cpu_stats["mean"]),
|
| 1227 |
+
fmt(cpu_stats["sum"]),
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
udf_stats = self.udf_time
|
| 1231 |
+
if udf_stats:
|
| 1232 |
+
out += indent
|
| 1233 |
+
out += "* UDF time: {} min, {} max, {} mean, {} total\n".format(
|
| 1234 |
+
fmt(udf_stats["min"]),
|
| 1235 |
+
fmt(udf_stats["max"]),
|
| 1236 |
+
fmt(udf_stats["mean"]),
|
| 1237 |
+
fmt(udf_stats["sum"]),
|
| 1238 |
+
)
|
| 1239 |
+
|
| 1240 |
+
memory_stats = self.memory
|
| 1241 |
+
if memory_stats:
|
| 1242 |
+
out += indent
|
| 1243 |
+
out += "* Peak heap memory usage (MiB): {} min, {} max, {} mean\n".format(
|
| 1244 |
+
memory_stats["min"],
|
| 1245 |
+
memory_stats["max"],
|
| 1246 |
+
memory_stats["mean"],
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
output_num_rows_stats = self.output_num_rows
|
| 1250 |
+
if output_num_rows_stats:
|
| 1251 |
+
out += indent
|
| 1252 |
+
out += (
|
| 1253 |
+
"* Output num rows per block: {} min, {} max, {} mean, {} total\n"
|
| 1254 |
+
).format(
|
| 1255 |
+
output_num_rows_stats["min"],
|
| 1256 |
+
output_num_rows_stats["max"],
|
| 1257 |
+
output_num_rows_stats["mean"],
|
| 1258 |
+
output_num_rows_stats["sum"],
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
output_size_bytes_stats = self.output_size_bytes
|
| 1262 |
+
if output_size_bytes_stats:
|
| 1263 |
+
out += indent
|
| 1264 |
+
out += (
|
| 1265 |
+
"* Output size bytes per block: {} min, {} max, {} mean, {} total\n"
|
| 1266 |
+
).format(
|
| 1267 |
+
output_size_bytes_stats["min"],
|
| 1268 |
+
output_size_bytes_stats["max"],
|
| 1269 |
+
output_size_bytes_stats["mean"],
|
| 1270 |
+
output_size_bytes_stats["sum"],
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
task_rows = self.task_rows
|
| 1274 |
+
if task_rows:
|
| 1275 |
+
out += indent
|
| 1276 |
+
out += (
|
| 1277 |
+
"* Output rows per task: {} min, {} max, {} mean, {} tasks used\n"
|
| 1278 |
+
).format(
|
| 1279 |
+
task_rows["min"],
|
| 1280 |
+
task_rows["max"],
|
| 1281 |
+
task_rows["mean"],
|
| 1282 |
+
task_rows["count"],
|
| 1283 |
+
)
|
| 1284 |
+
|
| 1285 |
+
node_count_stats = self.node_count
|
| 1286 |
+
if node_count_stats:
|
| 1287 |
+
out += indent
|
| 1288 |
+
out += "* Tasks per node: {} min, {} max, {} mean; {} nodes used\n".format(
|
| 1289 |
+
node_count_stats["min"],
|
| 1290 |
+
node_count_stats["max"],
|
| 1291 |
+
node_count_stats["mean"],
|
| 1292 |
+
node_count_stats["count"],
|
| 1293 |
+
)
|
| 1294 |
+
if output_num_rows_stats and self.time_total_s and wall_time_stats:
|
| 1295 |
+
# For throughput, we compute both an observed Ray Data operator throughput
|
| 1296 |
+
# and an estimated single node operator throughput.
|
| 1297 |
+
|
| 1298 |
+
# The observed Ray Data operator throughput is computed by dividing the
|
| 1299 |
+
# total number of rows produced by the wall time of the operator,
|
| 1300 |
+
# time_total_s.
|
| 1301 |
+
|
| 1302 |
+
# The estimated single node operator throughput is computed by dividing the
|
| 1303 |
+
# total number of rows produced by the the sum of the wall times across all
|
| 1304 |
+
# blocks of the operator. This assumes that on a single node the work done
|
| 1305 |
+
# would be equivalent, with no concurrency.
|
| 1306 |
+
total_num_out_rows = output_num_rows_stats["sum"]
|
| 1307 |
+
out += indent
|
| 1308 |
+
out += "* Operator throughput:\n"
|
| 1309 |
+
out += (
|
| 1310 |
+
indent + "\t* Ray Data throughput:"
|
| 1311 |
+
f" {total_num_out_rows / self.time_total_s} "
|
| 1312 |
+
"rows/s\n"
|
| 1313 |
+
)
|
| 1314 |
+
out += (
|
| 1315 |
+
indent + "\t* Estimated single node throughput:"
|
| 1316 |
+
f" {total_num_out_rows / wall_time_stats['sum']} "
|
| 1317 |
+
"rows/s\n"
|
| 1318 |
+
)
|
| 1319 |
+
return out
|
| 1320 |
+
|
| 1321 |
+
def __repr__(self, level=0) -> str:
|
| 1322 |
+
"""For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from
|
| 1323 |
+
`OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string
|
| 1324 |
+
that summarizes operator execution statistics.
|
| 1325 |
+
|
| 1326 |
+
Returns:
|
| 1327 |
+
String with summary statistics for executing the given operator.
|
| 1328 |
+
"""
|
| 1329 |
+
indent = leveled_indent(level)
|
| 1330 |
+
indent += leveled_indent(1) if self.is_sub_operator else ""
|
| 1331 |
+
|
| 1332 |
+
wall_time_stats = {k: fmt(v) for k, v in (self.wall_time or {}).items()}
|
| 1333 |
+
cpu_stats = {k: fmt(v) for k, v in (self.cpu_time or {}).items()}
|
| 1334 |
+
memory_stats = {k: fmt(v) for k, v in (self.memory or {}).items()}
|
| 1335 |
+
output_num_rows_stats = {
|
| 1336 |
+
k: fmt(v) for k, v in (self.output_num_rows or {}).items()
|
| 1337 |
+
}
|
| 1338 |
+
output_size_bytes_stats = {
|
| 1339 |
+
k: fmt(v) for k, v in (self.output_size_bytes or {}).items()
|
| 1340 |
+
}
|
| 1341 |
+
node_conut_stats = {k: fmt(v) for k, v in (self.node_count or {}).items()}
|
| 1342 |
+
out = (
|
| 1343 |
+
f"{indent}OperatorStatsSummary(\n"
|
| 1344 |
+
f"{indent} operator_name='{self.operator_name}',\n"
|
| 1345 |
+
f"{indent} is_suboperator={self.is_sub_operator},\n"
|
| 1346 |
+
f"{indent} time_total_s={fmt(self.time_total_s)},\n"
|
| 1347 |
+
# block_execution_summary_str already ends with \n
|
| 1348 |
+
f"{indent} block_execution_summary_str={self.block_execution_summary_str}"
|
| 1349 |
+
f"{indent} wall_time={wall_time_stats or None},\n"
|
| 1350 |
+
f"{indent} cpu_time={cpu_stats or None},\n"
|
| 1351 |
+
f"{indent} memory={memory_stats or None},\n"
|
| 1352 |
+
f"{indent} output_num_rows={output_num_rows_stats or None},\n"
|
| 1353 |
+
f"{indent} output_size_bytes={output_size_bytes_stats or None},\n"
|
| 1354 |
+
f"{indent} node_count={node_conut_stats or None},\n"
|
| 1355 |
+
f"{indent})"
|
| 1356 |
+
)
|
| 1357 |
+
return out
|
| 1358 |
+
|
| 1359 |
+
|
| 1360 |
+
@dataclass
|
| 1361 |
+
class IterStatsSummary:
|
| 1362 |
+
# Time spent in actor based prefetching, in seconds.
|
| 1363 |
+
wait_time: Timer
|
| 1364 |
+
# Time spent in `ray.get()`, in seconds
|
| 1365 |
+
get_time: Timer
|
| 1366 |
+
# Time spent in batch building, in seconds
|
| 1367 |
+
next_time: Timer
|
| 1368 |
+
# Time spent in `_format_batch_()`, in seconds
|
| 1369 |
+
format_time: Timer
|
| 1370 |
+
# Time spent in collate fn, in seconds
|
| 1371 |
+
collate_time: Timer
|
| 1372 |
+
# Time spent in finalize_fn, in seconds
|
| 1373 |
+
finalize_batch_time: Timer
|
| 1374 |
+
# Total time user thread is blocked by iter_batches
|
| 1375 |
+
block_time: Timer
|
| 1376 |
+
# Time spent in user code, in seconds
|
| 1377 |
+
user_time: Timer
|
| 1378 |
+
initialize_time: Timer
|
| 1379 |
+
# Total time taken by Dataset iterator, in seconds
|
| 1380 |
+
total_time: Timer
|
| 1381 |
+
# Time spent in streaming split coordinator
|
| 1382 |
+
streaming_split_coord_time: Timer
|
| 1383 |
+
# Num of blocks that are in local object store
|
| 1384 |
+
iter_blocks_local: int
|
| 1385 |
+
# Num of blocks that are in remote node and have to fetch locally
|
| 1386 |
+
iter_blocks_remote: int
|
| 1387 |
+
# Num of blocks with unknown locations
|
| 1388 |
+
iter_unknown_location: int
|
| 1389 |
+
|
| 1390 |
+
def __str__(self) -> str:
|
| 1391 |
+
return self.to_string()
|
| 1392 |
+
|
| 1393 |
+
def to_string(self) -> str:
|
| 1394 |
+
out = ""
|
| 1395 |
+
if (
|
| 1396 |
+
self.block_time.get()
|
| 1397 |
+
or self.total_time.get()
|
| 1398 |
+
or self.get_time.get()
|
| 1399 |
+
or self.next_time.get()
|
| 1400 |
+
or self.format_time.get()
|
| 1401 |
+
or self.collate_time.get()
|
| 1402 |
+
or self.finalize_batch_time.get()
|
| 1403 |
+
):
|
| 1404 |
+
out += "\nDataset iterator time breakdown:\n"
|
| 1405 |
+
if self.total_time.get():
|
| 1406 |
+
out += "* Total time overall: {}\n".format(fmt(self.total_time.get()))
|
| 1407 |
+
if self.initialize_time.get():
|
| 1408 |
+
out += (
|
| 1409 |
+
" * Total time in Ray Data iterator initialization code: "
|
| 1410 |
+
"{}\n".format(fmt(self.initialize_time.get()))
|
| 1411 |
+
)
|
| 1412 |
+
if self.block_time.get():
|
| 1413 |
+
out += (
|
| 1414 |
+
" * Total time user thread is blocked by Ray Data iter_batches: "
|
| 1415 |
+
"{}\n".format(fmt(self.block_time.get()))
|
| 1416 |
+
)
|
| 1417 |
+
if self.user_time.get():
|
| 1418 |
+
out += " * Total execution time for user thread: {}\n".format(
|
| 1419 |
+
fmt(self.user_time.get())
|
| 1420 |
+
)
|
| 1421 |
+
out += (
|
| 1422 |
+
"* Batch iteration time breakdown (summed across prefetch threads):\n"
|
| 1423 |
+
)
|
| 1424 |
+
if self.get_time.get():
|
| 1425 |
+
out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format(
|
| 1426 |
+
fmt(self.get_time.min()),
|
| 1427 |
+
fmt(self.get_time.max()),
|
| 1428 |
+
fmt(self.get_time.avg()),
|
| 1429 |
+
fmt(self.get_time.get()),
|
| 1430 |
+
)
|
| 1431 |
+
if self.next_time.get():
|
| 1432 |
+
batch_creation_str = (
|
| 1433 |
+
" * In batch creation: {} min, {} max, " "{} avg, {} total\n"
|
| 1434 |
+
)
|
| 1435 |
+
out += batch_creation_str.format(
|
| 1436 |
+
fmt(self.next_time.min()),
|
| 1437 |
+
fmt(self.next_time.max()),
|
| 1438 |
+
fmt(self.next_time.avg()),
|
| 1439 |
+
fmt(self.next_time.get()),
|
| 1440 |
+
)
|
| 1441 |
+
if self.format_time.get():
|
| 1442 |
+
format_str = (
|
| 1443 |
+
" * In batch formatting: {} min, {} max, " "{} avg, {} total\n"
|
| 1444 |
+
)
|
| 1445 |
+
out += format_str.format(
|
| 1446 |
+
fmt(self.format_time.min()),
|
| 1447 |
+
fmt(self.format_time.max()),
|
| 1448 |
+
fmt(self.format_time.avg()),
|
| 1449 |
+
fmt(self.format_time.get()),
|
| 1450 |
+
)
|
| 1451 |
+
if self.collate_time.get():
|
| 1452 |
+
out += " * In collate_fn: {} min, {} max, {} avg, {} total\n".format(
|
| 1453 |
+
fmt(self.collate_time.min()),
|
| 1454 |
+
fmt(self.collate_time.max()),
|
| 1455 |
+
fmt(self.collate_time.avg()),
|
| 1456 |
+
fmt(self.collate_time.get()),
|
| 1457 |
+
)
|
| 1458 |
+
if self.finalize_batch_time.get():
|
| 1459 |
+
format_str = (
|
| 1460 |
+
" * In host->device transfer: {} min, {} max, {} avg, {} total\n"
|
| 1461 |
+
)
|
| 1462 |
+
out += format_str.format(
|
| 1463 |
+
fmt(self.finalize_batch_time.min()),
|
| 1464 |
+
fmt(self.finalize_batch_time.max()),
|
| 1465 |
+
fmt(self.finalize_batch_time.avg()),
|
| 1466 |
+
fmt(self.finalize_batch_time.get()),
|
| 1467 |
+
)
|
| 1468 |
+
if DataContext.get_current().enable_get_object_locations_for_metrics:
|
| 1469 |
+
out += "Block locations:\n"
|
| 1470 |
+
out += " * Num blocks local: {}\n".format(self.iter_blocks_local)
|
| 1471 |
+
out += " * Num blocks remote: {}\n".format(self.iter_blocks_remote)
|
| 1472 |
+
out += " * Num blocks unknown location: {}\n".format(
|
| 1473 |
+
self.iter_unknown_location
|
| 1474 |
+
)
|
| 1475 |
+
if self.streaming_split_coord_time.get() != 0:
|
| 1476 |
+
out += "Streaming split coordinator overhead time: "
|
| 1477 |
+
out += f"{fmt(self.streaming_split_coord_time.get())}\n"
|
| 1478 |
+
|
| 1479 |
+
return out
|
| 1480 |
+
|
| 1481 |
+
def __repr__(self, level=0) -> str:
|
| 1482 |
+
indent = leveled_indent(level)
|
| 1483 |
+
return (
|
| 1484 |
+
f"IterStatsSummary(\n"
|
| 1485 |
+
f"{indent} wait_time={fmt(self.wait_time.get()) or None},\n"
|
| 1486 |
+
f"{indent} get_time={fmt(self.get_time.get()) or None},\n"
|
| 1487 |
+
f"{indent} iter_blocks_local={self.iter_blocks_local or None},\n"
|
| 1488 |
+
f"{indent} iter_blocks_remote={self.iter_blocks_remote or None},\n"
|
| 1489 |
+
f"{indent} iter_unknown_location={self.iter_unknown_location or None},\n"
|
| 1490 |
+
f"{indent} next_time={fmt(self.next_time.get()) or None},\n"
|
| 1491 |
+
f"{indent} format_time={fmt(self.format_time.get()) or None},\n"
|
| 1492 |
+
f"{indent} user_time={fmt(self.user_time.get()) or None},\n"
|
| 1493 |
+
f"{indent} total_time={fmt(self.total_time.get()) or None},\n"
|
| 1494 |
+
f"{indent})"
|
| 1495 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/_internal/table_block.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
from typing import (
|
| 3 |
+
TYPE_CHECKING,
|
| 4 |
+
Any,
|
| 5 |
+
Dict,
|
| 6 |
+
Iterator,
|
| 7 |
+
List,
|
| 8 |
+
Mapping,
|
| 9 |
+
Optional,
|
| 10 |
+
TypeVar,
|
| 11 |
+
Union,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from ray.air.constants import TENSOR_COLUMN_NAME
|
| 17 |
+
from ray.data._internal.block_builder import BlockBuilder
|
| 18 |
+
from ray.data._internal.numpy_support import is_array_like
|
| 19 |
+
from ray.data._internal.row import TableRow
|
| 20 |
+
from ray.data._internal.size_estimator import SizeEstimator
|
| 21 |
+
from ray.data._internal.util import MiB
|
| 22 |
+
from ray.data.block import Block, BlockAccessor
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
T = TypeVar("T")
|
| 29 |
+
|
| 30 |
+
# The max size of Python tuples to buffer before compacting them into a
|
| 31 |
+
# table in the BlockBuilder.
|
| 32 |
+
MAX_UNCOMPACTED_SIZE_BYTES = 50 * MiB
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TableBlockBuilder(BlockBuilder):
|
| 36 |
+
def __init__(self, block_type):
|
| 37 |
+
# The set of uncompacted Python values buffered.
|
| 38 |
+
self._columns = collections.defaultdict(list)
|
| 39 |
+
# The column names of uncompacted Python values buffered.
|
| 40 |
+
self._column_names = None
|
| 41 |
+
# The set of compacted tables we have built so far.
|
| 42 |
+
self._tables: List[Any] = []
|
| 43 |
+
# Cursor into tables indicating up to which table we've accumulated table sizes.
|
| 44 |
+
# This is used to defer table size calculation, which can be expensive for e.g.
|
| 45 |
+
# Pandas DataFrames.
|
| 46 |
+
# This cursor points to the first table for which we haven't accumulated a table
|
| 47 |
+
# size.
|
| 48 |
+
self._tables_size_cursor = 0
|
| 49 |
+
# Accumulated table sizes, up to the table in _tables pointed to by
|
| 50 |
+
# _tables_size_cursor.
|
| 51 |
+
self._tables_size_bytes = 0
|
| 52 |
+
# Size estimator for un-compacted table values.
|
| 53 |
+
self._uncompacted_size = SizeEstimator()
|
| 54 |
+
self._num_rows = 0
|
| 55 |
+
self._num_compactions = 0
|
| 56 |
+
self._block_type = block_type
|
| 57 |
+
|
| 58 |
+
def add(self, item: Union[dict, TableRow, np.ndarray]) -> None:
|
| 59 |
+
if isinstance(item, TableRow):
|
| 60 |
+
item = item.as_pydict()
|
| 61 |
+
elif isinstance(item, np.ndarray):
|
| 62 |
+
item = {TENSOR_COLUMN_NAME: item}
|
| 63 |
+
if not isinstance(item, collections.abc.Mapping):
|
| 64 |
+
raise ValueError(
|
| 65 |
+
"Returned elements of an TableBlock must be of type `dict`, "
|
| 66 |
+
"got {} (type {}).".format(item, type(item))
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
item_column_names = item.keys()
|
| 70 |
+
if self._column_names is not None:
|
| 71 |
+
# Check all added rows have same columns.
|
| 72 |
+
if item_column_names != self._column_names:
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Current row has different columns compared to previous rows. "
|
| 75 |
+
f"Columns of current row: {sorted(item_column_names)}, "
|
| 76 |
+
f"Columns of previous rows: {sorted(self._column_names)}."
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
# Initialize column names with the first added row.
|
| 80 |
+
self._column_names = item_column_names
|
| 81 |
+
|
| 82 |
+
for key, value in item.items():
|
| 83 |
+
if is_array_like(value) and not isinstance(value, np.ndarray):
|
| 84 |
+
value = np.array(value)
|
| 85 |
+
self._columns[key].append(value)
|
| 86 |
+
self._num_rows += 1
|
| 87 |
+
self._compact_if_needed()
|
| 88 |
+
self._uncompacted_size.add(item)
|
| 89 |
+
|
| 90 |
+
def add_block(self, block: Any) -> None:
|
| 91 |
+
if not isinstance(block, self._block_type):
|
| 92 |
+
raise TypeError(
|
| 93 |
+
f"Got a block of type {type(block)}, expected {self._block_type}."
|
| 94 |
+
"If you are mapping a function, ensure it returns an "
|
| 95 |
+
"object with the expected type. Block:\n"
|
| 96 |
+
f"{block}"
|
| 97 |
+
)
|
| 98 |
+
accessor = BlockAccessor.for_block(block)
|
| 99 |
+
self._tables.append(block)
|
| 100 |
+
self._num_rows += accessor.num_rows()
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def _concat_tables(tables: List[Block]) -> Block:
|
| 108 |
+
raise NotImplementedError
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _empty_table() -> Any:
|
| 112 |
+
raise NotImplementedError
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def _concat_would_copy() -> bool:
|
| 116 |
+
raise NotImplementedError
|
| 117 |
+
|
| 118 |
+
def will_build_yield_copy(self) -> bool:
|
| 119 |
+
if self._columns:
|
| 120 |
+
# Building a table from a dict of list columns always creates a copy.
|
| 121 |
+
return True
|
| 122 |
+
return self._concat_would_copy() and len(self._tables) > 1
|
| 123 |
+
|
| 124 |
+
def build(self) -> Block:
|
| 125 |
+
if self._columns:
|
| 126 |
+
tables = [self._table_from_pydict(self._columns)]
|
| 127 |
+
else:
|
| 128 |
+
tables = []
|
| 129 |
+
|
| 130 |
+
tables.extend(self._tables)
|
| 131 |
+
|
| 132 |
+
if len(tables) > 0:
|
| 133 |
+
return self._concat_tables(tables)
|
| 134 |
+
else:
|
| 135 |
+
return self._empty_table()
|
| 136 |
+
|
| 137 |
+
def num_rows(self) -> int:
|
| 138 |
+
return self._num_rows
|
| 139 |
+
|
| 140 |
+
def get_estimated_memory_usage(self) -> int:
|
| 141 |
+
if self._num_rows == 0:
|
| 142 |
+
return 0
|
| 143 |
+
for table in self._tables[self._tables_size_cursor :]:
|
| 144 |
+
self._tables_size_bytes += BlockAccessor.for_block(table).size_bytes()
|
| 145 |
+
self._tables_size_cursor = len(self._tables)
|
| 146 |
+
return self._tables_size_bytes + self._uncompacted_size.size_bytes()
|
| 147 |
+
|
| 148 |
+
def _compact_if_needed(self) -> None:
|
| 149 |
+
assert self._columns
|
| 150 |
+
if self._uncompacted_size.size_bytes() < MAX_UNCOMPACTED_SIZE_BYTES:
|
| 151 |
+
return
|
| 152 |
+
block = self._table_from_pydict(self._columns)
|
| 153 |
+
self.add_block(block)
|
| 154 |
+
self._uncompacted_size = SizeEstimator()
|
| 155 |
+
self._columns.clear()
|
| 156 |
+
self._num_compactions += 1
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class TableBlockAccessor(BlockAccessor):
|
| 160 |
+
ROW_TYPE: TableRow = TableRow
|
| 161 |
+
|
| 162 |
+
def __init__(self, table: Any):
|
| 163 |
+
self._table = table
|
| 164 |
+
|
| 165 |
+
def _get_row(self, index: int, copy: bool = False) -> Union[TableRow, np.ndarray]:
|
| 166 |
+
base_row = self.slice(index, index + 1, copy=copy)
|
| 167 |
+
row = self.ROW_TYPE(base_row)
|
| 168 |
+
return row
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def _munge_conflict(name, count):
|
| 172 |
+
return f"{name}_{count+1}"
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def _build_tensor_row(row: TableRow) -> np.ndarray:
|
| 176 |
+
raise NotImplementedError
|
| 177 |
+
|
| 178 |
+
def to_default(self) -> Block:
|
| 179 |
+
# Always promote Arrow blocks to pandas for consistency, since
|
| 180 |
+
# we lazily convert pandas->Arrow internally for efficiency.
|
| 181 |
+
default = self.to_pandas()
|
| 182 |
+
return default
|
| 183 |
+
|
| 184 |
+
def column_names(self) -> List[str]:
|
| 185 |
+
raise NotImplementedError
|
| 186 |
+
|
| 187 |
+
def append_column(self, name: str, data: Any) -> Block:
|
| 188 |
+
raise NotImplementedError
|
| 189 |
+
|
| 190 |
+
def to_block(self) -> Block:
|
| 191 |
+
return self._table
|
| 192 |
+
|
| 193 |
+
def iter_rows(
|
| 194 |
+
self, public_row_format: bool
|
| 195 |
+
) -> Iterator[Union[Mapping, np.ndarray]]:
|
| 196 |
+
outer = self
|
| 197 |
+
|
| 198 |
+
class Iter:
|
| 199 |
+
def __init__(self):
|
| 200 |
+
self._cur = -1
|
| 201 |
+
|
| 202 |
+
def __iter__(self):
|
| 203 |
+
return self
|
| 204 |
+
|
| 205 |
+
def __next__(self):
|
| 206 |
+
self._cur += 1
|
| 207 |
+
if self._cur < outer.num_rows():
|
| 208 |
+
row = outer._get_row(self._cur)
|
| 209 |
+
if public_row_format and isinstance(row, TableRow):
|
| 210 |
+
return row.as_pydict()
|
| 211 |
+
else:
|
| 212 |
+
return row
|
| 213 |
+
raise StopIteration
|
| 214 |
+
|
| 215 |
+
return Iter()
|
| 216 |
+
|
| 217 |
+
def _zip(self, acc: BlockAccessor) -> "Block":
|
| 218 |
+
raise NotImplementedError
|
| 219 |
+
|
| 220 |
+
def zip(self, other: "Block") -> "Block":
|
| 221 |
+
acc = BlockAccessor.for_block(other)
|
| 222 |
+
if not isinstance(acc, type(self)):
|
| 223 |
+
if isinstance(self, TableBlockAccessor) and isinstance(
|
| 224 |
+
acc, TableBlockAccessor
|
| 225 |
+
):
|
| 226 |
+
# If block types are different, but still both of TableBlock type, try
|
| 227 |
+
# converting both to default block type before zipping.
|
| 228 |
+
self_norm, other_norm = TableBlockAccessor.normalize_block_types(
|
| 229 |
+
[self._table, other],
|
| 230 |
+
)
|
| 231 |
+
return BlockAccessor.for_block(self_norm).zip(other_norm)
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
"Cannot zip {} with block of type {}".format(
|
| 235 |
+
type(self), type(other)
|
| 236 |
+
)
|
| 237 |
+
)
|
| 238 |
+
if acc.num_rows() != self.num_rows():
|
| 239 |
+
raise ValueError(
|
| 240 |
+
"Cannot zip self (length {}) with block of length {}".format(
|
| 241 |
+
self.num_rows(), acc.num_rows()
|
| 242 |
+
)
|
| 243 |
+
)
|
| 244 |
+
return self._zip(acc)
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def _empty_table() -> Any:
|
| 248 |
+
raise NotImplementedError
|
| 249 |
+
|
| 250 |
+
def _sample(self, n_samples: int, sort_key: "SortKey") -> Any:
|
| 251 |
+
raise NotImplementedError
|
| 252 |
+
|
| 253 |
+
def sample(self, n_samples: int, sort_key: "SortKey") -> Any:
|
| 254 |
+
if sort_key is None or callable(sort_key):
|
| 255 |
+
raise NotImplementedError(
|
| 256 |
+
f"Table sort key must be a column name, was: {sort_key}"
|
| 257 |
+
)
|
| 258 |
+
if self.num_rows() == 0:
|
| 259 |
+
# If the pyarrow table is empty we may not have schema
|
| 260 |
+
# so calling table.select() will raise an error.
|
| 261 |
+
return self._empty_table()
|
| 262 |
+
k = min(n_samples, self.num_rows())
|
| 263 |
+
return self._sample(k, sort_key)
|
| 264 |
+
|
| 265 |
+
@classmethod
|
| 266 |
+
def normalize_block_types(
|
| 267 |
+
cls,
|
| 268 |
+
blocks: List[Block],
|
| 269 |
+
normalize_type: Optional[str] = None,
|
| 270 |
+
) -> List[Block]:
|
| 271 |
+
"""Normalize input blocks to the specified `normalize_type`. If the blocks
|
| 272 |
+
are already all of the same type, returns the original blocks.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
blocks: A list of TableBlocks to be normalized.
|
| 276 |
+
normalize_type: The type to normalize the blocks to. If None,
|
| 277 |
+
the default block type (Arrow) is used.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
A list of blocks of the same type.
|
| 281 |
+
"""
|
| 282 |
+
seen_types = set()
|
| 283 |
+
for block in blocks:
|
| 284 |
+
acc = BlockAccessor.for_block(block)
|
| 285 |
+
if not isinstance(acc, TableBlockAccessor):
|
| 286 |
+
raise ValueError(
|
| 287 |
+
"Block type normalization is only supported for TableBlock, "
|
| 288 |
+
f"but received block of type: {type(block)}."
|
| 289 |
+
)
|
| 290 |
+
seen_types.add(type(block))
|
| 291 |
+
|
| 292 |
+
# Return original blocks if they are all of the same type.
|
| 293 |
+
if len(seen_types) <= 1:
|
| 294 |
+
return blocks
|
| 295 |
+
|
| 296 |
+
if normalize_type == "arrow":
|
| 297 |
+
results = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
|
| 298 |
+
elif normalize_type == "pandas":
|
| 299 |
+
results = [BlockAccessor.for_block(block).to_pandas() for block in blocks]
|
| 300 |
+
else:
|
| 301 |
+
results = [BlockAccessor.for_block(block).to_default() for block in blocks]
|
| 302 |
+
|
| 303 |
+
if any(not isinstance(block, type(results[0])) for block in results):
|
| 304 |
+
raise ValueError(
|
| 305 |
+
"Expected all blocks to be of the same type after normalization, but "
|
| 306 |
+
f"got different types: {[type(b) for b in results]}. "
|
| 307 |
+
"Try using blocks of the same type to avoid the issue "
|
| 308 |
+
"with block normalization."
|
| 309 |
+
)
|
| 310 |
+
return results
|
.venv/lib/python3.11/site-packages/ray/data/_internal/torch_iterable_dataset.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import IterableDataset
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TorchIterableDataset(IterableDataset):
|
| 5 |
+
def __init__(self, generator_func):
|
| 6 |
+
self.generator_func = generator_func
|
| 7 |
+
|
| 8 |
+
def __iter__(self):
|
| 9 |
+
it = self.generator_func()
|
| 10 |
+
yield from it
|
.venv/lib/python3.11/site-packages/ray/data/_internal/util.py
ADDED
|
@@ -0,0 +1,1262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
import urllib.parse
|
| 10 |
+
from queue import Empty, Full, Queue
|
| 11 |
+
from types import ModuleType
|
| 12 |
+
from typing import (
|
| 13 |
+
TYPE_CHECKING,
|
| 14 |
+
Any,
|
| 15 |
+
Callable,
|
| 16 |
+
Generator,
|
| 17 |
+
Iterable,
|
| 18 |
+
Iterator,
|
| 19 |
+
List,
|
| 20 |
+
Optional,
|
| 21 |
+
Tuple,
|
| 22 |
+
TypeVar,
|
| 23 |
+
Union,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
import ray
|
| 29 |
+
from ray._private.utils import _get_pyarrow_version
|
| 30 |
+
from ray.data.context import DEFAULT_READ_OP_MIN_NUM_BLOCKS, WARN_PREFIX, DataContext
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
import pandas
|
| 34 |
+
import pyarrow
|
| 35 |
+
|
| 36 |
+
from ray.data._internal.compute import ComputeStrategy
|
| 37 |
+
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
|
| 38 |
+
from ray.data.block import Block, BlockMetadata, UserDefinedFunction
|
| 39 |
+
from ray.data.datasource import Datasource, Reader
|
| 40 |
+
from ray.util.placement_group import PlacementGroup
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
KiB = 1024 # bytes
|
| 46 |
+
MiB = 1024 * KiB
|
| 47 |
+
GiB = 1024 * MiB
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
SENTINEL = object()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# NOTE: Make sure that these lower and upper bounds stay in sync with version
|
| 54 |
+
# constraints given in python/setup.py.
|
| 55 |
+
# Inclusive minimum pyarrow version.
|
| 56 |
+
MIN_PYARROW_VERSION = "6.0.1"
|
| 57 |
+
RAY_DISABLE_PYARROW_VERSION_CHECK = "RAY_DISABLE_PYARROW_VERSION_CHECK"
|
| 58 |
+
_VERSION_VALIDATED = False
|
| 59 |
+
_LOCAL_SCHEME = "local"
|
| 60 |
+
_EXAMPLE_SCHEME = "example"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
LazyModule = Union[None, bool, ModuleType]
|
| 64 |
+
_pyarrow_dataset: LazyModule = None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class _NullSentinel:
|
| 68 |
+
"""Sentinel value that sorts greater than any other value."""
|
| 69 |
+
|
| 70 |
+
def __eq__(self, other):
|
| 71 |
+
return isinstance(other, _NullSentinel)
|
| 72 |
+
|
| 73 |
+
def __lt__(self, other):
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
def __le__(self, other):
|
| 77 |
+
return isinstance(other, _NullSentinel)
|
| 78 |
+
|
| 79 |
+
def __gt__(self, other):
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
def __ge__(self, other):
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
def __hash__(self):
|
| 86 |
+
return id(self)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
NULL_SENTINEL = _NullSentinel()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _lazy_import_pyarrow_dataset() -> LazyModule:
|
| 93 |
+
global _pyarrow_dataset
|
| 94 |
+
if _pyarrow_dataset is None:
|
| 95 |
+
try:
|
| 96 |
+
from pyarrow import dataset as _pyarrow_dataset
|
| 97 |
+
except ModuleNotFoundError:
|
| 98 |
+
# If module is not found, set _pyarrow to False so we won't
|
| 99 |
+
# keep trying to import it on every _lazy_import_pyarrow() call.
|
| 100 |
+
_pyarrow_dataset = False
|
| 101 |
+
return _pyarrow_dataset
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _check_pyarrow_version():
|
| 105 |
+
"""Check that pyarrow's version is within the supported bounds."""
|
| 106 |
+
global _VERSION_VALIDATED
|
| 107 |
+
|
| 108 |
+
if not _VERSION_VALIDATED:
|
| 109 |
+
if os.environ.get(RAY_DISABLE_PYARROW_VERSION_CHECK, "0") == "1":
|
| 110 |
+
_VERSION_VALIDATED = True
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
version = _get_pyarrow_version()
|
| 114 |
+
if version is not None:
|
| 115 |
+
from packaging.version import parse as parse_version
|
| 116 |
+
|
| 117 |
+
if parse_version(version) < parse_version(MIN_PYARROW_VERSION):
|
| 118 |
+
raise ImportError(
|
| 119 |
+
f"Dataset requires pyarrow >= {MIN_PYARROW_VERSION}, but "
|
| 120 |
+
f"{version} is installed. Reinstall with "
|
| 121 |
+
f'`pip install -U "pyarrow"`. '
|
| 122 |
+
"If you want to disable this pyarrow version check, set the "
|
| 123 |
+
f"environment variable {RAY_DISABLE_PYARROW_VERSION_CHECK}=1."
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
logger.warning(
|
| 127 |
+
"You are using the 'pyarrow' module, but the exact version is unknown "
|
| 128 |
+
"(possibly carried as an internal component by another module). Please "
|
| 129 |
+
f"make sure you are using pyarrow >= {MIN_PYARROW_VERSION} to ensure "
|
| 130 |
+
"compatibility with Ray Dataset. "
|
| 131 |
+
"If you want to disable this pyarrow version check, set the "
|
| 132 |
+
f"environment variable {RAY_DISABLE_PYARROW_VERSION_CHECK}=1."
|
| 133 |
+
)
|
| 134 |
+
_VERSION_VALIDATED = True
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _autodetect_parallelism(
|
| 138 |
+
parallelism: int,
|
| 139 |
+
target_max_block_size: int,
|
| 140 |
+
ctx: DataContext,
|
| 141 |
+
datasource_or_legacy_reader: Optional[Union["Datasource", "Reader"]] = None,
|
| 142 |
+
mem_size: Optional[int] = None,
|
| 143 |
+
placement_group: Optional["PlacementGroup"] = None,
|
| 144 |
+
avail_cpus: Optional[int] = None,
|
| 145 |
+
) -> Tuple[int, str, Optional[int]]:
|
| 146 |
+
"""Returns parallelism to use and the min safe parallelism to avoid OOMs.
|
| 147 |
+
|
| 148 |
+
This detects parallelism using the following heuristics, applied in order:
|
| 149 |
+
|
| 150 |
+
1) We start with the default value of 200. This can be overridden by
|
| 151 |
+
setting the `read_op_min_num_blocks` attribute of
|
| 152 |
+
:class:`~ray.data.context.DataContext`.
|
| 153 |
+
2) Min block size. If the parallelism would make blocks smaller than this
|
| 154 |
+
threshold, the parallelism is reduced to avoid the overhead of tiny blocks.
|
| 155 |
+
3) Max block size. If the parallelism would make blocks larger than this
|
| 156 |
+
threshold, the parallelism is increased to avoid OOMs during processing.
|
| 157 |
+
4) Available CPUs. If the parallelism cannot make use of all the available
|
| 158 |
+
CPUs in the cluster, the parallelism is increased until it can.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
parallelism: The user-requested parallelism, or -1 for auto-detection.
|
| 162 |
+
target_max_block_size: The target max block size to
|
| 163 |
+
produce. We pass this separately from the
|
| 164 |
+
DatasetContext because it may be set per-op instead of
|
| 165 |
+
per-Dataset.
|
| 166 |
+
ctx: The current Dataset context to use for configs.
|
| 167 |
+
datasource_or_legacy_reader: The datasource or legacy reader, to be used for
|
| 168 |
+
data size estimation.
|
| 169 |
+
mem_size: If passed, then used to compute the parallelism according to
|
| 170 |
+
target_max_block_size.
|
| 171 |
+
placement_group: The placement group that this Dataset
|
| 172 |
+
will execute inside, if any.
|
| 173 |
+
avail_cpus: Override avail cpus detection (for testing only).
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Tuple of detected parallelism (only if -1 was specified), the reason
|
| 177 |
+
for the detected parallelism (only if -1 was specified), and the estimated
|
| 178 |
+
inmemory size of the dataset.
|
| 179 |
+
"""
|
| 180 |
+
min_safe_parallelism = 1
|
| 181 |
+
max_reasonable_parallelism = sys.maxsize
|
| 182 |
+
if mem_size is None and datasource_or_legacy_reader:
|
| 183 |
+
mem_size = datasource_or_legacy_reader.estimate_inmemory_data_size()
|
| 184 |
+
if mem_size is not None and not np.isnan(mem_size):
|
| 185 |
+
min_safe_parallelism = max(1, int(mem_size / target_max_block_size))
|
| 186 |
+
max_reasonable_parallelism = max(1, int(mem_size / ctx.target_min_block_size))
|
| 187 |
+
|
| 188 |
+
reason = ""
|
| 189 |
+
if parallelism < 0:
|
| 190 |
+
if parallelism != -1:
|
| 191 |
+
raise ValueError("`parallelism` must either be -1 or a positive integer.")
|
| 192 |
+
|
| 193 |
+
if (
|
| 194 |
+
ctx.min_parallelism is not None
|
| 195 |
+
and ctx.min_parallelism != DEFAULT_READ_OP_MIN_NUM_BLOCKS
|
| 196 |
+
and ctx.read_op_min_num_blocks == DEFAULT_READ_OP_MIN_NUM_BLOCKS
|
| 197 |
+
):
|
| 198 |
+
logger.warning(
|
| 199 |
+
"``DataContext.min_parallelism`` is deprecated in Ray 2.10. "
|
| 200 |
+
"Please specify ``DataContext.read_op_min_num_blocks`` instead."
|
| 201 |
+
)
|
| 202 |
+
ctx.read_op_min_num_blocks = ctx.min_parallelism
|
| 203 |
+
|
| 204 |
+
# Start with 2x the number of cores as a baseline, with a min floor.
|
| 205 |
+
if placement_group is None:
|
| 206 |
+
placement_group = ray.util.get_current_placement_group()
|
| 207 |
+
avail_cpus = avail_cpus or _estimate_avail_cpus(placement_group)
|
| 208 |
+
parallelism = max(
|
| 209 |
+
min(ctx.read_op_min_num_blocks, max_reasonable_parallelism),
|
| 210 |
+
min_safe_parallelism,
|
| 211 |
+
avail_cpus * 2,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if parallelism == ctx.read_op_min_num_blocks:
|
| 215 |
+
reason = (
|
| 216 |
+
"DataContext.get_current().read_op_min_num_blocks="
|
| 217 |
+
f"{ctx.read_op_min_num_blocks}"
|
| 218 |
+
)
|
| 219 |
+
elif parallelism == max_reasonable_parallelism:
|
| 220 |
+
reason = (
|
| 221 |
+
"output blocks of size at least "
|
| 222 |
+
"DataContext.get_current().target_min_block_size="
|
| 223 |
+
f"{ctx.target_min_block_size / (1024 * 1024)}MiB"
|
| 224 |
+
)
|
| 225 |
+
elif parallelism == min_safe_parallelism:
|
| 226 |
+
reason = (
|
| 227 |
+
"output blocks of size at most "
|
| 228 |
+
"DataContext.get_current().target_max_block_size="
|
| 229 |
+
f"{ctx.target_max_block_size / (1024 * 1024)}MiB"
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
reason = (
|
| 233 |
+
"parallelism at least twice the available number "
|
| 234 |
+
f"of CPUs ({avail_cpus})"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
logger.debug(
|
| 238 |
+
f"Autodetected parallelism={parallelism} based on "
|
| 239 |
+
f"estimated_available_cpus={avail_cpus} and "
|
| 240 |
+
f"estimated_data_size={mem_size}."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return parallelism, reason, mem_size
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _estimate_avail_cpus(cur_pg: Optional["PlacementGroup"]) -> int:
|
| 247 |
+
"""Estimates the available CPU parallelism for this Dataset in the cluster.
|
| 248 |
+
|
| 249 |
+
If we aren't in a placement group, this is trivially the number of CPUs in the
|
| 250 |
+
cluster. Otherwise, we try to calculate how large the placement group is relative
|
| 251 |
+
to the size of the cluster.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
cur_pg: The current placement group, if any.
|
| 255 |
+
"""
|
| 256 |
+
cluster_cpus = int(ray.cluster_resources().get("CPU", 1))
|
| 257 |
+
cluster_gpus = int(ray.cluster_resources().get("GPU", 0))
|
| 258 |
+
|
| 259 |
+
# If we're in a placement group, we shouldn't assume the entire cluster's
|
| 260 |
+
# resources are available for us to use. Estimate an upper bound on what's
|
| 261 |
+
# reasonable to assume is available for datasets to use.
|
| 262 |
+
if cur_pg:
|
| 263 |
+
pg_cpus = 0
|
| 264 |
+
for bundle in cur_pg.bundle_specs:
|
| 265 |
+
# Calculate the proportion of the cluster this placement group "takes up".
|
| 266 |
+
# Then scale our cluster_cpus proportionally to avoid over-parallelizing
|
| 267 |
+
# if there are many parallel Tune trials using the cluster.
|
| 268 |
+
cpu_fraction = bundle.get("CPU", 0) / max(1, cluster_cpus)
|
| 269 |
+
gpu_fraction = bundle.get("GPU", 0) / max(1, cluster_gpus)
|
| 270 |
+
max_fraction = max(cpu_fraction, gpu_fraction)
|
| 271 |
+
# Over-parallelize by up to a factor of 2, but no more than that. It's
|
| 272 |
+
# preferrable to over-estimate than under-estimate.
|
| 273 |
+
pg_cpus += 2 * int(max_fraction * cluster_cpus)
|
| 274 |
+
|
| 275 |
+
return min(cluster_cpus, pg_cpus)
|
| 276 |
+
|
| 277 |
+
return cluster_cpus
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _estimate_available_parallelism() -> int:
|
| 281 |
+
"""Estimates the available CPU parallelism for this Dataset in the cluster.
|
| 282 |
+
If we are currently in a placement group, take that into account."""
|
| 283 |
+
cur_pg = ray.util.get_current_placement_group()
|
| 284 |
+
return _estimate_avail_cpus(cur_pg)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _warn_on_high_parallelism(requested_parallelism, num_read_tasks):
|
| 288 |
+
available_cpu_slots = ray.available_resources().get("CPU", 1)
|
| 289 |
+
if (
|
| 290 |
+
requested_parallelism
|
| 291 |
+
and num_read_tasks > available_cpu_slots * 4
|
| 292 |
+
and num_read_tasks >= 5000
|
| 293 |
+
):
|
| 294 |
+
logger.warning(
|
| 295 |
+
f"{WARN_PREFIX} The requested parallelism of {requested_parallelism} "
|
| 296 |
+
"is more than 4x the number of available CPU slots in the cluster of "
|
| 297 |
+
f"{available_cpu_slots}. This can "
|
| 298 |
+
"lead to slowdowns during the data reading phase due to excessive "
|
| 299 |
+
"task creation. Reduce the parallelism to match with the available "
|
| 300 |
+
"CPU slots in the cluster, or set parallelism to -1 for Ray Data "
|
| 301 |
+
"to automatically determine the parallelism. "
|
| 302 |
+
"You can ignore this message if the cluster is expected to autoscale."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _check_import(obj, *, module: str, package: str) -> None:
|
| 307 |
+
"""Check if a required dependency is installed.
|
| 308 |
+
|
| 309 |
+
If `module` can't be imported, this function raises an `ImportError` instructing
|
| 310 |
+
the user to install `package` from PyPI.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
obj: The object that has a dependency.
|
| 314 |
+
module: The name of the module to import.
|
| 315 |
+
package: The name of the package on PyPI.
|
| 316 |
+
"""
|
| 317 |
+
try:
|
| 318 |
+
importlib.import_module(module)
|
| 319 |
+
except ImportError:
|
| 320 |
+
raise ImportError(
|
| 321 |
+
f"`{obj.__class__.__name__}` depends on '{package}', but '{package}' "
|
| 322 |
+
f"couldn't be imported. You can install '{package}' by running `pip "
|
| 323 |
+
f"install {package}`."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _resolve_custom_scheme(path: str) -> str:
|
| 328 |
+
"""Returns the resolved path if the given path follows a Ray-specific custom
|
| 329 |
+
scheme. Othewise, returns the path unchanged.
|
| 330 |
+
|
| 331 |
+
The supported custom schemes are: "local", "example".
|
| 332 |
+
"""
|
| 333 |
+
parsed_uri = urllib.parse.urlparse(path)
|
| 334 |
+
if parsed_uri.scheme == _LOCAL_SCHEME:
|
| 335 |
+
path = parsed_uri.netloc + parsed_uri.path
|
| 336 |
+
elif parsed_uri.scheme == _EXAMPLE_SCHEME:
|
| 337 |
+
example_data_path = pathlib.Path(__file__).parent.parent / "examples" / "data"
|
| 338 |
+
path = example_data_path / (parsed_uri.netloc + parsed_uri.path)
|
| 339 |
+
path = str(path.resolve())
|
| 340 |
+
return path
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _is_local_scheme(paths: Union[str, List[str]]) -> bool:
|
| 344 |
+
"""Returns True if the given paths are in local scheme.
|
| 345 |
+
Note: The paths must be in same scheme, i.e. it's invalid and
|
| 346 |
+
will raise error if paths are mixed with different schemes.
|
| 347 |
+
"""
|
| 348 |
+
if isinstance(paths, str):
|
| 349 |
+
paths = [paths]
|
| 350 |
+
if isinstance(paths, pathlib.Path):
|
| 351 |
+
paths = [str(paths)]
|
| 352 |
+
elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths):
|
| 353 |
+
raise ValueError("paths must be a path string or a list of path strings.")
|
| 354 |
+
elif len(paths) == 0:
|
| 355 |
+
raise ValueError("Must provide at least one path.")
|
| 356 |
+
num = sum(urllib.parse.urlparse(path).scheme == _LOCAL_SCHEME for path in paths)
|
| 357 |
+
if num > 0 and num < len(paths):
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"The paths must all be local-scheme or not local-scheme, "
|
| 360 |
+
f"but found mixed {paths}"
|
| 361 |
+
)
|
| 362 |
+
return num == len(paths)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _truncated_repr(obj: Any) -> str:
|
| 366 |
+
"""Utility to return a truncated object representation for error messages."""
|
| 367 |
+
msg = str(obj)
|
| 368 |
+
if len(msg) > 200:
|
| 369 |
+
msg = msg[:200] + "..."
|
| 370 |
+
return msg
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def _insert_doc_at_pattern(
|
| 374 |
+
obj,
|
| 375 |
+
*,
|
| 376 |
+
message: str,
|
| 377 |
+
pattern: str,
|
| 378 |
+
insert_after: bool = True,
|
| 379 |
+
directive: Optional[str] = None,
|
| 380 |
+
skip_matches: int = 0,
|
| 381 |
+
) -> str:
|
| 382 |
+
if "\n" in message:
|
| 383 |
+
raise ValueError(
|
| 384 |
+
"message shouldn't contain any newlines, since this function will insert "
|
| 385 |
+
f"its own linebreaks when text wrapping: {message}"
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
doc = obj.__doc__.strip()
|
| 389 |
+
if not doc:
|
| 390 |
+
doc = ""
|
| 391 |
+
|
| 392 |
+
if pattern == "" and insert_after:
|
| 393 |
+
# Empty pattern + insert_after means that we want to append the message to the
|
| 394 |
+
# end of the docstring.
|
| 395 |
+
head = doc
|
| 396 |
+
tail = ""
|
| 397 |
+
else:
|
| 398 |
+
tail = doc
|
| 399 |
+
i = tail.find(pattern)
|
| 400 |
+
skip_matches_left = skip_matches
|
| 401 |
+
while i != -1:
|
| 402 |
+
if insert_after:
|
| 403 |
+
# Set offset to the first character after the pattern.
|
| 404 |
+
offset = i + len(pattern)
|
| 405 |
+
else:
|
| 406 |
+
# Set offset to the first character in the matched line.
|
| 407 |
+
offset = tail[:i].rfind("\n") + 1
|
| 408 |
+
head = tail[:offset]
|
| 409 |
+
tail = tail[offset:]
|
| 410 |
+
skip_matches_left -= 1
|
| 411 |
+
if skip_matches_left <= 0:
|
| 412 |
+
break
|
| 413 |
+
elif not insert_after:
|
| 414 |
+
# Move past the found pattern, since we're skipping it.
|
| 415 |
+
tail = tail[i - offset + len(pattern) :]
|
| 416 |
+
i = tail.find(pattern)
|
| 417 |
+
else:
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"Pattern {pattern} not found after {skip_matches} skips in docstring "
|
| 420 |
+
f"{doc}"
|
| 421 |
+
)
|
| 422 |
+
# Get indentation of the to-be-inserted text.
|
| 423 |
+
after_lines = list(filter(bool, tail.splitlines()))
|
| 424 |
+
if len(after_lines) > 0:
|
| 425 |
+
lines = after_lines
|
| 426 |
+
else:
|
| 427 |
+
lines = list(filter(bool, reversed(head.splitlines())))
|
| 428 |
+
# Should always have at least one non-empty line in the docstring.
|
| 429 |
+
assert len(lines) > 0
|
| 430 |
+
indent = " " * (len(lines[0]) - len(lines[0].lstrip()))
|
| 431 |
+
# Handle directive.
|
| 432 |
+
message = message.strip("\n")
|
| 433 |
+
if directive is not None:
|
| 434 |
+
base = f"{indent}.. {directive}::\n"
|
| 435 |
+
message = message.replace("\n", "\n" + indent + " " * 4)
|
| 436 |
+
message = base + indent + " " * 4 + message
|
| 437 |
+
else:
|
| 438 |
+
message = indent + message.replace("\n", "\n" + indent)
|
| 439 |
+
# Add two blank lines before/after message, if necessary.
|
| 440 |
+
if insert_after ^ (pattern == "\n\n"):
|
| 441 |
+
# Only two blank lines before message if:
|
| 442 |
+
# 1. Inserting message after pattern and pattern is not two blank lines.
|
| 443 |
+
# 2. Inserting message before pattern and pattern is two blank lines.
|
| 444 |
+
message = "\n\n" + message
|
| 445 |
+
if (not insert_after) ^ (pattern == "\n\n"):
|
| 446 |
+
# Only two blank lines after message if:
|
| 447 |
+
# 1. Inserting message before pattern and pattern is not two blank lines.
|
| 448 |
+
# 2. Inserting message after pattern and pattern is two blank lines.
|
| 449 |
+
message = message + "\n\n"
|
| 450 |
+
|
| 451 |
+
# Insert message before/after pattern.
|
| 452 |
+
parts = [head, message, tail]
|
| 453 |
+
# Build new docstring.
|
| 454 |
+
obj.__doc__ = "".join(parts)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _consumption_api(
|
| 458 |
+
if_more_than_read: bool = False,
|
| 459 |
+
datasource_metadata: Optional[str] = None,
|
| 460 |
+
extra_condition: Optional[str] = None,
|
| 461 |
+
delegate: Optional[str] = None,
|
| 462 |
+
pattern="Examples:",
|
| 463 |
+
insert_after=False,
|
| 464 |
+
):
|
| 465 |
+
"""Annotate the function with an indication that it's a consumption API, and that it
|
| 466 |
+
will trigger Dataset execution.
|
| 467 |
+
"""
|
| 468 |
+
base = (
|
| 469 |
+
" will trigger execution of the lazy transformations performed on "
|
| 470 |
+
"this dataset."
|
| 471 |
+
)
|
| 472 |
+
if delegate:
|
| 473 |
+
message = delegate + base
|
| 474 |
+
elif not if_more_than_read:
|
| 475 |
+
message = "This operation" + base
|
| 476 |
+
else:
|
| 477 |
+
condition = "If this dataset consists of more than a read, "
|
| 478 |
+
if datasource_metadata is not None:
|
| 479 |
+
condition += (
|
| 480 |
+
f"or if the {datasource_metadata} can't be determined from the "
|
| 481 |
+
"metadata provided by the datasource, "
|
| 482 |
+
)
|
| 483 |
+
if extra_condition is not None:
|
| 484 |
+
condition += extra_condition + ", "
|
| 485 |
+
message = condition + "then this operation" + base
|
| 486 |
+
|
| 487 |
+
def wrap(obj):
|
| 488 |
+
_insert_doc_at_pattern(
|
| 489 |
+
obj,
|
| 490 |
+
message=message,
|
| 491 |
+
pattern=pattern,
|
| 492 |
+
insert_after=insert_after,
|
| 493 |
+
directive="note",
|
| 494 |
+
)
|
| 495 |
+
return obj
|
| 496 |
+
|
| 497 |
+
return wrap
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def ConsumptionAPI(*args, **kwargs):
|
| 501 |
+
"""Annotate the function with an indication that it's a consumption API, and that it
|
| 502 |
+
will trigger Dataset execution.
|
| 503 |
+
"""
|
| 504 |
+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
| 505 |
+
return _consumption_api()(args[0])
|
| 506 |
+
return _consumption_api(*args, **kwargs)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _all_to_all_api(*args, **kwargs):
|
| 510 |
+
"""Annotate the function with an indication that it's a all to all API, and that it
|
| 511 |
+
is an operation that requires all inputs to be materialized in-memory to execute.
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
def wrap(obj):
|
| 515 |
+
_insert_doc_at_pattern(
|
| 516 |
+
obj,
|
| 517 |
+
message=(
|
| 518 |
+
"This operation requires all inputs to be "
|
| 519 |
+
"materialized in object store for it to execute."
|
| 520 |
+
),
|
| 521 |
+
pattern="Examples:",
|
| 522 |
+
insert_after=False,
|
| 523 |
+
directive="note",
|
| 524 |
+
)
|
| 525 |
+
return obj
|
| 526 |
+
|
| 527 |
+
return wrap
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def AllToAllAPI(*args, **kwargs):
|
| 531 |
+
"""Annotate the function with an indication that it's a all to all API, and that it
|
| 532 |
+
is an operation that requires all inputs to be materialized in-memory to execute.
|
| 533 |
+
"""
|
| 534 |
+
# This should only be used as a decorator for dataset methods.
|
| 535 |
+
assert len(args) == 1 and len(kwargs) == 0 and callable(args[0])
|
| 536 |
+
return _all_to_all_api()(args[0])
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def get_compute_strategy(
|
| 540 |
+
fn: "UserDefinedFunction",
|
| 541 |
+
fn_constructor_args: Optional[Iterable[Any]] = None,
|
| 542 |
+
compute: Optional[Union[str, "ComputeStrategy"]] = None,
|
| 543 |
+
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
|
| 544 |
+
) -> "ComputeStrategy":
|
| 545 |
+
"""Get `ComputeStrategy` based on the function or class, and concurrency
|
| 546 |
+
information.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
fn: The function or generator to apply to a record batch, or a class type
|
| 550 |
+
that can be instantiated to create such a callable.
|
| 551 |
+
fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
|
| 552 |
+
compute: Either "tasks" (default) to use Ray Tasks or an
|
| 553 |
+
:class:`~ray.data.ActorPoolStrategy` to use an autoscaling actor pool.
|
| 554 |
+
concurrency: The number of Ray workers to use concurrently.
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
The `ComputeStrategy` for execution.
|
| 558 |
+
"""
|
| 559 |
+
# Lazily import these objects to avoid circular imports.
|
| 560 |
+
from ray.data._internal.compute import ActorPoolStrategy, TaskPoolStrategy
|
| 561 |
+
from ray.data.block import CallableClass
|
| 562 |
+
|
| 563 |
+
if isinstance(fn, CallableClass):
|
| 564 |
+
is_callable_class = True
|
| 565 |
+
else:
|
| 566 |
+
# TODO(chengsu): disallow object that is not a function. For example,
|
| 567 |
+
# An object instance of class often indicates a bug in user code.
|
| 568 |
+
is_callable_class = False
|
| 569 |
+
if fn_constructor_args is not None:
|
| 570 |
+
raise ValueError(
|
| 571 |
+
"``fn_constructor_args`` can only be specified if providing a "
|
| 572 |
+
f"callable class instance for ``fn``, but got: {fn}."
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
if compute is not None:
|
| 576 |
+
# Legacy code path to support `compute` argument.
|
| 577 |
+
logger.warning(
|
| 578 |
+
"The argument ``compute`` is deprecated in Ray 2.9. Please specify "
|
| 579 |
+
"argument ``concurrency`` instead. For more information, see "
|
| 580 |
+
"https://docs.ray.io/en/master/data/transforming-data.html#"
|
| 581 |
+
"stateful-transforms."
|
| 582 |
+
)
|
| 583 |
+
if is_callable_class and (
|
| 584 |
+
compute == "tasks" or isinstance(compute, TaskPoolStrategy)
|
| 585 |
+
):
|
| 586 |
+
raise ValueError(
|
| 587 |
+
"``compute`` must specify an actor compute strategy when using a "
|
| 588 |
+
f"callable class, but got: {compute}. For example, use "
|
| 589 |
+
"``compute=ray.data.ActorPoolStrategy(size=n)``."
|
| 590 |
+
)
|
| 591 |
+
elif not is_callable_class and (
|
| 592 |
+
compute == "actors" or isinstance(compute, ActorPoolStrategy)
|
| 593 |
+
):
|
| 594 |
+
raise ValueError(
|
| 595 |
+
f"``compute`` is specified as the actor compute strategy: {compute}, "
|
| 596 |
+
f"but ``fn`` is not a callable class: {fn}. Pass a callable class or "
|
| 597 |
+
"use the default ``compute`` strategy."
|
| 598 |
+
)
|
| 599 |
+
return compute
|
| 600 |
+
elif concurrency is not None:
|
| 601 |
+
if isinstance(concurrency, tuple):
|
| 602 |
+
if (
|
| 603 |
+
len(concurrency) == 2
|
| 604 |
+
and isinstance(concurrency[0], int)
|
| 605 |
+
and isinstance(concurrency[1], int)
|
| 606 |
+
):
|
| 607 |
+
if is_callable_class:
|
| 608 |
+
return ActorPoolStrategy(
|
| 609 |
+
min_size=concurrency[0], max_size=concurrency[1]
|
| 610 |
+
)
|
| 611 |
+
else:
|
| 612 |
+
raise ValueError(
|
| 613 |
+
"``concurrency`` is set as a tuple of integers, but ``fn`` "
|
| 614 |
+
f"is not a callable class: {fn}. Use ``concurrency=n`` to "
|
| 615 |
+
"control maximum number of workers to use."
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
raise ValueError(
|
| 619 |
+
"``concurrency`` is expected to be set as a tuple of "
|
| 620 |
+
f"integers, but got: {concurrency}."
|
| 621 |
+
)
|
| 622 |
+
elif isinstance(concurrency, int):
|
| 623 |
+
if is_callable_class:
|
| 624 |
+
return ActorPoolStrategy(size=concurrency)
|
| 625 |
+
else:
|
| 626 |
+
return TaskPoolStrategy(size=concurrency)
|
| 627 |
+
else:
|
| 628 |
+
raise ValueError(
|
| 629 |
+
"``concurrency`` is expected to be set as an integer or a "
|
| 630 |
+
f"tuple of integers, but got: {concurrency}."
|
| 631 |
+
)
|
| 632 |
+
else:
|
| 633 |
+
if is_callable_class:
|
| 634 |
+
raise ValueError(
|
| 635 |
+
"``concurrency`` must be specified when using a callable class. "
|
| 636 |
+
"For example, use ``concurrency=n`` for a pool of ``n`` workers."
|
| 637 |
+
)
|
| 638 |
+
else:
|
| 639 |
+
return TaskPoolStrategy()
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def capfirst(s: str):
|
| 643 |
+
"""Capitalize the first letter of a string
|
| 644 |
+
|
| 645 |
+
Args:
|
| 646 |
+
s: String to capitalize
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
Capitalized string
|
| 650 |
+
"""
|
| 651 |
+
return s[0].upper() + s[1:]
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def capitalize(s: str):
|
| 655 |
+
"""Capitalize a string, removing '_' and keeping camelcase.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
s: String to capitalize
|
| 659 |
+
|
| 660 |
+
Returns:
|
| 661 |
+
Capitalized string with no underscores.
|
| 662 |
+
"""
|
| 663 |
+
return "".join(capfirst(x) for x in s.split("_"))
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def pandas_df_to_arrow_block(df: "pandas.DataFrame") -> "Block":
|
| 667 |
+
from ray.data.block import BlockAccessor, BlockExecStats
|
| 668 |
+
|
| 669 |
+
block = BlockAccessor.for_block(df).to_arrow()
|
| 670 |
+
stats = BlockExecStats.builder()
|
| 671 |
+
return (
|
| 672 |
+
block,
|
| 673 |
+
BlockAccessor.for_block(block).get_metadata(exec_stats=stats.build()),
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def ndarray_to_block(ndarray: np.ndarray, ctx: DataContext) -> "Block":
|
| 678 |
+
from ray.data.block import BlockAccessor, BlockExecStats
|
| 679 |
+
|
| 680 |
+
DataContext._set_current(ctx)
|
| 681 |
+
|
| 682 |
+
stats = BlockExecStats.builder()
|
| 683 |
+
block = BlockAccessor.batch_to_block({"data": ndarray})
|
| 684 |
+
metadata = BlockAccessor.for_block(block).get_metadata(exec_stats=stats.build())
|
| 685 |
+
return block, metadata
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def get_table_block_metadata(
|
| 689 |
+
table: Union["pyarrow.Table", "pandas.DataFrame"]
|
| 690 |
+
) -> "BlockMetadata":
|
| 691 |
+
from ray.data.block import BlockAccessor, BlockExecStats
|
| 692 |
+
|
| 693 |
+
stats = BlockExecStats.builder()
|
| 694 |
+
return BlockAccessor.for_block(table).get_metadata(exec_stats=stats.build())
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def unify_block_metadata_schema(
|
| 698 |
+
metadata: List["BlockMetadata"],
|
| 699 |
+
) -> Optional[Union[type, "pyarrow.lib.Schema"]]:
|
| 700 |
+
"""For the input list of BlockMetadata, return a unified schema of the
|
| 701 |
+
corresponding blocks. If the metadata have no valid schema, returns None.
|
| 702 |
+
"""
|
| 703 |
+
# Some blocks could be empty, in which case we cannot get their schema.
|
| 704 |
+
# TODO(ekl) validate schema is the same across different blocks.
|
| 705 |
+
from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
|
| 706 |
+
|
| 707 |
+
# First check if there are blocks with computed schemas, then unify
|
| 708 |
+
# valid schemas from all such blocks.
|
| 709 |
+
schemas_to_unify = []
|
| 710 |
+
for m in metadata:
|
| 711 |
+
if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
|
| 712 |
+
schemas_to_unify.append(m.schema)
|
| 713 |
+
if schemas_to_unify:
|
| 714 |
+
# Check valid pyarrow installation before attempting schema unification
|
| 715 |
+
try:
|
| 716 |
+
import pyarrow as pa
|
| 717 |
+
except ImportError:
|
| 718 |
+
pa = None
|
| 719 |
+
# If the result contains PyArrow schemas, unify them
|
| 720 |
+
if pa is not None and all(isinstance(s, pa.Schema) for s in schemas_to_unify):
|
| 721 |
+
return unify_schemas(schemas_to_unify)
|
| 722 |
+
# Otherwise, if the resulting schemas are simple types (e.g. int),
|
| 723 |
+
# return the first schema.
|
| 724 |
+
return schemas_to_unify[0]
|
| 725 |
+
return None
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def find_partition_index(
|
| 729 |
+
table: Union["pyarrow.Table", "pandas.DataFrame"],
|
| 730 |
+
desired: Tuple[Union[int, float]],
|
| 731 |
+
sort_key: "SortKey",
|
| 732 |
+
) -> int:
|
| 733 |
+
"""For the given block, find the index where the desired value should be
|
| 734 |
+
added, to maintain sorted order.
|
| 735 |
+
|
| 736 |
+
We do this by iterating over each column, starting with the primary sort key,
|
| 737 |
+
and binary searching for the desired value in the column. Each binary search
|
| 738 |
+
shortens the "range" of indices (represented by ``left`` and ``right``, which
|
| 739 |
+
are indices of rows) where the desired value could be inserted.
|
| 740 |
+
|
| 741 |
+
Args:
|
| 742 |
+
table: The block to search in.
|
| 743 |
+
desired: A single tuple representing the boundary to partition at.
|
| 744 |
+
``len(desired)`` must be less than or equal to the number of columns
|
| 745 |
+
being sorted.
|
| 746 |
+
sort_key: The sort key to use for sorting, providing the columns to be
|
| 747 |
+
sorted and their directions.
|
| 748 |
+
|
| 749 |
+
Returns:
|
| 750 |
+
The index where the desired value should be inserted to maintain sorted
|
| 751 |
+
order.
|
| 752 |
+
"""
|
| 753 |
+
columns = sort_key.get_columns()
|
| 754 |
+
descending = sort_key.get_descending()
|
| 755 |
+
|
| 756 |
+
left, right = 0, len(table)
|
| 757 |
+
for i in range(len(desired)):
|
| 758 |
+
if left == right:
|
| 759 |
+
return right
|
| 760 |
+
col_name = columns[i]
|
| 761 |
+
col_vals = table[col_name].to_numpy()[left:right]
|
| 762 |
+
desired_val = desired[i]
|
| 763 |
+
|
| 764 |
+
# Handle null values - replace them with sentinel values
|
| 765 |
+
if desired_val is None:
|
| 766 |
+
desired_val = NULL_SENTINEL
|
| 767 |
+
|
| 768 |
+
# Replace None/NaN values in col_vals with sentinel
|
| 769 |
+
null_mask = col_vals == None # noqa: E711
|
| 770 |
+
if null_mask.any():
|
| 771 |
+
col_vals = col_vals.copy() # Make a copy to avoid modifying original
|
| 772 |
+
col_vals[null_mask] = NULL_SENTINEL
|
| 773 |
+
|
| 774 |
+
prevleft = left
|
| 775 |
+
if descending[i] is True:
|
| 776 |
+
# ``np.searchsorted`` expects the array to be sorted in ascending
|
| 777 |
+
# order, so we pass ``sorter``, which is an array of integer indices
|
| 778 |
+
# that sort ``col_vals`` into ascending order. The returned index
|
| 779 |
+
# is an index into the ascending order of ``col_vals``, so we need
|
| 780 |
+
# to subtract it from ``len(col_vals)`` to get the index in the
|
| 781 |
+
# original descending order of ``col_vals``.
|
| 782 |
+
left = prevleft + (
|
| 783 |
+
len(col_vals)
|
| 784 |
+
- np.searchsorted(
|
| 785 |
+
col_vals,
|
| 786 |
+
desired_val,
|
| 787 |
+
side="right",
|
| 788 |
+
sorter=np.arange(len(col_vals) - 1, -1, -1),
|
| 789 |
+
)
|
| 790 |
+
)
|
| 791 |
+
right = prevleft + (
|
| 792 |
+
len(col_vals)
|
| 793 |
+
- np.searchsorted(
|
| 794 |
+
col_vals,
|
| 795 |
+
desired_val,
|
| 796 |
+
side="left",
|
| 797 |
+
sorter=np.arange(len(col_vals) - 1, -1, -1),
|
| 798 |
+
)
|
| 799 |
+
)
|
| 800 |
+
else:
|
| 801 |
+
left = prevleft + np.searchsorted(col_vals, desired_val, side="left")
|
| 802 |
+
right = prevleft + np.searchsorted(col_vals, desired_val, side="right")
|
| 803 |
+
return right if descending[0] is True else left
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def find_partitions(
|
| 807 |
+
table: Union["pyarrow.Table", "pandas.DataFrame"],
|
| 808 |
+
boundaries: List[Tuple[Union[int, float]]],
|
| 809 |
+
sort_key: "SortKey",
|
| 810 |
+
):
|
| 811 |
+
partitions = []
|
| 812 |
+
|
| 813 |
+
# For each boundary value, count the number of items that are less
|
| 814 |
+
# than it. Since the block is sorted, these counts partition the items
|
| 815 |
+
# such that boundaries[i] <= x < boundaries[i + 1] for each x in
|
| 816 |
+
# partition[i]. If `descending` is true, `boundaries` would also be
|
| 817 |
+
# in descending order and we only need to count the number of items
|
| 818 |
+
# *greater than* the boundary value instead.
|
| 819 |
+
bounds = [
|
| 820 |
+
find_partition_index(table, boundary, sort_key) for boundary in boundaries
|
| 821 |
+
]
|
| 822 |
+
|
| 823 |
+
last_idx = 0
|
| 824 |
+
for idx in bounds:
|
| 825 |
+
partitions.append(table[last_idx:idx])
|
| 826 |
+
last_idx = idx
|
| 827 |
+
partitions.append(table[last_idx:])
|
| 828 |
+
return partitions
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def get_attribute_from_class_name(class_name: str) -> Any:
|
| 832 |
+
"""Get Python attribute from the provided class name.
|
| 833 |
+
|
| 834 |
+
The caller needs to make sure the provided class name includes
|
| 835 |
+
full module name, and can be imported successfully.
|
| 836 |
+
"""
|
| 837 |
+
from importlib import import_module
|
| 838 |
+
|
| 839 |
+
paths = class_name.split(".")
|
| 840 |
+
if len(paths) < 2:
|
| 841 |
+
raise ValueError(f"Cannot create object from {class_name}.")
|
| 842 |
+
|
| 843 |
+
module_name = ".".join(paths[:-1])
|
| 844 |
+
attribute_name = paths[-1]
|
| 845 |
+
return getattr(import_module(module_name), attribute_name)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
T = TypeVar("T")
|
| 849 |
+
U = TypeVar("U")
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
class _InterruptibleQueue(Queue):
|
| 853 |
+
"""Extension of Python's `queue.Queue` providing ability to get interrupt its
|
| 854 |
+
method callers in other threads"""
|
| 855 |
+
|
| 856 |
+
INTERRUPTION_CHECK_FREQUENCY_SEC = 0.5
|
| 857 |
+
|
| 858 |
+
def __init__(
|
| 859 |
+
self, max_size: int, interrupted_event: Optional[threading.Event] = None
|
| 860 |
+
):
|
| 861 |
+
super().__init__(maxsize=max_size)
|
| 862 |
+
self._interrupted_event = interrupted_event or threading.Event()
|
| 863 |
+
|
| 864 |
+
def get(self, block=True, timeout=None):
|
| 865 |
+
if not block or timeout is not None:
|
| 866 |
+
return super().get(block, timeout)
|
| 867 |
+
|
| 868 |
+
# In case when the call is blocking and no timeout is specified (ie blocking
|
| 869 |
+
# indefinitely) we apply the following protocol to make it interruptible:
|
| 870 |
+
#
|
| 871 |
+
# 1. `Queue.get` is invoked w/ 500ms timeout
|
| 872 |
+
# 2. `Empty` exception is intercepted (will be raised upon timeout elapsing)
|
| 873 |
+
# 3. If interrupted flag is set `InterruptedError` is raised
|
| 874 |
+
# 4. Otherwise, protocol retried (until interrupted or queue
|
| 875 |
+
# becoming non-empty)
|
| 876 |
+
while True:
|
| 877 |
+
if self._interrupted_event.is_set():
|
| 878 |
+
raise InterruptedError()
|
| 879 |
+
|
| 880 |
+
try:
|
| 881 |
+
return super().get(
|
| 882 |
+
block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
|
| 883 |
+
)
|
| 884 |
+
except Empty:
|
| 885 |
+
pass
|
| 886 |
+
|
| 887 |
+
def put(self, item, block=True, timeout=None):
|
| 888 |
+
if not block or timeout is not None:
|
| 889 |
+
super().put(item, block, timeout)
|
| 890 |
+
return
|
| 891 |
+
|
| 892 |
+
# In case when the call is blocking and no timeout is specified (ie blocking
|
| 893 |
+
# indefinitely) we apply the following protocol to make it interruptible:
|
| 894 |
+
#
|
| 895 |
+
# 1. `Queue.pet` is invoked w/ 500ms timeout
|
| 896 |
+
# 2. `Full` exception is intercepted (will be raised upon timeout elapsing)
|
| 897 |
+
# 3. If interrupted flag is set `InterruptedError` is raised
|
| 898 |
+
# 4. Otherwise, protocol retried (until interrupted or queue
|
| 899 |
+
# becomes non-full)
|
| 900 |
+
while True:
|
| 901 |
+
if self._interrupted_event.is_set():
|
| 902 |
+
raise InterruptedError()
|
| 903 |
+
|
| 904 |
+
try:
|
| 905 |
+
super().put(
|
| 906 |
+
item, block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
|
| 907 |
+
)
|
| 908 |
+
return
|
| 909 |
+
except Full:
|
| 910 |
+
pass
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
def make_async_gen(
|
| 914 |
+
base_iterator: Iterator[T],
|
| 915 |
+
fn: Callable[[Iterator[T]], Iterator[U]],
|
| 916 |
+
num_workers: int = 1,
|
| 917 |
+
queue_buffer_size: int = 2,
|
| 918 |
+
) -> Generator[U, None, None]:
|
| 919 |
+
|
| 920 |
+
gen_id = random.randint(0, 2**31 - 1)
|
| 921 |
+
|
| 922 |
+
"""Returns a generator (iterator) mapping items from the
|
| 923 |
+
provided iterator applying provided transformation in parallel (using a
|
| 924 |
+
thread-pool).
|
| 925 |
+
|
| 926 |
+
NOTE: Even though the mapping is performed in parallel across N
|
| 927 |
+
threads, this method provides crucial guarantee of preserving the
|
| 928 |
+
ordering of the source iterator, ie that
|
| 929 |
+
|
| 930 |
+
iterator = [A1, A2, ... An]
|
| 931 |
+
mapped iterator = [map(A1), map(A2), ..., map(An)]
|
| 932 |
+
|
| 933 |
+
Preserving ordering is crucial to eliminate non-determinism in producing
|
| 934 |
+
content of the blocks.
|
| 935 |
+
|
| 936 |
+
Args:
|
| 937 |
+
base_iterator: Iterator yielding elements to map
|
| 938 |
+
fn: Transformation to apply to each element
|
| 939 |
+
num_workers: The number of threads to use in the threadpool (defaults to 1)
|
| 940 |
+
buffer_size: Number of objects to be buffered in its input/output
|
| 941 |
+
queues (per queue; defaults to 2). Total number of objects held
|
| 942 |
+
in memory could be calculated as:
|
| 943 |
+
|
| 944 |
+
num_workers * buffer_size * 2 (input and output)
|
| 945 |
+
|
| 946 |
+
Returns:
|
| 947 |
+
An generator (iterator) of the elements corresponding to the source
|
| 948 |
+
elements mapped by provided transformation (while *preserving the ordering*)
|
| 949 |
+
"""
|
| 950 |
+
|
| 951 |
+
if num_workers < 1:
|
| 952 |
+
raise ValueError("Size of threadpool must be at least 1.")
|
| 953 |
+
|
| 954 |
+
# To apply transformations to elements in parallel *and* preserve the ordering
|
| 955 |
+
# following invariants are established:
|
| 956 |
+
# - Every worker is handled by standalone thread
|
| 957 |
+
# - Every worker is assigned an input and an output queue
|
| 958 |
+
#
|
| 959 |
+
# And following protocol is implemented:
|
| 960 |
+
# - Filling worker traverses input iterator round-robin'ing elements across
|
| 961 |
+
# the input queues (in order!)
|
| 962 |
+
# - Transforming workers traverse respective input queue in-order: de-queueing
|
| 963 |
+
# element, applying transformation and enqueuing the result into the output
|
| 964 |
+
# queue
|
| 965 |
+
# - Generator (returned from this method) traverses output queues (in the same
|
| 966 |
+
# order as input queues) dequeues 1 mapped element at a time from each output
|
| 967 |
+
# queue and yields it
|
| 968 |
+
#
|
| 969 |
+
# Signal handler used to interrupt workers when terminating
|
| 970 |
+
interrupted_event = threading.Event()
|
| 971 |
+
|
| 972 |
+
input_queues = [
|
| 973 |
+
_InterruptibleQueue(queue_buffer_size, interrupted_event)
|
| 974 |
+
for _ in range(num_workers)
|
| 975 |
+
]
|
| 976 |
+
output_queues = [
|
| 977 |
+
_InterruptibleQueue(queue_buffer_size, interrupted_event)
|
| 978 |
+
for _ in range(num_workers)
|
| 979 |
+
]
|
| 980 |
+
|
| 981 |
+
# Filling worker
|
| 982 |
+
def _run_filling_worker():
|
| 983 |
+
try:
|
| 984 |
+
# First, round-robin elements from the iterator into
|
| 985 |
+
# corresponding input queues (one by one)
|
| 986 |
+
for idx, item in enumerate(base_iterator):
|
| 987 |
+
input_queues[idx % num_workers].put(item)
|
| 988 |
+
|
| 989 |
+
# Enqueue sentinel objects to signal end of the line
|
| 990 |
+
for idx in range(num_workers):
|
| 991 |
+
input_queues[idx].put(SENTINEL)
|
| 992 |
+
|
| 993 |
+
except InterruptedError:
|
| 994 |
+
pass
|
| 995 |
+
|
| 996 |
+
except Exception as e:
|
| 997 |
+
logger.warning("Caught exception in filling worker!", exc_info=e)
|
| 998 |
+
# In case of filling worker encountering an exception we have to propagate
|
| 999 |
+
# it back to the (main) iterating thread. To achieve that we're traversing
|
| 1000 |
+
# output queues *backwards* relative to the order of iterator-thread such
|
| 1001 |
+
# that they are more likely to meet w/in a single iteration.
|
| 1002 |
+
for output_queue in reversed(output_queues):
|
| 1003 |
+
output_queue.put(e)
|
| 1004 |
+
|
| 1005 |
+
# Transforming worker
|
| 1006 |
+
def _run_transforming_worker(worker_id: int):
|
| 1007 |
+
input_queue = input_queues[worker_id]
|
| 1008 |
+
output_queue = output_queues[worker_id]
|
| 1009 |
+
|
| 1010 |
+
try:
|
| 1011 |
+
# Create iterator draining the queue, until it receives sentinel
|
| 1012 |
+
#
|
| 1013 |
+
# NOTE: `queue.get` is blocking!
|
| 1014 |
+
input_queue_iter = iter(input_queue.get, SENTINEL)
|
| 1015 |
+
|
| 1016 |
+
mapped_iter = fn(input_queue_iter)
|
| 1017 |
+
for result in mapped_iter:
|
| 1018 |
+
# Enqueue result of the transformation
|
| 1019 |
+
output_queue.put(result)
|
| 1020 |
+
|
| 1021 |
+
# Enqueue sentinel (to signal that transformations are completed)
|
| 1022 |
+
output_queue.put(SENTINEL)
|
| 1023 |
+
|
| 1024 |
+
except InterruptedError:
|
| 1025 |
+
pass
|
| 1026 |
+
|
| 1027 |
+
except Exception as e:
|
| 1028 |
+
logger.warning("Caught exception in transforming worker!", exc_info=e)
|
| 1029 |
+
# NOTE: In this case we simply enqueue the exception rather than
|
| 1030 |
+
# interrupting
|
| 1031 |
+
output_queue.put(e)
|
| 1032 |
+
|
| 1033 |
+
# Start workers threads
|
| 1034 |
+
filling_worker_thread = threading.Thread(
|
| 1035 |
+
target=_run_filling_worker,
|
| 1036 |
+
name=f"map_tp_filling_worker-{gen_id}",
|
| 1037 |
+
daemon=True,
|
| 1038 |
+
)
|
| 1039 |
+
filling_worker_thread.start()
|
| 1040 |
+
|
| 1041 |
+
transforming_worker_threads = [
|
| 1042 |
+
threading.Thread(
|
| 1043 |
+
target=_run_transforming_worker,
|
| 1044 |
+
name=f"map_tp_transforming_worker-{gen_id}-{worker_idx}",
|
| 1045 |
+
args=(worker_idx,),
|
| 1046 |
+
daemon=True,
|
| 1047 |
+
)
|
| 1048 |
+
for worker_idx in range(num_workers)
|
| 1049 |
+
]
|
| 1050 |
+
|
| 1051 |
+
for t in transforming_worker_threads:
|
| 1052 |
+
t.start()
|
| 1053 |
+
|
| 1054 |
+
# Use main thread to yield output batches
|
| 1055 |
+
try:
|
| 1056 |
+
# Keep track of remaining non-empty output queues
|
| 1057 |
+
remaining_output_queues = output_queues
|
| 1058 |
+
|
| 1059 |
+
while len(remaining_output_queues) > 0:
|
| 1060 |
+
# To provide deterministic ordering of the produced iterator we rely
|
| 1061 |
+
# on the following invariants:
|
| 1062 |
+
#
|
| 1063 |
+
# - Elements from the original iterator are round-robin'd into
|
| 1064 |
+
# input queues (in order)
|
| 1065 |
+
# - Individual workers drain their respective input queues populating
|
| 1066 |
+
# output queues with the results of applying transformation to the
|
| 1067 |
+
# original item (and hence preserving original ordering of the input
|
| 1068 |
+
# queue)
|
| 1069 |
+
# - To yield from the generator output queues are traversed in the same
|
| 1070 |
+
# order and one single element is dequeued (in a blocking way!) at a
|
| 1071 |
+
# time from every individual output queue
|
| 1072 |
+
#
|
| 1073 |
+
non_empty_queues = []
|
| 1074 |
+
empty_queues = []
|
| 1075 |
+
|
| 1076 |
+
# At every iteration only remaining non-empty queues
|
| 1077 |
+
# are traversed (to prevent blocking on exhausted queue)
|
| 1078 |
+
for output_queue in remaining_output_queues:
|
| 1079 |
+
# NOTE: This is blocking!
|
| 1080 |
+
item = output_queue.get()
|
| 1081 |
+
|
| 1082 |
+
if isinstance(item, Exception):
|
| 1083 |
+
raise item
|
| 1084 |
+
|
| 1085 |
+
if item is SENTINEL:
|
| 1086 |
+
empty_queues.append(output_queue)
|
| 1087 |
+
else:
|
| 1088 |
+
non_empty_queues.append(output_queue)
|
| 1089 |
+
yield item
|
| 1090 |
+
|
| 1091 |
+
assert (
|
| 1092 |
+
non_empty_queues + empty_queues == remaining_output_queues
|
| 1093 |
+
), "Exhausted non-trailing queue!"
|
| 1094 |
+
|
| 1095 |
+
remaining_output_queues = non_empty_queues
|
| 1096 |
+
|
| 1097 |
+
finally:
|
| 1098 |
+
# Set flag to interrupt workers (to make sure no dangling
|
| 1099 |
+
# threads holding the objects are left behind)
|
| 1100 |
+
#
|
| 1101 |
+
# NOTE: Interrupted event is set to interrupt the running threads
|
| 1102 |
+
# that might be blocked otherwise waiting on inputs from respective
|
| 1103 |
+
# queues. However, even though we're interrupting the threads we can't
|
| 1104 |
+
# guarantee that threads will be interrupted in time (as this is
|
| 1105 |
+
# dependent on Python's GC finalizer to close the generator by raising
|
| 1106 |
+
# `GeneratorExit`) and hence we can't join on either filling or
|
| 1107 |
+
# transforming workers.
|
| 1108 |
+
interrupted_event.set()
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
def call_with_retry(
|
| 1112 |
+
f: Callable[[], Any],
|
| 1113 |
+
description: str,
|
| 1114 |
+
*,
|
| 1115 |
+
match: Optional[List[str]] = None,
|
| 1116 |
+
max_attempts: int = 10,
|
| 1117 |
+
max_backoff_s: int = 32,
|
| 1118 |
+
) -> Any:
|
| 1119 |
+
"""Retry a function with exponential backoff.
|
| 1120 |
+
|
| 1121 |
+
Args:
|
| 1122 |
+
f: The function to retry.
|
| 1123 |
+
match: A list of strings to match in the exception message. If ``None``, any
|
| 1124 |
+
error is retried.
|
| 1125 |
+
description: An imperitive description of the function being retried. For
|
| 1126 |
+
example, "open the file".
|
| 1127 |
+
max_attempts: The maximum number of attempts to retry.
|
| 1128 |
+
max_backoff_s: The maximum number of seconds to backoff.
|
| 1129 |
+
"""
|
| 1130 |
+
assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
|
| 1131 |
+
|
| 1132 |
+
for i in range(max_attempts):
|
| 1133 |
+
try:
|
| 1134 |
+
return f()
|
| 1135 |
+
except Exception as e:
|
| 1136 |
+
is_retryable = match is None or any(
|
| 1137 |
+
[pattern in str(e) for pattern in match]
|
| 1138 |
+
)
|
| 1139 |
+
if is_retryable and i + 1 < max_attempts:
|
| 1140 |
+
# Retry with binary expoential backoff with random jitter.
|
| 1141 |
+
backoff = min((2 ** (i + 1)), max_backoff_s) * random.random()
|
| 1142 |
+
logger.debug(
|
| 1143 |
+
f"Retrying {i+1} attempts to {description} after {backoff} seconds."
|
| 1144 |
+
)
|
| 1145 |
+
time.sleep(backoff)
|
| 1146 |
+
else:
|
| 1147 |
+
raise e from None
|
| 1148 |
+
|
| 1149 |
+
|
| 1150 |
+
def iterate_with_retry(
|
| 1151 |
+
iterable_factory: Callable[[], Iterable],
|
| 1152 |
+
description: str,
|
| 1153 |
+
*,
|
| 1154 |
+
match: Optional[List[str]] = None,
|
| 1155 |
+
max_attempts: int = 10,
|
| 1156 |
+
max_backoff_s: int = 32,
|
| 1157 |
+
) -> Any:
|
| 1158 |
+
"""Iterate through an iterable with retries.
|
| 1159 |
+
|
| 1160 |
+
If the iterable raises an exception, this function recreates and re-iterates
|
| 1161 |
+
through the iterable, while skipping the items that have already been yielded.
|
| 1162 |
+
|
| 1163 |
+
Args:
|
| 1164 |
+
iterable_factory: A no-argument function that creates the iterable.
|
| 1165 |
+
match: A list of strings to match in the exception message. If ``None``, any
|
| 1166 |
+
error is retried.
|
| 1167 |
+
description: An imperitive description of the function being retried. For
|
| 1168 |
+
example, "open the file".
|
| 1169 |
+
max_attempts: The maximum number of attempts to retry.
|
| 1170 |
+
max_backoff_s: The maximum number of seconds to backoff.
|
| 1171 |
+
"""
|
| 1172 |
+
assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
|
| 1173 |
+
|
| 1174 |
+
num_items_yielded = 0
|
| 1175 |
+
for attempt in range(max_attempts):
|
| 1176 |
+
try:
|
| 1177 |
+
iterable = iterable_factory()
|
| 1178 |
+
for item_index, item in enumerate(iterable):
|
| 1179 |
+
if item_index < num_items_yielded:
|
| 1180 |
+
# Skip items that have already been yielded.
|
| 1181 |
+
continue
|
| 1182 |
+
|
| 1183 |
+
num_items_yielded += 1
|
| 1184 |
+
yield item
|
| 1185 |
+
return
|
| 1186 |
+
except Exception as e:
|
| 1187 |
+
is_retryable = match is None or any(
|
| 1188 |
+
[pattern in str(e) for pattern in match]
|
| 1189 |
+
)
|
| 1190 |
+
if is_retryable and attempt + 1 < max_attempts:
|
| 1191 |
+
# Retry with binary expoential backoff with random jitter.
|
| 1192 |
+
backoff = min((2 ** (attempt + 1)), max_backoff_s) * random.random()
|
| 1193 |
+
logger.debug(
|
| 1194 |
+
f"Retrying {attempt+1} attempts to {description} "
|
| 1195 |
+
f"after {backoff} seconds."
|
| 1196 |
+
)
|
| 1197 |
+
time.sleep(backoff)
|
| 1198 |
+
else:
|
| 1199 |
+
raise e from None
|
| 1200 |
+
|
| 1201 |
+
|
| 1202 |
+
def create_dataset_tag(dataset_name: Optional[str], *args):
|
| 1203 |
+
tag = dataset_name or "dataset"
|
| 1204 |
+
for arg in args:
|
| 1205 |
+
tag += f"_{arg}"
|
| 1206 |
+
return tag
|
| 1207 |
+
|
| 1208 |
+
|
| 1209 |
+
def convert_bytes_to_human_readable_str(num_bytes: int) -> str:
|
| 1210 |
+
if num_bytes >= 1e9:
|
| 1211 |
+
num_bytes_str = f"{round(num_bytes / 1e9)}GB"
|
| 1212 |
+
elif num_bytes >= 1e6:
|
| 1213 |
+
num_bytes_str = f"{round(num_bytes / 1e6)}MB"
|
| 1214 |
+
else:
|
| 1215 |
+
num_bytes_str = f"{round(num_bytes / 1e3)}KB"
|
| 1216 |
+
return num_bytes_str
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
def _validate_rows_per_file_args(
|
| 1220 |
+
*, num_rows_per_file: Optional[int] = None, min_rows_per_file: Optional[int] = None
|
| 1221 |
+
) -> Optional[int]:
|
| 1222 |
+
"""Helper method to validate and handle rows per file arguments.
|
| 1223 |
+
|
| 1224 |
+
Args:
|
| 1225 |
+
num_rows_per_file: Deprecated parameter for number of rows per file
|
| 1226 |
+
min_rows_per_file: New parameter for minimum rows per file
|
| 1227 |
+
|
| 1228 |
+
Returns:
|
| 1229 |
+
The effective min_rows_per_file value to use
|
| 1230 |
+
"""
|
| 1231 |
+
if num_rows_per_file is not None:
|
| 1232 |
+
import warnings
|
| 1233 |
+
|
| 1234 |
+
warnings.warn(
|
| 1235 |
+
"`num_rows_per_file` is deprecated and will be removed in a future release. "
|
| 1236 |
+
"Use `min_rows_per_file` instead.",
|
| 1237 |
+
DeprecationWarning,
|
| 1238 |
+
stacklevel=3,
|
| 1239 |
+
)
|
| 1240 |
+
if min_rows_per_file is not None:
|
| 1241 |
+
raise ValueError(
|
| 1242 |
+
"Cannot specify both `num_rows_per_file` and `min_rows_per_file`. "
|
| 1243 |
+
"Use `min_rows_per_file` as `num_rows_per_file` is deprecated."
|
| 1244 |
+
)
|
| 1245 |
+
return num_rows_per_file
|
| 1246 |
+
return min_rows_per_file
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
+
def is_nan(value):
|
| 1250 |
+
try:
|
| 1251 |
+
return isinstance(value, float) and np.isnan(value)
|
| 1252 |
+
except TypeError:
|
| 1253 |
+
return False
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
def keys_equal(keys1, keys2):
|
| 1257 |
+
if len(keys1) != len(keys2):
|
| 1258 |
+
return False
|
| 1259 |
+
for k1, k2 in zip(keys1, keys2):
|
| 1260 |
+
if not ((is_nan(k1) and is_nan(k2)) or k1 == k2):
|
| 1261 |
+
return False
|
| 1262 |
+
return True
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.data._internal.datasource.sql_datasource import Connection
|
| 2 |
+
from ray.data.datasource.datasink import (
|
| 3 |
+
Datasink,
|
| 4 |
+
DummyOutputDatasink,
|
| 5 |
+
WriteResult,
|
| 6 |
+
WriteReturnType,
|
| 7 |
+
)
|
| 8 |
+
from ray.data.datasource.datasource import (
|
| 9 |
+
Datasource,
|
| 10 |
+
RandomIntRowDatasource,
|
| 11 |
+
Reader,
|
| 12 |
+
ReadTask,
|
| 13 |
+
)
|
| 14 |
+
from ray.data.datasource.file_based_datasource import (
|
| 15 |
+
FileBasedDatasource,
|
| 16 |
+
FileShuffleConfig,
|
| 17 |
+
_S3FileSystemWrapper,
|
| 18 |
+
)
|
| 19 |
+
from ray.data.datasource.file_datasink import (
|
| 20 |
+
BlockBasedFileDatasink,
|
| 21 |
+
RowBasedFileDatasink,
|
| 22 |
+
)
|
| 23 |
+
from ray.data.datasource.file_meta_provider import (
|
| 24 |
+
BaseFileMetadataProvider,
|
| 25 |
+
DefaultFileMetadataProvider,
|
| 26 |
+
FastFileMetadataProvider,
|
| 27 |
+
FileMetadataProvider,
|
| 28 |
+
)
|
| 29 |
+
from ray.data.datasource.filename_provider import FilenameProvider
|
| 30 |
+
from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider
|
| 31 |
+
from ray.data.datasource.partitioning import (
|
| 32 |
+
Partitioning,
|
| 33 |
+
PartitionStyle,
|
| 34 |
+
PathPartitionFilter,
|
| 35 |
+
PathPartitionParser,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Note: HuggingFaceDatasource should NOT be imported here, because
|
| 39 |
+
# we want to only import the Hugging Face datasets library when we use
|
| 40 |
+
# ray.data.from_huggingface() or HuggingFaceDatasource() directly.
|
| 41 |
+
__all__ = [
|
| 42 |
+
"BaseFileMetadataProvider",
|
| 43 |
+
"BlockBasedFileDatasink",
|
| 44 |
+
"Connection",
|
| 45 |
+
"Datasink",
|
| 46 |
+
"Datasource",
|
| 47 |
+
"DeltaSharingDatasource",
|
| 48 |
+
"DefaultFileMetadataProvider",
|
| 49 |
+
"DummyOutputDatasink",
|
| 50 |
+
"FastFileMetadataProvider",
|
| 51 |
+
"FileBasedDatasource",
|
| 52 |
+
"FileShuffleConfig",
|
| 53 |
+
"FileMetadataProvider",
|
| 54 |
+
"FilenameProvider",
|
| 55 |
+
"ParquetMetadataProvider",
|
| 56 |
+
"PartitionStyle",
|
| 57 |
+
"PathPartitionFilter",
|
| 58 |
+
"PathPartitionParser",
|
| 59 |
+
"Partitioning",
|
| 60 |
+
"RandomIntRowDatasource",
|
| 61 |
+
"ReadTask",
|
| 62 |
+
"Reader",
|
| 63 |
+
"RowBasedFileDatasink",
|
| 64 |
+
"_S3FileSystemWrapper",
|
| 65 |
+
"WriteResult",
|
| 66 |
+
"WriteReturnType",
|
| 67 |
+
]
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasink.cpython-311.pyc
ADDED
|
Binary file (8.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasource.cpython-311.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_based_datasource.cpython-311.pyc
ADDED
|
Binary file (26.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_datasink.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_meta_provider.cpython-311.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/filename_provider.cpython-311.pyc
ADDED
|
Binary file (6.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/parquet_meta_provider.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/partitioning.cpython-311.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/path_util.cpython-311.pyc
ADDED
|
Binary file (9.25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/datasource/file_datasink.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import posixpath
|
| 3 |
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
| 4 |
+
from urllib.parse import urlparse
|
| 5 |
+
|
| 6 |
+
from ray._private.utils import _add_creatable_buckets_param_if_s3_uri
|
| 7 |
+
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
| 8 |
+
from ray.data._internal.execution.interfaces import TaskContext
|
| 9 |
+
from ray.data._internal.util import _is_local_scheme, call_with_retry
|
| 10 |
+
from ray.data.block import Block, BlockAccessor
|
| 11 |
+
from ray.data.context import DataContext
|
| 12 |
+
from ray.data.datasource.datasink import Datasink, WriteResult
|
| 13 |
+
from ray.data.datasource.filename_provider import (
|
| 14 |
+
FilenameProvider,
|
| 15 |
+
_DefaultFilenameProvider,
|
| 16 |
+
)
|
| 17 |
+
from ray.data.datasource.path_util import _resolve_paths_and_filesystem
|
| 18 |
+
from ray.util.annotations import DeveloperAPI
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
import pyarrow
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
WRITE_FILE_MAX_ATTEMPTS = 10
|
| 27 |
+
WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class _FileDatasink(Datasink[None]):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
path: str,
|
| 34 |
+
*,
|
| 35 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 36 |
+
try_create_dir: bool = True,
|
| 37 |
+
open_stream_args: Optional[Dict[str, Any]] = None,
|
| 38 |
+
filename_provider: Optional[FilenameProvider] = None,
|
| 39 |
+
dataset_uuid: Optional[str] = None,
|
| 40 |
+
file_format: Optional[str] = None,
|
| 41 |
+
):
|
| 42 |
+
"""Initialize this datasink.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
path: The folder to write files to.
|
| 46 |
+
filesystem: The filesystem to write files to. If not provided, the
|
| 47 |
+
filesystem is inferred from the path.
|
| 48 |
+
try_create_dir: Whether to create the directory to write files to.
|
| 49 |
+
open_stream_args: Arguments to pass to ``filesystem.open_output_stream``.
|
| 50 |
+
filename_provider: A :class:`ray.data.datasource.FilenameProvider` that
|
| 51 |
+
generates filenames for each row or block.
|
| 52 |
+
dataset_uuid: The UUID of the dataset being written. If specified, it's
|
| 53 |
+
included in the filename.
|
| 54 |
+
file_format: The file extension. If specified, files are written with this
|
| 55 |
+
extension.
|
| 56 |
+
"""
|
| 57 |
+
if open_stream_args is None:
|
| 58 |
+
open_stream_args = {}
|
| 59 |
+
|
| 60 |
+
if filename_provider is None:
|
| 61 |
+
filename_provider = _DefaultFilenameProvider(
|
| 62 |
+
dataset_uuid=dataset_uuid, file_format=file_format
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.unresolved_path = path
|
| 66 |
+
paths, self.filesystem = _resolve_paths_and_filesystem(path, filesystem)
|
| 67 |
+
assert len(paths) == 1, len(paths)
|
| 68 |
+
self.path = paths[0]
|
| 69 |
+
|
| 70 |
+
self.try_create_dir = try_create_dir
|
| 71 |
+
self.open_stream_args = open_stream_args
|
| 72 |
+
self.filename_provider = filename_provider
|
| 73 |
+
self.dataset_uuid = dataset_uuid
|
| 74 |
+
self.file_format = file_format
|
| 75 |
+
|
| 76 |
+
self.has_created_dir = False
|
| 77 |
+
|
| 78 |
+
def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
|
| 79 |
+
return self.filesystem.open_output_stream(path, **self.open_stream_args)
|
| 80 |
+
|
| 81 |
+
def on_write_start(self) -> None:
|
| 82 |
+
self.has_created_dir = self._create_dir(self.path)
|
| 83 |
+
|
| 84 |
+
def _create_dir(self, dest) -> bool:
|
| 85 |
+
"""Create a directory to write files to.
|
| 86 |
+
|
| 87 |
+
If ``try_create_dir`` is ``False``, this method is a no-op.
|
| 88 |
+
"""
|
| 89 |
+
from pyarrow.fs import FileType
|
| 90 |
+
|
| 91 |
+
# We should skip creating directories in s3 unless the user specifically
|
| 92 |
+
# overrides this behavior. PyArrow's s3fs implementation for create_dir
|
| 93 |
+
# will attempt to check if the parent directory exists before trying to
|
| 94 |
+
# create the directory (with recursive=True it will try to do this to
|
| 95 |
+
# all of the directories until the root of the bucket). An IAM Policy that
|
| 96 |
+
# restricts access to a subset of prefixes within the bucket might cause
|
| 97 |
+
# the creation of the directory to fail even if the permissions should
|
| 98 |
+
# allow the data can be written to the specified path. For example if a
|
| 99 |
+
# a policy only allows users to write blobs prefixed with s3://bucket/foo
|
| 100 |
+
# a call to create_dir for s3://bucket/foo/bar will fail even though it
|
| 101 |
+
# should not.
|
| 102 |
+
parsed_uri = urlparse(dest)
|
| 103 |
+
is_s3_uri = parsed_uri.scheme == "s3"
|
| 104 |
+
skip_create_dir_for_s3 = (
|
| 105 |
+
is_s3_uri and not DataContext.get_current().s3_try_create_dir
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if self.try_create_dir and not skip_create_dir_for_s3:
|
| 109 |
+
if self.filesystem.get_file_info(dest).type is FileType.NotFound:
|
| 110 |
+
# Arrow's S3FileSystem doesn't allow creating buckets by default, so we
|
| 111 |
+
# add a query arg enabling bucket creation if an S3 URI is provided.
|
| 112 |
+
tmp = _add_creatable_buckets_param_if_s3_uri(dest)
|
| 113 |
+
self.filesystem.create_dir(tmp, recursive=True)
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
def write(
|
| 119 |
+
self,
|
| 120 |
+
blocks: Iterable[Block],
|
| 121 |
+
ctx: TaskContext,
|
| 122 |
+
) -> None:
|
| 123 |
+
builder = DelegatingBlockBuilder()
|
| 124 |
+
for block in blocks:
|
| 125 |
+
builder.add_block(block)
|
| 126 |
+
block = builder.build()
|
| 127 |
+
block_accessor = BlockAccessor.for_block(block)
|
| 128 |
+
|
| 129 |
+
if block_accessor.num_rows() == 0:
|
| 130 |
+
logger.warning(f"Skipped writing empty block to {self.path}")
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
self.write_block(block_accessor, 0, ctx)
|
| 134 |
+
|
| 135 |
+
def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
|
| 136 |
+
raise NotImplementedError
|
| 137 |
+
|
| 138 |
+
def on_write_complete(self, write_result: WriteResult[None]):
|
| 139 |
+
# If no rows were written, we can delete the directory.
|
| 140 |
+
if self.has_created_dir and write_result.num_rows == 0:
|
| 141 |
+
self.filesystem.delete_dir(self.path)
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def supports_distributed_writes(self) -> bool:
|
| 145 |
+
return not _is_local_scheme(self.unresolved_path)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@DeveloperAPI
|
| 149 |
+
class RowBasedFileDatasink(_FileDatasink):
|
| 150 |
+
"""A datasink that writes one row to each file.
|
| 151 |
+
|
| 152 |
+
Subclasses must implement ``write_row_to_file`` and call the superclass constructor.
|
| 153 |
+
|
| 154 |
+
Examples:
|
| 155 |
+
.. testcode::
|
| 156 |
+
|
| 157 |
+
import io
|
| 158 |
+
from typing import Any, Dict
|
| 159 |
+
|
| 160 |
+
import pyarrow
|
| 161 |
+
from PIL import Image
|
| 162 |
+
|
| 163 |
+
from ray.data.datasource import RowBasedFileDatasink
|
| 164 |
+
|
| 165 |
+
class ImageDatasink(RowBasedFileDatasink):
|
| 166 |
+
def __init__(self, path: str, *, column: str, file_format: str = "png"):
|
| 167 |
+
super().__init__(path, file_format=file_format)
|
| 168 |
+
self._file_format = file_format
|
| 169 |
+
self._column = column
|
| 170 |
+
|
| 171 |
+
def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
|
| 172 |
+
image = Image.fromarray(row[self._column])
|
| 173 |
+
buffer = io.BytesIO()
|
| 174 |
+
image.save(buffer, format=self._file_format)
|
| 175 |
+
file.write(buffer.getvalue())
|
| 176 |
+
""" # noqa: E501
|
| 177 |
+
|
| 178 |
+
def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
|
| 179 |
+
"""Write a row to a file.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
row: The row to write.
|
| 183 |
+
file: The file to write the row to.
|
| 184 |
+
"""
|
| 185 |
+
raise NotImplementedError
|
| 186 |
+
|
| 187 |
+
def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
|
| 188 |
+
for row_index, row in enumerate(block.iter_rows(public_row_format=False)):
|
| 189 |
+
filename = self.filename_provider.get_filename_for_row(
|
| 190 |
+
row, ctx.task_idx, block_index, row_index
|
| 191 |
+
)
|
| 192 |
+
write_path = posixpath.join(self.path, filename)
|
| 193 |
+
|
| 194 |
+
def write_row_to_path(row, write_path):
|
| 195 |
+
with self.open_output_stream(write_path) as file:
|
| 196 |
+
self.write_row_to_file(row, file)
|
| 197 |
+
|
| 198 |
+
logger.debug(f"Writing {write_path} file.")
|
| 199 |
+
call_with_retry(
|
| 200 |
+
lambda row=row, write_path=write_path: write_row_to_path(
|
| 201 |
+
row, write_path
|
| 202 |
+
),
|
| 203 |
+
description=f"write '{write_path}'",
|
| 204 |
+
match=DataContext.get_current().retried_io_errors,
|
| 205 |
+
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
|
| 206 |
+
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@DeveloperAPI
|
| 211 |
+
class BlockBasedFileDatasink(_FileDatasink):
|
| 212 |
+
"""A datasink that writes multiple rows to each file.
|
| 213 |
+
|
| 214 |
+
Subclasses must implement ``write_block_to_file`` and call the superclass
|
| 215 |
+
constructor.
|
| 216 |
+
|
| 217 |
+
Examples:
|
| 218 |
+
.. testcode::
|
| 219 |
+
|
| 220 |
+
class CSVDatasink(BlockBasedFileDatasink):
|
| 221 |
+
def __init__(self, path: str):
|
| 222 |
+
super().__init__(path, file_format="csv")
|
| 223 |
+
|
| 224 |
+
def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
|
| 225 |
+
from pyarrow import csv
|
| 226 |
+
csv.write_csv(block.to_arrow(), file)
|
| 227 |
+
""" # noqa: E501
|
| 228 |
+
|
| 229 |
+
def __init__(
|
| 230 |
+
self, path, *, min_rows_per_file: Optional[int] = None, **file_datasink_kwargs
|
| 231 |
+
):
|
| 232 |
+
super().__init__(path, **file_datasink_kwargs)
|
| 233 |
+
|
| 234 |
+
self._min_rows_per_file = min_rows_per_file
|
| 235 |
+
|
| 236 |
+
def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
|
| 237 |
+
"""Write a block of data to a file.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
block: The block to write.
|
| 241 |
+
file: The file to write the block to.
|
| 242 |
+
"""
|
| 243 |
+
raise NotImplementedError
|
| 244 |
+
|
| 245 |
+
def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
|
| 246 |
+
filename = self.filename_provider.get_filename_for_block(
|
| 247 |
+
block, ctx.task_idx, block_index
|
| 248 |
+
)
|
| 249 |
+
write_path = posixpath.join(self.path, filename)
|
| 250 |
+
|
| 251 |
+
def write_block_to_path():
|
| 252 |
+
with self.open_output_stream(write_path) as file:
|
| 253 |
+
self.write_block_to_file(block, file)
|
| 254 |
+
|
| 255 |
+
logger.debug(f"Writing {write_path} file.")
|
| 256 |
+
call_with_retry(
|
| 257 |
+
write_block_to_path,
|
| 258 |
+
description=f"write '{write_path}'",
|
| 259 |
+
match=DataContext.get_current().retried_io_errors,
|
| 260 |
+
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
|
| 261 |
+
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
@property
|
| 265 |
+
def min_rows_per_write(self) -> Optional[int]:
|
| 266 |
+
return self._min_rows_per_file
|
.venv/lib/python3.11/site-packages/ray/data/datasource/partitioning.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import posixpath
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
| 5 |
+
|
| 6 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
import pyarrow
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
PartitionDataType = Type[Union[int, float, str, bool]]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@DeveloperAPI
|
| 16 |
+
class PartitionStyle(str, Enum):
|
| 17 |
+
"""Supported dataset partition styles.
|
| 18 |
+
|
| 19 |
+
Inherits from `str` to simplify plain text serialization/deserialization.
|
| 20 |
+
|
| 21 |
+
Examples:
|
| 22 |
+
>>> # Serialize to JSON text.
|
| 23 |
+
>>> json.dumps(PartitionStyle.HIVE) # doctest: +SKIP
|
| 24 |
+
'"hive"'
|
| 25 |
+
|
| 26 |
+
>>> # Deserialize from JSON text.
|
| 27 |
+
>>> PartitionStyle(json.loads('"hive"')) # doctest: +SKIP
|
| 28 |
+
<PartitionStyle.HIVE: 'hive'>
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
HIVE = "hive"
|
| 32 |
+
DIRECTORY = "dir"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@DeveloperAPI
|
| 36 |
+
@dataclass
|
| 37 |
+
class Partitioning:
|
| 38 |
+
"""Partition scheme used to describe path-based partitions.
|
| 39 |
+
|
| 40 |
+
Path-based partition formats embed all partition keys and values directly in
|
| 41 |
+
their dataset file paths.
|
| 42 |
+
|
| 43 |
+
For example, to read a dataset with
|
| 44 |
+
`Hive-style partitions <https://athena.guide/articles/hive-style-partitioning>`_:
|
| 45 |
+
|
| 46 |
+
>>> import ray
|
| 47 |
+
>>> from ray.data.datasource.partitioning import Partitioning
|
| 48 |
+
>>> ds = ray.data.read_csv(
|
| 49 |
+
... "s3://anonymous@ray-example-data/iris.csv",
|
| 50 |
+
... partitioning=Partitioning("hive"),
|
| 51 |
+
... )
|
| 52 |
+
|
| 53 |
+
Instead, if your files are arranged in a directory structure such as:
|
| 54 |
+
|
| 55 |
+
.. code::
|
| 56 |
+
|
| 57 |
+
root/dog/dog_0.jpeg
|
| 58 |
+
root/dog/dog_1.jpeg
|
| 59 |
+
...
|
| 60 |
+
|
| 61 |
+
root/cat/cat_0.jpeg
|
| 62 |
+
root/cat/cat_1.jpeg
|
| 63 |
+
...
|
| 64 |
+
|
| 65 |
+
Then you can use directory-based partitioning:
|
| 66 |
+
|
| 67 |
+
>>> import ray
|
| 68 |
+
>>> from ray.data.datasource.partitioning import Partitioning
|
| 69 |
+
>>> root = "s3://anonymous@air-example-data/cifar-10/images"
|
| 70 |
+
>>> partitioning = Partitioning("dir", field_names=["class"], base_dir=root)
|
| 71 |
+
>>> ds = ray.data.read_images(root, partitioning=partitioning)
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
#: The partition style - may be either HIVE or DIRECTORY.
|
| 75 |
+
style: PartitionStyle
|
| 76 |
+
#: "/"-delimited base directory that all partitioned paths should
|
| 77 |
+
#: exist under (exclusive). File paths either outside of, or at the first
|
| 78 |
+
#: level of, this directory will be considered unpartitioned. Specify
|
| 79 |
+
#: `None` or an empty string to search for partitions in all file path
|
| 80 |
+
#: directories.
|
| 81 |
+
base_dir: Optional[str] = None
|
| 82 |
+
#: The partition key field names (i.e. column names for tabular
|
| 83 |
+
#: datasets). When non-empty, the order and length of partition key
|
| 84 |
+
#: field names must match the order and length of partition values.
|
| 85 |
+
#: Required when parsing DIRECTORY partitioned paths or generating
|
| 86 |
+
#: HIVE partitioned paths.
|
| 87 |
+
field_names: Optional[List[str]] = None
|
| 88 |
+
#: A dictionary that maps partition key names to their desired data type. If not
|
| 89 |
+
#: provided, the data type defaults to string.
|
| 90 |
+
field_types: Optional[Dict[str, PartitionDataType]] = None
|
| 91 |
+
#: Filesystem that will be used for partition path file I/O.
|
| 92 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None
|
| 93 |
+
|
| 94 |
+
def __post_init__(self):
|
| 95 |
+
if self.base_dir is None:
|
| 96 |
+
self.base_dir = ""
|
| 97 |
+
|
| 98 |
+
if self.field_types is None:
|
| 99 |
+
self.field_types = {}
|
| 100 |
+
|
| 101 |
+
self._normalized_base_dir = None
|
| 102 |
+
self._resolved_filesystem = None
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def normalized_base_dir(self) -> str:
|
| 106 |
+
"""Returns the base directory normalized for compatibility with a filesystem."""
|
| 107 |
+
if self._normalized_base_dir is None:
|
| 108 |
+
self._normalize_base_dir()
|
| 109 |
+
return self._normalized_base_dir
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def resolved_filesystem(self) -> "pyarrow.fs.FileSystem":
|
| 113 |
+
"""Returns the filesystem resolved for compatibility with a base directory."""
|
| 114 |
+
if self._resolved_filesystem is None:
|
| 115 |
+
self._normalize_base_dir()
|
| 116 |
+
return self._resolved_filesystem
|
| 117 |
+
|
| 118 |
+
def _normalize_base_dir(self):
|
| 119 |
+
"""Normalizes the partition base directory for compatibility with the
|
| 120 |
+
given filesystem.
|
| 121 |
+
|
| 122 |
+
This should be called once a filesystem has been resolved to ensure that this
|
| 123 |
+
base directory is correctly discovered at the root of all partitioned file
|
| 124 |
+
paths.
|
| 125 |
+
"""
|
| 126 |
+
from ray.data.datasource.path_util import _resolve_paths_and_filesystem
|
| 127 |
+
|
| 128 |
+
paths, self._resolved_filesystem = _resolve_paths_and_filesystem(
|
| 129 |
+
self.base_dir,
|
| 130 |
+
self.filesystem,
|
| 131 |
+
)
|
| 132 |
+
assert (
|
| 133 |
+
len(paths) == 1
|
| 134 |
+
), f"Expected 1 normalized base directory, but found {len(paths)}"
|
| 135 |
+
normalized_base_dir = paths[0]
|
| 136 |
+
if len(normalized_base_dir) and not normalized_base_dir.endswith("/"):
|
| 137 |
+
normalized_base_dir += "/"
|
| 138 |
+
self._normalized_base_dir = normalized_base_dir
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@DeveloperAPI
|
| 142 |
+
class PathPartitionParser:
|
| 143 |
+
"""Partition parser for path-based partition formats.
|
| 144 |
+
|
| 145 |
+
Path-based partition formats embed all partition keys and values directly in
|
| 146 |
+
their dataset file paths.
|
| 147 |
+
|
| 148 |
+
Two path partition formats are currently supported - `HIVE` and `DIRECTORY`.
|
| 149 |
+
|
| 150 |
+
For `HIVE` Partitioning, all partition directories under the base directory
|
| 151 |
+
will be discovered based on `{key1}={value1}/{key2}={value2}` naming
|
| 152 |
+
conventions. Key/value pairs do not need to be presented in the same
|
| 153 |
+
order across all paths. Directory names nested under the base directory that
|
| 154 |
+
don't follow this naming condition will be considered unpartitioned. If a
|
| 155 |
+
partition filter is defined, then it will be called with an empty input
|
| 156 |
+
dictionary for each unpartitioned file.
|
| 157 |
+
|
| 158 |
+
For `DIRECTORY` Partitioning, all directories under the base directory will
|
| 159 |
+
be interpreted as partition values of the form `{value1}/{value2}`. An
|
| 160 |
+
accompanying ordered list of partition field names must also be provided,
|
| 161 |
+
where the order and length of all partition values must match the order and
|
| 162 |
+
length of field names. Files stored directly in the base directory will
|
| 163 |
+
be considered unpartitioned. If a partition filter is defined, then it will
|
| 164 |
+
be called with an empty input dictionary for each unpartitioned file. For
|
| 165 |
+
example, if the base directory is `"foo"`, then `"foo.csv"` and `"foo/bar.csv"`
|
| 166 |
+
would be considered unpartitioned files but `"foo/bar/baz.csv"` would be associated
|
| 167 |
+
with partition `"bar"`. If the base directory is undefined, then `"foo.csv"` would
|
| 168 |
+
be unpartitioned, `"foo/bar.csv"` would be associated with partition `"foo"`, and
|
| 169 |
+
"foo/bar/baz.csv" would be associated with partition `("foo", "bar")`.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def of(
|
| 174 |
+
style: PartitionStyle = PartitionStyle.HIVE,
|
| 175 |
+
base_dir: Optional[str] = None,
|
| 176 |
+
field_names: Optional[List[str]] = None,
|
| 177 |
+
field_types: Optional[Dict[str, PartitionDataType]] = None,
|
| 178 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 179 |
+
) -> "PathPartitionParser":
|
| 180 |
+
"""Creates a path-based partition parser using a flattened argument list.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
style: The partition style - may be either HIVE or DIRECTORY.
|
| 184 |
+
base_dir: "/"-delimited base directory to start searching for partitions
|
| 185 |
+
(exclusive). File paths outside of this directory will be considered
|
| 186 |
+
unpartitioned. Specify `None` or an empty string to search for
|
| 187 |
+
partitions in all file path directories.
|
| 188 |
+
field_names: The partition key names. Required for DIRECTORY partitioning.
|
| 189 |
+
Optional for HIVE partitioning. When non-empty, the order and length of
|
| 190 |
+
partition key field names must match the order and length of partition
|
| 191 |
+
directories discovered. Partition key field names are not required to
|
| 192 |
+
exist in the dataset schema.
|
| 193 |
+
field_types: A dictionary that maps partition key names to their desired
|
| 194 |
+
data type. If not provided, the data type default to string.
|
| 195 |
+
filesystem: Filesystem that will be used for partition path file I/O.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
The new path-based partition parser.
|
| 199 |
+
"""
|
| 200 |
+
scheme = Partitioning(style, base_dir, field_names, field_types, filesystem)
|
| 201 |
+
return PathPartitionParser(scheme)
|
| 202 |
+
|
| 203 |
+
def __init__(self, partitioning: Partitioning):
|
| 204 |
+
"""Creates a path-based partition parser.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
partitioning: The path-based partition scheme. The parser starts
|
| 208 |
+
searching for partitions from this scheme's base directory. File paths
|
| 209 |
+
outside the base directory will be considered unpartitioned. If the
|
| 210 |
+
base directory is `None` or an empty string then this will search for
|
| 211 |
+
partitions in all file path directories. Field names are required for
|
| 212 |
+
DIRECTORY partitioning, and optional for HIVE partitioning. When
|
| 213 |
+
non-empty, the order and length of partition key field names must match
|
| 214 |
+
the order and length of partition directories discovered.
|
| 215 |
+
"""
|
| 216 |
+
style = partitioning.style
|
| 217 |
+
field_names = partitioning.field_names
|
| 218 |
+
if style == PartitionStyle.DIRECTORY and not field_names:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
"Directory partitioning requires a corresponding list of "
|
| 221 |
+
"partition key field names. Please retry your request with one "
|
| 222 |
+
"or more field names specified."
|
| 223 |
+
)
|
| 224 |
+
parsers = {
|
| 225 |
+
PartitionStyle.HIVE: self._parse_hive_path,
|
| 226 |
+
PartitionStyle.DIRECTORY: self._parse_dir_path,
|
| 227 |
+
}
|
| 228 |
+
self._parser_fn: Callable[[str], Dict[str, str]] = parsers.get(style)
|
| 229 |
+
if self._parser_fn is None:
|
| 230 |
+
raise ValueError(
|
| 231 |
+
f"Unsupported partition style: {style}. "
|
| 232 |
+
f"Supported styles: {parsers.keys()}"
|
| 233 |
+
)
|
| 234 |
+
self._scheme = partitioning
|
| 235 |
+
|
| 236 |
+
def __call__(self, path: str) -> Dict[str, str]:
|
| 237 |
+
"""Parses partition keys and values from a single file path.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
path: Input file path to parse.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Dictionary mapping directory partition keys to values from the input file
|
| 244 |
+
path. Returns an empty dictionary for unpartitioned files.
|
| 245 |
+
"""
|
| 246 |
+
dir_path = self._dir_path_trim_base(path)
|
| 247 |
+
if dir_path is None:
|
| 248 |
+
return {}
|
| 249 |
+
partitions: Dict[str, str] = self._parser_fn(dir_path)
|
| 250 |
+
|
| 251 |
+
for field, data_type in self._scheme.field_types.items():
|
| 252 |
+
partitions[field] = _cast_value(partitions[field], data_type)
|
| 253 |
+
|
| 254 |
+
return partitions
|
| 255 |
+
|
| 256 |
+
@property
|
| 257 |
+
def scheme(self) -> Partitioning:
|
| 258 |
+
"""Returns the partitioning for this parser."""
|
| 259 |
+
return self._scheme
|
| 260 |
+
|
| 261 |
+
def _dir_path_trim_base(self, path: str) -> Optional[str]:
|
| 262 |
+
"""Trims the normalized base directory and returns the directory path.
|
| 263 |
+
|
| 264 |
+
Returns None if the path does not start with the normalized base directory.
|
| 265 |
+
Simply returns the directory path if the base directory is undefined.
|
| 266 |
+
"""
|
| 267 |
+
if not path.startswith(self._scheme.normalized_base_dir):
|
| 268 |
+
return None
|
| 269 |
+
path = path[len(self._scheme.normalized_base_dir) :]
|
| 270 |
+
return posixpath.dirname(path)
|
| 271 |
+
|
| 272 |
+
def _parse_hive_path(self, dir_path: str) -> Dict[str, str]:
|
| 273 |
+
"""Hive partition path parser.
|
| 274 |
+
|
| 275 |
+
Returns a dictionary mapping partition keys to values given a hive-style
|
| 276 |
+
partition path of the form "{key1}={value1}/{key2}={value2}/..." or an empty
|
| 277 |
+
dictionary for unpartitioned files.
|
| 278 |
+
"""
|
| 279 |
+
dirs = [d for d in dir_path.split("/") if d and (d.count("=") == 1)]
|
| 280 |
+
kv_pairs = [d.split("=") for d in dirs] if dirs else []
|
| 281 |
+
field_names = self._scheme.field_names
|
| 282 |
+
if field_names and kv_pairs:
|
| 283 |
+
if len(kv_pairs) != len(field_names):
|
| 284 |
+
raise ValueError(
|
| 285 |
+
f"Expected {len(field_names)} partition value(s) but found "
|
| 286 |
+
f"{len(kv_pairs)}: {kv_pairs}."
|
| 287 |
+
)
|
| 288 |
+
for i, field_name in enumerate(field_names):
|
| 289 |
+
if kv_pairs[i][0] != field_name:
|
| 290 |
+
raise ValueError(
|
| 291 |
+
f"Expected partition key {field_name} but found "
|
| 292 |
+
f"{kv_pairs[i][0]}"
|
| 293 |
+
)
|
| 294 |
+
return dict(kv_pairs)
|
| 295 |
+
|
| 296 |
+
def _parse_dir_path(self, dir_path: str) -> Dict[str, str]:
|
| 297 |
+
"""Directory partition path parser.
|
| 298 |
+
|
| 299 |
+
Returns a dictionary mapping directory partition keys to values from a
|
| 300 |
+
partition path of the form "{value1}/{value2}/..." or an empty dictionary for
|
| 301 |
+
unpartitioned files.
|
| 302 |
+
|
| 303 |
+
Requires a corresponding ordered list of partition key field names to map the
|
| 304 |
+
correct key to each value.
|
| 305 |
+
"""
|
| 306 |
+
dirs = [d for d in dir_path.split("/") if d]
|
| 307 |
+
field_names = self._scheme.field_names
|
| 308 |
+
|
| 309 |
+
if dirs and len(dirs) != len(field_names):
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"Expected {len(field_names)} partition value(s) but found "
|
| 312 |
+
f"{len(dirs)}: {dirs}."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if not dirs:
|
| 316 |
+
return {}
|
| 317 |
+
return {
|
| 318 |
+
field: directory
|
| 319 |
+
for field, directory in zip(field_names, dirs)
|
| 320 |
+
if field is not None
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@PublicAPI(stability="beta")
|
| 325 |
+
class PathPartitionFilter:
|
| 326 |
+
"""Partition filter for path-based partition formats.
|
| 327 |
+
|
| 328 |
+
Used to explicitly keep or reject files based on a custom filter function that
|
| 329 |
+
takes partition keys and values parsed from the file's path as input.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
@staticmethod
|
| 333 |
+
def of(
|
| 334 |
+
filter_fn: Callable[[Dict[str, str]], bool],
|
| 335 |
+
style: PartitionStyle = PartitionStyle.HIVE,
|
| 336 |
+
base_dir: Optional[str] = None,
|
| 337 |
+
field_names: Optional[List[str]] = None,
|
| 338 |
+
field_types: Optional[Dict[str, PartitionDataType]] = None,
|
| 339 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 340 |
+
) -> "PathPartitionFilter":
|
| 341 |
+
"""Creates a path-based partition filter using a flattened argument list.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
filter_fn: Callback used to filter partitions. Takes a dictionary mapping
|
| 345 |
+
partition keys to values as input. Unpartitioned files are denoted with
|
| 346 |
+
an empty input dictionary. Returns `True` to read a file for that
|
| 347 |
+
partition or `False` to skip it. Partition keys and values are always
|
| 348 |
+
strings read from the filesystem path. For example, this removes all
|
| 349 |
+
unpartitioned files:
|
| 350 |
+
|
| 351 |
+
.. code:: python
|
| 352 |
+
|
| 353 |
+
lambda d: True if d else False
|
| 354 |
+
|
| 355 |
+
This raises an assertion error for any unpartitioned file found:
|
| 356 |
+
|
| 357 |
+
.. code:: python
|
| 358 |
+
|
| 359 |
+
def do_assert(val, msg):
|
| 360 |
+
assert val, msg
|
| 361 |
+
|
| 362 |
+
lambda d: do_assert(d, "Expected all files to be partitioned!")
|
| 363 |
+
|
| 364 |
+
And this only reads files from January, 2022 partitions:
|
| 365 |
+
|
| 366 |
+
.. code:: python
|
| 367 |
+
|
| 368 |
+
lambda d: d["month"] == "January" and d["year"] == "2022"
|
| 369 |
+
|
| 370 |
+
style: The partition style - may be either HIVE or DIRECTORY.
|
| 371 |
+
base_dir: "/"-delimited base directory to start searching for partitions
|
| 372 |
+
(exclusive). File paths outside of this directory will be considered
|
| 373 |
+
unpartitioned. Specify `None` or an empty string to search for
|
| 374 |
+
partitions in all file path directories.
|
| 375 |
+
field_names: The partition key names. Required for DIRECTORY partitioning.
|
| 376 |
+
Optional for HIVE partitioning. When non-empty, the order and length of
|
| 377 |
+
partition key field names must match the order and length of partition
|
| 378 |
+
directories discovered. Partition key field names are not required to
|
| 379 |
+
exist in the dataset schema.
|
| 380 |
+
field_types: A dictionary that maps partition key names to their desired
|
| 381 |
+
data type. If not provided, the data type defaults to string.
|
| 382 |
+
filesystem: Filesystem that will be used for partition path file I/O.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
The new path-based partition filter.
|
| 386 |
+
"""
|
| 387 |
+
scheme = Partitioning(style, base_dir, field_names, field_types, filesystem)
|
| 388 |
+
path_partition_parser = PathPartitionParser(scheme)
|
| 389 |
+
return PathPartitionFilter(path_partition_parser, filter_fn)
|
| 390 |
+
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
path_partition_parser: PathPartitionParser,
|
| 394 |
+
filter_fn: Callable[[Dict[str, str]], bool],
|
| 395 |
+
):
|
| 396 |
+
"""Creates a new path-based partition filter based on a parser.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
path_partition_parser: The path-based partition parser.
|
| 400 |
+
filter_fn: Callback used to filter partitions. Takes a dictionary mapping
|
| 401 |
+
partition keys to values as input. Unpartitioned files are denoted with
|
| 402 |
+
an empty input dictionary. Returns `True` to read a file for that
|
| 403 |
+
partition or `False` to skip it. Partition keys and values are always
|
| 404 |
+
strings read from the filesystem path. For example, this removes all
|
| 405 |
+
unpartitioned files:
|
| 406 |
+
``lambda d: True if d else False``
|
| 407 |
+
This raises an assertion error for any unpartitioned file found:
|
| 408 |
+
``lambda d: assert d, "Expected all files to be partitioned!"``
|
| 409 |
+
And this only reads files from January, 2022 partitions:
|
| 410 |
+
``lambda d: d["month"] == "January" and d["year"] == "2022"``
|
| 411 |
+
"""
|
| 412 |
+
self._parser = path_partition_parser
|
| 413 |
+
self._filter_fn = filter_fn
|
| 414 |
+
|
| 415 |
+
def __call__(self, paths: List[str]) -> List[str]:
|
| 416 |
+
"""Returns all paths that pass this partition scheme's partition filter.
|
| 417 |
+
|
| 418 |
+
If no partition filter is set, then returns all input paths. If a base
|
| 419 |
+
directory is set, then only paths under this base directory will be parsed
|
| 420 |
+
for partitions. All paths outside of this base directory will automatically
|
| 421 |
+
be considered unpartitioned, and passed into the filter function as empty
|
| 422 |
+
dictionaries.
|
| 423 |
+
|
| 424 |
+
Also normalizes the partition base directory for compatibility with the
|
| 425 |
+
given filesystem before applying the filter.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
paths: Paths to pass through the partition filter function. All
|
| 429 |
+
paths should be normalized for compatibility with the given
|
| 430 |
+
filesystem.
|
| 431 |
+
Returns:
|
| 432 |
+
List of paths that pass the partition filter, or all paths if no
|
| 433 |
+
partition filter is defined.
|
| 434 |
+
"""
|
| 435 |
+
filtered_paths = paths
|
| 436 |
+
if self._filter_fn is not None:
|
| 437 |
+
filtered_paths = [
|
| 438 |
+
path for path in paths if self._filter_fn(self._parser(path))
|
| 439 |
+
]
|
| 440 |
+
return filtered_paths
|
| 441 |
+
|
| 442 |
+
@property
|
| 443 |
+
def parser(self) -> PathPartitionParser:
|
| 444 |
+
"""Returns the path partition parser for this filter."""
|
| 445 |
+
return self._parser
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _cast_value(value: str, data_type: PartitionDataType) -> Any:
|
| 449 |
+
if data_type is int:
|
| 450 |
+
return int(value)
|
| 451 |
+
elif data_type is float:
|
| 452 |
+
return float(value)
|
| 453 |
+
elif data_type is bool:
|
| 454 |
+
return value.lower() == "true"
|
| 455 |
+
else:
|
| 456 |
+
return value
|
.venv/lib/python3.11/site-packages/ray/data/datasource/path_util.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
import sys
|
| 3 |
+
import urllib
|
| 4 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from ray.data._internal.util import _resolve_custom_scheme
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
import pyarrow
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _has_file_extension(path: str, extensions: Optional[List[str]]) -> bool:
|
| 13 |
+
"""Check if a path has a file extension in the provided list.
|
| 14 |
+
|
| 15 |
+
Examples:
|
| 16 |
+
>>> _has_file_extension("foo.csv", ["csv"])
|
| 17 |
+
True
|
| 18 |
+
>>> _has_file_extension("foo.CSV", ["csv"])
|
| 19 |
+
True
|
| 20 |
+
>>> _has_file_extension("foo.csv", ["json", "jsonl"])
|
| 21 |
+
False
|
| 22 |
+
>>> _has_file_extension("foo.csv", None)
|
| 23 |
+
True
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
path: The path to check.
|
| 27 |
+
extensions: A list of extensions to check against. If `None`, any extension is
|
| 28 |
+
considered valid.
|
| 29 |
+
"""
|
| 30 |
+
assert extensions is None or isinstance(extensions, list), type(extensions)
|
| 31 |
+
|
| 32 |
+
if extensions is None:
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
# The user-specified extensions don't contain a leading dot, so we add it here.
|
| 36 |
+
extensions = [f".{ext.lower()}" for ext in extensions]
|
| 37 |
+
return any(path.lower().endswith(ext) for ext in extensions)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _resolve_paths_and_filesystem(
|
| 41 |
+
paths: Union[str, List[str]],
|
| 42 |
+
filesystem: "pyarrow.fs.FileSystem" = None,
|
| 43 |
+
) -> Tuple[List[str], "pyarrow.fs.FileSystem"]:
|
| 44 |
+
"""
|
| 45 |
+
Resolves and normalizes all provided paths, infers a filesystem from the
|
| 46 |
+
paths and ensures that all paths use the same filesystem.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
paths: A single file/directory path or a list of file/directory paths.
|
| 50 |
+
A list of paths can contain both files and directories.
|
| 51 |
+
filesystem: The filesystem implementation that should be used for
|
| 52 |
+
reading these files. If None, a filesystem will be inferred. If not
|
| 53 |
+
None, the provided filesystem will still be validated against all
|
| 54 |
+
filesystems inferred from the provided paths to ensure
|
| 55 |
+
compatibility.
|
| 56 |
+
"""
|
| 57 |
+
import pyarrow as pa
|
| 58 |
+
from pyarrow.fs import (
|
| 59 |
+
FileSystem,
|
| 60 |
+
FSSpecHandler,
|
| 61 |
+
PyFileSystem,
|
| 62 |
+
_resolve_filesystem_and_path,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if isinstance(paths, str):
|
| 66 |
+
paths = [paths]
|
| 67 |
+
if isinstance(paths, pathlib.Path):
|
| 68 |
+
paths = [str(paths)]
|
| 69 |
+
elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths):
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Expected `paths` to be a `str`, `pathlib.Path`, or `list[str]`, but got "
|
| 72 |
+
f"`{paths}`."
|
| 73 |
+
)
|
| 74 |
+
elif len(paths) == 0:
|
| 75 |
+
raise ValueError("Must provide at least one path.")
|
| 76 |
+
|
| 77 |
+
need_unwrap_path_protocol = True
|
| 78 |
+
if filesystem and not isinstance(filesystem, FileSystem):
|
| 79 |
+
err_msg = (
|
| 80 |
+
f"The filesystem passed must either conform to "
|
| 81 |
+
f"pyarrow.fs.FileSystem, or "
|
| 82 |
+
f"fsspec.spec.AbstractFileSystem. The provided "
|
| 83 |
+
f"filesystem was: {filesystem}"
|
| 84 |
+
)
|
| 85 |
+
try:
|
| 86 |
+
import fsspec
|
| 87 |
+
from fsspec.implementations.http import HTTPFileSystem
|
| 88 |
+
except ModuleNotFoundError:
|
| 89 |
+
# If filesystem is not a pyarrow filesystem and fsspec isn't
|
| 90 |
+
# installed, then filesystem is neither a pyarrow filesystem nor
|
| 91 |
+
# an fsspec filesystem, so we raise a TypeError.
|
| 92 |
+
raise TypeError(err_msg) from None
|
| 93 |
+
if not isinstance(filesystem, fsspec.spec.AbstractFileSystem):
|
| 94 |
+
raise TypeError(err_msg) from None
|
| 95 |
+
if isinstance(filesystem, HTTPFileSystem):
|
| 96 |
+
# If filesystem is fsspec HTTPFileSystem, the protocol/scheme of paths
|
| 97 |
+
# should not be unwrapped/removed, because HTTPFileSystem expects full file
|
| 98 |
+
# paths including protocol/scheme. This is different behavior compared to
|
| 99 |
+
# file systems implementation in pyarrow.fs.FileSystem.
|
| 100 |
+
need_unwrap_path_protocol = False
|
| 101 |
+
|
| 102 |
+
filesystem = PyFileSystem(FSSpecHandler(filesystem))
|
| 103 |
+
|
| 104 |
+
resolved_paths = []
|
| 105 |
+
for path in paths:
|
| 106 |
+
path = _resolve_custom_scheme(path)
|
| 107 |
+
try:
|
| 108 |
+
resolved_filesystem, resolved_path = _resolve_filesystem_and_path(
|
| 109 |
+
path, filesystem
|
| 110 |
+
)
|
| 111 |
+
except pa.lib.ArrowInvalid as e:
|
| 112 |
+
if "Cannot parse URI" in str(e):
|
| 113 |
+
resolved_filesystem, resolved_path = _resolve_filesystem_and_path(
|
| 114 |
+
_encode_url(path), filesystem
|
| 115 |
+
)
|
| 116 |
+
resolved_path = _decode_url(resolved_path)
|
| 117 |
+
elif "Unrecognized filesystem type in URI" in str(e):
|
| 118 |
+
scheme = urllib.parse.urlparse(path, allow_fragments=False).scheme
|
| 119 |
+
if scheme in ["http", "https"]:
|
| 120 |
+
# If scheme of path is HTTP and filesystem is not resolved,
|
| 121 |
+
# try to use fsspec HTTPFileSystem. This expects fsspec is
|
| 122 |
+
# installed.
|
| 123 |
+
try:
|
| 124 |
+
from fsspec.implementations.http import HTTPFileSystem
|
| 125 |
+
except ModuleNotFoundError:
|
| 126 |
+
raise ImportError(
|
| 127 |
+
"Please install fsspec to read files from HTTP."
|
| 128 |
+
) from None
|
| 129 |
+
|
| 130 |
+
resolved_filesystem = PyFileSystem(FSSpecHandler(HTTPFileSystem()))
|
| 131 |
+
resolved_path = path
|
| 132 |
+
need_unwrap_path_protocol = False
|
| 133 |
+
else:
|
| 134 |
+
raise
|
| 135 |
+
else:
|
| 136 |
+
raise
|
| 137 |
+
if filesystem is None:
|
| 138 |
+
filesystem = resolved_filesystem
|
| 139 |
+
elif need_unwrap_path_protocol:
|
| 140 |
+
resolved_path = _unwrap_protocol(resolved_path)
|
| 141 |
+
resolved_path = filesystem.normalize_path(resolved_path)
|
| 142 |
+
resolved_paths.append(resolved_path)
|
| 143 |
+
|
| 144 |
+
return resolved_paths, filesystem
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _unwrap_protocol(path):
|
| 148 |
+
"""
|
| 149 |
+
Slice off any protocol prefixes on path.
|
| 150 |
+
"""
|
| 151 |
+
if sys.platform == "win32" and _is_local_windows_path(path):
|
| 152 |
+
# Represent as posix path such that downstream functions properly handle it.
|
| 153 |
+
# This is executed when 'file://' is NOT included in the path.
|
| 154 |
+
return pathlib.Path(path).as_posix()
|
| 155 |
+
|
| 156 |
+
parsed = urllib.parse.urlparse(path, allow_fragments=False) # support '#' in path
|
| 157 |
+
query = "?" + parsed.query if parsed.query else "" # support '?' in path
|
| 158 |
+
netloc = parsed.netloc
|
| 159 |
+
if parsed.scheme == "s3" and "@" in parsed.netloc:
|
| 160 |
+
# If the path contains an @, it is assumed to be an anonymous
|
| 161 |
+
# credentialed path, and we need to strip off the credentials.
|
| 162 |
+
netloc = parsed.netloc.split("@")[-1]
|
| 163 |
+
|
| 164 |
+
parsed_path = parsed.path
|
| 165 |
+
# urlparse prepends the path with a '/'. This does not work on Windows
|
| 166 |
+
# so if this is the case strip the leading slash.
|
| 167 |
+
if (
|
| 168 |
+
sys.platform == "win32"
|
| 169 |
+
and not netloc
|
| 170 |
+
and len(parsed_path) >= 3
|
| 171 |
+
and parsed_path[0] == "/" # The problematic leading slash
|
| 172 |
+
and parsed_path[1].isalpha() # Ensure it is a drive letter.
|
| 173 |
+
and parsed_path[2:4] in (":", ":/")
|
| 174 |
+
):
|
| 175 |
+
parsed_path = parsed_path[1:]
|
| 176 |
+
|
| 177 |
+
return netloc + parsed_path + query
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _is_url(path) -> bool:
|
| 181 |
+
return urllib.parse.urlparse(path).scheme != ""
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _is_local_windows_path(path: str) -> bool:
|
| 185 |
+
"""Determines if path is a Windows file-system location."""
|
| 186 |
+
if sys.platform != "win32":
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
if len(path) >= 1 and path[0] == "\\":
|
| 190 |
+
return True
|
| 191 |
+
if (
|
| 192 |
+
len(path) >= 3
|
| 193 |
+
and path[1] == ":"
|
| 194 |
+
and (path[2] == "/" or path[2] == "\\")
|
| 195 |
+
and path[0].isalpha()
|
| 196 |
+
):
|
| 197 |
+
return True
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _encode_url(path):
|
| 202 |
+
return urllib.parse.quote(path, safe="/:")
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _decode_url(path):
|
| 206 |
+
return urllib.parse.unquote(path)
|
.venv/lib/python3.11/site-packages/ray/data/extensions/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.air.util.tensor_extensions.arrow import (
|
| 2 |
+
ArrowTensorTypeV2,
|
| 3 |
+
get_arrow_extension_tensor_types,
|
| 4 |
+
)
|
| 5 |
+
from ray.data.extensions.object_extension import (
|
| 6 |
+
ArrowPythonObjectArray,
|
| 7 |
+
ArrowPythonObjectScalar,
|
| 8 |
+
ArrowPythonObjectType,
|
| 9 |
+
PythonObjectArray,
|
| 10 |
+
PythonObjectDtype,
|
| 11 |
+
_object_extension_type_allowed,
|
| 12 |
+
)
|
| 13 |
+
from ray.data.extensions.tensor_extension import (
|
| 14 |
+
ArrowConversionError,
|
| 15 |
+
ArrowTensorArray,
|
| 16 |
+
ArrowTensorType,
|
| 17 |
+
ArrowVariableShapedTensorArray,
|
| 18 |
+
ArrowVariableShapedTensorType,
|
| 19 |
+
TensorArray,
|
| 20 |
+
TensorArrayElement,
|
| 21 |
+
TensorDtype,
|
| 22 |
+
column_needs_tensor_extension,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
# Tensor array extension.
|
| 27 |
+
"TensorDtype",
|
| 28 |
+
"TensorArray",
|
| 29 |
+
"TensorArrayElement",
|
| 30 |
+
"ArrowTensorType",
|
| 31 |
+
"ArrowTensorTypeV2",
|
| 32 |
+
"ArrowTensorArray",
|
| 33 |
+
"ArrowVariableShapedTensorType",
|
| 34 |
+
"ArrowVariableShapedTensorArray",
|
| 35 |
+
"column_needs_tensor_extension",
|
| 36 |
+
"ArrowConversionError",
|
| 37 |
+
# Object array extension
|
| 38 |
+
"ArrowPythonObjectArray",
|
| 39 |
+
"ArrowPythonObjectType",
|
| 40 |
+
"ArrowPythonObjectScalar",
|
| 41 |
+
"PythonObjectArray",
|
| 42 |
+
"PythonObjectDtype",
|
| 43 |
+
"_object_extension_type_allowed",
|
| 44 |
+
"get_arrow_extension_tensor_types",
|
| 45 |
+
]
|
.venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/object_extension.cpython-311.pyc
ADDED
|
Binary file (595 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/tensor_extension.cpython-311.pyc
ADDED
|
Binary file (842 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/data/extensions/object_extension.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.air.util.object_extensions.arrow import ( # noqa: F401
|
| 2 |
+
ArrowPythonObjectArray,
|
| 3 |
+
ArrowPythonObjectScalar,
|
| 4 |
+
ArrowPythonObjectType,
|
| 5 |
+
_object_extension_type_allowed,
|
| 6 |
+
)
|
| 7 |
+
from ray.air.util.object_extensions.pandas import ( # noqa: F401
|
| 8 |
+
PythonObjectArray,
|
| 9 |
+
PythonObjectDtype,
|
| 10 |
+
)
|
.venv/lib/python3.11/site-packages/ray/data/extensions/tensor_extension.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.air.util.tensor_extensions.arrow import ( # noqa: F401
|
| 2 |
+
ArrowConversionError,
|
| 3 |
+
ArrowTensorArray,
|
| 4 |
+
ArrowTensorType,
|
| 5 |
+
ArrowTensorTypeV2,
|
| 6 |
+
ArrowVariableShapedTensorArray,
|
| 7 |
+
ArrowVariableShapedTensorType,
|
| 8 |
+
)
|
| 9 |
+
from ray.air.util.tensor_extensions.pandas import ( # noqa: F401
|
| 10 |
+
TensorArray,
|
| 11 |
+
TensorArrayElement,
|
| 12 |
+
TensorDtype,
|
| 13 |
+
column_needs_tensor_extension,
|
| 14 |
+
)
|
| 15 |
+
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray # noqa: F401
|
.venv/lib/python3.11/site-packages/ray/data/preprocessors/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.data.preprocessors.chain import Chain
|
| 2 |
+
from ray.data.preprocessors.concatenator import Concatenator
|
| 3 |
+
from ray.data.preprocessors.discretizer import (
|
| 4 |
+
CustomKBinsDiscretizer,
|
| 5 |
+
UniformKBinsDiscretizer,
|
| 6 |
+
)
|
| 7 |
+
from ray.data.preprocessors.encoder import (
|
| 8 |
+
Categorizer,
|
| 9 |
+
LabelEncoder,
|
| 10 |
+
MultiHotEncoder,
|
| 11 |
+
OneHotEncoder,
|
| 12 |
+
OrdinalEncoder,
|
| 13 |
+
)
|
| 14 |
+
from ray.data.preprocessors.hasher import FeatureHasher
|
| 15 |
+
from ray.data.preprocessors.imputer import SimpleImputer
|
| 16 |
+
from ray.data.preprocessors.normalizer import Normalizer
|
| 17 |
+
from ray.data.preprocessors.scaler import (
|
| 18 |
+
MaxAbsScaler,
|
| 19 |
+
MinMaxScaler,
|
| 20 |
+
RobustScaler,
|
| 21 |
+
StandardScaler,
|
| 22 |
+
)
|
| 23 |
+
from ray.data.preprocessors.tokenizer import Tokenizer
|
| 24 |
+
from ray.data.preprocessors.torch import TorchVisionPreprocessor
|
| 25 |
+
from ray.data.preprocessors.transformer import PowerTransformer
|
| 26 |
+
from ray.data.preprocessors.vectorizer import CountVectorizer, HashingVectorizer
|
| 27 |
+
|
| 28 |
+
__all__ = [
|
| 29 |
+
"Categorizer",
|
| 30 |
+
"CountVectorizer",
|
| 31 |
+
"Chain",
|
| 32 |
+
"FeatureHasher",
|
| 33 |
+
"HashingVectorizer",
|
| 34 |
+
"LabelEncoder",
|
| 35 |
+
"MaxAbsScaler",
|
| 36 |
+
"MinMaxScaler",
|
| 37 |
+
"MultiHotEncoder",
|
| 38 |
+
"Normalizer",
|
| 39 |
+
"OneHotEncoder",
|
| 40 |
+
"OrdinalEncoder",
|
| 41 |
+
"PowerTransformer",
|
| 42 |
+
"RobustScaler",
|
| 43 |
+
"SimpleImputer",
|
| 44 |
+
"StandardScaler",
|
| 45 |
+
"Concatenator",
|
| 46 |
+
"Tokenizer",
|
| 47 |
+
"TorchVisionPreprocessor",
|
| 48 |
+
"CustomKBinsDiscretizer",
|
| 49 |
+
"UniformKBinsDiscretizer",
|
| 50 |
+
]
|