koichi12 commited on
Commit
110275e
·
verified ·
1 Parent(s): 98ca408

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/mistral_common/data/tekken_240911.json +3 -0
  3. .venv/lib/python3.11/site-packages/ray/data/_internal/__init__.py +0 -0
  4. .venv/lib/python3.11/site-packages/ray/data/_internal/aggregate.py +411 -0
  5. .venv/lib/python3.11/site-packages/ray/data/_internal/arrow_block.py +649 -0
  6. .venv/lib/python3.11/site-packages/ray/data/_internal/batcher.py +325 -0
  7. .venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/block_batching.py +60 -0
  8. .venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/interfaces.py +47 -0
  9. .venv/lib/python3.11/site-packages/ray/data/_internal/block_builder.py +39 -0
  10. .venv/lib/python3.11/site-packages/ray/data/_internal/block_list.py +98 -0
  11. .venv/lib/python3.11/site-packages/ray/data/_internal/compute.py +143 -0
  12. .venv/lib/python3.11/site-packages/ray/data/_internal/delegating_block_builder.py +76 -0
  13. .venv/lib/python3.11/site-packages/ray/data/_internal/equalize.py +142 -0
  14. .venv/lib/python3.11/site-packages/ray/data/_internal/logging.py +208 -0
  15. .venv/lib/python3.11/site-packages/ray/data/_internal/memory_tracing.py +147 -0
  16. .venv/lib/python3.11/site-packages/ray/data/_internal/null_aggregate.py +276 -0
  17. .venv/lib/python3.11/site-packages/ray/data/_internal/numpy_support.py +233 -0
  18. .venv/lib/python3.11/site-packages/ray/data/_internal/output_buffer.py +109 -0
  19. .venv/lib/python3.11/site-packages/ray/data/_internal/pandas_block.py +728 -0
  20. .venv/lib/python3.11/site-packages/ray/data/_internal/plan.py +602 -0
  21. .venv/lib/python3.11/site-packages/ray/data/_internal/progress_bar.py +217 -0
  22. .venv/lib/python3.11/site-packages/ray/data/_internal/remote_fn.py +80 -0
  23. .venv/lib/python3.11/site-packages/ray/data/_internal/row.py +42 -0
  24. .venv/lib/python3.11/site-packages/ray/data/_internal/size_estimator.py +92 -0
  25. .venv/lib/python3.11/site-packages/ray/data/_internal/split.py +297 -0
  26. .venv/lib/python3.11/site-packages/ray/data/_internal/stats.py +1495 -0
  27. .venv/lib/python3.11/site-packages/ray/data/_internal/table_block.py +310 -0
  28. .venv/lib/python3.11/site-packages/ray/data/_internal/torch_iterable_dataset.py +10 -0
  29. .venv/lib/python3.11/site-packages/ray/data/_internal/util.py +1262 -0
  30. .venv/lib/python3.11/site-packages/ray/data/datasource/__init__.py +67 -0
  31. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/__init__.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasink.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasource.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_based_datasource.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_datasink.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_meta_provider.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/filename_provider.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/parquet_meta_provider.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/partitioning.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/path_util.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/data/datasource/file_datasink.py +266 -0
  42. .venv/lib/python3.11/site-packages/ray/data/datasource/partitioning.py +456 -0
  43. .venv/lib/python3.11/site-packages/ray/data/datasource/path_util.py +206 -0
  44. .venv/lib/python3.11/site-packages/ray/data/extensions/__init__.py +45 -0
  45. .venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/__init__.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/object_extension.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/tensor_extension.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/data/extensions/object_extension.py +10 -0
  49. .venv/lib/python3.11/site-packages/ray/data/extensions/tensor_extension.py +15 -0
  50. .venv/lib/python3.11/site-packages/ray/data/preprocessors/__init__.py +50 -0
.gitattributes CHANGED
@@ -151,3 +151,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
151
  .venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
152
  .venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
153
  .venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
151
  .venv/lib/python3.11/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
152
  .venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
153
  .venv/lib/python3.11/site-packages/torchgen/__pycache__/model.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
154
+ .venv/lib/python3.11/site-packages/mistral_common/data/tekken_240911.json filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240911.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:386b1f98fba69b38c3de512a4eb602dc69a95dae0e54e6ce048ea3e29a2627a8
3
+ size 19280967
.venv/lib/python3.11/site-packages/ray/data/_internal/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/data/_internal/aggregate.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
3
+
4
+ from ray.data._internal.null_aggregate import (
5
+ _null_wrap_accumulate_block,
6
+ _null_wrap_accumulate_row,
7
+ _null_wrap_finalize,
8
+ _null_wrap_init,
9
+ _null_wrap_merge,
10
+ )
11
+ from ray.data._internal.planner.exchange.sort_task_spec import SortKey
12
+ from ray.data.aggregate import AggregateFn
13
+ from ray.data.block import AggType, Block, BlockAccessor
14
+
15
+ if TYPE_CHECKING:
16
+ import pyarrow as pa
17
+
18
+
19
+ class _AggregateOnKeyBase(AggregateFn):
20
+ def _set_key_fn(self, on: str):
21
+ self._key_fn = on
22
+
23
+ def _validate(self, schema: Optional[Union[type, "pa.lib.Schema"]]) -> None:
24
+ SortKey(self._key_fn).validate_schema(schema)
25
+
26
+
27
+ class Count(AggregateFn):
28
+ """Defines count aggregation."""
29
+
30
+ def __init__(self):
31
+ super().__init__(
32
+ init=lambda k: 0,
33
+ accumulate_block=(
34
+ lambda a, block: a + BlockAccessor.for_block(block).num_rows()
35
+ ),
36
+ merge=lambda a1, a2: a1 + a2,
37
+ name="count()",
38
+ )
39
+
40
+
41
+ class Sum(_AggregateOnKeyBase):
42
+ """Defines sum aggregation."""
43
+
44
+ def __init__(
45
+ self,
46
+ on: Optional[str] = None,
47
+ ignore_nulls: bool = True,
48
+ alias_name: Optional[str] = None,
49
+ ):
50
+ self._set_key_fn(on)
51
+ if alias_name:
52
+ self._rs_name = alias_name
53
+ else:
54
+ self._rs_name = f"sum({str(on)})"
55
+
56
+ null_merge = _null_wrap_merge(ignore_nulls, lambda a1, a2: a1 + a2)
57
+
58
+ super().__init__(
59
+ init=_null_wrap_init(lambda k: 0),
60
+ merge=null_merge,
61
+ accumulate_block=_null_wrap_accumulate_block(
62
+ ignore_nulls,
63
+ lambda block: BlockAccessor.for_block(block).sum(on, ignore_nulls),
64
+ null_merge,
65
+ ),
66
+ finalize=_null_wrap_finalize(lambda a: a),
67
+ name=(self._rs_name),
68
+ )
69
+
70
+
71
+ class Min(_AggregateOnKeyBase):
72
+ """Defines min aggregation."""
73
+
74
+ def __init__(
75
+ self,
76
+ on: Optional[str] = None,
77
+ ignore_nulls: bool = True,
78
+ alias_name: Optional[str] = None,
79
+ ):
80
+ self._set_key_fn(on)
81
+ if alias_name:
82
+ self._rs_name = alias_name
83
+ else:
84
+ self._rs_name = f"min({str(on)})"
85
+
86
+ null_merge = _null_wrap_merge(ignore_nulls, min)
87
+
88
+ super().__init__(
89
+ init=_null_wrap_init(lambda k: float("inf")),
90
+ merge=null_merge,
91
+ accumulate_block=_null_wrap_accumulate_block(
92
+ ignore_nulls,
93
+ lambda block: BlockAccessor.for_block(block).min(on, ignore_nulls),
94
+ null_merge,
95
+ ),
96
+ finalize=_null_wrap_finalize(lambda a: a),
97
+ name=(self._rs_name),
98
+ )
99
+
100
+
101
+ class Max(_AggregateOnKeyBase):
102
+ """Defines max aggregation."""
103
+
104
+ def __init__(
105
+ self,
106
+ on: Optional[str] = None,
107
+ ignore_nulls: bool = True,
108
+ alias_name: Optional[str] = None,
109
+ ):
110
+ self._set_key_fn(on)
111
+ if alias_name:
112
+ self._rs_name = alias_name
113
+ else:
114
+ self._rs_name = f"max({str(on)})"
115
+
116
+ null_merge = _null_wrap_merge(ignore_nulls, max)
117
+
118
+ super().__init__(
119
+ init=_null_wrap_init(lambda k: float("-inf")),
120
+ merge=null_merge,
121
+ accumulate_block=_null_wrap_accumulate_block(
122
+ ignore_nulls,
123
+ lambda block: BlockAccessor.for_block(block).max(on, ignore_nulls),
124
+ null_merge,
125
+ ),
126
+ finalize=_null_wrap_finalize(lambda a: a),
127
+ name=(self._rs_name),
128
+ )
129
+
130
+
131
+ class Mean(_AggregateOnKeyBase):
132
+ """Defines mean aggregation."""
133
+
134
+ def __init__(
135
+ self,
136
+ on: Optional[str] = None,
137
+ ignore_nulls: bool = True,
138
+ alias_name: Optional[str] = None,
139
+ ):
140
+ self._set_key_fn(on)
141
+ if alias_name:
142
+ self._rs_name = alias_name
143
+ else:
144
+ self._rs_name = f"mean({str(on)})"
145
+
146
+ null_merge = _null_wrap_merge(
147
+ ignore_nulls, lambda a1, a2: [a1[0] + a2[0], a1[1] + a2[1]]
148
+ )
149
+
150
+ def vectorized_mean(block: Block) -> AggType:
151
+ block_acc = BlockAccessor.for_block(block)
152
+ count = block_acc.count(on)
153
+ if count == 0 or count is None:
154
+ # Empty or all null.
155
+ return None
156
+ sum_ = block_acc.sum(on, ignore_nulls)
157
+ if sum_ is None:
158
+ # ignore_nulls=False and at least one null.
159
+ return None
160
+ return [sum_, count]
161
+
162
+ super().__init__(
163
+ init=_null_wrap_init(lambda k: [0, 0]),
164
+ merge=null_merge,
165
+ accumulate_block=_null_wrap_accumulate_block(
166
+ ignore_nulls,
167
+ vectorized_mean,
168
+ null_merge,
169
+ ),
170
+ finalize=_null_wrap_finalize(lambda a: a[0] / a[1]),
171
+ name=(self._rs_name),
172
+ )
173
+
174
+
175
+ class Std(_AggregateOnKeyBase):
176
+ """Defines standard deviation aggregation.
177
+
178
+ Uses Welford's online method for an accumulator-style computation of the
179
+ standard deviation. This method was chosen due to its numerical
180
+ stability, and it being computable in a single pass.
181
+ This may give different (but more accurate) results than NumPy, Pandas,
182
+ and sklearn, which use a less numerically stable two-pass algorithm.
183
+ See
184
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ on: Optional[str] = None,
190
+ ddof: int = 1,
191
+ ignore_nulls: bool = True,
192
+ alias_name: Optional[str] = None,
193
+ ):
194
+ self._set_key_fn(on)
195
+ if alias_name:
196
+ self._rs_name = alias_name
197
+ else:
198
+ self._rs_name = f"std({str(on)})"
199
+
200
+ def merge(a: List[float], b: List[float]):
201
+ # Merges two accumulations into one.
202
+ # See
203
+ # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
204
+ M2_a, mean_a, count_a = a
205
+ M2_b, mean_b, count_b = b
206
+ delta = mean_b - mean_a
207
+ count = count_a + count_b
208
+ # NOTE: We use this mean calculation since it's more numerically
209
+ # stable than mean_a + delta * count_b / count, which actually
210
+ # deviates from Pandas in the ~15th decimal place and causes our
211
+ # exact comparison tests to fail.
212
+ mean = (mean_a * count_a + mean_b * count_b) / count
213
+ # Update the sum of squared differences.
214
+ M2 = M2_a + M2_b + (delta**2) * count_a * count_b / count
215
+ return [M2, mean, count]
216
+
217
+ null_merge = _null_wrap_merge(ignore_nulls, merge)
218
+
219
+ def vectorized_std(block: Block) -> AggType:
220
+ block_acc = BlockAccessor.for_block(block)
221
+ count = block_acc.count(on)
222
+ if count == 0 or count is None:
223
+ # Empty or all null.
224
+ return None
225
+ sum_ = block_acc.sum(on, ignore_nulls)
226
+ if sum_ is None:
227
+ # ignore_nulls=False and at least one null.
228
+ return None
229
+ mean = sum_ / count
230
+ M2 = block_acc.sum_of_squared_diffs_from_mean(on, ignore_nulls, mean)
231
+ return [M2, mean, count]
232
+
233
+ def finalize(a: List[float]):
234
+ # Compute the final standard deviation from the accumulated
235
+ # sum of squared differences from current mean and the count.
236
+ M2, mean, count = a
237
+ if count < 2:
238
+ return 0.0
239
+ return math.sqrt(M2 / (count - ddof))
240
+
241
+ super().__init__(
242
+ init=_null_wrap_init(lambda k: [0, 0, 0]),
243
+ merge=null_merge,
244
+ accumulate_block=_null_wrap_accumulate_block(
245
+ ignore_nulls,
246
+ vectorized_std,
247
+ null_merge,
248
+ ),
249
+ finalize=_null_wrap_finalize(finalize),
250
+ name=(self._rs_name),
251
+ )
252
+
253
+
254
+ class AbsMax(_AggregateOnKeyBase):
255
+ """Defines absolute max aggregation."""
256
+
257
+ def __init__(
258
+ self,
259
+ on: Optional[str] = None,
260
+ ignore_nulls: bool = True,
261
+ alias_name: Optional[str] = None,
262
+ ):
263
+ self._set_key_fn(on)
264
+ on_fn = _to_on_fn(on)
265
+ if alias_name:
266
+ self._rs_name = alias_name
267
+ else:
268
+ self._rs_name = f"abs_max({str(on)})"
269
+
270
+ super().__init__(
271
+ init=_null_wrap_init(lambda k: 0),
272
+ merge=_null_wrap_merge(ignore_nulls, max),
273
+ accumulate_row=_null_wrap_accumulate_row(
274
+ ignore_nulls, on_fn, lambda a, r: max(a, abs(r))
275
+ ),
276
+ finalize=_null_wrap_finalize(lambda a: a),
277
+ name=(self._rs_name),
278
+ )
279
+
280
+
281
+ def _to_on_fn(on: Optional[str]):
282
+ if on is None:
283
+ return lambda r: r
284
+ elif isinstance(on, str):
285
+ return lambda r: r[on]
286
+ else:
287
+ return on
288
+
289
+
290
+ class Quantile(_AggregateOnKeyBase):
291
+ """Defines Quantile aggregation."""
292
+
293
+ def __init__(
294
+ self,
295
+ on: Optional[str] = None,
296
+ q: float = 0.5,
297
+ ignore_nulls: bool = True,
298
+ alias_name: Optional[str] = None,
299
+ ):
300
+ self._set_key_fn(on)
301
+ self._q = q
302
+ if alias_name:
303
+ self._rs_name = alias_name
304
+ else:
305
+ self._rs_name = f"quantile({str(on)})"
306
+
307
+ def merge(a: List[int], b: List[int]):
308
+ if isinstance(a, List) and isinstance(b, List):
309
+ a.extend(b)
310
+ return a
311
+ if isinstance(a, List) and (not isinstance(b, List)):
312
+ if b is not None and b != "":
313
+ a.append(b)
314
+ return a
315
+ if isinstance(b, List) and (not isinstance(a, List)):
316
+ if a is not None and a != "":
317
+ b.append(a)
318
+ return b
319
+
320
+ ls = []
321
+ if a is not None and a != "":
322
+ ls.append(a)
323
+ if b is not None and b != "":
324
+ ls.append(b)
325
+ return ls
326
+
327
+ null_merge = _null_wrap_merge(ignore_nulls, merge)
328
+
329
+ def block_row_ls(block: Block) -> AggType:
330
+ block_acc = BlockAccessor.for_block(block)
331
+ ls = []
332
+ for row in block_acc.iter_rows(public_row_format=False):
333
+ ls.append(row.get(on))
334
+ return ls
335
+
336
+ import math
337
+
338
+ def percentile(input_values, key: Optional[Callable[[Any], Any]] = None):
339
+ if not input_values:
340
+ return None
341
+
342
+ if key is None:
343
+ key = lambda x: x # noqa: E731
344
+
345
+ input_values = sorted(input_values)
346
+ k = (len(input_values) - 1) * self._q
347
+ f = math.floor(k)
348
+ c = math.ceil(k)
349
+ if f == c:
350
+ return key(input_values[int(k)])
351
+ d0 = key(input_values[int(f)]) * (c - k)
352
+ d1 = key(input_values[int(c)]) * (k - f)
353
+ return round(d0 + d1, 5)
354
+
355
+ super().__init__(
356
+ init=_null_wrap_init(lambda k: [0]),
357
+ merge=null_merge,
358
+ accumulate_block=_null_wrap_accumulate_block(
359
+ ignore_nulls,
360
+ block_row_ls,
361
+ null_merge,
362
+ ),
363
+ finalize=_null_wrap_finalize(percentile),
364
+ name=(self._rs_name),
365
+ )
366
+
367
+
368
+ class Unique(_AggregateOnKeyBase):
369
+ """Defines unique aggregation."""
370
+
371
+ def __init__(
372
+ self,
373
+ on: Optional[str] = None,
374
+ alias_name: Optional[str] = None,
375
+ ):
376
+ self._set_key_fn(on)
377
+ if alias_name:
378
+ self._rs_name = alias_name
379
+ else:
380
+ self._rs_name = f"unique({str(on)})"
381
+
382
+ def to_set(x):
383
+ if isinstance(x, set):
384
+ return x
385
+ elif isinstance(x, list):
386
+ return set(x)
387
+ else:
388
+ return {x}
389
+
390
+ def block_row_unique(block: Block) -> AggType:
391
+ import pyarrow.compute as pac
392
+
393
+ col = BlockAccessor.for_block(block).to_arrow().column(on)
394
+ return pac.unique(col).to_pylist()
395
+
396
+ def merge(a, b):
397
+ return to_set(a) | to_set(b)
398
+
399
+ null_merge = _null_wrap_merge(False, merge)
400
+
401
+ super().__init__(
402
+ init=_null_wrap_init(lambda x: set()),
403
+ merge=null_merge,
404
+ accumulate_block=_null_wrap_accumulate_block(
405
+ False,
406
+ block_row_unique,
407
+ null_merge,
408
+ ),
409
+ name=(self._rs_name),
410
+ finalize=_null_wrap_finalize(lambda x: x),
411
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_block.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import heapq
3
+ import logging
4
+ import random
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Iterator,
11
+ List,
12
+ Optional,
13
+ Sequence,
14
+ Tuple,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+
19
+ import numpy as np
20
+
21
+ from ray._private.utils import _get_pyarrow_version
22
+ from ray.air.constants import TENSOR_COLUMN_NAME
23
+ from ray.air.util.tensor_extensions.arrow import (
24
+ convert_to_pyarrow_array,
25
+ pyarrow_table_from_pydict,
26
+ )
27
+ from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow
28
+ from ray.data._internal.numpy_support import convert_to_numpy
29
+ from ray.data._internal.row import TableRow
30
+ from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
31
+ from ray.data._internal.util import NULL_SENTINEL, find_partitions, keys_equal
32
+ from ray.data.block import (
33
+ Block,
34
+ BlockAccessor,
35
+ BlockExecStats,
36
+ BlockMetadata,
37
+ BlockType,
38
+ KeyType,
39
+ U,
40
+ )
41
+ from ray.data.context import DataContext
42
+
43
+ try:
44
+ import pyarrow
45
+ except ImportError:
46
+ pyarrow = None
47
+
48
+
49
+ if TYPE_CHECKING:
50
+ import pandas
51
+
52
+ from ray.data._internal.planner.exchange.sort_task_spec import SortKey
53
+ from ray.data.aggregate import AggregateFn
54
+
55
+
56
+ T = TypeVar("T")
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ # We offload some transformations to polars for performance.
61
+ def get_sort_transform(context: DataContext) -> Callable:
62
+ if context.use_polars:
63
+ return transform_polars.sort
64
+ else:
65
+ return transform_pyarrow.sort
66
+
67
+
68
+ def get_concat_and_sort_transform(context: DataContext) -> Callable:
69
+ if context.use_polars:
70
+ return transform_polars.concat_and_sort
71
+ else:
72
+ return transform_pyarrow.concat_and_sort
73
+
74
+
75
+ class ArrowRow(TableRow):
76
+ """
77
+ Row of a tabular Dataset backed by a Arrow Table block.
78
+ """
79
+
80
+ def __getitem__(self, key: Union[str, List[str]]) -> Any:
81
+ from ray.data.extensions import get_arrow_extension_tensor_types
82
+
83
+ tensor_arrow_extension_types = get_arrow_extension_tensor_types()
84
+
85
+ def get_item(keys: List[str]) -> Any:
86
+ schema = self._row.schema
87
+ if isinstance(schema.field(keys[0]).type, tensor_arrow_extension_types):
88
+ # Build a tensor row.
89
+ return tuple(
90
+ [
91
+ ArrowBlockAccessor._build_tensor_row(self._row, col_name=key)
92
+ for key in keys
93
+ ]
94
+ )
95
+
96
+ table = self._row.select(keys)
97
+ if len(table) == 0:
98
+ return None
99
+
100
+ items = [col[0] for col in table.columns]
101
+ try:
102
+ # Try to interpret this as a pyarrow.Scalar value.
103
+ return tuple([item.as_py() for item in items])
104
+
105
+ except AttributeError:
106
+ # Assume that this row is an element of an extension array, and
107
+ # that it is bypassing pyarrow's scalar model for Arrow < 8.0.0.
108
+ return items
109
+
110
+ is_single_item = isinstance(key, str)
111
+ keys = [key] if is_single_item else key
112
+
113
+ items = get_item(keys)
114
+
115
+ if items is None:
116
+ return None
117
+ elif is_single_item:
118
+ return items[0]
119
+ else:
120
+ return items
121
+
122
+ def __iter__(self) -> Iterator:
123
+ for k in self._row.column_names:
124
+ yield k
125
+
126
+ def __len__(self):
127
+ return self._row.num_columns
128
+
129
+
130
+ class ArrowBlockBuilder(TableBlockBuilder):
131
+ def __init__(self):
132
+ if pyarrow is None:
133
+ raise ImportError("Run `pip install pyarrow` for Arrow support")
134
+ super().__init__((pyarrow.Table, bytes))
135
+
136
+ @staticmethod
137
+ def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:
138
+ pa_cols: Dict[str, pyarrow.Array] = dict()
139
+
140
+ for col_name, col_vals in columns.items():
141
+ np_col_vals = convert_to_numpy(col_vals)
142
+
143
+ pa_cols[col_name] = convert_to_pyarrow_array(np_col_vals, col_name)
144
+
145
+ return pyarrow_table_from_pydict(pa_cols)
146
+
147
+ @staticmethod
148
+ def _concat_tables(tables: List[Block]) -> Block:
149
+ return transform_pyarrow.concat(tables)
150
+
151
+ @staticmethod
152
+ def _concat_would_copy() -> bool:
153
+ return False
154
+
155
+ @staticmethod
156
+ def _empty_table() -> "pyarrow.Table":
157
+ return pyarrow_table_from_pydict({})
158
+
159
+ def block_type(self) -> BlockType:
160
+ return BlockType.ARROW
161
+
162
+
163
+ class ArrowBlockAccessor(TableBlockAccessor):
164
+ ROW_TYPE = ArrowRow
165
+
166
+ def __init__(self, table: "pyarrow.Table"):
167
+ if pyarrow is None:
168
+ raise ImportError("Run `pip install pyarrow` for Arrow support")
169
+ super().__init__(table)
170
+
171
+ def column_names(self) -> List[str]:
172
+ return self._table.column_names
173
+
174
+ def append_column(self, name: str, data: Any) -> Block:
175
+ assert name not in self._table.column_names
176
+
177
+ if any(isinstance(item, np.ndarray) for item in data):
178
+ raise NotImplementedError(
179
+ f"`{self.__class__.__name__}.append_column()` doesn't support "
180
+ "array-like data."
181
+ )
182
+
183
+ return self._table.append_column(name, [data])
184
+
185
+ @classmethod
186
+ def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":
187
+ reader = pyarrow.ipc.open_stream(data)
188
+ return cls(reader.read_all())
189
+
190
+ @staticmethod
191
+ def _build_tensor_row(
192
+ row: ArrowRow, col_name: str = TENSOR_COLUMN_NAME
193
+ ) -> np.ndarray:
194
+ from packaging.version import parse as parse_version
195
+
196
+ element = row[col_name][0]
197
+ # TODO(Clark): Reduce this to np.asarray(element) once we only support Arrow
198
+ # 9.0.0+.
199
+ pyarrow_version = _get_pyarrow_version()
200
+ if pyarrow_version is not None:
201
+ pyarrow_version = parse_version(pyarrow_version)
202
+ if pyarrow_version is None or pyarrow_version >= parse_version("8.0.0"):
203
+ assert isinstance(element, pyarrow.ExtensionScalar)
204
+ if pyarrow_version is None or pyarrow_version >= parse_version("9.0.0"):
205
+ # For Arrow 9.0.0+, accessing an element in a chunked tensor array
206
+ # produces an ArrowTensorScalar, which we convert to an ndarray using
207
+ # .as_py().
208
+ element = element.as_py()
209
+ else:
210
+ # For Arrow 8.*, accessing an element in a chunked tensor array produces
211
+ # an ExtensionScalar, which we convert to an ndarray using our custom
212
+ # method.
213
+ element = element.type._extension_scalar_to_ndarray(element)
214
+ # For Arrow < 8.0.0, accessing an element in a chunked tensor array produces an
215
+ # ndarray, which we return directly.
216
+ assert isinstance(element, np.ndarray), type(element)
217
+ return element
218
+
219
+ def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
220
+ view = self._table.slice(start, end - start)
221
+ if copy:
222
+ view = transform_pyarrow.combine_chunks(view)
223
+ return view
224
+
225
+ def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
226
+ # TODO(swang): Creating this np.array index can add a lot of memory
227
+ # pressure when there are a large number of small rows. Investigate
228
+ # random shuffling in place to reduce memory pressure.
229
+ # See https://github.com/ray-project/ray/issues/42146.
230
+ random = np.random.RandomState(random_seed)
231
+ return self.take(random.permutation(self.num_rows()))
232
+
233
+ def schema(self) -> "pyarrow.lib.Schema":
234
+ return self._table.schema
235
+
236
+ def to_pandas(self) -> "pandas.DataFrame":
237
+ from ray.air.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays
238
+
239
+ df = self._table.to_pandas()
240
+ ctx = DataContext.get_current()
241
+ if ctx.enable_tensor_extension_casting:
242
+ df = _cast_tensor_columns_to_ndarrays(df)
243
+ return df
244
+
245
+ def to_numpy(
246
+ self, columns: Optional[Union[str, List[str]]] = None
247
+ ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
248
+ if columns is None:
249
+ columns = self._table.column_names
250
+ should_be_single_ndarray = False
251
+ elif isinstance(columns, list):
252
+ should_be_single_ndarray = False
253
+ else:
254
+ columns = [columns]
255
+ should_be_single_ndarray = True
256
+
257
+ column_names_set = set(self._table.column_names)
258
+ for column in columns:
259
+ if column not in column_names_set:
260
+ raise ValueError(
261
+ f"Cannot find column {column}, available columns: "
262
+ f"{column_names_set}"
263
+ )
264
+
265
+ column_values_ndarrays = []
266
+
267
+ for col_name in columns:
268
+ col = self._table[col_name]
269
+
270
+ # Combine columnar values arrays to make these contiguous
271
+ # (making them compatible with numpy format)
272
+ combined_array = transform_pyarrow.combine_chunked_array(col)
273
+
274
+ column_values_ndarrays.append(
275
+ transform_pyarrow.to_numpy(combined_array, zero_copy_only=False)
276
+ )
277
+
278
+ if should_be_single_ndarray:
279
+ assert len(columns) == 1
280
+ return column_values_ndarrays[0]
281
+ else:
282
+ return dict(zip(columns, column_values_ndarrays))
283
+
284
+ def to_arrow(self) -> "pyarrow.Table":
285
+ return self._table
286
+
287
+ def num_rows(self) -> int:
288
+ # Arrow may represent an empty table via an N > 0 row, 0-column table, e.g. when
289
+ # slicing an empty table, so we return 0 if num_columns == 0.
290
+ return self._table.num_rows if self._table.num_columns > 0 else 0
291
+
292
+ def size_bytes(self) -> int:
293
+ return self._table.nbytes
294
+
295
+ def _zip(self, acc: BlockAccessor) -> "Block":
296
+ r = self.to_arrow()
297
+ s = acc.to_arrow()
298
+ for col_name in s.column_names:
299
+ col = s.column(col_name)
300
+ # Ensure the column names are unique after zip.
301
+ if col_name in r.column_names:
302
+ i = 1
303
+ new_name = col_name
304
+ while new_name in r.column_names:
305
+ new_name = "{}_{}".format(col_name, i)
306
+ i += 1
307
+ col_name = new_name
308
+ r = r.append_column(col_name, col)
309
+ return r
310
+
311
+ @staticmethod
312
+ def builder() -> ArrowBlockBuilder:
313
+ return ArrowBlockBuilder()
314
+
315
+ @staticmethod
316
+ def _empty_table() -> "pyarrow.Table":
317
+ return ArrowBlockBuilder._empty_table()
318
+
319
+ def take(
320
+ self,
321
+ indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
322
+ ) -> "pyarrow.Table":
323
+ """Select rows from the underlying table.
324
+
325
+ This method is an alternative to pyarrow.Table.take(), which breaks for
326
+ extension arrays.
327
+ """
328
+ return transform_pyarrow.take_table(self._table, indices)
329
+
330
+ def select(self, columns: List[str]) -> "pyarrow.Table":
331
+ if not all(isinstance(col, str) for col in columns):
332
+ raise ValueError(
333
+ "Columns must be a list of column name strings when aggregating on "
334
+ f"Arrow blocks, but got: {columns}."
335
+ )
336
+ return self._table.select(columns)
337
+
338
+ def rename_columns(self, columns_rename: Dict[str, str]) -> "pyarrow.Table":
339
+ return self._table.rename_columns(columns_rename)
340
+
341
+ def _sample(self, n_samples: int, sort_key: "SortKey") -> "pyarrow.Table":
342
+ indices = random.sample(range(self._table.num_rows), n_samples)
343
+ table = self._table.select(sort_key.get_columns())
344
+ return transform_pyarrow.take_table(table, indices)
345
+
346
+ def count(self, on: str) -> Optional[U]:
347
+ """Count the number of non-null values in the provided column."""
348
+ import pyarrow.compute as pac
349
+
350
+ if not isinstance(on, str):
351
+ raise ValueError(
352
+ "on must be a string when aggregating on Arrow blocks, but got:"
353
+ f"{type(on)}."
354
+ )
355
+
356
+ if self.num_rows() == 0:
357
+ return None
358
+
359
+ col = self._table[on]
360
+ return pac.count(col).as_py()
361
+
362
+ def _apply_arrow_compute(
363
+ self, compute_fn: Callable, on: str, ignore_nulls: bool
364
+ ) -> Optional[U]:
365
+ """Helper providing null handling around applying an aggregation to a column."""
366
+ import pyarrow as pa
367
+
368
+ if not isinstance(on, str):
369
+ raise ValueError(
370
+ "on must be a string when aggregating on Arrow blocks, but got:"
371
+ f"{type(on)}."
372
+ )
373
+
374
+ if self.num_rows() == 0:
375
+ return None
376
+
377
+ col = self._table[on]
378
+ if pa.types.is_null(col.type):
379
+ return None
380
+ else:
381
+ return compute_fn(col, skip_nulls=ignore_nulls).as_py()
382
+
383
+ def sum(self, on: str, ignore_nulls: bool) -> Optional[U]:
384
+ import pyarrow.compute as pac
385
+
386
+ return self._apply_arrow_compute(pac.sum, on, ignore_nulls)
387
+
388
+ def min(self, on: str, ignore_nulls: bool) -> Optional[U]:
389
+ import pyarrow.compute as pac
390
+
391
+ return self._apply_arrow_compute(pac.min, on, ignore_nulls)
392
+
393
+ def max(self, on: str, ignore_nulls: bool) -> Optional[U]:
394
+ import pyarrow.compute as pac
395
+
396
+ return self._apply_arrow_compute(pac.max, on, ignore_nulls)
397
+
398
+ def mean(self, on: str, ignore_nulls: bool) -> Optional[U]:
399
+ import pyarrow.compute as pac
400
+
401
+ return self._apply_arrow_compute(pac.mean, on, ignore_nulls)
402
+
403
+ def sum_of_squared_diffs_from_mean(
404
+ self,
405
+ on: str,
406
+ ignore_nulls: bool,
407
+ mean: Optional[U] = None,
408
+ ) -> Optional[U]:
409
+ import pyarrow.compute as pac
410
+
411
+ if mean is None:
412
+ # If precomputed mean not given, we compute it ourselves.
413
+ mean = self.mean(on, ignore_nulls)
414
+ if mean is None:
415
+ return None
416
+ return self._apply_arrow_compute(
417
+ lambda col, skip_nulls: pac.sum(
418
+ pac.power(pac.subtract(col, mean), 2),
419
+ skip_nulls=skip_nulls,
420
+ ),
421
+ on,
422
+ ignore_nulls,
423
+ )
424
+
425
+ def sort_and_partition(
426
+ self, boundaries: List[T], sort_key: "SortKey"
427
+ ) -> List["Block"]:
428
+ if self._table.num_rows == 0:
429
+ # If the pyarrow table is empty we may not have schema
430
+ # so calling sort_indices() will raise an error.
431
+ return [self._empty_table() for _ in range(len(boundaries) + 1)]
432
+
433
+ context = DataContext.get_current()
434
+ sort = get_sort_transform(context)
435
+
436
+ table = sort(self._table, sort_key)
437
+ if len(boundaries) == 0:
438
+ return [table]
439
+ return find_partitions(table, boundaries, sort_key)
440
+
441
+ def combine(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block:
442
+ """Combine rows with the same key into an accumulator.
443
+
444
+ This assumes the block is already sorted by key in ascending order.
445
+
446
+ Args:
447
+ sort_key: A column name or list of column names.
448
+ If this is ``None``, place all rows in a single group.
449
+
450
+ aggs: The aggregations to do.
451
+
452
+ Returns:
453
+ A sorted block of [k, v_1, ..., v_n] columns where k is the groupby
454
+ key and v_i is the partially combined accumulator for the ith given
455
+ aggregation.
456
+ If key is None then the k column is omitted.
457
+ """
458
+ keys: List[str] = sort_key.get_columns()
459
+
460
+ def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
461
+ """Creates an iterator over zero-copy group views."""
462
+ if not keys:
463
+ # Global aggregation consists of a single "group", so we short-circuit.
464
+ yield tuple(), self.to_block()
465
+ return
466
+
467
+ start = end = 0
468
+ iter = self.iter_rows(public_row_format=False)
469
+ next_row = None
470
+ while True:
471
+ try:
472
+ if next_row is None:
473
+ next_row = next(iter)
474
+ next_keys = next_row[keys]
475
+ while keys_equal(next_row[keys], next_keys):
476
+ end += 1
477
+ try:
478
+ next_row = next(iter)
479
+ except StopIteration:
480
+ next_row = None
481
+ break
482
+ yield next_keys, self.slice(start, end)
483
+ start = end
484
+ except StopIteration:
485
+ break
486
+
487
+ builder = ArrowBlockBuilder()
488
+ for group_keys, group_view in iter_groups():
489
+ # Aggregate.
490
+ init_vals = group_keys
491
+ if len(group_keys) == 1:
492
+ init_vals = group_keys[0]
493
+
494
+ accumulators = [agg.init(init_vals) for agg in aggs]
495
+ for i in range(len(aggs)):
496
+ accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)
497
+
498
+ # Build the row.
499
+ row = {}
500
+ if keys:
501
+ for k, gk in zip(keys, group_keys):
502
+ row[k] = gk
503
+
504
+ count = collections.defaultdict(int)
505
+ for agg, accumulator in zip(aggs, accumulators):
506
+ name = agg.name
507
+ # Check for conflicts with existing aggregation name.
508
+ if count[name] > 0:
509
+ name = self._munge_conflict(name, count[name])
510
+ count[name] += 1
511
+ row[name] = accumulator
512
+
513
+ builder.add(row)
514
+
515
+ return builder.build()
516
+
517
+ @staticmethod
518
+ def merge_sorted_blocks(
519
+ blocks: List[Block], sort_key: "SortKey"
520
+ ) -> Tuple[Block, BlockMetadata]:
521
+ stats = BlockExecStats.builder()
522
+ blocks = [b for b in blocks if b.num_rows > 0]
523
+ if len(blocks) == 0:
524
+ ret = ArrowBlockAccessor._empty_table()
525
+ else:
526
+ # Handle blocks of different types.
527
+ blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")
528
+ concat_and_sort = get_concat_and_sort_transform(DataContext.get_current())
529
+ ret = concat_and_sort(blocks, sort_key)
530
+ return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())
531
+
532
+ @staticmethod
533
+ def aggregate_combined_blocks(
534
+ blocks: List[Block],
535
+ sort_key: "SortKey",
536
+ aggs: Tuple["AggregateFn"],
537
+ finalize: bool,
538
+ ) -> Tuple[Block, BlockMetadata]:
539
+ """Aggregate sorted, partially combined blocks with the same key range.
540
+
541
+ This assumes blocks are already sorted by key in ascending order,
542
+ so we can do merge sort to get all the rows with the same key.
543
+
544
+ Args:
545
+ blocks: A list of partially combined and sorted blocks.
546
+ sort_key: The column name of key or None for global aggregation.
547
+ aggs: The aggregations to do.
548
+ finalize: Whether to finalize the aggregation. This is used as an
549
+ optimization for cases where we repeatedly combine partially
550
+ aggregated groups.
551
+
552
+ Returns:
553
+ A block of [k, v_1, ..., v_n] columns and its metadata where k is
554
+ the groupby key and v_i is the corresponding aggregation result for
555
+ the ith given aggregation.
556
+ If key is None then the k column is omitted.
557
+ """
558
+
559
+ stats = BlockExecStats.builder()
560
+ keys = sort_key.get_columns()
561
+
562
+ def key_fn(r):
563
+ if keys:
564
+ return tuple(r[keys])
565
+ else:
566
+ return (0,)
567
+
568
+ # Replace Nones with NULL_SENTINEL to ensure safe sorting.
569
+ def key_fn_with_null_sentinel(r):
570
+ values = key_fn(r)
571
+ return [NULL_SENTINEL if v is None else v for v in values]
572
+
573
+ # Handle blocks of different types.
574
+ blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")
575
+
576
+ iter = heapq.merge(
577
+ *[
578
+ ArrowBlockAccessor(block).iter_rows(public_row_format=False)
579
+ for block in blocks
580
+ ],
581
+ key=key_fn_with_null_sentinel,
582
+ )
583
+ next_row = None
584
+ builder = ArrowBlockBuilder()
585
+ while True:
586
+ try:
587
+ if next_row is None:
588
+ next_row = next(iter)
589
+ next_keys = key_fn(next_row)
590
+ next_key_columns = keys
591
+
592
+ def gen():
593
+ nonlocal iter
594
+ nonlocal next_row
595
+ while keys_equal(key_fn(next_row), next_keys):
596
+ yield next_row
597
+ try:
598
+ next_row = next(iter)
599
+ except StopIteration:
600
+ next_row = None
601
+ break
602
+
603
+ # Merge.
604
+ first = True
605
+ accumulators = [None] * len(aggs)
606
+ resolved_agg_names = [None] * len(aggs)
607
+ for r in gen():
608
+ if first:
609
+ count = collections.defaultdict(int)
610
+ for i in range(len(aggs)):
611
+ name = aggs[i].name
612
+ # Check for conflicts with existing aggregation
613
+ # name.
614
+ if count[name] > 0:
615
+ name = ArrowBlockAccessor._munge_conflict(
616
+ name, count[name]
617
+ )
618
+ count[name] += 1
619
+ resolved_agg_names[i] = name
620
+ accumulators[i] = r[name]
621
+ first = False
622
+ else:
623
+ for i in range(len(aggs)):
624
+ accumulators[i] = aggs[i].merge(
625
+ accumulators[i], r[resolved_agg_names[i]]
626
+ )
627
+ # Build the row.
628
+ row = {}
629
+ if keys:
630
+ for col_name, next_key in zip(next_key_columns, next_keys):
631
+ row[col_name] = next_key
632
+
633
+ for agg, agg_name, accumulator in zip(
634
+ aggs, resolved_agg_names, accumulators
635
+ ):
636
+ if finalize:
637
+ row[agg_name] = agg.finalize(accumulator)
638
+ else:
639
+ row[agg_name] = accumulator
640
+
641
+ builder.add(row)
642
+ except StopIteration:
643
+ break
644
+
645
+ ret = builder.build()
646
+ return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())
647
+
648
+ def block_type(self) -> BlockType:
649
+ return BlockType.ARROW
.venv/lib/python3.11/site-packages/ray/data/_internal/batcher.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from ray.data._internal.arrow_block import ArrowBlockAccessor
4
+ from ray.data._internal.arrow_ops import transform_pyarrow
5
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
6
+ from ray.data.block import Block, BlockAccessor
7
+
8
+ # pyarrow.Table.slice is slow when the table has many chunks
9
+ # so we combine chunks into a single one to make slice faster
10
+ # with the cost of an extra copy.
11
+ # See https://github.com/ray-project/ray/issues/31108 for more details.
12
+ # TODO(jjyao): remove this once
13
+ # https://github.com/apache/arrow/issues/35126 is resolved.
14
+ MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS = 10
15
+
16
+ # Delay compaction until the shuffle buffer has reached this ratio over the min
17
+ # shuffle buffer size. Setting this to 1 minimizes memory usage, at the cost of
18
+ # frequent compactions. Setting this to higher values increases memory usage but
19
+ # reduces compaction frequency.
20
+ SHUFFLE_BUFFER_COMPACTION_RATIO = 1.5
21
+
22
+
23
+ class BatcherInterface:
24
+ def add(self, block: Block):
25
+ """Add a block to the block buffer.
26
+
27
+ Args:
28
+ block: Block to add to the block buffer.
29
+ """
30
+ raise NotImplementedError()
31
+
32
+ def done_adding(self) -> bool:
33
+ """Indicate to the batcher that no more blocks will be added to the buffer."""
34
+ raise NotImplementedError()
35
+
36
+ def has_batch(self) -> bool:
37
+ """Whether this Batcher has any full batches."""
38
+ raise NotImplementedError()
39
+
40
+ def has_any(self) -> bool:
41
+ """Whether this Batcher has any data."""
42
+ raise NotImplementedError()
43
+
44
+ def next_batch(self) -> Block:
45
+ """Get the next batch from the block buffer.
46
+
47
+ Returns:
48
+ A batch represented as a Block.
49
+ """
50
+ raise NotImplementedError()
51
+
52
+
53
+ class Batcher(BatcherInterface):
54
+ """Chunks blocks into batches."""
55
+
56
+ # Implementation Note: When there are multiple batches per block, this batcher will
57
+ # slice off and return each batch and add the remaining block back to the buffer
58
+ # instead of optimally slicing and returning all batches from the block at once.
59
+ # This will result in extra (and nested) block slicing. However, since slices are
60
+ # zero-copy views, we sacrifice what should be a small performance hit for better
61
+ # readability.
62
+
63
+ def __init__(self, batch_size: Optional[int], ensure_copy: bool = False):
64
+ """
65
+ Construct a batcher that yields batches of batch_sizes rows.
66
+
67
+ Args:
68
+ batch_size: The size of batches to yield.
69
+ ensure_copy: Whether batches are always copied from the underlying base
70
+ blocks (not zero-copy views).
71
+ """
72
+ self._batch_size = batch_size
73
+ self._buffer = []
74
+ self._buffer_size = 0
75
+ self._done_adding = False
76
+ self._ensure_copy = ensure_copy
77
+
78
+ def add(self, block: Block):
79
+ """Add a block to the block buffer.
80
+
81
+ Note empty block is not added to buffer.
82
+
83
+ Args:
84
+ block: Block to add to the block buffer.
85
+ """
86
+ if BlockAccessor.for_block(block).num_rows() > 0:
87
+ self._buffer.append(block)
88
+ self._buffer_size += BlockAccessor.for_block(block).num_rows()
89
+
90
+ def done_adding(self) -> bool:
91
+ """Indicate to the batcher that no more blocks will be added to the batcher."""
92
+ self._done_adding = True
93
+
94
+ def has_batch(self) -> bool:
95
+ """Whether this Batcher has any full batches."""
96
+ return self.has_any() and (
97
+ self._batch_size is None or self._buffer_size >= self._batch_size
98
+ )
99
+
100
+ def has_any(self) -> bool:
101
+ """Whether this Batcher has any data."""
102
+ return self._buffer_size > 0
103
+
104
+ def next_batch(self) -> Block:
105
+ """Get the next batch from the block buffer.
106
+
107
+ Returns:
108
+ A batch represented as a Block.
109
+ """
110
+ assert self.has_batch() or (self._done_adding and self.has_any())
111
+ needs_copy = self._ensure_copy
112
+ # If no batch size, short-circuit.
113
+ if self._batch_size is None:
114
+ assert len(self._buffer) == 1
115
+ block = self._buffer[0]
116
+ if needs_copy:
117
+ # Copy block if needing to ensure fresh batch copy.
118
+ block = BlockAccessor.for_block(block)
119
+ block = block.slice(0, block.num_rows(), copy=True)
120
+ self._buffer = []
121
+ self._buffer_size = 0
122
+ return block
123
+ output = DelegatingBlockBuilder()
124
+ leftover = []
125
+ needed = self._batch_size
126
+ for block in self._buffer:
127
+ accessor = BlockAccessor.for_block(block)
128
+ if needed <= 0:
129
+ # We already have a full batch, so add this block to
130
+ # the leftovers.
131
+ leftover.append(block)
132
+ elif accessor.num_rows() <= needed:
133
+ output.add_block(accessor.to_block())
134
+ needed -= accessor.num_rows()
135
+ else:
136
+ if (
137
+ isinstance(accessor, ArrowBlockAccessor)
138
+ and block.num_columns > 0
139
+ and block.column(0).num_chunks
140
+ >= MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS
141
+ ):
142
+ accessor = BlockAccessor.for_block(
143
+ transform_pyarrow.combine_chunks(block)
144
+ )
145
+ # We only need part of the block to fill out a batch.
146
+ output.add_block(accessor.slice(0, needed, copy=False))
147
+ # Add the rest of the block to the leftovers.
148
+ leftover.append(accessor.slice(needed, accessor.num_rows(), copy=False))
149
+ needed = 0
150
+
151
+ # Move the leftovers into the block buffer so they're the first
152
+ # blocks consumed on the next batch extraction.
153
+ self._buffer = leftover
154
+ self._buffer_size -= self._batch_size
155
+ needs_copy = needs_copy and not output.will_build_yield_copy()
156
+ batch = output.build()
157
+ if needs_copy:
158
+ # Need to ensure that the batch is a fresh copy.
159
+ batch = BlockAccessor.for_block(batch)
160
+ batch = batch.slice(0, batch.num_rows(), copy=True)
161
+ return batch
162
+
163
+
164
+ class ShufflingBatcher(BatcherInterface):
165
+ """Chunks blocks into shuffled batches, using a local in-memory shuffle buffer."""
166
+
167
+ # Implementation Note:
168
+ #
169
+ # This shuffling batcher lazily builds a shuffle buffer from added blocks, and once
170
+ # a batch is requested via .next_batch(), it concatenates the blocks into a concrete
171
+ # shuffle buffer and randomly shuffles the entire buffer.
172
+ #
173
+ # Adding of more blocks can be intermixed with retrieving batches, but it should be
174
+ # noted that we can end up performing two expensive operations on each retrieval:
175
+ # 1. Build added blocks into a concrete shuffle buffer.
176
+ # 2. Shuffling the entire buffer.
177
+ # To amortize the overhead of this process, we only shuffle the blocks after a
178
+ # delay designated by SHUFFLE_BUFFER_COMPACTION_RATIO.
179
+ #
180
+ # Similarly, adding blocks is very cheap. Each added block will be appended to a
181
+ # list, with concatenation of the underlying data delayed until the next batch
182
+ # compaction.
183
+
184
+ def __init__(
185
+ self,
186
+ batch_size: Optional[int],
187
+ shuffle_buffer_min_size: int,
188
+ shuffle_seed: Optional[int] = None,
189
+ ):
190
+ """Constructs a random-shuffling block batcher.
191
+
192
+ Args:
193
+ batch_size: Record batch size.
194
+ shuffle_buffer_min_size: Minimum number of rows that must be in the local
195
+ in-memory shuffle buffer in order to yield a batch. When there are no
196
+ more rows to be added to the buffer, the number of rows in the buffer
197
+ *will* decrease below this value while yielding the remaining batches,
198
+ and the final batch may have less than ``batch_size`` rows. Increasing
199
+ this will improve the randomness of the shuffle but may increase the
200
+ latency to the first batch.
201
+ shuffle_seed: The seed to use for the local random shuffle.
202
+ """
203
+ if batch_size is None:
204
+ raise ValueError("Must specify a batch_size if using a local shuffle.")
205
+ self._batch_size = batch_size
206
+ self._shuffle_seed = shuffle_seed
207
+ if shuffle_buffer_min_size < batch_size:
208
+ # Round it up internally to `batch_size` since our algorithm requires it.
209
+ # This is harmless since it only offers extra randomization.
210
+ shuffle_buffer_min_size = batch_size
211
+ self._buffer_min_size = shuffle_buffer_min_size
212
+ self._builder = DelegatingBlockBuilder()
213
+ self._shuffle_buffer: Block = None
214
+ self._batch_head = 0
215
+ self._done_adding = False
216
+
217
+ def add(self, block: Block):
218
+ """Add a block to the shuffle buffer.
219
+
220
+ Note empty block is not added to buffer.
221
+
222
+ Args:
223
+ block: Block to add to the shuffle buffer.
224
+ """
225
+ if BlockAccessor.for_block(block).num_rows() > 0:
226
+ self._builder.add_block(block)
227
+
228
+ def done_adding(self) -> bool:
229
+ """Indicate to the batcher that no more blocks will be added to the batcher.
230
+
231
+ No more blocks should be added to the batcher after calling this.
232
+ """
233
+ self._done_adding = True
234
+
235
+ def has_any(self) -> bool:
236
+ """Whether this batcher has any data."""
237
+ return self._buffer_size() > 0
238
+
239
+ def has_batch(self) -> bool:
240
+ """Whether this batcher has any batches."""
241
+ buffer_size = self._buffer_size()
242
+
243
+ if not self._done_adding:
244
+ # Delay pulling of batches until the buffer is large enough in order to
245
+ # amortize compaction overhead.
246
+ return self._materialized_buffer_size() >= self._buffer_min_size or (
247
+ buffer_size - self._batch_size
248
+ >= self._buffer_min_size * SHUFFLE_BUFFER_COMPACTION_RATIO
249
+ )
250
+ else:
251
+ return buffer_size >= self._batch_size
252
+
253
+ def _buffer_size(self) -> int:
254
+ """Return shuffle buffer size."""
255
+ buffer_size = self._builder.num_rows()
256
+ buffer_size += self._materialized_buffer_size()
257
+ return buffer_size
258
+
259
+ def _materialized_buffer_size(self) -> int:
260
+ """Return materialized (compacted portion of) shuffle buffer size."""
261
+ if self._shuffle_buffer is None:
262
+ return 0
263
+ # The size of the concrete (materialized) shuffle buffer, adjusting
264
+ # for the batch head position, which also serves as a counter of the number
265
+ # of already-yielded rows from the current concrete shuffle buffer.
266
+ return max(
267
+ 0,
268
+ BlockAccessor.for_block(self._shuffle_buffer).num_rows() - self._batch_head,
269
+ )
270
+
271
+ def next_batch(self) -> Block:
272
+ """Get the next shuffled batch from the shuffle buffer.
273
+
274
+ Returns:
275
+ A batch represented as a Block.
276
+ """
277
+ assert self.has_batch() or (self._done_adding and self.has_any())
278
+ # Add rows in the builder to the shuffle buffer. Note that we delay compaction
279
+ # as much as possible to amortize the concatenation overhead. Compaction is
280
+ # only necessary when the materialized buffer size falls below the min size.
281
+ if self._builder.num_rows() > 0 and (
282
+ self._done_adding
283
+ or self._materialized_buffer_size() <= self._buffer_min_size
284
+ ):
285
+ if self._shuffle_buffer is not None:
286
+ if self._batch_head > 0:
287
+ # Compact the materialized shuffle buffer.
288
+ block = BlockAccessor.for_block(self._shuffle_buffer)
289
+ self._shuffle_buffer = block.slice(
290
+ self._batch_head, block.num_rows()
291
+ )
292
+ # Add the unyielded rows from the existing shuffle buffer.
293
+ self._builder.add_block(self._shuffle_buffer)
294
+ # Build the new shuffle buffer.
295
+ self._shuffle_buffer = self._builder.build()
296
+ self._shuffle_buffer = BlockAccessor.for_block(
297
+ self._shuffle_buffer
298
+ ).random_shuffle(self._shuffle_seed)
299
+ if self._shuffle_seed is not None:
300
+ self._shuffle_seed += 1
301
+ if (
302
+ isinstance(
303
+ BlockAccessor.for_block(self._shuffle_buffer), ArrowBlockAccessor
304
+ )
305
+ and self._shuffle_buffer.num_columns > 0
306
+ and self._shuffle_buffer.column(0).num_chunks
307
+ >= MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS
308
+ ):
309
+ self._shuffle_buffer = transform_pyarrow.combine_chunks(
310
+ self._shuffle_buffer
311
+ )
312
+ # Reset the builder.
313
+ self._builder = DelegatingBlockBuilder()
314
+ self._batch_head = 0
315
+
316
+ assert self._shuffle_buffer is not None
317
+ buffer_size = BlockAccessor.for_block(self._shuffle_buffer).num_rows()
318
+ # Truncate the batch to the buffer size, if necessary.
319
+ batch_size = min(self._batch_size, buffer_size)
320
+ slice_start = self._batch_head
321
+ self._batch_head += batch_size
322
+ # Yield the shuffled batch.
323
+ return BlockAccessor.for_block(self._shuffle_buffer).slice(
324
+ slice_start, self._batch_head
325
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/block_batching.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ from typing import Callable, Iterator, Optional, TypeVar
3
+
4
+ from ray.data._internal.block_batching.util import (
5
+ blocks_to_batches,
6
+ collate,
7
+ extract_data_from_batch,
8
+ format_batches,
9
+ )
10
+ from ray.data._internal.stats import DatasetStats
11
+ from ray.data.block import Block, DataBatch
12
+
13
+ T = TypeVar("T")
14
+
15
+
16
+ def batch_blocks(
17
+ blocks: Iterator[Block],
18
+ *,
19
+ stats: Optional[DatasetStats] = None,
20
+ batch_size: Optional[int] = None,
21
+ batch_format: str = "default",
22
+ drop_last: bool = False,
23
+ collate_fn: Optional[Callable[[DataBatch], DataBatch]] = None,
24
+ shuffle_buffer_min_size: Optional[int] = None,
25
+ shuffle_seed: Optional[int] = None,
26
+ ensure_copy: bool = False,
27
+ ) -> Iterator[DataBatch]:
28
+ """Create formatted batches of data from 1 or more blocks.
29
+
30
+ This function takes in an iterator of already fetched blocks. Consequently, this
31
+ function doesn't support block prefetching.
32
+ """
33
+
34
+ def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]:
35
+ batch_iter = format_batches(
36
+ blocks_to_batches(
37
+ block_iter=base_iterator,
38
+ stats=stats,
39
+ batch_size=batch_size,
40
+ drop_last=drop_last,
41
+ shuffle_buffer_min_size=shuffle_buffer_min_size,
42
+ shuffle_seed=shuffle_seed,
43
+ ensure_copy=ensure_copy,
44
+ ),
45
+ batch_format=batch_format,
46
+ stats=stats,
47
+ )
48
+
49
+ if collate_fn is not None:
50
+ batch_iter = collate(batch_iter, collate_fn=collate_fn, stats=stats)
51
+
52
+ batch_iter = extract_data_from_batch(batch_iter)
53
+ yield from batch_iter
54
+
55
+ batch_iter = _iterator_fn(blocks)
56
+
57
+ for formatted_batch in batch_iter:
58
+ user_timer = stats.iter_user_s.timer() if stats else nullcontext()
59
+ with user_timer:
60
+ yield formatted_batch
.venv/lib/python3.11/site-packages/ray/data/_internal/block_batching/interfaces.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import dataclass
3
+ from typing import Any, List
4
+
5
+ from ray.data.block import Block, DataBatch
6
+ from ray.types import ObjectRef
7
+
8
+
9
+ @dataclass
10
+ class Batch:
11
+ """A batch of data with a corresponding index.
12
+
13
+ Attributes:
14
+ batch_idx: The global index of this batch so that downstream operations can
15
+ maintain ordering.
16
+ data: The batch of data.
17
+ """
18
+
19
+ batch_idx: int
20
+ data: DataBatch
21
+
22
+
23
+ class CollatedBatch(Batch):
24
+ """A batch of collated data with a corresponding index.
25
+
26
+ Attributes:
27
+ batch_idx: The global index of this batch so that downstream operations can
28
+ maintain ordering.
29
+ data: The batch of data which is the output of a user provided collate_fn
30
+ Therefore, the type of this data can be Any.
31
+ """
32
+
33
+ batch_idx: int
34
+ data: Any
35
+
36
+
37
+ class BlockPrefetcher(metaclass=abc.ABCMeta):
38
+ """Interface for prefetching blocks."""
39
+
40
+ @abc.abstractmethod
41
+ def prefetch_blocks(self, blocks: List[ObjectRef[Block]]):
42
+ """Prefetch the provided blocks to this node."""
43
+ pass
44
+
45
+ def stop(self):
46
+ """Stop prefetching and release resources."""
47
+ pass
.venv/lib/python3.11/site-packages/ray/data/_internal/block_builder.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generic
2
+
3
+ from ray.data.block import Block, BlockAccessor, BlockType, T
4
+
5
+
6
+ class BlockBuilder(Generic[T]):
7
+ """A builder class for blocks."""
8
+
9
+ @staticmethod
10
+ def for_block(block: Block) -> "BlockBuilder":
11
+ return BlockAccessor.for_block(block).builder()
12
+
13
+ def add(self, item: T) -> None:
14
+ """Append a single row to the block being built."""
15
+ raise NotImplementedError
16
+
17
+ def add_block(self, block: Block) -> None:
18
+ """Append an entire block to the block being built."""
19
+ raise NotImplementedError
20
+
21
+ def will_build_yield_copy(self) -> bool:
22
+ """Whether building this block will yield a new block copy."""
23
+ raise NotImplementedError
24
+
25
+ def build(self) -> Block:
26
+ """Build the block."""
27
+ raise NotImplementedError
28
+
29
+ def num_rows(self) -> int:
30
+ """Return the number of rows added in the block."""
31
+ raise NotImplementedError
32
+
33
+ def get_estimated_memory_usage(self) -> int:
34
+ """Return the estimated memory usage so far in bytes."""
35
+ raise NotImplementedError
36
+
37
+ def block_type(self) -> BlockType:
38
+ """Return the block type."""
39
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/ray/data/_internal/block_list.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator, List, Tuple
2
+
3
+ from ray.data._internal.memory_tracing import trace_allocation
4
+ from ray.data.block import Block, BlockMetadata
5
+ from ray.types import ObjectRef
6
+
7
+
8
+ class BlockList:
9
+ """A list of blocks that may be computed or pending computation.
10
+
11
+ All blocks are known ahead of time
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ blocks: List[ObjectRef[Block]],
17
+ metadata: List[BlockMetadata],
18
+ *,
19
+ owned_by_consumer: bool,
20
+ ):
21
+ assert len(blocks) == len(metadata), (blocks, metadata)
22
+ for b in blocks:
23
+ trace_allocation(b, "BlockList.__init__")
24
+ self._blocks: List[ObjectRef[Block]] = blocks
25
+ self._num_blocks = len(self._blocks)
26
+ self._metadata: List[BlockMetadata] = metadata
27
+ # Whether the block list is owned by consuming APIs, and if so it can be
28
+ # eagerly deleted after read by the consumer.
29
+ self._owned_by_consumer = owned_by_consumer
30
+ # This field can be set to indicate the number of estimated output blocks,
31
+ # since each read task may produce multiple output blocks after splitting.
32
+ self._estimated_num_blocks = None
33
+
34
+ def __repr__(self):
35
+ return f"BlockList(owned_by_consumer={self._owned_by_consumer})"
36
+
37
+ def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
38
+ """Get the metadata for all blocks."""
39
+ return self._metadata.copy()
40
+
41
+ def copy(self) -> "BlockList":
42
+ """Perform a shallow copy of this BlockList."""
43
+ return BlockList(
44
+ self._blocks, self._metadata, owned_by_consumer=self._owned_by_consumer
45
+ )
46
+
47
+ def clear(self) -> None:
48
+ """Erase references to the tasks tracked by the BlockList."""
49
+ self._blocks = None
50
+
51
+ def is_cleared(self) -> bool:
52
+ """Whether this BlockList has been cleared."""
53
+ return self._blocks is None
54
+
55
+ def _check_if_cleared(self) -> None:
56
+ """Raise an error if this BlockList has been previously cleared."""
57
+ if self.is_cleared():
58
+ raise ValueError(
59
+ "This Dataset's blocks have been moved, which means that you "
60
+ "can no longer use this Dataset."
61
+ )
62
+
63
+ def get_blocks(self) -> List[ObjectRef[Block]]:
64
+ """Get list of the blocks of this block list.
65
+
66
+ This blocks on the execution of the tasks generating block outputs.
67
+ The length of this iterator is not known until execution.
68
+ """
69
+ self._check_if_cleared()
70
+ return list(self._blocks)
71
+
72
+ def get_blocks_with_metadata(self) -> List[Tuple[ObjectRef[Block], BlockMetadata]]:
73
+ """Bulk version of iter_blocks_with_metadata().
74
+
75
+ Prefer calling this instead of the iter form for performance if you
76
+ don't need lazy evaluation.
77
+ """
78
+ self.get_blocks()
79
+ return list(self.iter_blocks_with_metadata())
80
+
81
+ def iter_blocks_with_metadata(
82
+ self,
83
+ ) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
84
+ """Iterate over the blocks along with their runtime metadata.
85
+
86
+ This blocks on the execution of the tasks generating block outputs.
87
+ The length of this iterator is not known until execution.
88
+ """
89
+ self._check_if_cleared()
90
+ return zip(self._blocks, self._metadata)
91
+
92
+ def initial_num_blocks(self) -> int:
93
+ """Returns the number of blocks of this BlockList."""
94
+ return self._num_blocks
95
+
96
+ def estimated_num_blocks(self) -> int:
97
+ """Estimate of number of output blocks, without triggering actual execution."""
98
+ return self._estimated_num_blocks or self._num_blocks
.venv/lib/python3.11/site-packages/ray/data/_internal/compute.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Callable, Iterable, Optional, TypeVar, Union
3
+
4
+ from ray.data._internal.execution.interfaces import TaskContext
5
+ from ray.data.block import Block, UserDefinedFunction
6
+ from ray.util.annotations import DeveloperAPI
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ T = TypeVar("T")
11
+ U = TypeVar("U")
12
+
13
+
14
+ # Block transform function applied by task and actor pools.
15
+ BlockTransform = Union[
16
+ # TODO(Clark): Once Ray only supports Python 3.8+, use protocol to constrain block
17
+ # transform type.
18
+ # Callable[[Block, ...], Iterable[Block]]
19
+ # Callable[[Block, UserDefinedFunction, ...], Iterable[Block]],
20
+ Callable[[Iterable[Block], TaskContext], Iterable[Block]],
21
+ Callable[[Iterable[Block], TaskContext, UserDefinedFunction], Iterable[Block]],
22
+ Callable[..., Iterable[Block]],
23
+ ]
24
+
25
+
26
+ @DeveloperAPI
27
+ class ComputeStrategy:
28
+ pass
29
+
30
+
31
+ @DeveloperAPI
32
+ class TaskPoolStrategy(ComputeStrategy):
33
+ def __init__(
34
+ self,
35
+ size: Optional[int] = None,
36
+ ):
37
+ """Construct TaskPoolStrategy for a Dataset transform.
38
+
39
+ Args:
40
+ size: Specify the maximum size of the task pool.
41
+ """
42
+
43
+ if size is not None and size < 1:
44
+ raise ValueError("`size` must be >= 1", size)
45
+ self.size = size
46
+
47
+ def __eq__(self, other: Any) -> bool:
48
+ return (isinstance(other, TaskPoolStrategy) and self.size == other.size) or (
49
+ other == "tasks" and self.size is None
50
+ )
51
+
52
+
53
+ class ActorPoolStrategy(ComputeStrategy):
54
+ """Specify the compute strategy for a Dataset transform.
55
+
56
+ ActorPoolStrategy specifies that an autoscaling pool of actors should be used
57
+ for a given Dataset transform. This is useful for stateful setup of callable
58
+ classes.
59
+
60
+ For a fixed-sized pool of size ``n``, specify ``compute=ActorPoolStrategy(size=n)``.
61
+ To autoscale from ``m`` to ``n`` actors, specify
62
+ ``ActorPoolStrategy(min_size=m, max_size=n)``.
63
+
64
+ To increase opportunities for pipelining task dependency prefetching with
65
+ computation and avoiding actor startup delays, set max_tasks_in_flight_per_actor
66
+ to 2 or greater; to try to decrease the delay due to queueing of tasks on the worker
67
+ actors, set max_tasks_in_flight_per_actor to 1.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ *,
73
+ size: Optional[int] = None,
74
+ min_size: Optional[int] = None,
75
+ max_size: Optional[int] = None,
76
+ max_tasks_in_flight_per_actor: Optional[int] = None,
77
+ ):
78
+ """Construct ActorPoolStrategy for a Dataset transform.
79
+
80
+ Args:
81
+ size: Specify a fixed size actor pool of this size. It is an error to
82
+ specify both `size` and `min_size` or `max_size`.
83
+ min_size: The minimize size of the actor pool.
84
+ max_size: The maximum size of the actor pool.
85
+ max_tasks_in_flight_per_actor: The maximum number of tasks to concurrently
86
+ send to a single actor worker. Increasing this will increase
87
+ opportunities for pipelining task dependency prefetching with
88
+ computation and avoiding actor startup delays, but will also increase
89
+ queueing delay.
90
+ """
91
+ if size is not None:
92
+ if size < 1:
93
+ raise ValueError("size must be >= 1", size)
94
+ if max_size is not None or min_size is not None:
95
+ raise ValueError(
96
+ "min_size and max_size cannot be set at the same time as `size`"
97
+ )
98
+ min_size = size
99
+ max_size = size
100
+ if min_size is not None and min_size < 1:
101
+ raise ValueError("min_size must be >= 1", min_size)
102
+ if max_size is not None:
103
+ if min_size is None:
104
+ min_size = 1 # Legacy default.
105
+ if min_size > max_size:
106
+ raise ValueError("min_size must be <= max_size", min_size, max_size)
107
+ if (
108
+ max_tasks_in_flight_per_actor is not None
109
+ and max_tasks_in_flight_per_actor < 1
110
+ ):
111
+ raise ValueError(
112
+ "max_tasks_in_flight_per_actor must be >= 1, got: ",
113
+ max_tasks_in_flight_per_actor,
114
+ )
115
+ self.min_size = min_size or 1
116
+ self.max_size = max_size or float("inf")
117
+ self.max_tasks_in_flight_per_actor = max_tasks_in_flight_per_actor
118
+ self.num_workers = 0
119
+ self.ready_to_total_workers_ratio = 0.8
120
+
121
+ def __eq__(self, other: Any) -> bool:
122
+ return isinstance(other, ActorPoolStrategy) and (
123
+ self.min_size == other.min_size
124
+ and self.max_size == other.max_size
125
+ and self.max_tasks_in_flight_per_actor
126
+ == other.max_tasks_in_flight_per_actor
127
+ )
128
+
129
+
130
+ def get_compute(compute_spec: Union[str, ComputeStrategy]) -> ComputeStrategy:
131
+ if not isinstance(compute_spec, (TaskPoolStrategy, ActorPoolStrategy)):
132
+ raise ValueError(
133
+ "In Ray 2.5, the compute spec must be either "
134
+ f"TaskPoolStrategy or ActorPoolStrategy, was: {compute_spec}."
135
+ )
136
+ elif not compute_spec or compute_spec == "tasks":
137
+ return TaskPoolStrategy()
138
+ elif compute_spec == "actors":
139
+ return ActorPoolStrategy()
140
+ elif isinstance(compute_spec, ComputeStrategy):
141
+ return compute_spec
142
+ else:
143
+ raise ValueError("compute must be one of [`tasks`, `actors`, ComputeStrategy]")
.venv/lib/python3.11/site-packages/ray/data/_internal/delegating_block_builder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from typing import Any, Mapping, Optional
3
+
4
+ from ray.data._internal.arrow_block import ArrowBlockBuilder
5
+ from ray.data._internal.block_builder import BlockBuilder
6
+ from ray.data.block import Block, BlockAccessor, BlockType, DataBatch
7
+
8
+
9
+ class DelegatingBlockBuilder(BlockBuilder):
10
+ def __init__(self):
11
+ self._builder = None
12
+ self._empty_block = None
13
+
14
+ @property
15
+ def _inferred_block_type(self) -> Optional[BlockType]:
16
+ """The block type inferred from the first item added to the builder."""
17
+ if self._builder is not None:
18
+ return self._builder.block_type()
19
+ return None
20
+
21
+ def add(self, item: Mapping[str, Any]) -> None:
22
+ assert isinstance(item, collections.abc.Mapping), item
23
+
24
+ if self._builder is None:
25
+ self._builder = ArrowBlockBuilder()
26
+
27
+ self._builder.add(item)
28
+
29
+ def add_batch(self, batch: DataBatch):
30
+ """Add a user-facing data batch to the builder.
31
+
32
+ This data batch will be converted to an internal block and then added to the
33
+ underlying builder.
34
+ """
35
+ block = BlockAccessor.batch_to_block(batch, self._inferred_block_type)
36
+ return self.add_block(block)
37
+
38
+ def add_block(self, block: Block):
39
+ accessor = BlockAccessor.for_block(block)
40
+ if accessor.num_rows() == 0:
41
+ # Don't infer types of empty lists. Store the block and use it if no
42
+ # other data is added. https://github.com/ray-project/ray/issues/20290
43
+ self._empty_block = block
44
+ return
45
+ if self._builder is None:
46
+ self._builder = accessor.builder()
47
+ else:
48
+ block_type = accessor.block_type()
49
+ assert block_type == self._inferred_block_type, (
50
+ block_type,
51
+ self._inferred_block_type,
52
+ )
53
+
54
+ self._builder.add_block(accessor.to_block())
55
+
56
+ def will_build_yield_copy(self) -> bool:
57
+ if self._builder is None:
58
+ return True
59
+ return self._builder.will_build_yield_copy()
60
+
61
+ def build(self) -> Block:
62
+ if self._builder is None:
63
+ if self._empty_block is not None:
64
+ self._builder = BlockAccessor.for_block(self._empty_block).builder()
65
+ self._builder.add_block(self._empty_block)
66
+ else:
67
+ self._builder = ArrowBlockBuilder()
68
+ return self._builder.build()
69
+
70
+ def num_rows(self) -> int:
71
+ return self._builder.num_rows() if self._builder is not None else 0
72
+
73
+ def get_estimated_memory_usage(self) -> int:
74
+ if self._builder is None:
75
+ return 0
76
+ return self._builder.get_estimated_memory_usage()
.venv/lib/python3.11/site-packages/ray/data/_internal/equalize.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ from ray.data._internal.execution.interfaces import RefBundle
4
+ from ray.data._internal.split import _calculate_blocks_rows, _split_at_indices
5
+ from ray.data.block import Block, BlockMetadata, BlockPartition
6
+ from ray.types import ObjectRef
7
+
8
+
9
+ def _equalize(
10
+ per_split_bundles: List[RefBundle],
11
+ owned_by_consumer: bool,
12
+ ) -> List[RefBundle]:
13
+ """Equalize split ref bundles into equal number of rows.
14
+
15
+ Args:
16
+ per_split_bundles: ref bundles to equalize.
17
+ Returns:
18
+ the equalized ref bundles.
19
+ """
20
+ if len(per_split_bundles) == 0:
21
+ return per_split_bundles
22
+ per_split_blocks_with_metadata = [bundle.blocks for bundle in per_split_bundles]
23
+ per_split_num_rows: List[List[int]] = [
24
+ _calculate_blocks_rows(split) for split in per_split_blocks_with_metadata
25
+ ]
26
+ total_rows = sum([sum(blocks_rows) for blocks_rows in per_split_num_rows])
27
+ target_split_size = total_rows // len(per_split_blocks_with_metadata)
28
+
29
+ # phase 1: shave the current splits by dropping blocks (into leftovers)
30
+ # and calculate num rows needed to the meet target.
31
+ shaved_splits, per_split_needed_rows, leftovers = _shave_all_splits(
32
+ per_split_blocks_with_metadata, per_split_num_rows, target_split_size
33
+ )
34
+
35
+ # validate invariants
36
+ for shaved_split, split_needed_row in zip(shaved_splits, per_split_needed_rows):
37
+ num_shaved_rows = sum([meta.num_rows for _, meta in shaved_split])
38
+ assert num_shaved_rows <= target_split_size
39
+ assert num_shaved_rows + split_needed_row == target_split_size
40
+
41
+ # phase 2: based on the num rows needed for each shaved split, split the leftovers
42
+ # in the shape that exactly matches the rows needed.
43
+ leftover_bundle = RefBundle(leftovers, owns_blocks=owned_by_consumer)
44
+ leftover_splits = _split_leftovers(leftover_bundle, per_split_needed_rows)
45
+
46
+ # phase 3: merge the shaved_splits and leftoever splits and return.
47
+ for i, leftover_split in enumerate(leftover_splits):
48
+ shaved_splits[i].extend(leftover_split)
49
+
50
+ # validate invariants.
51
+ num_shaved_rows = sum([meta.num_rows for _, meta in shaved_splits[i]])
52
+ assert num_shaved_rows == target_split_size
53
+
54
+ # Compose the result back to RefBundle
55
+ equalized_ref_bundles: List[RefBundle] = []
56
+ for split in shaved_splits:
57
+ equalized_ref_bundles.append(RefBundle(split, owns_blocks=owned_by_consumer))
58
+ return equalized_ref_bundles
59
+
60
+
61
+ def _shave_one_split(
62
+ split: BlockPartition, num_rows_per_block: List[int], target_size: int
63
+ ) -> Tuple[BlockPartition, int, BlockPartition]:
64
+ """Shave a block list to the target size.
65
+
66
+ Args:
67
+ split: the block list to shave.
68
+ num_rows_per_block: num rows for each block in the list.
69
+ target_size: the upper bound target size of the shaved list.
70
+ Returns:
71
+ A tuple of:
72
+ - shaved block list.
73
+ - num of rows needed for the block list to meet the target size.
74
+ - leftover blocks.
75
+
76
+ """
77
+ # iterates through the blocks from the input list and
78
+ shaved = []
79
+ leftovers = []
80
+ shaved_rows = 0
81
+ for block_with_meta, block_rows in zip(split, num_rows_per_block):
82
+ if block_rows + shaved_rows <= target_size:
83
+ shaved.append(block_with_meta)
84
+ shaved_rows += block_rows
85
+ else:
86
+ leftovers.append(block_with_meta)
87
+ num_rows_needed = target_size - shaved_rows
88
+ return shaved, num_rows_needed, leftovers
89
+
90
+
91
+ def _shave_all_splits(
92
+ input_splits: List[BlockPartition],
93
+ per_split_num_rows: List[List[int]],
94
+ target_size: int,
95
+ ) -> Tuple[List[BlockPartition], List[int], BlockPartition]:
96
+ """Shave all block list to the target size.
97
+
98
+ Args:
99
+ input_splits: all block list to shave.
100
+ input_splits: num rows (per block) for each block list.
101
+ target_size: the upper bound target size of the shaved lists.
102
+ Returns:
103
+ A tuple of:
104
+ - all shaved block list.
105
+ - num of rows needed for the block list to meet the target size.
106
+ - leftover blocks.
107
+ """
108
+ shaved_splits = []
109
+ per_split_needed_rows = []
110
+ leftovers = []
111
+
112
+ for split, num_rows_per_block in zip(input_splits, per_split_num_rows):
113
+ shaved, num_rows_needed, _leftovers = _shave_one_split(
114
+ split, num_rows_per_block, target_size
115
+ )
116
+ shaved_splits.append(shaved)
117
+ per_split_needed_rows.append(num_rows_needed)
118
+ leftovers.extend(_leftovers)
119
+
120
+ return shaved_splits, per_split_needed_rows, leftovers
121
+
122
+
123
+ def _split_leftovers(
124
+ leftovers: RefBundle, per_split_needed_rows: List[int]
125
+ ) -> List[BlockPartition]:
126
+ """Split leftover blocks by the num of rows needed."""
127
+ num_splits = len(per_split_needed_rows)
128
+ split_indices = []
129
+ prev = 0
130
+ for i, num_rows_needed in enumerate(per_split_needed_rows):
131
+ split_indices.append(prev + num_rows_needed)
132
+ prev = split_indices[i]
133
+ split_result: Tuple[
134
+ List[List[ObjectRef[Block]]], List[List[BlockMetadata]]
135
+ ] = _split_at_indices(
136
+ leftovers.blocks,
137
+ split_indices,
138
+ leftovers.owns_blocks,
139
+ )
140
+ return [list(zip(block_refs, meta)) for block_refs, meta in zip(*split_result)][
141
+ :num_splits
142
+ ]
.venv/lib/python3.11/site-packages/ray/data/_internal/logging.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+ import os
4
+ from typing import Optional
5
+
6
+ import yaml
7
+
8
+ import ray
9
+
10
+ DEFAULT_CONFIG = {
11
+ "version": 1,
12
+ "disable_existing_loggers": False,
13
+ "formatters": {
14
+ "ray": {
15
+ "format": "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s" # noqa: E501
16
+ },
17
+ "ray_json": {"class": "ray._private.ray_logging.formatters.JSONFormatter"},
18
+ },
19
+ "filters": {
20
+ "console_filter": {"()": "ray.data._internal.logging.HiddenRecordFilter"},
21
+ "core_context_filter": {
22
+ "()": "ray._private.ray_logging.filters.CoreContextFilter"
23
+ },
24
+ },
25
+ "handlers": {
26
+ "file": {
27
+ "class": "ray.data._internal.logging.SessionFileHandler",
28
+ "formatter": "ray",
29
+ "filename": "ray-data.log",
30
+ },
31
+ "file_json": {
32
+ "class": "ray.data._internal.logging.SessionFileHandler",
33
+ "formatter": "ray_json",
34
+ "filename": "ray-data.log",
35
+ "filters": ["core_context_filter"],
36
+ },
37
+ "console": {
38
+ "class": "ray._private.log.PlainRayHandler",
39
+ "formatter": "ray",
40
+ "level": "INFO",
41
+ "filters": ["console_filter"],
42
+ },
43
+ },
44
+ "loggers": {
45
+ "ray.data": {
46
+ "level": "DEBUG",
47
+ "handlers": ["file", "console"],
48
+ "propagate": False,
49
+ },
50
+ "ray.air.util.tensor_extensions": {
51
+ "level": "DEBUG",
52
+ "handlers": ["file", "console"],
53
+ "propagate": False,
54
+ },
55
+ },
56
+ }
57
+
58
+ # Dictionary of substitutions to be performed when using JSON mode. Handlers with names
59
+ # corresponding to keys will be replaced by those corresponding to values.
60
+ RAY_DATA_LOG_HANDLER_JSON_SUBSTITUTIONS = {"file": "file_json"}
61
+
62
+ # Env. variable to specify the encoding of the file logs when using the default config.
63
+ RAY_DATA_LOG_ENCODING_ENV_VAR_NAME = "RAY_DATA_LOG_ENCODING"
64
+
65
+ # Env. variable to specify the logging config path use defaults if not set
66
+ RAY_DATA_LOGGING_CONFIG_ENV_VAR_NAME = "RAY_DATA_LOGGING_CONFIG"
67
+
68
+ # To facilitate debugging, Ray Data writes debug logs to a file. However, if Ray Data
69
+ # logs every scheduler loop, logging might impact performance. So, we add a "TRACE"
70
+ # level where logs aren't written by default.
71
+ #
72
+ # Use the following code to log a message at the "TRACE" level:
73
+ # ```
74
+ # logger.log(logging.getLevelName("TRACE"), "Your message here.")
75
+ # ````
76
+ logging.addLevelName(logging.DEBUG - 1, "TRACE")
77
+
78
+
79
+ class HiddenRecordFilter:
80
+ """Filters out log records with the "hide" attribute set to True.
81
+
82
+ This filter allows you to override default logging behavior. For example, if errors
83
+ are printed by default, and you don't want to print a specific error, you can set
84
+ the "hide" attribute to avoid printing the message.
85
+
86
+ .. testcode::
87
+
88
+ import logging
89
+ logger = logging.getLogger("ray.data.spam")
90
+
91
+ # This warning won't be printed to the console.
92
+ logger.warning("ham", extra={"hide": True})
93
+ """
94
+
95
+ def filter(self, record):
96
+ return not getattr(record, "hide", False)
97
+
98
+
99
+ class SessionFileHandler(logging.Handler):
100
+ """A handler that writes to a log file in the Ray session directory.
101
+
102
+ The Ray session directory isn't available until Ray is initialized, so this handler
103
+ lazily creates the file handler when you emit a log record.
104
+
105
+ Args:
106
+ filename: The name of the log file. The file is created in the 'logs' directory
107
+ of the Ray session directory.
108
+ """
109
+
110
+ def __init__(self, filename: str):
111
+ super().__init__()
112
+ self._filename = filename
113
+ self._handler = None
114
+ self._formatter = None
115
+ self._path = None
116
+
117
+ def emit(self, record):
118
+ if self._handler is None:
119
+ self._try_create_handler()
120
+ if self._handler is not None:
121
+ self._handler.emit(record)
122
+
123
+ def setFormatter(self, fmt: logging.Formatter) -> None:
124
+ if self._handler is not None:
125
+ self._handler.setFormatter(fmt)
126
+ self._formatter = fmt
127
+
128
+ def _try_create_handler(self):
129
+ assert self._handler is None
130
+
131
+ log_directory = get_log_directory()
132
+ if log_directory is None:
133
+ return
134
+
135
+ os.makedirs(log_directory, exist_ok=True)
136
+
137
+ self._path = os.path.join(log_directory, self._filename)
138
+ self._handler = logging.FileHandler(self._path)
139
+ if self._formatter is not None:
140
+ self._handler.setFormatter(self._formatter)
141
+
142
+
143
+ def configure_logging() -> None:
144
+ """Configure the Python logger named 'ray.data'.
145
+
146
+ This function loads the configration YAML specified by "RAY_DATA_LOGGING_CONFIG"
147
+ environment variable. If the variable isn't set, this function loads the default
148
+ "logging.yaml" file that is adjacent to this module.
149
+
150
+ If "RAY_DATA_LOG_ENCODING" is specified as "JSON" we will enable JSON logging mode
151
+ if using the default logging config.
152
+ """
153
+
154
+ def _load_logging_config(config_path: str):
155
+ with open(config_path) as file:
156
+ config = yaml.safe_load(file)
157
+ return config
158
+
159
+ # Dynamically load env vars
160
+ config_path = os.environ.get(RAY_DATA_LOGGING_CONFIG_ENV_VAR_NAME)
161
+ log_encoding = os.environ.get(RAY_DATA_LOG_ENCODING_ENV_VAR_NAME)
162
+
163
+ if config_path is not None:
164
+ config = _load_logging_config(config_path)
165
+ else:
166
+ config = DEFAULT_CONFIG
167
+ if log_encoding is not None and log_encoding.upper() == "JSON":
168
+ for logger in config["loggers"].values():
169
+ for (
170
+ old_handler_name,
171
+ new_handler_name,
172
+ ) in RAY_DATA_LOG_HANDLER_JSON_SUBSTITUTIONS.items():
173
+ logger["handlers"].remove(old_handler_name)
174
+ logger["handlers"].append(new_handler_name)
175
+
176
+ logging.config.dictConfig(config)
177
+
178
+ # After configuring logger, warn if RAY_DATA_LOGGING_CONFIG is used with
179
+ # RAY_DATA_LOG_ENCODING, because they are not both supported together.
180
+ if config_path is not None and log_encoding is not None:
181
+ logger = logging.getLogger(__name__)
182
+ logger.warning(
183
+ "Using `RAY_DATA_LOG_ENCODING` is not supported with "
184
+ + "`RAY_DATA_LOGGING_CONFIG`"
185
+ )
186
+
187
+
188
+ def reset_logging() -> None:
189
+ """Reset the logger named 'ray.data' to its initial state.
190
+
191
+ Used for testing.
192
+ """
193
+ logger = logging.getLogger("ray.data")
194
+ logger.handlers.clear()
195
+ logger.setLevel(logging.NOTSET)
196
+
197
+
198
+ def get_log_directory() -> Optional[str]:
199
+ """Return the directory where Ray Data writes log files.
200
+
201
+ If Ray isn't initialized, this function returns ``None``.
202
+ """
203
+ global_node = ray._private.worker._global_node
204
+ if global_node is None:
205
+ return None
206
+
207
+ session_dir = global_node.get_session_dir_path()
208
+ return os.path.join(session_dir, "logs", "ray-data")
.venv/lib/python3.11/site-packages/ray/data/_internal/memory_tracing.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility for debugging object store memory eager deletion in Datasets.
2
+
3
+ NOTE: the performance overhead of tracing object allocation is fairly substantial.
4
+ This is meant to use in unit test for debugging. Please do not enable in production,
5
+ without performance optimization.
6
+
7
+ Enable with RAY_DATA_TRACE_ALLOCATIONS=1.
8
+
9
+ Basic usage is to call `trace_allocation` each time a new object is created, and call
10
+ `trace_deallocation` when an object should be disposed of. When the workload is
11
+ complete, call `leak_report` to view possibly leaked objects.
12
+
13
+ Note that so called "leaked" objects will be reclaimed eventually by reference counting
14
+ in Ray. This is just to debug the eager deletion protocol which is more efficient.
15
+ """
16
+
17
+ from io import StringIO
18
+ from typing import Dict, List
19
+
20
+ import ray
21
+ from ray.data.context import DataContext
22
+
23
+
24
+ def trace_allocation(ref: ray.ObjectRef, loc: str) -> None:
25
+ """Record that an object has been created.
26
+
27
+ Args:
28
+ ref: The object created.
29
+ loc: A human-readable string identifying the call site.
30
+ """
31
+ ctx = DataContext.get_current()
32
+ if ctx.trace_allocations:
33
+ tracer = _get_mem_actor()
34
+ # TODO: it would be nice to determine loc automatically based on the stack.
35
+ ray.get(tracer.trace_alloc.remote([ref], loc))
36
+
37
+
38
+ def trace_deallocation(ref: ray.ObjectRef, loc: str, free: bool = True) -> None:
39
+ """Record that an object has been deleted (and delete if free=True).
40
+
41
+ Args:
42
+ ref: The object we no longer need.
43
+ loc: A human-readable string identifying the call site.
44
+ free: Whether to eagerly destroy the object instead of waiting for Ray
45
+ reference counting to kick in.
46
+ """
47
+ if free:
48
+ ray._private.internal_api.free(ref, local_only=False)
49
+ ctx = DataContext.get_current()
50
+ if ctx.trace_allocations:
51
+ tracer = _get_mem_actor()
52
+ ray.get(tracer.trace_dealloc.remote([ref], loc, free))
53
+
54
+
55
+ def leak_report() -> str:
56
+ tracer = _get_mem_actor()
57
+ return ray.get(tracer.leak_report.remote())
58
+
59
+
60
+ @ray.remote(num_cpus=0)
61
+ class _MemActor:
62
+ def __init__(self):
63
+ self.allocated: Dict[ray.ObjectRef, dict] = {}
64
+ self.deallocated: Dict[ray.ObjectRef, dict] = {}
65
+ self.skip_dealloc: Dict[ray.ObjectRef, str] = {}
66
+ self.peak_mem = 0
67
+ self.cur_mem = 0
68
+
69
+ def trace_alloc(self, ref: List[ray.ObjectRef], loc: str):
70
+ ref = ref[0] # Avoid Ray materializing the ref.
71
+ if ref not in self.allocated:
72
+ meta = ray.experimental.get_object_locations([ref])
73
+ size_bytes = meta.get("object_size", 0)
74
+ if not size_bytes:
75
+ size_bytes = -1
76
+ from ray import cloudpickle as pickle
77
+
78
+ try:
79
+ obj = ray.get(ref, timeout=5.0)
80
+ size_bytes = len(pickle.dumps(obj))
81
+ except Exception:
82
+ print("[mem_tracing] ERROR getting size")
83
+ size_bytes = -1
84
+ print(f"[mem_tracing] Allocated {size_bytes} bytes at {loc}: {ref}")
85
+ entry = {
86
+ "size_bytes": size_bytes,
87
+ "loc": loc,
88
+ }
89
+ self.allocated[ref] = entry
90
+ self.cur_mem += size_bytes
91
+ self.peak_mem = max(self.cur_mem, self.peak_mem)
92
+
93
+ def trace_dealloc(self, ref: List[ray.ObjectRef], loc: str, freed: bool):
94
+ ref = ref[0] # Avoid Ray materializing the ref.
95
+ size_bytes = self.allocated.get(ref, {}).get("size_bytes", 0)
96
+ if freed:
97
+ print(f"[mem_tracing] Freed {size_bytes} bytes at {loc}: {ref}")
98
+ if ref in self.allocated:
99
+ self.cur_mem -= size_bytes
100
+ self.deallocated[ref] = self.allocated.pop(ref)
101
+ self.deallocated[ref]["dealloc_loc"] = loc
102
+ if ref in self.deallocated:
103
+ # This object reference is already deallocated.
104
+ pass
105
+ else:
106
+ print(f"[mem_tracing] WARNING: allocation of {ref} was not traced!")
107
+ else:
108
+ print(f"[mem_tracing] Skipped freeing {size_bytes} bytes at {loc}: {ref}")
109
+ self.skip_dealloc[ref] = loc
110
+
111
+ def leak_report(self) -> str:
112
+ output = StringIO()
113
+ output.write("[mem_tracing] ===== Leaked objects =====\n")
114
+ for ref in self.allocated:
115
+ size_bytes = self.allocated[ref].get("size_bytes")
116
+ loc = self.allocated[ref].get("loc")
117
+ if ref in self.skip_dealloc:
118
+ dealloc_loc = self.skip_dealloc[ref]
119
+ output.write(
120
+ f"[mem_tracing] Leaked object, created at {loc}, size "
121
+ f"{size_bytes}, skipped dealloc at {dealloc_loc}: {ref}\n"
122
+ )
123
+ else:
124
+ output.write(
125
+ f"[mem_tracing] Leaked object, created at {loc}, "
126
+ f"size {size_bytes}: {ref}\n"
127
+ )
128
+ output.write("[mem_tracing] ===== End leaked objects =====\n")
129
+ output.write("[mem_tracing] ===== Freed objects =====\n")
130
+ for ref in self.deallocated:
131
+ size_bytes = self.deallocated[ref].get("size_bytes")
132
+ loc = self.deallocated[ref].get("loc")
133
+ dealloc_loc = self.deallocated[ref].get("dealloc_loc")
134
+ output.write(
135
+ f"[mem_tracing] Freed object from {loc} at {dealloc_loc}, "
136
+ f"size {size_bytes}: {ref}\n"
137
+ )
138
+ output.write("[mem_tracing] ===== End freed objects =====\n")
139
+ output.write(f"[mem_tracing] Peak size bytes {self.peak_mem}\n")
140
+ output.write(f"[mem_tracing] Current size bytes {self.cur_mem}\n")
141
+ return output.getvalue()
142
+
143
+
144
+ def _get_mem_actor():
145
+ return _MemActor.options(
146
+ name="mem_tracing_actor", get_if_exists=True, lifetime="detached"
147
+ ).remote()
.venv/lib/python3.11/site-packages/ray/data/_internal/null_aggregate.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import ModuleType
2
+ from typing import Any, Callable, Tuple, Union
3
+
4
+ import numpy as np
5
+
6
+ from ray.data.block import AggType, Block, KeyType, T, U
7
+
8
+ WrappedAggType = Tuple[AggType, int]
9
+
10
+
11
+ # This module contains aggregation helpers for handling nulls.
12
+ # The null handling policy is:
13
+ # 1. Mix of values and nulls - ignore_nulls=True: Ignore the nulls, return
14
+ # aggregation of non-null values.
15
+ # 2. Mix of values and nulls - ignore_nulls=False: Return None.
16
+ # 3. All nulls: Return None.
17
+ # 4. Empty dataset: Return None.
18
+ #
19
+ # This is accomplished by checking rows for null values and by propagating nulls
20
+ # if found AND if we're not ignoring them. If not ignoring nulls, in order to delineate
21
+ # between found null rows and an empty block accumulation when merging (the latter of
22
+ # which we want to propagate; the former of which we do not), we attach a boolean flag
23
+ # indicating whether or not an accumulation contains valid data to intermediate block
24
+ # accumulations via _wrap_acc() and _unwrap_acc(). This allows us to properly merge
25
+ # intermediate block accumulations under a streaming constraint.
26
+
27
+
28
+ def _wrap_acc(a: AggType, has_data: bool) -> WrappedAggType:
29
+ """
30
+ Wrap accumulation with a numeric boolean flag indicating whether or not
31
+ this accumulation contains real data; if it doesn't, we consider it to be
32
+ empty.
33
+
34
+ Args:
35
+ a: The accumulation value.
36
+ has_data: Whether the accumulation contains real data.
37
+
38
+ Returns:
39
+ An AggType list with the last element being a numeric boolean flag indicating
40
+ whether or not this accumulation contains real data. If the input a has length
41
+ n, the returned AggType has length n + 1.
42
+ """
43
+ if not isinstance(a, list):
44
+ a = [a]
45
+ return a + [1 if has_data else 0]
46
+
47
+
48
+ def _unwrap_acc(a: WrappedAggType) -> Tuple[AggType, bool]:
49
+ """
50
+ Unwrap the accumulation, which we assume has been wrapped (via _wrap_acc) with a
51
+ numeric boolean flag indicating whether or not this accumulation contains real data.
52
+
53
+ Args:
54
+ a: The wrapped accumulation value that we wish to unwrap.
55
+
56
+ Returns:
57
+ A tuple containing the unwrapped accumulation value and a boolean indicating
58
+ whether the accumulation contains real data.
59
+ """
60
+ has_data = a[-1] == 1
61
+ a = a[:-1]
62
+ if len(a) == 1:
63
+ a = a[0]
64
+ return a, has_data
65
+
66
+
67
+ def _null_wrap_init(
68
+ init: Callable[[KeyType], AggType]
69
+ ) -> Callable[[KeyType], WrappedAggType]:
70
+ """
71
+ Wraps an accumulation initializer with null handling.
72
+
73
+ The returned initializer function adds on a has_data field that the accumulator
74
+ uses to track whether an aggregation is empty.
75
+
76
+ Args:
77
+ init: The core init function to wrap.
78
+
79
+ Returns:
80
+ A new accumulation initializer function that can handle nulls.
81
+ """
82
+
83
+ def _init(k: KeyType) -> AggType:
84
+ a = init(k)
85
+ # Initializing accumulation, so indicate that the accumulation doesn't represent
86
+ # real data yet.
87
+ return _wrap_acc(a, has_data=False)
88
+
89
+ return _init
90
+
91
+
92
+ def _null_wrap_merge(
93
+ ignore_nulls: bool,
94
+ merge: Callable[[AggType, AggType], AggType],
95
+ ) -> Callable[[WrappedAggType, WrappedAggType], WrappedAggType]:
96
+ """
97
+ Wrap merge function with null handling.
98
+
99
+ The returned merge function expects a1 and a2 to be either None or of the form:
100
+ a = [acc_data_1, ..., acc_data_2, has_data].
101
+
102
+ This merges two accumulations subject to the following null rules:
103
+ 1. If a1 is empty and a2 is empty, return empty accumulation.
104
+ 2. If a1 (a2) is empty and a2 (a1) is None, return None.
105
+ 3. If a1 (a2) is empty and a2 (a1) is non-None, return a2 (a1).
106
+ 4. If a1 (a2) is None, return a2 (a1) if ignoring nulls, None otherwise.
107
+ 5. If a1 and a2 are both non-null, return merge(a1, a2).
108
+
109
+ Args:
110
+ ignore_nulls: Whether nulls should be ignored or cause a None result.
111
+ merge: The core merge function to wrap.
112
+
113
+ Returns:
114
+ A new merge function that handles nulls.
115
+ """
116
+
117
+ def _merge(a1: WrappedAggType, a2: WrappedAggType) -> WrappedAggType:
118
+ if a1 is None:
119
+ # If we're ignoring nulls, propagate a2; otherwise, propagate None.
120
+ return a2 if ignore_nulls else None
121
+ unwrapped_a1, a1_has_data = _unwrap_acc(a1)
122
+ if not a1_has_data:
123
+ # If a1 is empty, propagate a2.
124
+ # No matter whether a2 is a real value, empty, or None,
125
+ # propagating each of these is correct if a1 is empty.
126
+ return a2
127
+ if a2 is None:
128
+ # If we're ignoring nulls, propagate a1; otherwise, propagate None.
129
+ return a1 if ignore_nulls else None
130
+ unwrapped_a2, a2_has_data = _unwrap_acc(a2)
131
+ if not a2_has_data:
132
+ # If a2 is empty, propagate a1.
133
+ return a1
134
+ a = merge(unwrapped_a1, unwrapped_a2)
135
+ return _wrap_acc(a, has_data=True)
136
+
137
+ return _merge
138
+
139
+
140
+ def _null_wrap_accumulate_row(
141
+ ignore_nulls: bool,
142
+ on_fn: Callable[[T], T],
143
+ accum: Callable[[AggType, T], AggType],
144
+ ) -> Callable[[WrappedAggType, T], WrappedAggType]:
145
+ """
146
+ Wrap accumulator function with null handling.
147
+
148
+ The returned accumulate function expects a to be either None or of the form:
149
+ a = [acc_data_1, ..., acc_data_n, has_data].
150
+
151
+ This performs an accumulation subject to the following null rules:
152
+ 1. If r is null and ignore_nulls=False, return None.
153
+ 2. If r is null and ignore_nulls=True, return a.
154
+ 3. If r is non-null and a is None, return None.
155
+ 4. If r is non-null and a is non-None, return accum(a[:-1], r).
156
+
157
+ Args:
158
+ ignore_nulls: Whether nulls should be ignored or cause a None result.
159
+ on_fn: Function selecting a subset of the row to apply the aggregation.
160
+ accum: The core accumulator function to wrap.
161
+
162
+ Returns:
163
+ A new accumulator function that handles nulls.
164
+ """
165
+
166
+ def _accum(a: WrappedAggType, r: T) -> WrappedAggType:
167
+ r = on_fn(r)
168
+ if _is_null(r):
169
+ if ignore_nulls:
170
+ # Ignoring nulls, return the current accumulation, ignoring r.
171
+ return a
172
+ else:
173
+ # Not ignoring nulls, so propagate the null.
174
+ return None
175
+ else:
176
+ if a is None:
177
+ # Accumulation is None so (1) a previous row must have been null, and
178
+ # (2) we must be propagating nulls, so continue to pragate this null.
179
+ return None
180
+ else:
181
+ # Row is non-null and accumulation is non-null, so we now apply the core
182
+ # accumulation.
183
+ a, _ = _unwrap_acc(a)
184
+ a = accum(a, r)
185
+ return _wrap_acc(a, has_data=True)
186
+
187
+ return _accum
188
+
189
+
190
+ def _null_wrap_accumulate_block(
191
+ ignore_nulls: bool,
192
+ accum_block: Callable[[Block], AggType],
193
+ null_merge: Callable[[WrappedAggType, WrappedAggType], WrappedAggType],
194
+ ) -> Callable[[WrappedAggType, Block], WrappedAggType]:
195
+ """
196
+ Wrap vectorized aggregate function with null handling.
197
+
198
+ This performs a block accumulation subject to the following null rules:
199
+ 1. If any row is null and ignore_nulls=False, return None.
200
+ 2. If at least one row is not null and ignore_nulls=True, return the block
201
+ accumulation.
202
+ 3. If all rows are null and ignore_nulls=True, return the base accumulation.
203
+ 4. If all rows non-null, return the block accumulation.
204
+
205
+ Args:
206
+ ignore_nulls: Whether nulls should be ignored or cause a None result.
207
+ accum_block: The core vectorized aggregate function to wrap.
208
+ null_merge: A null-handling merge, as returned from _null_wrap_merge().
209
+
210
+ Returns:
211
+ A new vectorized aggregate function that handles nulls.
212
+ """
213
+
214
+ def _accum_block_null(a: WrappedAggType, block: Block) -> WrappedAggType:
215
+ ret = accum_block(block)
216
+ if ret is not None:
217
+ ret = _wrap_acc(ret, has_data=True)
218
+ elif ignore_nulls:
219
+ # This can happen if we're ignoring nulls but the entire block only consists
220
+ # of nulls. We treat the block as if it were empty in this case.
221
+ ret = a
222
+ return null_merge(a, ret)
223
+
224
+ return _accum_block_null
225
+
226
+
227
+ def _null_wrap_finalize(
228
+ finalize: Callable[[AggType], AggType]
229
+ ) -> Callable[[WrappedAggType], U]:
230
+ """
231
+ Wrap finalizer with null handling.
232
+
233
+ If the accumulation is empty or None, the returned finalizer returns None.
234
+
235
+ Args:
236
+ finalize: The core finalizing function to wrap.
237
+
238
+ Returns:
239
+ A new finalizing function that handles nulls.
240
+ """
241
+
242
+ def _finalize(a: AggType) -> U:
243
+ if a is None:
244
+ return None
245
+ a, has_data = _unwrap_acc(a)
246
+ if not has_data:
247
+ return None
248
+ return finalize(a)
249
+
250
+ return _finalize
251
+
252
+
253
+ LazyModule = Union[None, bool, ModuleType]
254
+ _pandas: LazyModule = None
255
+
256
+
257
+ def _lazy_import_pandas() -> LazyModule:
258
+ global _pandas
259
+ if _pandas is None:
260
+ try:
261
+ import pandas as _pandas
262
+ except ModuleNotFoundError:
263
+ # If module is not found, set _pandas to False so we won't
264
+ # keep trying to import it on every _lazy_import_pandas() call.
265
+ _pandas = False
266
+ return _pandas
267
+
268
+
269
+ def _is_null(r: Any):
270
+ pd = _lazy_import_pandas()
271
+ if pd:
272
+ return pd.isnull(r)
273
+ try:
274
+ return np.isnan(r)
275
+ except TypeError:
276
+ return r is None
.venv/lib/python3.11/site-packages/ray/data/_internal/numpy_support.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ from datetime import datetime
4
+ from typing import Any, Dict, List, Union
5
+
6
+ import numpy as np
7
+
8
+ from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
9
+ from ray.data._internal.util import _truncated_repr
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def is_array_like(value: Any) -> bool:
15
+ """Checks whether objects are array-like, excluding numpy scalars."""
16
+
17
+ return hasattr(value, "__array__") and hasattr(value, "__len__")
18
+
19
+
20
+ def is_valid_udf_return(udf_return_col: Any) -> bool:
21
+ """Check whether a UDF column is valid.
22
+
23
+ Valid columns must either be a list of elements, or an array-like object.
24
+ """
25
+
26
+ return isinstance(udf_return_col, list) or is_array_like(udf_return_col)
27
+
28
+
29
+ def is_nested_list(udf_return_col: List[Any]) -> bool:
30
+ for e in udf_return_col:
31
+ if isinstance(e, list):
32
+ return True
33
+ return False
34
+
35
+
36
+ def validate_numpy_batch(batch: Union[Dict[str, np.ndarray], Dict[str, list]]) -> None:
37
+ if not isinstance(batch, collections.abc.Mapping) or any(
38
+ not is_valid_udf_return(col) for col in batch.values()
39
+ ):
40
+ raise ValueError(
41
+ "Batch must be an ndarray or dictionary of ndarrays when converting "
42
+ f"a numpy batch to a block, got: {type(batch)} "
43
+ f"({_truncated_repr(batch)})"
44
+ )
45
+
46
+
47
+ def _detect_highest_datetime_precision(datetime_list: List[datetime]) -> str:
48
+ """Detect the highest precision for a list of datetime objects.
49
+
50
+ Args:
51
+ datetime_list: List of datetime objects.
52
+
53
+ Returns:
54
+ A string representing the highest precision among the datetime objects
55
+ ('D', 's', 'ms', 'us', 'ns').
56
+ """
57
+ # Define precision hierarchy
58
+ precision_hierarchy = ["D", "s", "ms", "us", "ns"]
59
+ highest_precision_index = 0 # Start with the lowest precision ("D")
60
+
61
+ for dt in datetime_list:
62
+ # Safely get the nanosecond value using getattr for backward compatibility
63
+ nanosecond = getattr(dt, "nanosecond", 0)
64
+ if nanosecond != 0:
65
+ current_precision = "ns"
66
+ elif dt.microsecond != 0:
67
+ # Check if the microsecond precision is exactly millisecond
68
+ if dt.microsecond % 1000 == 0:
69
+ current_precision = "ms"
70
+ else:
71
+ current_precision = "us"
72
+ elif dt.second != 0 or dt.minute != 0 or dt.hour != 0:
73
+ # pyarrow does not support h or m, use s for those cases to
74
+ current_precision = "s"
75
+ else:
76
+ current_precision = "D"
77
+
78
+ # Update highest_precision_index based on the hierarchy
79
+ current_index = precision_hierarchy.index(current_precision)
80
+ highest_precision_index = max(highest_precision_index, current_index)
81
+
82
+ # Stop early if highest possible precision is reached
83
+ if highest_precision_index == len(precision_hierarchy) - 1:
84
+ break
85
+
86
+ return precision_hierarchy[highest_precision_index]
87
+
88
+
89
+ def _convert_to_datetime64(dt: datetime, precision: str) -> np.datetime64:
90
+ """
91
+ Converts a datetime object to a numpy datetime64 object with the specified
92
+ precision.
93
+
94
+ Args:
95
+ dt: A datetime object to be converted.
96
+ precision: The desired precision for the datetime64 conversion. Possible
97
+ values are 'D', 's', 'ms', 'us', 'ns'.
98
+
99
+ Returns:
100
+ np.datetime64: A numpy datetime64 object with the specified precision.
101
+ """
102
+ if precision == "ns":
103
+ # Calculate nanoseconds from microsecond and nanosecond
104
+ microseconds_as_ns = dt.microsecond * 1000
105
+ # Use getattr for backward compatibility where nanosecond attribute may not
106
+ # exist
107
+ nanoseconds = getattr(dt, "nanosecond", 0)
108
+ total_nanoseconds = microseconds_as_ns + nanoseconds
109
+ # Create datetime64 from base datetime with microsecond precision
110
+ base_dt = np.datetime64(dt, "us")
111
+ # Add remaining nanoseconds as timedelta
112
+ return base_dt + np.timedelta64(total_nanoseconds - microseconds_as_ns, "ns")
113
+ else:
114
+ return np.datetime64(dt).astype(f"datetime64[{precision}]")
115
+
116
+
117
+ def _convert_datetime_list_to_array(datetime_list: List[datetime]) -> np.ndarray:
118
+ """Convert a list of datetime objects to a NumPy array of datetime64 with proper
119
+ precision.
120
+
121
+ Args:
122
+ datetime_list (List[datetime]): A list of `datetime` objects to be converted.
123
+ Each `datetime` object represents a specific point in time.
124
+
125
+ Returns:
126
+ np.ndarray: A NumPy array containing the `datetime64` values of the datetime
127
+ objects from the input list, with the appropriate precision (e.g., nanoseconds,
128
+ microseconds, milliseconds, etc.).
129
+ """
130
+ # Detect the highest precision for the datetime objects
131
+ precision = _detect_highest_datetime_precision(datetime_list)
132
+
133
+ # Convert each datetime to the corresponding numpy datetime64 with the appropriate
134
+ # precision
135
+ return np.array([_convert_to_datetime64(dt, precision) for dt in datetime_list])
136
+
137
+
138
+ def convert_to_numpy(column_values: Any) -> np.ndarray:
139
+ """Convert UDF columns (output of map_batches) to numpy, if possible.
140
+
141
+ This includes lists of scalars, objects supporting the array protocol, and lists
142
+ of objects supporting the array protocol, such as `[1, 2, 3]`, `Tensor([1, 2, 3])`,
143
+ and `[array(1), array(2), array(3)]`.
144
+
145
+ Returns:
146
+ The input as an np.ndarray if possible, otherwise the original input.
147
+
148
+ Raises:
149
+ ValueError if an input was array-like but we failed to convert it to an array.
150
+ """
151
+
152
+ if isinstance(column_values, np.ndarray):
153
+ # No copy/conversion needed, just keep it verbatim.
154
+ return column_values
155
+
156
+ elif isinstance(column_values, list):
157
+ if len(column_values) == 1 and isinstance(column_values[0], np.ndarray):
158
+ # Optimization to avoid conversion overhead from list to np.array.
159
+ return np.expand_dims(column_values[0], axis=0)
160
+
161
+ if all(isinstance(elem, datetime) for elem in column_values):
162
+ return _convert_datetime_list_to_array(column_values)
163
+
164
+ # Try to convert list values into an numpy array via
165
+ # np.array(), so users don't need to manually cast.
166
+ # NOTE: we don't cast generic iterables, since types like
167
+ # `str` are also Iterable.
168
+ try:
169
+ # Convert array-like objects (like torch.Tensor) to `np.ndarray`s
170
+ if all(is_array_like(e) for e in column_values):
171
+ # Use np.asarray() instead of np.array() to avoid copying if possible.
172
+ column_values = [np.asarray(e) for e in column_values]
173
+
174
+ shapes = set()
175
+ has_object = False
176
+ for e in column_values:
177
+ if isinstance(e, np.ndarray):
178
+ shapes.add((e.dtype, e.shape))
179
+ elif isinstance(e, bytes):
180
+ # Don't convert variable length binary data to Numpy arrays as it
181
+ # treats zero encoding as termination by default.
182
+ # Per recommendation from
183
+ # https://github.com/apache/arrow/issues/26470,
184
+ # we use object dtype.
185
+ # https://github.com/ray-project/ray/issues/35586#issuecomment-1558148261
186
+ has_object = True
187
+ elif not np.isscalar(e):
188
+ has_object = True
189
+
190
+ # When column values are
191
+ # - Arrays of heterogeneous shapes
192
+ # - Byte-strings (viewed as arrays of heterogeneous shapes)
193
+ # - Non-scalar objects (tuples, lists, arbitrary object types)
194
+ #
195
+ # Custom "ragged ndarray" is created, represented as an array of
196
+ # references (ie ndarray with dtype=object)
197
+ if has_object or len(shapes) > 1:
198
+ # This util works around some limitations of np.array(dtype=object).
199
+ return create_ragged_ndarray(column_values)
200
+ else:
201
+ return np.array(column_values)
202
+
203
+ except Exception as e:
204
+ logger.error(
205
+ f"Failed to convert column values to numpy array: "
206
+ f"{_truncated_repr(column_values)}",
207
+ exc_info=e,
208
+ )
209
+
210
+ raise ValueError(
211
+ "Failed to convert column values to numpy array: "
212
+ f"({_truncated_repr(column_values)}): {e}."
213
+ ) from e
214
+
215
+ elif is_array_like(column_values):
216
+ # Converts other array-like objects such as torch.Tensor.
217
+ try:
218
+ # Use np.asarray() instead of np.array() to avoid copying if possible.
219
+ return np.asarray(column_values)
220
+ except Exception as e:
221
+ logger.error(
222
+ f"Failed to convert column values to numpy array: "
223
+ f"{_truncated_repr(column_values)}",
224
+ exc_info=e,
225
+ )
226
+
227
+ raise ValueError(
228
+ "Failed to convert column values to numpy array: "
229
+ f"({_truncated_repr(column_values)}): {e}."
230
+ ) from e
231
+
232
+ else:
233
+ return column_values
.venv/lib/python3.11/site-packages/ray/data/_internal/output_buffer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
4
+ from ray.data.block import Block, BlockAccessor, DataBatch
5
+ from ray.data.context import MAX_SAFE_BLOCK_SIZE_FACTOR
6
+
7
+
8
+ class BlockOutputBuffer:
9
+ """Generates output blocks of a given size given a stream of inputs.
10
+
11
+ This class is used to turn a stream of items / blocks of arbitrary size
12
+ into a stream of blocks of ``target_max_block_size``. The caller should
13
+ check ``has_next()`` after each ``add()`` call, and call ``next()`` to get
14
+ the next block when ``has_next()`` returns True.
15
+
16
+ When all items have been added, the caller must call ``finalize()`` and
17
+ then check ``has_next()`` one last time.
18
+
19
+ Examples:
20
+ >>> from ray.data._internal.output_buffer import BlockOutputBuffer
21
+ >>> udf = ... # doctest: +SKIP
22
+ >>> generator = ... # doctest: +SKIP
23
+ >>> # Yield a stream of output blocks.
24
+ >>> output = BlockOutputBuffer(udf, 500 * 1024 * 1024) # doctest: +SKIP
25
+ >>> for item in generator(): # doctest: +SKIP
26
+ ... output.add(item) # doctest: +SKIP
27
+ ... if output.has_next(): # doctest: +SKIP
28
+ ... yield output.next() # doctest: +SKIP
29
+ >>> output.finalize() # doctest: +SKIP
30
+ >>> if output.has_next() # doctest: +SKIP
31
+ ... yield output.next() # doctest: +SKIP
32
+ """
33
+
34
+ def __init__(self, target_max_block_size: int):
35
+ self._target_max_block_size = target_max_block_size
36
+ self._buffer = DelegatingBlockBuilder()
37
+ self._returned_at_least_one_block = False
38
+ self._finalized = False
39
+
40
+ def add(self, item: Any) -> None:
41
+ """Add a single item to this output buffer."""
42
+ assert not self._finalized
43
+ self._buffer.add(item)
44
+
45
+ def add_batch(self, batch: DataBatch) -> None:
46
+ """Add a data batch to this output buffer."""
47
+ assert not self._finalized
48
+ self._buffer.add_batch(batch)
49
+
50
+ def add_block(self, block: Block) -> None:
51
+ """Add a data block to this output buffer."""
52
+ assert not self._finalized
53
+ self._buffer.add_block(block)
54
+
55
+ def finalize(self) -> None:
56
+ """Must be called once all items have been added."""
57
+ assert not self._finalized
58
+ self._finalized = True
59
+
60
+ def has_next(self) -> bool:
61
+ """Returns true when a complete output block is produced."""
62
+ if self._finalized:
63
+ return not self._returned_at_least_one_block or self._buffer.num_rows() > 0
64
+ else:
65
+ return (
66
+ self._buffer.get_estimated_memory_usage() > self._target_max_block_size
67
+ )
68
+
69
+ def next(self) -> Block:
70
+ """Returns the next complete output block."""
71
+ assert self.has_next()
72
+
73
+ block_to_yield = self._buffer.build()
74
+ block_remainder = None
75
+ block = BlockAccessor.for_block(block_to_yield)
76
+ if (
77
+ block.size_bytes()
78
+ >= MAX_SAFE_BLOCK_SIZE_FACTOR * self._target_max_block_size
79
+ ):
80
+ # Slice a block to respect the target max block size. We only do
81
+ # this if we are more than 50% above the target block size, because
82
+ # this ensures that the last block produced will be at least half
83
+ # the block size.
84
+ num_bytes_per_row = block.size_bytes() // block.num_rows()
85
+ target_num_rows = max(1, self._target_max_block_size // num_bytes_per_row)
86
+
87
+ if target_num_rows < block.num_rows():
88
+ # NOTE: We're maintaining following protocol of slicing underlying block
89
+ # into appropriately sized ones:
90
+ #
91
+ # - (Finalized) Target blocks sliced from the original one
92
+ # and are *copied* to avoid referencing original blocks
93
+ # - Temporary remainder of the block should *NOT* be copied
94
+ # such as to avoid repeatedly copying the remainder bytes
95
+ # of the block, resulting in O(M * N) total bytes being
96
+ # copied, where N is the total number of bytes in the original
97
+ # block and M is the number of blocks that will be produced by
98
+ # this iterator
99
+ block_to_yield = block.slice(0, target_num_rows, copy=True)
100
+ block_remainder = block.slice(
101
+ target_num_rows, block.num_rows(), copy=False
102
+ )
103
+
104
+ self._buffer = DelegatingBlockBuilder()
105
+ if block_remainder is not None:
106
+ self._buffer.add_block(block_remainder)
107
+
108
+ self._returned_at_least_one_block = True
109
+ return block_to_yield
.venv/lib/python3.11/site-packages/ray/data/_internal/pandas_block.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import heapq
3
+ import logging
4
+ import sys
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Iterator,
11
+ List,
12
+ Optional,
13
+ Sequence,
14
+ Tuple,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+
19
+ import numpy as np
20
+
21
+ from ray.air.constants import TENSOR_COLUMN_NAME
22
+ from ray.air.util.tensor_extensions.utils import _is_ndarray_tensor
23
+ from ray.data._internal.numpy_support import convert_to_numpy, validate_numpy_batch
24
+ from ray.data._internal.row import TableRow
25
+ from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
26
+ from ray.data._internal.util import find_partitions, keys_equal
27
+ from ray.data.block import (
28
+ Block,
29
+ BlockAccessor,
30
+ BlockExecStats,
31
+ BlockMetadata,
32
+ BlockType,
33
+ KeyType,
34
+ U,
35
+ )
36
+ from ray.data.context import DataContext
37
+
38
+ if TYPE_CHECKING:
39
+ import pandas
40
+ import pyarrow
41
+
42
+ from ray.data._internal.planner.exchange.sort_task_spec import SortKey
43
+ from ray.data.aggregate import AggregateFn
44
+
45
+ T = TypeVar("T")
46
+ # Max number of samples used to estimate the Pandas block size.
47
+ _PANDAS_SIZE_BYTES_MAX_SAMPLE_COUNT = 50
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+ _pandas = None
52
+
53
+
54
+ def lazy_import_pandas():
55
+ global _pandas
56
+ if _pandas is None:
57
+ import pandas
58
+
59
+ _pandas = pandas
60
+ return _pandas
61
+
62
+
63
+ class PandasRow(TableRow):
64
+ """
65
+ Row of a tabular Dataset backed by a Pandas DataFrame block.
66
+ """
67
+
68
+ def __getitem__(self, key: Union[str, List[str]]) -> Any:
69
+ from ray.data.extensions import TensorArrayElement
70
+
71
+ pd = lazy_import_pandas()
72
+
73
+ def get_item(keys: List[str]) -> Any:
74
+ col = self._row[keys]
75
+ if len(col) == 0:
76
+ return None
77
+
78
+ items = col.iloc[0]
79
+ if isinstance(items.iloc[0], TensorArrayElement):
80
+ # Getting an item in a Pandas tensor column may return
81
+ # a TensorArrayElement, which we have to convert to an ndarray.
82
+ return pd.Series(item.to_numpy() for item in items)
83
+
84
+ try:
85
+ # Try to interpret this as a numpy-type value.
86
+ # See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
87
+ return pd.Series(item.as_py() for item in items)
88
+
89
+ except (AttributeError, ValueError):
90
+ # Fallback to the original form.
91
+ return items
92
+
93
+ is_single_item = isinstance(key, str)
94
+ keys = [key] if is_single_item else key
95
+
96
+ items = get_item(keys)
97
+
98
+ if items is None:
99
+ return None
100
+ elif is_single_item:
101
+ return items.iloc[0]
102
+ else:
103
+ return items
104
+
105
+ def __iter__(self) -> Iterator:
106
+ for k in self._row.columns:
107
+ yield k
108
+
109
+ def __len__(self):
110
+ return self._row.shape[1]
111
+
112
+
113
+ class PandasBlockBuilder(TableBlockBuilder):
114
+ def __init__(self):
115
+ pandas = lazy_import_pandas()
116
+ super().__init__(pandas.DataFrame)
117
+
118
+ @staticmethod
119
+ def _table_from_pydict(columns: Dict[str, List[Any]]) -> "pandas.DataFrame":
120
+ pandas = lazy_import_pandas()
121
+
122
+ pd_columns: Dict[str, Any] = {}
123
+
124
+ for col_name, col_vals in columns.items():
125
+ np_col_vals = convert_to_numpy(col_vals)
126
+
127
+ if col_name == TENSOR_COLUMN_NAME or _is_ndarray_tensor(np_col_vals):
128
+ from ray.data.extensions.tensor_extension import TensorArray
129
+
130
+ pd_columns[col_name] = TensorArray(np_col_vals)
131
+ else:
132
+ pd_columns[col_name] = np_col_vals
133
+
134
+ return pandas.DataFrame(pd_columns)
135
+
136
+ @staticmethod
137
+ def _concat_tables(tables: List["pandas.DataFrame"]) -> "pandas.DataFrame":
138
+ pandas = lazy_import_pandas()
139
+ from ray.air.util.data_batch_conversion import (
140
+ _cast_ndarray_columns_to_tensor_extension,
141
+ )
142
+
143
+ if len(tables) > 1:
144
+ df = pandas.concat(tables, ignore_index=True)
145
+ df.reset_index(drop=True, inplace=True)
146
+ else:
147
+ df = tables[0]
148
+ ctx = DataContext.get_current()
149
+ if ctx.enable_tensor_extension_casting:
150
+ df = _cast_ndarray_columns_to_tensor_extension(df)
151
+ return df
152
+
153
+ @staticmethod
154
+ def _concat_would_copy() -> bool:
155
+ return True
156
+
157
+ @staticmethod
158
+ def _empty_table() -> "pandas.DataFrame":
159
+ pandas = lazy_import_pandas()
160
+ return pandas.DataFrame()
161
+
162
+ def block_type(self) -> BlockType:
163
+ return BlockType.PANDAS
164
+
165
+
166
+ # This is to be compatible with pyarrow.lib.schema
167
+ # TODO (kfstorm): We need a format-independent way to represent schema.
168
+ PandasBlockSchema = collections.namedtuple("PandasBlockSchema", ["names", "types"])
169
+
170
+
171
+ class PandasBlockAccessor(TableBlockAccessor):
172
+ ROW_TYPE = PandasRow
173
+
174
+ def __init__(self, table: "pandas.DataFrame"):
175
+ super().__init__(table)
176
+
177
+ def column_names(self) -> List[str]:
178
+ return self._table.columns.tolist()
179
+
180
+ def append_column(self, name: str, data: Any) -> Block:
181
+ assert name not in self._table.columns
182
+
183
+ if any(isinstance(item, np.ndarray) for item in data):
184
+ raise NotImplementedError(
185
+ f"`{self.__class__.__name__}.append_column()` doesn't support "
186
+ "array-like data."
187
+ )
188
+
189
+ table = self._table.copy()
190
+ table[name] = data
191
+ return table
192
+
193
+ @staticmethod
194
+ def _build_tensor_row(row: PandasRow) -> np.ndarray:
195
+ from ray.data.extensions import TensorArrayElement
196
+
197
+ tensor = row[TENSOR_COLUMN_NAME].iloc[0]
198
+ if isinstance(tensor, TensorArrayElement):
199
+ # Getting an item in a Pandas tensor column may return a TensorArrayElement,
200
+ # which we have to convert to an ndarray.
201
+ tensor = tensor.to_numpy()
202
+ return tensor
203
+
204
+ def slice(self, start: int, end: int, copy: bool = False) -> "pandas.DataFrame":
205
+ view = self._table[start:end]
206
+ view.reset_index(drop=True, inplace=True)
207
+ if copy:
208
+ view = view.copy(deep=True)
209
+ return view
210
+
211
+ def take(self, indices: List[int]) -> "pandas.DataFrame":
212
+ table = self._table.take(indices)
213
+ table.reset_index(drop=True, inplace=True)
214
+ return table
215
+
216
+ def select(self, columns: List[str]) -> "pandas.DataFrame":
217
+ if not all(isinstance(col, str) for col in columns):
218
+ raise ValueError(
219
+ "Columns must be a list of column name strings when aggregating on "
220
+ f"Pandas blocks, but got: {columns}."
221
+ )
222
+ return self._table[columns]
223
+
224
+ def rename_columns(self, columns_rename: Dict[str, str]) -> "pandas.DataFrame":
225
+ return self._table.rename(columns=columns_rename, inplace=False, copy=False)
226
+
227
+ def random_shuffle(self, random_seed: Optional[int]) -> "pandas.DataFrame":
228
+ table = self._table.sample(frac=1, random_state=random_seed)
229
+ table.reset_index(drop=True, inplace=True)
230
+ return table
231
+
232
+ def schema(self) -> PandasBlockSchema:
233
+ dtypes = self._table.dtypes
234
+ schema = PandasBlockSchema(
235
+ names=dtypes.index.tolist(), types=dtypes.values.tolist()
236
+ )
237
+ # Column names with non-str types of a pandas DataFrame is not
238
+ # supported by Ray Dataset.
239
+ if any(not isinstance(name, str) for name in schema.names):
240
+ raise ValueError(
241
+ "A Pandas DataFrame with column names of non-str types"
242
+ " is not supported by Ray Dataset. Column names of this"
243
+ f" DataFrame: {schema.names!r}."
244
+ )
245
+ return schema
246
+
247
+ def to_pandas(self) -> "pandas.DataFrame":
248
+ from ray.air.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays
249
+
250
+ ctx = DataContext.get_current()
251
+ table = self._table
252
+ if ctx.enable_tensor_extension_casting:
253
+ table = _cast_tensor_columns_to_ndarrays(table)
254
+ return table
255
+
256
+ def to_numpy(
257
+ self, columns: Optional[Union[str, List[str]]] = None
258
+ ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
259
+ if columns is None:
260
+ columns = self._table.columns.tolist()
261
+ should_be_single_ndarray = False
262
+ elif isinstance(columns, list):
263
+ should_be_single_ndarray = False
264
+ else:
265
+ columns = [columns]
266
+ should_be_single_ndarray = True
267
+
268
+ column_names_set = set(self._table.columns)
269
+ for column in columns:
270
+ if column not in column_names_set:
271
+ raise ValueError(
272
+ f"Cannot find column {column}, available columns: "
273
+ f"{self._table.columns.tolist()}"
274
+ )
275
+
276
+ arrays = []
277
+ for column in columns:
278
+ arrays.append(self._table[column].to_numpy())
279
+
280
+ if should_be_single_ndarray:
281
+ arrays = arrays[0]
282
+ else:
283
+ arrays = dict(zip(columns, arrays))
284
+ return arrays
285
+
286
+ def to_arrow(self) -> "pyarrow.Table":
287
+ import pyarrow
288
+
289
+ # Set `preserve_index=False` so that Arrow doesn't add a '__index_level_0__'
290
+ # column to the resulting table.
291
+ return pyarrow.Table.from_pandas(self._table, preserve_index=False)
292
+
293
+ @staticmethod
294
+ def numpy_to_block(
295
+ batch: Union[Dict[str, np.ndarray], Dict[str, list]],
296
+ ) -> "pandas.DataFrame":
297
+ validate_numpy_batch(batch)
298
+
299
+ block = PandasBlockBuilder._table_from_pydict(batch)
300
+ return block
301
+
302
+ def num_rows(self) -> int:
303
+ return self._table.shape[0]
304
+
305
+ def size_bytes(self) -> int:
306
+ from pandas.api.types import is_object_dtype
307
+
308
+ from ray.air.util.tensor_extensions.pandas import TensorArray
309
+ from ray.data.extensions import TensorArrayElement, TensorDtype
310
+
311
+ pd = lazy_import_pandas()
312
+
313
+ def get_deep_size(obj):
314
+ """Calculates the memory size of objects,
315
+ including nested objects using an iterative approach."""
316
+ seen = set()
317
+ total_size = 0
318
+ objects = collections.deque([obj])
319
+ while objects:
320
+ current = objects.pop()
321
+
322
+ # Skip interning-eligible immutable objects
323
+ if isinstance(current, (str, bytes, int, float)):
324
+ size = sys.getsizeof(current)
325
+ total_size += size
326
+ continue
327
+
328
+ # Check if the object has been seen before
329
+ # i.e. a = np.ndarray([1,2,3]), b = [a,a]
330
+ # The patten above will have only one memory copy
331
+ if id(current) in seen:
332
+ continue
333
+ seen.add(id(current))
334
+
335
+ try:
336
+ size = sys.getsizeof(current)
337
+ except TypeError:
338
+ size = 0
339
+ total_size += size
340
+
341
+ # Handle specific cases
342
+ if isinstance(current, np.ndarray):
343
+ total_size += current.nbytes - size # Avoid double counting
344
+ elif isinstance(current, pd.DataFrame):
345
+ total_size += (
346
+ current.memory_usage(index=True, deep=True).sum() - size
347
+ )
348
+ elif isinstance(current, (list, tuple, set)):
349
+ objects.extend(current)
350
+ elif isinstance(current, dict):
351
+ objects.extend(current.keys())
352
+ objects.extend(current.values())
353
+ elif isinstance(current, TensorArrayElement):
354
+ objects.extend(current.to_numpy())
355
+ return total_size
356
+
357
+ # Get initial memory usage including deep introspection
358
+ memory_usage = self._table.memory_usage(index=True, deep=True)
359
+
360
+ # TensorDtype for ray.air.util.tensor_extensions.pandas.TensorDtype
361
+ object_need_check = (TensorDtype,)
362
+ max_sample_count = _PANDAS_SIZE_BYTES_MAX_SAMPLE_COUNT
363
+
364
+ # Handle object columns separately
365
+ for column in self._table.columns:
366
+ # Check pandas object dtype and the extension dtype
367
+ if is_object_dtype(self._table[column].dtype) or isinstance(
368
+ self._table[column].dtype, object_need_check
369
+ ):
370
+ total_size = len(self._table[column])
371
+
372
+ # Determine the sample size based on max_sample_count
373
+ sample_size = min(total_size, max_sample_count)
374
+ # Following codes can also handel case that sample_size == total_size
375
+ sampled_data = self._table[column].sample(n=sample_size).values
376
+
377
+ try:
378
+ if isinstance(sampled_data, TensorArray) and np.issubdtype(
379
+ sampled_data[0].numpy_dtype, np.number
380
+ ):
381
+ column_memory_sample = sampled_data.nbytes
382
+ else:
383
+ vectorized_size_calc = np.vectorize(lambda x: get_deep_size(x))
384
+ column_memory_sample = np.sum(
385
+ vectorized_size_calc(sampled_data)
386
+ )
387
+ # Scale back to the full column size if we sampled
388
+ column_memory = column_memory_sample * (total_size / sample_size)
389
+ memory_usage[column] = int(column_memory)
390
+ except Exception as e:
391
+ # Handle or log the exception as needed
392
+ logger.warning(f"Error calculating size for column '{column}': {e}")
393
+
394
+ # Sum up total memory usage
395
+ total_memory_usage = memory_usage.sum()
396
+
397
+ return int(total_memory_usage)
398
+
399
+ def _zip(self, acc: BlockAccessor) -> "pandas.DataFrame":
400
+ r = self.to_pandas().copy(deep=False)
401
+ s = acc.to_pandas()
402
+ for col_name in s.columns:
403
+ col = s[col_name]
404
+ column_names = list(r.columns)
405
+ # Ensure the column names are unique after zip.
406
+ if col_name in column_names:
407
+ i = 1
408
+ new_name = col_name
409
+ while new_name in column_names:
410
+ new_name = "{}_{}".format(col_name, i)
411
+ i += 1
412
+ col_name = new_name
413
+ r[col_name] = col
414
+ return r
415
+
416
+ @staticmethod
417
+ def builder() -> PandasBlockBuilder:
418
+ return PandasBlockBuilder()
419
+
420
+ @staticmethod
421
+ def _empty_table() -> "pandas.DataFrame":
422
+ return PandasBlockBuilder._empty_table()
423
+
424
+ def _sample(self, n_samples: int, sort_key: "SortKey") -> "pandas.DataFrame":
425
+ return self._table[sort_key.get_columns()].sample(n_samples, ignore_index=True)
426
+
427
+ def _apply_agg(
428
+ self, agg_fn: Callable[["pandas.Series", bool], U], on: str
429
+ ) -> Optional[U]:
430
+ """Helper providing null handling around applying an aggregation to a column."""
431
+ pd = lazy_import_pandas()
432
+ if on is not None and not isinstance(on, str):
433
+ raise ValueError(
434
+ "on must be a string or None when aggregating on Pandas blocks, but "
435
+ f"got: {type(on)}."
436
+ )
437
+
438
+ if self.num_rows() == 0:
439
+ return None
440
+
441
+ col = self._table[on]
442
+ try:
443
+ val = agg_fn(col)
444
+ except TypeError as e:
445
+ # Converting an all-null column in an Arrow Table to a Pandas DataFrame
446
+ # column will result in an all-None column of object type, which will raise
447
+ # a type error when attempting to do most binary operations. We explicitly
448
+ # check for this type failure here so we can properly propagate a null.
449
+ if np.issubdtype(col.dtype, np.object_) and col.isnull().all():
450
+ return None
451
+ raise e from None
452
+ if pd.isnull(val):
453
+ return None
454
+ return val
455
+
456
+ def count(self, on: str) -> Optional[U]:
457
+ return self._apply_agg(lambda col: col.count(), on)
458
+
459
+ def sum(self, on: str, ignore_nulls: bool) -> Optional[U]:
460
+ pd = lazy_import_pandas()
461
+ if on is not None and not isinstance(on, str):
462
+ raise ValueError(
463
+ "on must be a string or None when aggregating on Pandas blocks, but "
464
+ f"got: {type(on)}."
465
+ )
466
+
467
+ if self.num_rows() == 0:
468
+ return None
469
+
470
+ col = self._table[on]
471
+ if col.isnull().all():
472
+ # Short-circuit on an all-null column, returning None. This is required for
473
+ # sum() since it will otherwise return 0 when summing on an all-null column,
474
+ # which is not what we want.
475
+ return None
476
+ val = col.sum(skipna=ignore_nulls)
477
+ if pd.isnull(val):
478
+ return None
479
+ return val
480
+
481
+ def min(self, on: str, ignore_nulls: bool) -> Optional[U]:
482
+ return self._apply_agg(lambda col: col.min(skipna=ignore_nulls), on)
483
+
484
+ def max(self, on: str, ignore_nulls: bool) -> Optional[U]:
485
+ return self._apply_agg(lambda col: col.max(skipna=ignore_nulls), on)
486
+
487
+ def mean(self, on: str, ignore_nulls: bool) -> Optional[U]:
488
+ return self._apply_agg(lambda col: col.mean(skipna=ignore_nulls), on)
489
+
490
+ def sum_of_squared_diffs_from_mean(
491
+ self,
492
+ on: str,
493
+ ignore_nulls: bool,
494
+ mean: Optional[U] = None,
495
+ ) -> Optional[U]:
496
+ if mean is None:
497
+ mean = self.mean(on, ignore_nulls)
498
+ return self._apply_agg(
499
+ lambda col: ((col - mean) ** 2).sum(skipna=ignore_nulls),
500
+ on,
501
+ )
502
+
503
+ def sort_and_partition(
504
+ self, boundaries: List[T], sort_key: "SortKey"
505
+ ) -> List[Block]:
506
+ if self._table.shape[0] == 0:
507
+ # If the pyarrow table is empty we may not have schema
508
+ # so calling sort_indices() will raise an error.
509
+ return [self._empty_table() for _ in range(len(boundaries) + 1)]
510
+
511
+ columns, ascending = sort_key.to_pandas_sort_args()
512
+ table = self._table.sort_values(by=columns, ascending=ascending)
513
+ if len(boundaries) == 0:
514
+ return [table]
515
+
516
+ return find_partitions(table, boundaries, sort_key)
517
+
518
+ # TODO (srinathk) Needs to handle None types correctly.
519
+ def combine(
520
+ self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]
521
+ ) -> "pandas.DataFrame":
522
+ """Combine rows with the same key into an accumulator.
523
+
524
+ This assumes the block is already sorted by key in ascending order.
525
+
526
+ Args:
527
+ sort_key: A SortKey object which holds column names/keys.
528
+ If this is ``None``, place all rows in a single group.
529
+
530
+ aggs: The aggregations to do.
531
+
532
+ Returns:
533
+ A sorted block of [k, v_1, ..., v_n] columns where k is the groupby
534
+ key and v_i is the partially combined accumulator for the ith given
535
+ aggregation.
536
+ If key is None then the k column is omitted.
537
+ """
538
+ keys: List[str] = sort_key.get_columns()
539
+ pd = lazy_import_pandas()
540
+
541
+ def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
542
+ """Creates an iterator over zero-copy group views."""
543
+ if not keys:
544
+ # Global aggregation consists of a single "group", so we short-circuit.
545
+ yield tuple(), self.to_block()
546
+ return
547
+
548
+ start = end = 0
549
+ iter = self.iter_rows(public_row_format=False)
550
+ next_row = None
551
+ while True:
552
+ try:
553
+ if next_row is None:
554
+ next_row = next(iter)
555
+ next_keys = next_row[keys]
556
+ while keys_equal(next_row[keys], next_keys):
557
+ end += 1
558
+ try:
559
+ next_row = next(iter)
560
+ except StopIteration:
561
+ next_row = None
562
+ break
563
+ if isinstance(next_keys, pd.Series):
564
+ next_keys = next_keys.values
565
+ yield next_keys, self.slice(start, end, copy=False)
566
+ start = end
567
+ except StopIteration:
568
+ break
569
+
570
+ builder = PandasBlockBuilder()
571
+ for group_keys, group_view in iter_groups():
572
+ # Aggregate.
573
+ init_vals = group_keys
574
+ if len(group_keys) == 1:
575
+ init_vals = group_keys[0]
576
+ accumulators = [agg.init(init_vals) for agg in aggs]
577
+ for i in range(len(aggs)):
578
+ accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)
579
+
580
+ # Build the row.
581
+ row = {}
582
+ if keys:
583
+ for k, gk in zip(keys, group_keys):
584
+ row[k] = gk
585
+
586
+ count = collections.defaultdict(int)
587
+ for agg, accumulator in zip(aggs, accumulators):
588
+ name = agg.name
589
+ # Check for conflicts with existing aggregation name.
590
+ if count[name] > 0:
591
+ name = self._munge_conflict(name, count[name])
592
+ count[name] += 1
593
+ row[name] = accumulator
594
+
595
+ builder.add(row)
596
+
597
+ return builder.build()
598
+
599
+ @staticmethod
600
+ def merge_sorted_blocks(
601
+ blocks: List[Block], sort_key: "SortKey"
602
+ ) -> Tuple["pandas.DataFrame", BlockMetadata]:
603
+ pd = lazy_import_pandas()
604
+ stats = BlockExecStats.builder()
605
+ blocks = [b for b in blocks if b.shape[0] > 0]
606
+ if len(blocks) == 0:
607
+ ret = PandasBlockAccessor._empty_table()
608
+ else:
609
+ # Handle blocks of different types.
610
+ blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas")
611
+ ret = pd.concat(blocks, ignore_index=True)
612
+ columns, ascending = sort_key.to_pandas_sort_args()
613
+ ret = ret.sort_values(by=columns, ascending=ascending)
614
+ return ret, PandasBlockAccessor(ret).get_metadata(exec_stats=stats.build())
615
+
616
+ @staticmethod
617
+ def aggregate_combined_blocks(
618
+ blocks: List["pandas.DataFrame"],
619
+ sort_key: "SortKey",
620
+ aggs: Tuple["AggregateFn"],
621
+ finalize: bool,
622
+ ) -> Tuple["pandas.DataFrame", BlockMetadata]:
623
+ """Aggregate sorted, partially combined blocks with the same key range.
624
+
625
+ This assumes blocks are already sorted by key in ascending order,
626
+ so we can do merge sort to get all the rows with the same key.
627
+
628
+ Args:
629
+ blocks: A list of partially combined and sorted blocks.
630
+ sort_key: The column name of key or None for global aggregation.
631
+ aggs: The aggregations to do.
632
+ finalize: Whether to finalize the aggregation. This is used as an
633
+ optimization for cases where we repeatedly combine partially
634
+ aggregated groups.
635
+
636
+ Returns:
637
+ A block of [k, v_1, ..., v_n] columns and its metadata where k is
638
+ the groupby key and v_i is the corresponding aggregation result for
639
+ the ith given aggregation.
640
+ If key is None then the k column is omitted.
641
+ """
642
+
643
+ stats = BlockExecStats.builder()
644
+ keys = sort_key.get_columns()
645
+
646
+ def key_fn(r):
647
+ if keys:
648
+ return tuple(r[keys])
649
+ else:
650
+ return (0,)
651
+
652
+ # Handle blocks of different types.
653
+ blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas")
654
+
655
+ iter = heapq.merge(
656
+ *[
657
+ PandasBlockAccessor(block).iter_rows(public_row_format=False)
658
+ for block in blocks
659
+ ],
660
+ key=key_fn,
661
+ )
662
+ next_row = None
663
+ builder = PandasBlockBuilder()
664
+ while True:
665
+ try:
666
+ if next_row is None:
667
+ next_row = next(iter)
668
+ next_keys = key_fn(next_row)
669
+ next_key_columns = keys
670
+
671
+ def gen():
672
+ nonlocal iter
673
+ nonlocal next_row
674
+ while keys_equal(key_fn(next_row), next_keys):
675
+ yield next_row
676
+ try:
677
+ next_row = next(iter)
678
+ except StopIteration:
679
+ next_row = None
680
+ break
681
+
682
+ # Merge.
683
+ first = True
684
+ accumulators = [None] * len(aggs)
685
+ resolved_agg_names = [None] * len(aggs)
686
+ for r in gen():
687
+ if first:
688
+ count = collections.defaultdict(int)
689
+ for i in range(len(aggs)):
690
+ name = aggs[i].name
691
+ # Check for conflicts with existing aggregation
692
+ # name.
693
+ if count[name] > 0:
694
+ name = PandasBlockAccessor._munge_conflict(
695
+ name, count[name]
696
+ )
697
+ count[name] += 1
698
+ resolved_agg_names[i] = name
699
+ accumulators[i] = r[name]
700
+ first = False
701
+ else:
702
+ for i in range(len(aggs)):
703
+ accumulators[i] = aggs[i].merge(
704
+ accumulators[i], r[resolved_agg_names[i]]
705
+ )
706
+ # Build the row.
707
+ row = {}
708
+ if keys:
709
+ for col_name, next_key in zip(next_key_columns, next_keys):
710
+ row[col_name] = next_key
711
+
712
+ for agg, agg_name, accumulator in zip(
713
+ aggs, resolved_agg_names, accumulators
714
+ ):
715
+ if finalize:
716
+ row[agg_name] = agg.finalize(accumulator)
717
+ else:
718
+ row[agg_name] = accumulator
719
+
720
+ builder.add(row)
721
+ except StopIteration:
722
+ break
723
+
724
+ ret = builder.build()
725
+ return ret, PandasBlockAccessor(ret).get_metadata(exec_stats=stats.build())
726
+
727
+ def block_type(self) -> BlockType:
728
+ return BlockType.PANDAS
.venv/lib/python3.11/site-packages/ray/data/_internal/plan.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import logging
4
+ from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union
5
+
6
+ import pyarrow
7
+
8
+ import ray
9
+ from ray._private.internal_api import get_memory_info_reply, get_state_from_address
10
+ from ray.data._internal.execution.interfaces import RefBundle
11
+ from ray.data._internal.logical.interfaces.logical_operator import LogicalOperator
12
+ from ray.data._internal.logical.interfaces.logical_plan import LogicalPlan
13
+ from ray.data._internal.logical.operators.from_operators import AbstractFrom
14
+ from ray.data._internal.logical.operators.input_data_operator import InputData
15
+ from ray.data._internal.logical.operators.read_operator import Read
16
+ from ray.data._internal.stats import DatasetStats
17
+ from ray.data._internal.util import create_dataset_tag, unify_block_metadata_schema
18
+ from ray.data.block import BlockMetadata
19
+ from ray.data.context import DataContext
20
+ from ray.data.exceptions import omit_traceback_stdout
21
+ from ray.util.debug import log_once
22
+
23
+ if TYPE_CHECKING:
24
+
25
+ from ray.data._internal.execution.interfaces import Executor
26
+ from ray.data.dataset import Dataset
27
+
28
+
29
+ # Scheduling strategy can be inherited from prev operator if not specified.
30
+ INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class ExecutionPlan:
37
+ """A lazy execution plan for a Dataset.
38
+
39
+ This lazy execution plan builds up a chain of ``List[RefBundle]`` -->
40
+ ``List[RefBundle]`` operators. Prior to execution, we apply a set of logical
41
+ plan optimizations, such as operator fusion, in order to reduce Ray task
42
+ overhead and data copies.
43
+
44
+ Internally, the execution plan holds a snapshot of a computed list of
45
+ blocks and their associated metadata under ``self._snapshot_bundle``,
46
+ where this snapshot is the cached output of executing the operator chain."""
47
+
48
+ def __init__(
49
+ self,
50
+ stats: DatasetStats,
51
+ *,
52
+ data_context: Optional[DataContext] = None,
53
+ ):
54
+ """Create a plan with no transformation operators.
55
+
56
+ Args:
57
+ stats: Stats for the base blocks.
58
+ data_context: :class:`~ray.data.context.DataContext`
59
+ object to use for execution.
60
+ """
61
+ self._in_stats = stats
62
+ # A computed snapshot of some prefix of operators and their corresponding
63
+ # output blocks and stats.
64
+ self._snapshot_operator: Optional[LogicalOperator] = None
65
+ self._snapshot_stats = None
66
+ self._snapshot_bundle = None
67
+ # Snapshot of only metadata corresponding to the final operator's
68
+ # output bundles, used as the source of truth for the Dataset's schema
69
+ # and count. This is calculated and cached when the plan is executed as an
70
+ # iterator (`execute_to_iterator()`), and avoids caching
71
+ # all of the output blocks in memory like in `self.snapshot_bundle`.
72
+ # TODO(scottjlee): To keep the caching logic consistent, update `execute()`
73
+ # to also store the metadata in `_snapshot_metadata` instead of
74
+ # `_snapshot_bundle`. For example, we could store the blocks in
75
+ # `self._snapshot_blocks` and the metadata in `self._snapshot_metadata`.
76
+ self._snapshot_metadata: Optional[BlockMetadata] = None
77
+
78
+ # Cached schema.
79
+ self._schema = None
80
+ # Set when a Dataset is constructed with this plan
81
+ self._dataset_uuid = None
82
+
83
+ self._dataset_name = None
84
+
85
+ self._has_started_execution = False
86
+
87
+ if data_context is None:
88
+ # Snapshot the current context, so that the config of Datasets is always
89
+ # determined by the config at the time it was created.
90
+ self._context = copy.deepcopy(DataContext.get_current())
91
+ else:
92
+ self._context = data_context
93
+
94
+ def __repr__(self) -> str:
95
+ return (
96
+ f"ExecutionPlan("
97
+ f"dataset_uuid={self._dataset_uuid}, "
98
+ f"snapshot_operator={self._snapshot_operator}"
99
+ f")"
100
+ )
101
+
102
+ def get_plan_as_string(self, dataset_cls: Type["Dataset"]) -> str:
103
+ """Create a cosmetic string representation of this execution plan.
104
+
105
+ Returns:
106
+ The string representation of this execution plan.
107
+ """
108
+ # NOTE: this is used for Dataset.__repr__ to give a user-facing string
109
+ # representation. Ideally ExecutionPlan.__repr__ should be replaced with this
110
+ # method as well.
111
+
112
+ from ray.data.dataset import MaterializedDataset
113
+
114
+ # Do not force execution for schema, as this method is expected to be very
115
+ # cheap.
116
+ plan_str = ""
117
+ plan_max_depth = 0
118
+ if not self.has_computed_output():
119
+
120
+ def generate_logical_plan_string(
121
+ op: LogicalOperator,
122
+ curr_str: str = "",
123
+ depth: int = 0,
124
+ ):
125
+ """Traverse (DFS) the LogicalPlan DAG and
126
+ return a string representation of the operators."""
127
+ if isinstance(op, (Read, InputData, AbstractFrom)):
128
+ return curr_str, depth
129
+
130
+ curr_max_depth = depth
131
+ op_name = op.name
132
+ if depth == 0:
133
+ curr_str += f"{op_name}\n"
134
+ else:
135
+ trailing_space = " " * ((depth - 1) * 3)
136
+ curr_str += f"{trailing_space}+- {op_name}\n"
137
+
138
+ for input in op.input_dependencies:
139
+ curr_str, input_max_depth = generate_logical_plan_string(
140
+ input, curr_str, depth + 1
141
+ )
142
+ curr_max_depth = max(curr_max_depth, input_max_depth)
143
+ return curr_str, curr_max_depth
144
+
145
+ # generate_logical_plan_string(self._logical_plan.dag)
146
+ plan_str, plan_max_depth = generate_logical_plan_string(
147
+ self._logical_plan.dag
148
+ )
149
+
150
+ if self._snapshot_bundle is not None:
151
+ # This plan has executed some but not all operators.
152
+ schema = unify_block_metadata_schema(self._snapshot_bundle.metadata)
153
+ count = self._snapshot_bundle.num_rows()
154
+ elif self._snapshot_metadata is not None:
155
+ schema = self._snapshot_metadata.schema
156
+ count = self._snapshot_metadata.num_rows
157
+ else:
158
+ # This plan hasn't executed any operators.
159
+ sources = self._logical_plan.sources()
160
+ # TODO(@bveeramani): Handle schemas for n-ary operators like `Union`.
161
+ if len(sources) > 1:
162
+ # Multiple sources, cannot determine schema.
163
+ schema = None
164
+ count = None
165
+ else:
166
+ assert len(sources) == 1
167
+ plan = ExecutionPlan(DatasetStats(metadata={}, parent=None))
168
+ plan.link_logical_plan(LogicalPlan(sources[0], plan._context))
169
+ schema = plan.schema()
170
+ count = plan.meta_count()
171
+ else:
172
+ # Get schema of output blocks.
173
+ schema = self.schema(fetch_if_missing=False)
174
+ count = self._snapshot_bundle.num_rows()
175
+
176
+ if schema is None:
177
+ schema_str = "Unknown schema"
178
+ elif isinstance(schema, type):
179
+ schema_str = str(schema)
180
+ else:
181
+ schema_str = []
182
+ for n, t in zip(schema.names, schema.types):
183
+ if hasattr(t, "__name__"):
184
+ t = t.__name__
185
+ schema_str.append(f"{n}: {t}")
186
+ schema_str = ", ".join(schema_str)
187
+ schema_str = "{" + schema_str + "}"
188
+
189
+ if count is None:
190
+ count = "?"
191
+
192
+ num_blocks = None
193
+ if dataset_cls == MaterializedDataset:
194
+ num_blocks = self.initial_num_blocks()
195
+ assert num_blocks is not None
196
+
197
+ name_str = (
198
+ "name={}, ".format(self._dataset_name)
199
+ if self._dataset_name is not None
200
+ else ""
201
+ )
202
+ num_blocks_str = f"num_blocks={num_blocks}, " if num_blocks else ""
203
+
204
+ dataset_str = "{}({}{}num_rows={}, schema={})".format(
205
+ dataset_cls.__name__,
206
+ name_str,
207
+ num_blocks_str,
208
+ count,
209
+ schema_str,
210
+ )
211
+
212
+ # If the resulting string representation fits in one line, use it directly.
213
+ SCHEMA_LINE_CHAR_LIMIT = 80
214
+ MIN_FIELD_LENGTH = 10
215
+ INDENT_STR = " " * 3
216
+ trailing_space = INDENT_STR * plan_max_depth
217
+
218
+ if len(dataset_str) > SCHEMA_LINE_CHAR_LIMIT:
219
+ # If the resulting string representation exceeds the line char limit,
220
+ # first try breaking up each `Dataset` parameter into its own line
221
+ # and check if each line fits within the line limit. We check the
222
+ # `schema` param's length, since this is likely the longest string.
223
+ schema_str_on_new_line = f"{trailing_space}{INDENT_STR}schema={schema_str}"
224
+ if len(schema_str_on_new_line) > SCHEMA_LINE_CHAR_LIMIT:
225
+ # If the schema cannot fit on a single line, break up each field
226
+ # into its own line.
227
+ schema_str = []
228
+ for n, t in zip(schema.names, schema.types):
229
+ if hasattr(t, "__name__"):
230
+ t = t.__name__
231
+ col_str = f"{trailing_space}{INDENT_STR * 2}{n}: {t}"
232
+ # If the field line exceeds the char limit, abbreviate
233
+ # the field name to fit while maintaining the full type
234
+ if len(col_str) > SCHEMA_LINE_CHAR_LIMIT:
235
+ shortened_suffix = f"...: {str(t)}"
236
+ # Show at least 10 characters of the field name, even if
237
+ # we have already hit the line limit with the type.
238
+ chars_left_for_col_name = max(
239
+ SCHEMA_LINE_CHAR_LIMIT - len(shortened_suffix),
240
+ MIN_FIELD_LENGTH,
241
+ )
242
+ col_str = (
243
+ f"{col_str[:chars_left_for_col_name]}{shortened_suffix}"
244
+ )
245
+ schema_str.append(col_str)
246
+ schema_str = ",\n".join(schema_str)
247
+ schema_str = (
248
+ "{\n" + schema_str + f"\n{trailing_space}{INDENT_STR}" + "}"
249
+ )
250
+ name_str = (
251
+ f"\n{trailing_space}{INDENT_STR}name={self._dataset_name},"
252
+ if self._dataset_name is not None
253
+ else ""
254
+ )
255
+ num_blocks_str = (
256
+ f"\n{trailing_space}{INDENT_STR}num_blocks={num_blocks},"
257
+ if num_blocks
258
+ else ""
259
+ )
260
+ dataset_str = (
261
+ f"{dataset_cls.__name__}("
262
+ f"{name_str}"
263
+ f"{num_blocks_str}"
264
+ f"\n{trailing_space}{INDENT_STR}num_rows={count},"
265
+ f"\n{trailing_space}{INDENT_STR}schema={schema_str}"
266
+ f"\n{trailing_space})"
267
+ )
268
+
269
+ if plan_max_depth == 0:
270
+ plan_str += dataset_str
271
+ else:
272
+ plan_str += f"{INDENT_STR * (plan_max_depth - 1)}+- {dataset_str}"
273
+ return plan_str
274
+
275
+ def link_logical_plan(self, logical_plan: "LogicalPlan"):
276
+ """Link the logical plan into this execution plan.
277
+
278
+ This is used for triggering execution for optimizer code path in this legacy
279
+ execution plan.
280
+ """
281
+ self._logical_plan = logical_plan
282
+ self._logical_plan._context = self._context
283
+
284
+ def copy(self) -> "ExecutionPlan":
285
+ """Create a shallow copy of this execution plan.
286
+
287
+ This copy can be executed without mutating the original, but clearing the copy
288
+ will also clear the original.
289
+
290
+ Returns:
291
+ A shallow copy of this execution plan.
292
+ """
293
+ plan_copy = ExecutionPlan(
294
+ self._in_stats,
295
+ data_context=self._context,
296
+ )
297
+ if self._snapshot_bundle is not None:
298
+ # Copy over the existing snapshot.
299
+ plan_copy._snapshot_bundle = self._snapshot_bundle
300
+ plan_copy._snapshot_operator = self._snapshot_operator
301
+ plan_copy._snapshot_stats = self._snapshot_stats
302
+ plan_copy._dataset_name = self._dataset_name
303
+ return plan_copy
304
+
305
+ def deep_copy(self) -> "ExecutionPlan":
306
+ """Create a deep copy of this execution plan.
307
+
308
+ This copy can be executed AND cleared without mutating the original.
309
+
310
+ Returns:
311
+ A deep copy of this execution plan.
312
+ """
313
+ plan_copy = ExecutionPlan(copy.copy(self._in_stats))
314
+ if self._snapshot_bundle:
315
+ # Copy over the existing snapshot.
316
+ plan_copy._snapshot_bundle = copy.copy(self._snapshot_bundle)
317
+ plan_copy._snapshot_operator = copy.copy(self._snapshot_operator)
318
+ plan_copy._snapshot_stats = copy.copy(self._snapshot_stats)
319
+ plan_copy._dataset_name = self._dataset_name
320
+ return plan_copy
321
+
322
+ def initial_num_blocks(self) -> Optional[int]:
323
+ """Get the estimated number of blocks from the logical plan
324
+ after applying execution plan optimizations, but prior to
325
+ fully executing the dataset."""
326
+ return self._logical_plan.dag.estimated_num_outputs()
327
+
328
+ def schema(
329
+ self, fetch_if_missing: bool = False
330
+ ) -> Union[type, "pyarrow.lib.Schema"]:
331
+ """Get the schema after applying all execution plan optimizations,
332
+ but prior to fully executing the dataset
333
+ (unless `fetch_if_missing` is set to True).
334
+
335
+ Args:
336
+ fetch_if_missing: Whether to execute the plan to fetch the schema.
337
+
338
+ Returns:
339
+ The schema of the output dataset.
340
+ """
341
+ if self._schema is not None:
342
+ return self._schema
343
+
344
+ schema = None
345
+ if self.has_computed_output():
346
+ schema = unify_block_metadata_schema(self._snapshot_bundle.metadata)
347
+ elif self._logical_plan.dag.aggregate_output_metadata().schema is not None:
348
+ schema = self._logical_plan.dag.aggregate_output_metadata().schema
349
+ elif fetch_if_missing:
350
+ iter_ref_bundles, _, _ = self.execute_to_iterator()
351
+ for ref_bundle in iter_ref_bundles:
352
+ for metadata in ref_bundle.metadata:
353
+ if metadata.schema is not None and (
354
+ metadata.num_rows is None or metadata.num_rows > 0
355
+ ):
356
+ schema = metadata.schema
357
+ break
358
+ elif self.is_read_only():
359
+ # For consistency with the previous implementation, we fetch the schema if
360
+ # the plan is read-only even if `fetch_if_missing` is False.
361
+ iter_ref_bundles, _, _ = self.execute_to_iterator()
362
+ try:
363
+ ref_bundle = next(iter(iter_ref_bundles))
364
+ for metadata in ref_bundle.metadata:
365
+ if metadata.schema is not None:
366
+ schema = metadata.schema
367
+ break
368
+ except StopIteration: # Empty dataset.
369
+ schema = None
370
+
371
+ self._schema = schema
372
+ return self._schema
373
+
374
+ def cache_schema(self, schema: Union[type, "pyarrow.lib.Schema"]):
375
+ self._schema = schema
376
+
377
+ def input_files(self) -> Optional[List[str]]:
378
+ """Get the input files of the dataset, if available."""
379
+ return self._logical_plan.dag.aggregate_output_metadata().input_files
380
+
381
+ def meta_count(self) -> Optional[int]:
382
+ """Get the number of rows after applying all plan optimizations, if possible.
383
+
384
+ This method will never trigger any computation.
385
+
386
+ Returns:
387
+ The number of records of the result Dataset, or None.
388
+ """
389
+ if self.has_computed_output():
390
+ num_rows = sum(m.num_rows for m in self._snapshot_bundle.metadata)
391
+ elif self._logical_plan.dag.aggregate_output_metadata().num_rows is not None:
392
+ num_rows = self._logical_plan.dag.aggregate_output_metadata().num_rows
393
+ else:
394
+ num_rows = None
395
+ return num_rows
396
+
397
+ @omit_traceback_stdout
398
+ def execute_to_iterator(
399
+ self,
400
+ ) -> Tuple[Iterator[RefBundle], DatasetStats, Optional["Executor"]]:
401
+ """Execute this plan, returning an iterator.
402
+
403
+ This will use streaming execution to generate outputs.
404
+
405
+ Returns:
406
+ Tuple of iterator over output RefBundles, DatasetStats, and the executor.
407
+ """
408
+ self._has_started_execution = True
409
+
410
+ # Always used the saved context for execution.
411
+ ctx = self._context
412
+
413
+ if self.has_computed_output():
414
+ bundle = self.execute()
415
+ return iter([bundle]), self._snapshot_stats, None
416
+
417
+ from ray.data._internal.execution.legacy_compat import (
418
+ execute_to_legacy_bundle_iterator,
419
+ )
420
+ from ray.data._internal.execution.streaming_executor import StreamingExecutor
421
+
422
+ metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
423
+ executor = StreamingExecutor(ctx, metrics_tag)
424
+ bundle_iter = execute_to_legacy_bundle_iterator(executor, self)
425
+ # Since the generator doesn't run any code until we try to fetch the first
426
+ # value, force execution of one bundle before we call get_stats().
427
+ gen = iter(bundle_iter)
428
+ try:
429
+ bundle_iter = itertools.chain([next(gen)], gen)
430
+ except StopIteration:
431
+ pass
432
+ self._snapshot_stats = executor.get_stats()
433
+ return bundle_iter, self._snapshot_stats, executor
434
+
435
+ @omit_traceback_stdout
436
+ def execute(
437
+ self,
438
+ preserve_order: bool = False,
439
+ ) -> RefBundle:
440
+ """Execute this plan.
441
+
442
+ Args:
443
+ preserve_order: Whether to preserve order in execution.
444
+
445
+ Returns:
446
+ The blocks of the output dataset.
447
+ """
448
+ self._has_started_execution = True
449
+
450
+ # Always used the saved context for execution.
451
+ context = self._context
452
+
453
+ if not ray.available_resources().get("CPU"):
454
+ if log_once("cpu_warning"):
455
+ logger.warning(
456
+ "Warning: The Ray cluster currently does not have "
457
+ "any available CPUs. The Dataset job will hang unless more CPUs "
458
+ "are freed up. A common reason is that cluster resources are "
459
+ "used by Actors or Tune trials; see the following link "
460
+ "for more details: "
461
+ "https://docs.ray.io/en/latest/data/data-internals.html#ray-data-and-tune" # noqa: E501
462
+ )
463
+ if not self.has_computed_output():
464
+ from ray.data._internal.execution.legacy_compat import (
465
+ _get_initial_stats_from_plan,
466
+ execute_to_legacy_block_list,
467
+ )
468
+
469
+ if self._logical_plan.dag.output_data() is not None:
470
+ # If the data is already materialized (e.g., `from_pandas`), we can
471
+ # skip execution and directly return the output data. This avoids
472
+ # recording unnecessary metrics for an empty plan execution.
473
+ stats = _get_initial_stats_from_plan(self)
474
+
475
+ # TODO(@bveeramani): Make `ExecutionPlan.execute()` return
476
+ # `List[RefBundle]` instead of `RefBundle`. Among other reasons, it'd
477
+ # allow us to remove the unwrapping logic below.
478
+ output_bundles = self._logical_plan.dag.output_data()
479
+ owns_blocks = all(bundle.owns_blocks for bundle in output_bundles)
480
+ bundle = RefBundle(
481
+ [
482
+ (block, metadata)
483
+ for bundle in output_bundles
484
+ for block, metadata in bundle.blocks
485
+ ],
486
+ owns_blocks=owns_blocks,
487
+ )
488
+ else:
489
+ from ray.data._internal.execution.streaming_executor import (
490
+ StreamingExecutor,
491
+ )
492
+
493
+ metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
494
+ executor = StreamingExecutor(
495
+ context,
496
+ metrics_tag,
497
+ )
498
+ blocks = execute_to_legacy_block_list(
499
+ executor,
500
+ self,
501
+ dataset_uuid=self._dataset_uuid,
502
+ preserve_order=preserve_order,
503
+ )
504
+ bundle = RefBundle(
505
+ tuple(blocks.iter_blocks_with_metadata()),
506
+ owns_blocks=blocks._owned_by_consumer,
507
+ )
508
+ stats = executor.get_stats()
509
+ stats_summary_string = stats.to_summary().to_string(
510
+ include_parent=False
511
+ )
512
+ if context.enable_auto_log_stats:
513
+ logger.info(stats_summary_string)
514
+
515
+ # Retrieve memory-related stats from ray.
516
+ try:
517
+ reply = get_memory_info_reply(
518
+ get_state_from_address(ray.get_runtime_context().gcs_address)
519
+ )
520
+ if reply.store_stats.spill_time_total_s > 0:
521
+ stats.global_bytes_spilled = int(
522
+ reply.store_stats.spilled_bytes_total
523
+ )
524
+ if reply.store_stats.restore_time_total_s > 0:
525
+ stats.global_bytes_restored = int(
526
+ reply.store_stats.restored_bytes_total
527
+ )
528
+ except Exception as e:
529
+ logger.debug(
530
+ "Skipping recording memory spilled and restored statistics due to "
531
+ f"exception: {e}"
532
+ )
533
+
534
+ stats.dataset_bytes_spilled = 0
535
+
536
+ def collect_stats(cur_stats):
537
+ stats.dataset_bytes_spilled += cur_stats.extra_metrics.get(
538
+ "obj_store_mem_spilled", 0
539
+ )
540
+ for parent in cur_stats.parents:
541
+ collect_stats(parent)
542
+
543
+ collect_stats(stats)
544
+
545
+ # Set the snapshot to the output of the final operator.
546
+ self._snapshot_bundle = bundle
547
+ self._snapshot_operator = self._logical_plan.dag
548
+ self._snapshot_stats = stats
549
+ self._snapshot_stats.dataset_uuid = self._dataset_uuid
550
+
551
+ return self._snapshot_bundle
552
+
553
+ @property
554
+ def has_started_execution(self) -> bool:
555
+ """Return ``True`` if this plan has been partially or fully executed."""
556
+ return self._has_started_execution
557
+
558
+ def clear_snapshot(self) -> None:
559
+ """Clear the snapshot kept in the plan to the beginning state."""
560
+ self._snapshot_bundle = None
561
+ self._snapshot_operator = None
562
+ self._snapshot_stats = None
563
+
564
+ def stats(self) -> DatasetStats:
565
+ """Return stats for this plan.
566
+
567
+ If the plan isn't executed, an empty stats object will be returned.
568
+ """
569
+ if not self._snapshot_stats:
570
+ return DatasetStats(metadata={}, parent=None)
571
+ return self._snapshot_stats
572
+
573
+ def has_lazy_input(self) -> bool:
574
+ """Return whether this plan has lazy input blocks."""
575
+ return all(isinstance(op, Read) for op in self._logical_plan.sources())
576
+
577
+ def is_read_only(self, root_op: Optional[LogicalOperator] = None) -> bool:
578
+ """Return whether the LogicalPlan corresponding to `root_op`
579
+ contains only a Read op. By default, the last operator of
580
+ the LogicalPlan is used."""
581
+ if root_op is None:
582
+ root_op = self._logical_plan.dag
583
+ return isinstance(root_op, Read) and len(root_op.input_dependencies) == 0
584
+
585
+ def has_computed_output(self) -> bool:
586
+ """Whether this plan has a computed snapshot for the final operator, i.e. for
587
+ the output of this plan.
588
+ """
589
+ return (
590
+ self._snapshot_bundle is not None
591
+ and self._snapshot_operator == self._logical_plan.dag
592
+ )
593
+
594
+ def require_preserve_order(self) -> bool:
595
+ """Whether this plan requires to preserve order."""
596
+ from ray.data._internal.logical.operators.all_to_all_operator import Sort
597
+ from ray.data._internal.logical.operators.n_ary_operator import Zip
598
+
599
+ for op in self._logical_plan.dag.post_order_iter():
600
+ if isinstance(op, (Zip, Sort)):
601
+ return True
602
+ return False
.venv/lib/python3.11/site-packages/ray/data/_internal/progress_bar.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ from typing import Any, List, Optional
4
+
5
+ import ray
6
+ from ray.experimental import tqdm_ray
7
+ from ray.types import ObjectRef
8
+ from ray.util.debug import log_once
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ try:
13
+ import tqdm
14
+
15
+ needs_warning = False
16
+ except ImportError:
17
+ tqdm = None
18
+ needs_warning = True
19
+
20
+ # Used a signal to cancel execution.
21
+ _canceled_threads = set()
22
+ _canceled_threads_lock = threading.Lock()
23
+
24
+
25
+ def extract_num_rows(result: Any) -> int:
26
+ """Extract the number of rows from a result object.
27
+
28
+ Args:
29
+ result: The result object from which to extract the number of rows.
30
+
31
+ Returns:
32
+ The number of rows, defaulting to 1 if it cannot be determined.
33
+ """
34
+ if hasattr(result, "num_rows"):
35
+ return result.num_rows
36
+ elif hasattr(result, "__len__"):
37
+ # For output is DataFrame,i.e. sort_sample
38
+ return len(result)
39
+ else:
40
+ return 1
41
+
42
+
43
+ class ProgressBar:
44
+ """Thin wrapper around tqdm to handle soft imports.
45
+
46
+ If `total` is `None` known (for example, it is unknown
47
+ because no tasks have finished yet), doesn't display the full
48
+ progress bar. Still displays basic progress stats from tqdm."""
49
+
50
+ # If the name/description of the progress bar exceeds this length,
51
+ # it will be truncated.
52
+ MAX_NAME_LENGTH = 100
53
+
54
+ def __init__(
55
+ self,
56
+ name: str,
57
+ total: Optional[int],
58
+ unit: str,
59
+ position: int = 0,
60
+ enabled: Optional[bool] = None,
61
+ ):
62
+ self._desc = self._truncate_name(name)
63
+ self._progress = 0
64
+ # Prepend a space to the unit for better formatting.
65
+ if unit[0] != " ":
66
+ unit = " " + unit
67
+
68
+ if enabled is None:
69
+ from ray.data import DataContext
70
+
71
+ enabled = DataContext.get_current().enable_progress_bars
72
+ if not enabled:
73
+ self._bar = None
74
+ elif tqdm:
75
+ ctx = ray.data.context.DataContext.get_current()
76
+ if ctx.use_ray_tqdm:
77
+ self._bar = tqdm_ray.tqdm(total=total, unit=unit, position=position)
78
+ else:
79
+ self._bar = tqdm.tqdm(
80
+ total=total or 0,
81
+ position=position,
82
+ dynamic_ncols=True,
83
+ unit=unit,
84
+ unit_scale=True,
85
+ )
86
+ self._bar.set_description(self._desc)
87
+ else:
88
+ global needs_warning
89
+ if needs_warning:
90
+ print("[dataset]: Run `pip install tqdm` to enable progress reporting.")
91
+ needs_warning = False
92
+ self._bar = None
93
+
94
+ def _truncate_name(self, name: str) -> str:
95
+ ctx = ray.data.context.DataContext.get_current()
96
+ if (
97
+ not ctx.enable_progress_bar_name_truncation
98
+ or len(name) <= self.MAX_NAME_LENGTH
99
+ ):
100
+ return name
101
+
102
+ op_names = name.split("->")
103
+ if len(op_names) == 1:
104
+ return op_names[0]
105
+
106
+ # Include as many operators as possible without approximately
107
+ # exceeding `MAX_NAME_LENGTH`. Always include the first and
108
+ # last operator names soit is easy to identify the DAG.
109
+ truncated_op_names = [op_names[0]]
110
+ for op_name in op_names[1:-1]:
111
+ if (
112
+ len("->".join(truncated_op_names))
113
+ + len("->")
114
+ + len(op_name)
115
+ + len("->")
116
+ + len(op_names[-1])
117
+ ) > self.MAX_NAME_LENGTH:
118
+ truncated_op_names.append("...")
119
+ if log_once("ray_data_truncate_operator_name"):
120
+ logger.warning(
121
+ f"Truncating long operator name to {self.MAX_NAME_LENGTH} "
122
+ "characters. To disable this behavior, set "
123
+ "`ray.data.DataContext.get_current()."
124
+ "DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`."
125
+ )
126
+ break
127
+ truncated_op_names.append(op_name)
128
+ truncated_op_names.append(op_names[-1])
129
+ return "->".join(truncated_op_names)
130
+
131
+ def block_until_complete(self, remaining: List[ObjectRef]) -> None:
132
+ t = threading.current_thread()
133
+ while remaining:
134
+ done, remaining = ray.wait(
135
+ remaining, num_returns=len(remaining), fetch_local=False, timeout=0.1
136
+ )
137
+ total_rows_processed = 0
138
+ for _, result in zip(done, ray.get(done)):
139
+ num_rows = extract_num_rows(result)
140
+ total_rows_processed += num_rows
141
+ self.update(total_rows_processed)
142
+
143
+ with _canceled_threads_lock:
144
+ if t in _canceled_threads:
145
+ break
146
+
147
+ def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]:
148
+ ref_to_result = {}
149
+ remaining = refs
150
+ t = threading.current_thread()
151
+ # Triggering fetch_local redundantly for the same object is slower.
152
+ # We only need to trigger the fetch_local once for each object,
153
+ # raylet will persist these fetch requests even after ray.wait returns.
154
+ # See https://github.com/ray-project/ray/issues/30375.
155
+ fetch_local = True
156
+ while remaining:
157
+ done, remaining = ray.wait(
158
+ remaining,
159
+ num_returns=len(remaining),
160
+ fetch_local=fetch_local,
161
+ timeout=0.1,
162
+ )
163
+ if fetch_local:
164
+ fetch_local = False
165
+ total_rows_processed = 0
166
+ for ref, result in zip(done, ray.get(done)):
167
+ ref_to_result[ref] = result
168
+ num_rows = extract_num_rows(result)
169
+ total_rows_processed += num_rows
170
+ self.update(total_rows_processed)
171
+
172
+ with _canceled_threads_lock:
173
+ if t in _canceled_threads:
174
+ break
175
+
176
+ return [ref_to_result[ref] for ref in refs]
177
+
178
+ def set_description(self, name: str) -> None:
179
+ name = self._truncate_name(name)
180
+ if self._bar and name != self._desc:
181
+ self._desc = name
182
+ self._bar.set_description(self._desc)
183
+
184
+ def get_description(self) -> str:
185
+ return self._desc
186
+
187
+ def refresh(self):
188
+ if self._bar:
189
+ self._bar.refresh()
190
+
191
+ def update(self, i: int = 0, total: Optional[int] = None) -> None:
192
+ if self._bar and (i != 0 or self._bar.total != total):
193
+ self._progress += i
194
+ if total is not None:
195
+ self._bar.total = total
196
+ if self._bar.total is not None and self._progress > self._bar.total:
197
+ # If the progress goes over 100%, update the total.
198
+ self._bar.total = self._progress
199
+ self._bar.update(i)
200
+
201
+ def close(self):
202
+ if self._bar:
203
+ if self._bar.total is not None and self._progress != self._bar.total:
204
+ # If the progress is not complete, update the total.
205
+ self._bar.total = self._progress
206
+ self._bar.refresh()
207
+ self._bar.close()
208
+ self._bar = None
209
+
210
+ def __del__(self):
211
+ self.close()
212
+
213
+ def __getstate__(self):
214
+ return {}
215
+
216
+ def __setstate__(self, state):
217
+ self._bar = None # Progress bar is disabled on remote nodes.
.venv/lib/python3.11/site-packages/ray/data/_internal/remote_fn.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Hashable, List
2
+
3
+ import ray
4
+
5
+ CACHED_FUNCTIONS = {}
6
+
7
+
8
+ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any:
9
+ """Lazily defines a ray.remote function.
10
+
11
+ This is used in Datasets to avoid circular import issues with ray.remote.
12
+ (ray imports ray.data in order to allow ``ray.data.read_foo()`` to work,
13
+ which means ray.remote cannot be used top-level in ray.data).
14
+
15
+ NOTE: Dynamic arguments should not be passed in directly,
16
+ and should be set with ``options`` instead:
17
+ ``cached_remote_fn(fn, **static_args).options(**dynamic_args)``.
18
+ """
19
+
20
+ # NOTE: Hash of the passed in arguments guarantees that we're caching
21
+ # complete instantiation of the Ray's remote method
22
+ #
23
+ # To compute the hash of passed in arguments and make sure it's deterministic
24
+ # - Sort all KV-pairs by the keys
25
+ # - Convert sorted list into tuple
26
+ # - Compute hash of the resulting tuple
27
+ hashable_args = _make_hashable(ray_remote_args)
28
+ args_hash = hash(hashable_args)
29
+
30
+ if (fn, args_hash) not in CACHED_FUNCTIONS:
31
+ default_ray_remote_args = {
32
+ # Use the default scheduling strategy for all tasks so that we will
33
+ # not inherit a placement group from the caller, if there is one.
34
+ # The caller of this function may override the scheduling strategy
35
+ # as needed.
36
+ "scheduling_strategy": "DEFAULT",
37
+ "max_retries": -1,
38
+ }
39
+ ray_remote_args = {**default_ray_remote_args, **ray_remote_args}
40
+ _add_system_error_to_retry_exceptions(ray_remote_args)
41
+
42
+ CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn)
43
+
44
+ return CACHED_FUNCTIONS[(fn, args_hash)]
45
+
46
+
47
+ def _make_hashable(obj):
48
+ if isinstance(obj, (List, tuple)):
49
+ return tuple([_make_hashable(o) for o in obj])
50
+ elif isinstance(obj, Dict):
51
+ converted = [(_make_hashable(k), _make_hashable(v)) for k, v in obj.items()]
52
+ return tuple(sorted(converted, key=lambda t: t[0]))
53
+ elif isinstance(obj, Hashable):
54
+ return obj
55
+ else:
56
+ raise ValueError(f"Type {type(obj)} is not hashable")
57
+
58
+
59
+ def _add_system_error_to_retry_exceptions(ray_remote_args) -> None:
60
+ """Modify the remote args so that Ray retries `RaySystemError`s.
61
+
62
+ Ray typically automatically retries system errors. However, in some cases, Ray won't
63
+ retry system errors if they're raised from task code. To ensure that Ray Data is
64
+ fault tolerant to those errors, we need to add `RaySystemError` to the
65
+ `retry_exceptions` list.
66
+
67
+ TODO: Fix this in Ray Core. See https://github.com/ray-project/ray/pull/45079.
68
+ """
69
+ retry_exceptions = ray_remote_args.get("retry_exceptions", False)
70
+ assert isinstance(retry_exceptions, (list, bool))
71
+
72
+ if (
73
+ isinstance(retry_exceptions, list)
74
+ and ray.exceptions.RaySystemError not in retry_exceptions
75
+ ):
76
+ retry_exceptions.append(ray.exceptions.RaySystemError)
77
+ elif not retry_exceptions:
78
+ retry_exceptions = [ray.exceptions.RaySystemError]
79
+
80
+ ray_remote_args["retry_exceptions"] = retry_exceptions
.venv/lib/python3.11/site-packages/ray/data/_internal/row.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Mapping
2
+ from typing import Any
3
+
4
+
5
+ class TableRow(Mapping):
6
+ """
7
+ A dict-like row of a tabular ``Dataset``.
8
+
9
+ This implements the dictionary mapping interface, but provides more
10
+ efficient access with less data copying than converting Arrow Tables
11
+ or Pandas DataFrames into per-row dicts. This class must be subclassed,
12
+ with subclasses implementing ``__getitem__``, ``__iter__``, and ``__len__``.
13
+
14
+ Concrete subclasses include ``ray.data._internal.arrow_block.ArrowRow`` and
15
+ ``ray.data._internal.pandas_block.PandasRow``.
16
+ """
17
+
18
+ def __init__(self, row: Any):
19
+ """
20
+ Construct a ``TableRow`` (internal API).
21
+
22
+ Args:
23
+ row: The tabular row that backs this row mapping.
24
+ """
25
+ self._row = row
26
+
27
+ def as_pydict(self) -> dict:
28
+ """
29
+ Convert to a normal Python dict. This will create a new copy of the row."""
30
+ return dict(self.items())
31
+
32
+ def __str__(self):
33
+ return str(self.as_pydict())
34
+
35
+ def __repr__(self):
36
+ return str(self)
37
+
38
+ def _repr_pretty_(self, p, cycle):
39
+ from IPython.lib.pretty import _dict_pprinter_factory
40
+
41
+ pprinter = _dict_pprinter_factory("{", "}")
42
+ return pprinter(self, p, cycle)
.venv/lib/python3.11/site-packages/ray/data/_internal/size_estimator.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import ray
4
+ from ray import cloudpickle
5
+
6
+ _ray_initialized = False
7
+
8
+
9
+ class SizeEstimator:
10
+ """Efficiently estimates the Ray serialized size of a stream of items.
11
+
12
+ For efficiency, this only samples a fraction of the added items for real
13
+ Ray-serialization.
14
+ """
15
+
16
+ def __init__(self):
17
+ self._running_mean = RunningMean()
18
+ self._count = 0
19
+
20
+ def add(self, item: Any) -> None:
21
+ self._count += 1
22
+ if self._count <= 10:
23
+ self._running_mean.add(self._real_size(item), weight=1)
24
+ elif self._count <= 100:
25
+ if self._count % 10 == 0:
26
+ self._running_mean.add(self._real_size(item), weight=10)
27
+ elif self._count % 100 == 0:
28
+ self._running_mean.add(self._real_size(item), weight=100)
29
+
30
+ def add_block(self, block: List[Any]) -> None:
31
+ if self._count < 10:
32
+ for i in range(min(10 - self._count, len(block))):
33
+ self._running_mean.add(self._real_size(block[i]), weight=1)
34
+ if self._count < 100:
35
+ for i in range(
36
+ 10 - (self._count % 10), min(100 - self._count, len(block)), 10
37
+ ):
38
+ self._running_mean.add(self._real_size(block[i]), weight=10)
39
+ if (len(block) + (self._count % 100)) // 100 > 1:
40
+ for i in range(100 - (self._count % 100), len(block), 100):
41
+ self._running_mean.add(self._real_size(block[i]), weight=100)
42
+ self._count += len(block)
43
+
44
+ def size_bytes(self) -> int:
45
+ return int(self._running_mean.mean * self._count)
46
+
47
+ def _real_size(self, item: Any) -> int:
48
+ is_client = ray.util.client.ray.is_connected()
49
+ # In client mode, fallback to using Ray cloudpickle instead of the
50
+ # real serializer.
51
+ if is_client:
52
+ return len(cloudpickle.dumps(item))
53
+
54
+ # We're using an internal Ray API, and have to ensure it's
55
+ # initialized # by calling a public API.
56
+ global _ray_initialized
57
+ if not _ray_initialized:
58
+ _ray_initialized = True
59
+ ray.put(None)
60
+ return (
61
+ ray._private.worker.global_worker.get_serialization_context()
62
+ .serialize(item)
63
+ .total_bytes
64
+ )
65
+
66
+
67
+ # Adapted from the RLlib MeanStdFilter.
68
+ class RunningMean:
69
+ def __init__(self):
70
+ self._weight = 0
71
+ self._mean = 0
72
+
73
+ def add(self, x: int, weight: int = 1) -> None:
74
+ if weight == 0:
75
+ return
76
+ n1 = self._weight
77
+ n2 = weight
78
+ n = n1 + n2
79
+ M = (n1 * self._mean + n2 * x) / n
80
+ self._weight = n
81
+ self._mean = M
82
+
83
+ @property
84
+ def n(self) -> int:
85
+ return self._weight
86
+
87
+ @property
88
+ def mean(self) -> float:
89
+ return self._mean
90
+
91
+ def __repr__(self):
92
+ return "(n={}, mean={})".format(self.n, self.mean)
.venv/lib/python3.11/site-packages/ray/data/_internal/split.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ from typing import Iterable, List, Tuple, Union
4
+
5
+ import ray
6
+ from ray.data._internal.memory_tracing import trace_deallocation
7
+ from ray.data._internal.remote_fn import cached_remote_fn
8
+ from ray.data.block import (
9
+ Block,
10
+ BlockAccessor,
11
+ BlockExecStats,
12
+ BlockMetadata,
13
+ BlockPartition,
14
+ )
15
+ from ray.types import ObjectRef
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _calculate_blocks_rows(
21
+ blocks_with_metadata: BlockPartition,
22
+ ) -> List[int]:
23
+ """Calculate the number of rows for a list of blocks with metadata."""
24
+ get_num_rows = cached_remote_fn(_get_num_rows)
25
+ block_rows = []
26
+ for block, metadata in blocks_with_metadata:
27
+ if metadata.num_rows is None:
28
+ # Need to fetch number of rows.
29
+ num_rows = ray.get(get_num_rows.remote(block))
30
+ metadata.num_rows = num_rows
31
+ else:
32
+ num_rows = metadata.num_rows
33
+ block_rows.append(num_rows)
34
+ return block_rows
35
+
36
+
37
+ def _generate_valid_indices(
38
+ num_rows_per_block: List[int],
39
+ split_indices: List[int],
40
+ ) -> List[int]:
41
+ """Generate valid split indices by apply min(index, total_num_rows)
42
+ to every index."""
43
+ total_rows = sum(num_rows_per_block)
44
+ return [min(index, total_rows) for index in split_indices]
45
+
46
+
47
+ def _generate_per_block_split_indices(
48
+ num_rows_per_block: List[int],
49
+ split_indices: List[int],
50
+ ) -> List[List[int]]:
51
+ """Given num rows per block and valid split indices, generate per block split indices.
52
+
53
+ Args:
54
+ num_rows_per_block: num of rows per block.
55
+ split_indices: The (global) indices at which to split the blocks.
56
+ Returns:
57
+ Per block split indices indicates each input block's split point(s).
58
+ """
59
+ # for each split index, we iterate though the currnet input block
60
+ # to see if the index falls into this block. if the index
61
+ # falls into this block, we push it back to the current block's
62
+ # split indices. Otherwise, we move on to the next block.
63
+ per_block_split_indices = []
64
+ current_input_block_id = 0
65
+ current_block_split_indices = []
66
+ current_block_global_offset = 0
67
+ current_index_id = 0
68
+
69
+ while current_index_id < len(split_indices):
70
+ split_index = split_indices[current_index_id]
71
+ current_block_row = num_rows_per_block[current_input_block_id]
72
+ if split_index - current_block_global_offset <= current_block_row:
73
+ current_block_split_indices.append(
74
+ split_index - current_block_global_offset
75
+ )
76
+ current_index_id += 1
77
+ continue
78
+ per_block_split_indices.append(current_block_split_indices)
79
+ current_block_split_indices = []
80
+ current_block_global_offset += num_rows_per_block[current_input_block_id]
81
+ current_input_block_id += 1
82
+
83
+ # we might finished all the indices but there are still blocks left, also
84
+ # current_block_split_indices might not be added yet.
85
+ while len(per_block_split_indices) < len(num_rows_per_block):
86
+ per_block_split_indices.append(current_block_split_indices)
87
+ current_block_split_indices = []
88
+ return per_block_split_indices
89
+
90
+
91
+ def _split_single_block(
92
+ block_id: int,
93
+ block: Block,
94
+ meta: BlockMetadata,
95
+ split_indices: List[int],
96
+ ) -> Tuple[Union[Tuple[int, List[BlockMetadata]], Block], ...]:
97
+ """Split the provided block at the given indices.
98
+
99
+ Args:
100
+ block_id: the id of this block in the block list.
101
+ block: block to be split.
102
+ meta: metadata of the block, we expect meta.num is valid.
103
+ split_indices: the indices where the block should be split.
104
+ Returns:
105
+ returns block_id, split blocks metadata, and a list of blocks
106
+ in the following form. We return blocks in this way
107
+ so that the owner of blocks could be the caller(driver)
108
+ instead of worker itself.
109
+ Tuple(block_id, split_blocks_meta), block0, block1 ...
110
+ """
111
+ split_meta = []
112
+ split_blocks = []
113
+ block_accessor = BlockAccessor.for_block(block)
114
+ prev_index = 0
115
+ # append one more entry at the last so we don't
116
+ # need handle empty edge case.
117
+ split_indices.append(meta.num_rows)
118
+ for index in split_indices:
119
+ logger.debug(f"slicing block {prev_index}:{index}")
120
+ stats = BlockExecStats.builder()
121
+ split_block = block_accessor.slice(prev_index, index)
122
+ accessor = BlockAccessor.for_block(split_block)
123
+ _meta = BlockMetadata(
124
+ num_rows=accessor.num_rows(),
125
+ size_bytes=accessor.size_bytes(),
126
+ schema=meta.schema,
127
+ input_files=meta.input_files,
128
+ exec_stats=stats.build(),
129
+ )
130
+ split_meta.append(_meta)
131
+ split_blocks.append(split_block)
132
+ prev_index = index
133
+ results = [(block_id, split_meta)]
134
+ results.extend(split_blocks)
135
+ return tuple(results)
136
+
137
+
138
+ def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]:
139
+ """drop split indices that creates empty block split. This could happen when there
140
+ are duplicated indices, or index equal to 0 (start of the block) or num_block_rows
141
+ (end of the block).
142
+ """
143
+ prev_index = -1
144
+ optimized_indices = []
145
+ for index in block_split_indices:
146
+ if index == 0 or index == num_rows:
147
+ continue
148
+ if index == prev_index:
149
+ continue
150
+ optimized_indices.append(index)
151
+ prev_index = index
152
+ return optimized_indices
153
+
154
+
155
+ def _split_all_blocks(
156
+ blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
157
+ per_block_split_indices: List[List[int]],
158
+ owned_by_consumer: bool,
159
+ ) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]:
160
+ """Split all the input blocks based on the split indices"""
161
+ split_single_block = cached_remote_fn(_split_single_block)
162
+
163
+ all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)
164
+
165
+ per_block_split_metadata_futures = []
166
+ per_block_split_block_refs = []
167
+
168
+ # tracking splitted blocks for gc.
169
+ blocks_splitted = []
170
+ for block_id, block_split_indices in enumerate(per_block_split_indices):
171
+ (block_ref, meta) = blocks_with_metadata[block_id]
172
+ block_row = meta.num_rows
173
+ block_split_indices = _drop_empty_block_split(block_split_indices, block_row)
174
+ if len(block_split_indices) == 0:
175
+ # optimization: if no split is needed, we just need to add it to the
176
+ # result
177
+ all_blocks_split_results[block_id] = [(block_ref, meta)]
178
+ else:
179
+ # otherwise call split remote function.
180
+ object_refs = split_single_block.options(
181
+ scheduling_strategy="SPREAD", num_returns=2 + len(block_split_indices)
182
+ ).remote(
183
+ block_id,
184
+ block_ref,
185
+ meta,
186
+ block_split_indices,
187
+ )
188
+ per_block_split_metadata_futures.append(object_refs[0])
189
+ per_block_split_block_refs.append(object_refs[1:])
190
+
191
+ blocks_splitted.append(block_ref)
192
+
193
+ if per_block_split_metadata_futures:
194
+ # only get metadata.
195
+ per_block_split_metadata = ray.get(per_block_split_metadata_futures)
196
+ for (block_id, meta), block_refs in zip(
197
+ per_block_split_metadata, per_block_split_block_refs
198
+ ):
199
+ assert len(meta) == len(block_refs)
200
+ all_blocks_split_results[block_id] = zip(block_refs, meta)
201
+
202
+ # We make a copy for the blocks that have been splitted, so the input blocks
203
+ # can be cleared if they are owned by consumer (consumer-owned blocks will
204
+ # only be consumed by the owner).
205
+ if owned_by_consumer:
206
+ for b in blocks_splitted:
207
+ trace_deallocation(b, "split._split_all_blocks")
208
+ else:
209
+ for b in blocks_splitted:
210
+ trace_deallocation(b, "split._split_all_blocks", free=False)
211
+
212
+ return itertools.chain.from_iterable(all_blocks_split_results)
213
+
214
+
215
+ def _generate_global_split_results(
216
+ all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]],
217
+ global_split_sizes: List[int],
218
+ ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
219
+ """Reassemble per block's split result into final split result."""
220
+ result_blocks = []
221
+ result_metas = []
222
+
223
+ current_blocks = []
224
+ current_meta = []
225
+ current_split_size = 0
226
+ current_split_id = 0
227
+
228
+ while current_split_id < len(global_split_sizes):
229
+ if current_split_size >= global_split_sizes[current_split_id]:
230
+ assert current_split_size == global_split_sizes[current_split_id]
231
+ result_blocks.append(current_blocks)
232
+ result_metas.append(current_meta)
233
+
234
+ current_blocks = []
235
+ current_meta = []
236
+ current_split_size = 0
237
+ current_split_id += 1
238
+ else:
239
+ (block_ref, meta) = next(all_blocks_split_results)
240
+ current_blocks.append(block_ref)
241
+ current_meta.append(meta)
242
+ current_split_size += meta.num_rows
243
+
244
+ return result_blocks, result_metas
245
+
246
+
247
+ def _split_at_indices(
248
+ blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
249
+ indices: List[int],
250
+ owned_by_consumer: bool = True,
251
+ block_rows: List[int] = None,
252
+ ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
253
+ """Split blocks at the provided indices.
254
+
255
+ Args:
256
+ blocks_with_metadata: Block futures to split, including the associated metadata.
257
+ indices: The (global) indices at which to split the blocks.
258
+ owned_by_consumer: Whether the provided blocks are owned by the consumer.
259
+ block_rows: The number of rows for each block, in case it has already been
260
+ computed.
261
+
262
+ Returns:
263
+ The block split futures and their metadata. If an index split is empty, the
264
+ corresponding block split will be empty .
265
+ """
266
+
267
+ # We implement the split in 3 phases.
268
+ # phase 1: calculate the per block split indices.
269
+ blocks_with_metadata = list(blocks_with_metadata)
270
+ if len(blocks_with_metadata) == 0:
271
+ return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1))
272
+ if block_rows is None:
273
+ block_rows = _calculate_blocks_rows(blocks_with_metadata)
274
+ valid_indices = _generate_valid_indices(block_rows, indices)
275
+ per_block_split_indices: List[List[int]] = _generate_per_block_split_indices(
276
+ block_rows, valid_indices
277
+ )
278
+
279
+ # phase 2: split each block based on the indices from previous step.
280
+ all_blocks_split_results: Iterable[
281
+ Tuple[ObjectRef[Block], BlockMetadata]
282
+ ] = _split_all_blocks(
283
+ blocks_with_metadata, per_block_split_indices, owned_by_consumer
284
+ )
285
+
286
+ # phase 3: generate the final split.
287
+
288
+ # first calculate the size for each split.
289
+ helper = [0] + valid_indices + [sum(block_rows)]
290
+ split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))]
291
+
292
+ return _generate_global_split_results(all_blocks_split_results, split_sizes)
293
+
294
+
295
+ def _get_num_rows(block: Block) -> int:
296
+ """Get the number of rows contained in the provided block."""
297
+ return BlockAccessor.for_block(block).num_rows()
.venv/lib/python3.11/site-packages/ray/data/_internal/stats.py ADDED
@@ -0,0 +1,1495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import threading
4
+ import time
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
8
+ from uuid import uuid4
9
+
10
+ import numpy as np
11
+
12
+ import ray
13
+ from ray.actor import ActorHandle
14
+ from ray.data._internal.block_list import BlockList
15
+ from ray.data._internal.execution.interfaces.op_runtime_metrics import (
16
+ MetricsGroup,
17
+ OpRuntimeMetrics,
18
+ )
19
+ from ray.data._internal.util import capfirst
20
+ from ray.data.block import BlockMetadata
21
+ from ray.data.context import DataContext
22
+ from ray.util.annotations import DeveloperAPI
23
+ from ray.util.metrics import Gauge
24
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ STATS_ACTOR_NAME = "datasets_stats_actor"
29
+ STATS_ACTOR_NAMESPACE = "_dataset_stats_actor"
30
+
31
+
32
+ StatsDict = Dict[str, List[BlockMetadata]]
33
+
34
+
35
+ def fmt(seconds: float) -> str:
36
+ if seconds > 1:
37
+ return str(round(seconds, 2)) + "s"
38
+ elif seconds > 0.001:
39
+ return str(round(seconds * 1000, 2)) + "ms"
40
+ else:
41
+ return str(round(seconds * 1000 * 1000, 2)) + "us"
42
+
43
+
44
+ def leveled_indent(lvl: int = 0, spaces_per_indent: int = 3) -> str:
45
+ """Returns a string of spaces which contains `level` indents,
46
+ each indent containing `spaces_per_indent` spaces. For example:
47
+ >>> leveled_indent(2, 3)
48
+ ' '
49
+ """
50
+ return (" " * spaces_per_indent) * lvl
51
+
52
+
53
+ class Timer:
54
+ """Helper class for tracking accumulated time (in seconds)."""
55
+
56
+ def __init__(self):
57
+ self._value: float = 0
58
+ self._min: float = float("inf")
59
+ self._max: float = 0
60
+ self._total_count: float = 0
61
+
62
+ @contextmanager
63
+ def timer(self) -> None:
64
+ time_start = time.perf_counter()
65
+ try:
66
+ yield
67
+ finally:
68
+ self.add(time.perf_counter() - time_start)
69
+
70
+ def add(self, value: float) -> None:
71
+ self._value += value
72
+ if value < self._min:
73
+ self._min = value
74
+ if value > self._max:
75
+ self._max = value
76
+ self._total_count += 1
77
+
78
+ def get(self) -> float:
79
+ return self._value
80
+
81
+ def min(self) -> float:
82
+ return self._min
83
+
84
+ def max(self) -> float:
85
+ return self._max
86
+
87
+ def avg(self) -> float:
88
+ return self._value / self._total_count if self._total_count else float("inf")
89
+
90
+
91
+ class _DatasetStatsBuilder:
92
+ """Helper class for building dataset stats.
93
+
94
+ When this class is created, we record the start time. When build() is
95
+ called with the final blocks of the new dataset, the time delta is
96
+ saved as part of the stats."""
97
+
98
+ def __init__(
99
+ self,
100
+ operator_name: str,
101
+ parent: "DatasetStats",
102
+ override_start_time: Optional[float],
103
+ ):
104
+ self.operator_name = operator_name
105
+ self.parent = parent
106
+ self.start_time = override_start_time or time.perf_counter()
107
+
108
+ def build_multioperator(self, metadata: StatsDict) -> "DatasetStats":
109
+ op_metadata = {}
110
+ for i, (k, v) in enumerate(metadata.items()):
111
+ capped_k = capfirst(k)
112
+ if len(metadata) > 1:
113
+ if i == 0:
114
+ op_metadata[self.operator_name + capped_k] = v
115
+ else:
116
+ op_metadata[self.operator_name.split("->")[-1] + capped_k] = v
117
+ else:
118
+ op_metadata[self.operator_name] = v
119
+ stats = DatasetStats(
120
+ metadata=op_metadata,
121
+ parent=self.parent,
122
+ base_name=self.operator_name,
123
+ )
124
+ stats.time_total_s = time.perf_counter() - self.start_time
125
+ return stats
126
+
127
+ def build(self, final_blocks: BlockList) -> "DatasetStats":
128
+ stats = DatasetStats(
129
+ metadata={self.operator_name: final_blocks.get_metadata()},
130
+ parent=self.parent,
131
+ )
132
+ stats.time_total_s = time.perf_counter() - self.start_time
133
+ return stats
134
+
135
+
136
+ @ray.remote(num_cpus=0)
137
+ class _StatsActor:
138
+ """Actor holding stats for blocks created by LazyBlockList.
139
+
140
+ This actor is shared across all datasets created in the same cluster.
141
+ In order to cap memory usage, we set a max number of stats to keep
142
+ in the actor. When this limit is exceeded, the stats will be garbage
143
+ collected in FIFO order.
144
+
145
+ TODO(ekl) we should consider refactoring LazyBlockList so stats can be
146
+ extracted without using an out-of-band actor."""
147
+
148
+ def __init__(self, max_stats=1000):
149
+ # Mapping from uuid -> (task_id -> list of blocks statistics).
150
+ self.metadata = collections.defaultdict(dict)
151
+ self.last_time = {}
152
+ self.start_time = {}
153
+ self.max_stats = max_stats
154
+ self.fifo_queue = []
155
+
156
+ # Assign dataset uuids with a global counter.
157
+ self.next_dataset_id = 0
158
+ # Dataset metadata to be queried directly by DashboardHead api.
159
+ self.datasets: Dict[str, Any] = {}
160
+
161
+ # Ray Data dashboard metrics
162
+ # Everything is a gauge because we need to reset all of
163
+ # a dataset's metrics to 0 after each finishes execution.
164
+ op_tags_keys = ("dataset", "operator")
165
+
166
+ # TODO(scottjlee): move these overvie metrics as fields in a
167
+ # separate dataclass, similar to OpRuntimeMetrics.
168
+ self.spilled_bytes = Gauge(
169
+ "data_spilled_bytes",
170
+ description="""Bytes spilled by dataset operators.
171
+ DataContext.enable_get_object_locations_for_metrics
172
+ must be set to True to report this metric""",
173
+ tag_keys=op_tags_keys,
174
+ )
175
+ self.allocated_bytes = Gauge(
176
+ "data_allocated_bytes",
177
+ description="Bytes allocated by dataset operators",
178
+ tag_keys=op_tags_keys,
179
+ )
180
+ self.freed_bytes = Gauge(
181
+ "data_freed_bytes",
182
+ description="Bytes freed by dataset operators",
183
+ tag_keys=op_tags_keys,
184
+ )
185
+ self.current_bytes = Gauge(
186
+ "data_current_bytes",
187
+ description="Bytes currently in memory store used by dataset operators",
188
+ tag_keys=op_tags_keys,
189
+ )
190
+ self.cpu_usage_cores = Gauge(
191
+ "data_cpu_usage_cores",
192
+ description="CPUs allocated to dataset operators",
193
+ tag_keys=op_tags_keys,
194
+ )
195
+ self.gpu_usage_cores = Gauge(
196
+ "data_gpu_usage_cores",
197
+ description="GPUs allocated to dataset operators",
198
+ tag_keys=op_tags_keys,
199
+ )
200
+ self.output_bytes = Gauge(
201
+ "data_output_bytes",
202
+ description="Bytes outputted by dataset operators",
203
+ tag_keys=op_tags_keys,
204
+ )
205
+ self.output_rows = Gauge(
206
+ "data_output_rows",
207
+ description="Rows outputted by dataset operators",
208
+ tag_keys=op_tags_keys,
209
+ )
210
+
211
+ # === Metrics from OpRuntimeMetrics ===
212
+ # Inputs-related metrics
213
+ self.execution_metrics_inputs = (
214
+ self._create_prometheus_metrics_for_execution_metrics(
215
+ metrics_group=MetricsGroup.INPUTS,
216
+ tag_keys=op_tags_keys,
217
+ )
218
+ )
219
+
220
+ # Outputs-related metrics
221
+ self.execution_metrics_outputs = (
222
+ self._create_prometheus_metrics_for_execution_metrics(
223
+ metrics_group=MetricsGroup.OUTPUTS,
224
+ tag_keys=op_tags_keys,
225
+ )
226
+ )
227
+
228
+ # Task-related metrics
229
+ self.execution_metrics_tasks = (
230
+ self._create_prometheus_metrics_for_execution_metrics(
231
+ metrics_group=MetricsGroup.TASKS,
232
+ tag_keys=op_tags_keys,
233
+ )
234
+ )
235
+
236
+ # Object store memory-related metrics
237
+ self.execution_metrics_obj_store_memory = (
238
+ self._create_prometheus_metrics_for_execution_metrics(
239
+ metrics_group=MetricsGroup.OBJECT_STORE_MEMORY,
240
+ tag_keys=op_tags_keys,
241
+ )
242
+ )
243
+
244
+ # Miscellaneous metrics
245
+ self.execution_metrics_misc = (
246
+ self._create_prometheus_metrics_for_execution_metrics(
247
+ metrics_group=MetricsGroup.MISC,
248
+ tag_keys=op_tags_keys,
249
+ )
250
+ )
251
+
252
+ iter_tag_keys = ("dataset",)
253
+ self.iter_total_blocked_s = Gauge(
254
+ "data_iter_total_blocked_seconds",
255
+ description="Seconds user thread is blocked by iter_batches()",
256
+ tag_keys=iter_tag_keys,
257
+ )
258
+ self.iter_user_s = Gauge(
259
+ "data_iter_user_seconds",
260
+ description="Seconds spent in user code",
261
+ tag_keys=iter_tag_keys,
262
+ )
263
+ self.iter_initialize_s = Gauge(
264
+ "data_iter_initialize_seconds",
265
+ description="Seconds spent in iterator initialization code",
266
+ tag_keys=iter_tag_keys,
267
+ )
268
+
269
+ def _create_prometheus_metrics_for_execution_metrics(
270
+ self, metrics_group: MetricsGroup, tag_keys: Tuple[str, ...]
271
+ ) -> Dict[str, Gauge]:
272
+ metrics = {}
273
+ for metric in OpRuntimeMetrics.get_metrics():
274
+ if not metric.metrics_group == metrics_group:
275
+ continue
276
+ metric_name = f"data_{metric.name}"
277
+ metric_description = metric.description
278
+ metrics[metric.name] = Gauge(
279
+ metric_name,
280
+ description=metric_description,
281
+ tag_keys=tag_keys,
282
+ )
283
+ return metrics
284
+
285
+ def record_start(self, stats_uuid):
286
+ self.start_time[stats_uuid] = time.perf_counter()
287
+ self.fifo_queue.append(stats_uuid)
288
+ # Purge the oldest stats if the limit is exceeded.
289
+ if len(self.fifo_queue) > self.max_stats:
290
+ uuid = self.fifo_queue.pop(0)
291
+ if uuid in self.start_time:
292
+ del self.start_time[uuid]
293
+ if uuid in self.last_time:
294
+ del self.last_time[uuid]
295
+ if uuid in self.metadata:
296
+ del self.metadata[uuid]
297
+
298
+ def record_task(
299
+ self, stats_uuid: str, task_idx: int, blocks_metadata: List[BlockMetadata]
300
+ ):
301
+ # Null out the schema to keep the stats size small.
302
+ # TODO(chengsu): ideally schema should be null out on caller side.
303
+ for metadata in blocks_metadata:
304
+ metadata.schema = None
305
+ if stats_uuid in self.start_time:
306
+ self.metadata[stats_uuid][task_idx] = blocks_metadata
307
+ self.last_time[stats_uuid] = time.perf_counter()
308
+
309
+ def get(self, stats_uuid):
310
+ if stats_uuid not in self.metadata:
311
+ return {}, 0.0
312
+ return (
313
+ self.metadata[stats_uuid],
314
+ self.last_time[stats_uuid] - self.start_time[stats_uuid],
315
+ )
316
+
317
+ def _get_stats_dict_size(self):
318
+ return len(self.start_time), len(self.last_time), len(self.metadata)
319
+
320
+ def get_dataset_id(self):
321
+ dataset_id = str(self.next_dataset_id)
322
+ self.next_dataset_id += 1
323
+ return dataset_id
324
+
325
+ def update_metrics(self, execution_metrics, iteration_metrics):
326
+ for metrics in execution_metrics:
327
+ self.update_execution_metrics(*metrics)
328
+ for metrics in iteration_metrics:
329
+ self.update_iteration_metrics(*metrics)
330
+
331
+ def update_execution_metrics(
332
+ self,
333
+ dataset_tag: str,
334
+ op_metrics: List[Dict[str, Union[int, float]]],
335
+ operator_tags: List[str],
336
+ state: Dict[str, Any],
337
+ ):
338
+ for stats, operator_tag in zip(op_metrics, operator_tags):
339
+ tags = self._create_tags(dataset_tag, operator_tag)
340
+
341
+ self.spilled_bytes.set(stats.get("obj_store_mem_spilled", 0), tags)
342
+ self.freed_bytes.set(stats.get("obj_store_mem_freed", 0), tags)
343
+ self.current_bytes.set(stats.get("obj_store_mem_used", 0), tags)
344
+ self.output_bytes.set(stats.get("bytes_task_outputs_generated", 0), tags)
345
+ self.output_rows.set(stats.get("rows_task_outputs_generated", 0), tags)
346
+ self.cpu_usage_cores.set(stats.get("cpu_usage", 0), tags)
347
+ self.gpu_usage_cores.set(stats.get("gpu_usage", 0), tags)
348
+
349
+ for field_name, prom_metric in self.execution_metrics_inputs.items():
350
+ prom_metric.set(stats.get(field_name, 0), tags)
351
+
352
+ for field_name, prom_metric in self.execution_metrics_outputs.items():
353
+ prom_metric.set(stats.get(field_name, 0), tags)
354
+
355
+ for field_name, prom_metric in self.execution_metrics_tasks.items():
356
+ prom_metric.set(stats.get(field_name, 0), tags)
357
+
358
+ for (
359
+ field_name,
360
+ prom_metric,
361
+ ) in self.execution_metrics_obj_store_memory.items():
362
+ prom_metric.set(stats.get(field_name, 0), tags)
363
+
364
+ for field_name, prom_metric in self.execution_metrics_misc.items():
365
+ prom_metric.set(stats.get(field_name, 0), tags)
366
+
367
+ # This update is called from a dataset's executor,
368
+ # so all tags should contain the same dataset
369
+ self.update_dataset(dataset_tag, state)
370
+
371
+ def update_iteration_metrics(
372
+ self,
373
+ stats: "DatasetStats",
374
+ dataset_tag,
375
+ ):
376
+ tags = self._create_tags(dataset_tag)
377
+ self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
378
+ self.iter_user_s.set(stats.iter_user_s.get(), tags)
379
+ self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)
380
+
381
+ def register_dataset(self, job_id: str, dataset_tag: str, operator_tags: List[str]):
382
+ self.datasets[dataset_tag] = {
383
+ "job_id": job_id,
384
+ "state": "RUNNING",
385
+ "progress": 0,
386
+ "total": 0,
387
+ "start_time": time.time(),
388
+ "end_time": None,
389
+ "operators": {
390
+ operator: {
391
+ "state": "RUNNING",
392
+ "progress": 0,
393
+ "total": 0,
394
+ }
395
+ for operator in operator_tags
396
+ },
397
+ }
398
+
399
+ def update_dataset(self, dataset_tag, state):
400
+ self.datasets[dataset_tag].update(state)
401
+
402
+ def get_datasets(self, job_id: Optional[str] = None):
403
+ if not job_id:
404
+ return self.datasets
405
+ return {k: v for k, v in self.datasets.items() if v["job_id"] == job_id}
406
+
407
+ def _create_tags(self, dataset_tag: str, operator_tag: Optional[str] = None):
408
+ tags = {"dataset": dataset_tag}
409
+ if operator_tag is not None:
410
+ tags["operator"] = operator_tag
411
+ return tags
412
+
413
+
414
+ # Creating/getting an actor from multiple threads is not safe.
415
+ # https://github.com/ray-project/ray/issues/41324
416
+ _stats_actor_lock: threading.RLock = threading.RLock()
417
+
418
+
419
+ def _get_or_create_stats_actor():
420
+ ctx = DataContext.get_current()
421
+ scheduling_strategy = ctx.scheduling_strategy
422
+ if not ray.util.client.ray.is_connected():
423
+ # Pin the stats actor to the local node
424
+ # so it fate-shares with the driver.
425
+ scheduling_strategy = NodeAffinitySchedulingStrategy(
426
+ ray.get_runtime_context().get_node_id(),
427
+ soft=False,
428
+ )
429
+ with _stats_actor_lock:
430
+ return _StatsActor.options(
431
+ name=STATS_ACTOR_NAME,
432
+ namespace=STATS_ACTOR_NAMESPACE,
433
+ get_if_exists=True,
434
+ lifetime="detached",
435
+ scheduling_strategy=scheduling_strategy,
436
+ ).remote()
437
+
438
+
439
+ class _StatsManager:
440
+ """A Class containing util functions that manage remote calls to _StatsActor.
441
+
442
+ This class collects stats from execution and iteration codepaths and keeps
443
+ track of the latest snapshot.
444
+
445
+ An instance of this class runs a single background thread that periodically
446
+ forwards the latest execution/iteration stats to the _StatsActor.
447
+
448
+ This thread will terminate itself after being inactive (meaning that there are
449
+ no active executors or iterators) for STATS_ACTOR_UPDATE_THREAD_INACTIVITY_LIMIT
450
+ iterations. After terminating, a new thread will start if more calls are made
451
+ to this class.
452
+ """
453
+
454
+ # Interval for making remote calls to the _StatsActor.
455
+ STATS_ACTOR_UPDATE_INTERVAL_SECONDS = 5
456
+
457
+ # After this many iterations of inactivity,
458
+ # _StatsManager._update_thread will close itself.
459
+ UPDATE_THREAD_INACTIVITY_LIMIT = 5
460
+
461
+ def __init__(self):
462
+ # Lazily get stats actor handle to avoid circular import.
463
+ self._stats_actor_handle: Optional[ActorHandle] = None
464
+ self._stats_actor_cluster_id = None
465
+
466
+ # Last execution stats snapshots for all executing datasets
467
+ self._last_execution_stats = {}
468
+ # Last iteration stats snapshots for all running iterators
469
+ self._last_iteration_stats: Dict[
470
+ str, Tuple[Dict[str, str], "DatasetStats"]
471
+ ] = {}
472
+ # Lock for updating stats snapshots
473
+ self._stats_lock: threading.Lock = threading.Lock()
474
+
475
+ # Background thread to make remote calls to _StatsActor
476
+ self._update_thread: Optional[threading.Thread] = None
477
+ self._update_thread_lock: threading.Lock = threading.Lock()
478
+
479
+ def _stats_actor(self, create_if_not_exists=True) -> Optional[ActorHandle]:
480
+ if ray._private.worker._global_node is None:
481
+ raise RuntimeError("Global node is not initialized.")
482
+ current_cluster_id = ray._private.worker._global_node.cluster_id
483
+ if (
484
+ self._stats_actor_handle is None
485
+ or self._stats_actor_cluster_id != current_cluster_id
486
+ ):
487
+ if create_if_not_exists:
488
+ self._stats_actor_handle = _get_or_create_stats_actor()
489
+ else:
490
+ try:
491
+ self._stats_actor_handle = ray.get_actor(
492
+ name=STATS_ACTOR_NAME, namespace=STATS_ACTOR_NAMESPACE
493
+ )
494
+ except ValueError:
495
+ return None
496
+ self._stats_actor_cluster_id = current_cluster_id
497
+ return self._stats_actor_handle
498
+
499
+ def _start_thread_if_not_running(self):
500
+ # Start background update thread if not running.
501
+ with self._update_thread_lock:
502
+ if self._update_thread is None or not self._update_thread.is_alive():
503
+
504
+ def _run_update_loop():
505
+ iter_stats_inactivity = 0
506
+ while True:
507
+ if self._last_iteration_stats or self._last_execution_stats:
508
+ try:
509
+ # Do not create _StatsActor if it doesn't exist because
510
+ # this thread can be running even after the cluster is
511
+ # shutdown. Creating an actor will automatically start
512
+ # a new cluster.
513
+ stats_actor = self._stats_actor(
514
+ create_if_not_exists=False
515
+ )
516
+ if stats_actor is None:
517
+ continue
518
+ stats_actor.update_metrics.remote(
519
+ execution_metrics=list(
520
+ self._last_execution_stats.values()
521
+ ),
522
+ iteration_metrics=list(
523
+ self._last_iteration_stats.values()
524
+ ),
525
+ )
526
+ iter_stats_inactivity = 0
527
+ except Exception:
528
+ logger.debug(
529
+ "Error occurred during remote call to _StatsActor.",
530
+ exc_info=True,
531
+ )
532
+ return
533
+ else:
534
+ iter_stats_inactivity += 1
535
+ if (
536
+ iter_stats_inactivity
537
+ >= _StatsManager.UPDATE_THREAD_INACTIVITY_LIMIT
538
+ ):
539
+ logger.debug(
540
+ "Terminating StatsManager thread due to inactivity."
541
+ )
542
+ return
543
+ time.sleep(StatsManager.STATS_ACTOR_UPDATE_INTERVAL_SECONDS)
544
+
545
+ self._update_thread = threading.Thread(
546
+ target=_run_update_loop, daemon=True
547
+ )
548
+ self._update_thread.start()
549
+
550
+ # Execution methods
551
+
552
+ def update_execution_metrics(
553
+ self,
554
+ dataset_tag: str,
555
+ op_metrics: List[OpRuntimeMetrics],
556
+ operator_tags: List[str],
557
+ state: Dict[str, Any],
558
+ force_update: bool = False,
559
+ ):
560
+ op_metrics_dicts = [metric.as_dict() for metric in op_metrics]
561
+ args = (dataset_tag, op_metrics_dicts, operator_tags, state)
562
+ if force_update:
563
+ self._stats_actor().update_execution_metrics.remote(*args)
564
+ else:
565
+ with self._stats_lock:
566
+ self._last_execution_stats[dataset_tag] = args
567
+ self._start_thread_if_not_running()
568
+
569
+ def clear_last_execution_stats(self, dataset_tag: str):
570
+ # After dataset completes execution, remove cached execution stats.
571
+ # Marks the dataset as finished on job page's Ray Data Overview.
572
+ with self._stats_lock:
573
+ if dataset_tag in self._last_execution_stats:
574
+ del self._last_execution_stats[dataset_tag]
575
+
576
+ # Iteration methods
577
+
578
+ def update_iteration_metrics(self, stats: "DatasetStats", dataset_tag: str):
579
+ with self._stats_lock:
580
+ self._last_iteration_stats[dataset_tag] = (stats, dataset_tag)
581
+ self._start_thread_if_not_running()
582
+
583
+ def clear_iteration_metrics(self, dataset_tag: str):
584
+ # Delete the last iteration stats so that update thread will have
585
+ # a chance to terminate.
586
+ # Note we don't reset the actual metric values through the StatsActor
587
+ # since the value is essentially a counter value. See
588
+ # https://github.com/ray-project/ray/pull/48618 for more context.
589
+ with self._stats_lock:
590
+ if dataset_tag in self._last_iteration_stats:
591
+ del self._last_iteration_stats[dataset_tag]
592
+
593
+ # Other methods
594
+
595
+ def register_dataset_to_stats_actor(self, dataset_tag, operator_tags):
596
+ self._stats_actor().register_dataset.remote(
597
+ ray.get_runtime_context().get_job_id(),
598
+ dataset_tag,
599
+ operator_tags,
600
+ )
601
+
602
+ def get_dataset_id_from_stats_actor(self) -> str:
603
+ try:
604
+ return ray.get(self._stats_actor().get_dataset_id.remote())
605
+ except Exception:
606
+ # Getting dataset id from _StatsActor may fail, in this case
607
+ # fall back to uuid4
608
+ return uuid4().hex
609
+
610
+
611
+ StatsManager = _StatsManager()
612
+
613
+
614
+ class DatasetStats:
615
+ """Holds the execution times for a given Dataset.
616
+
617
+ This object contains a reference to the parent Dataset's stats as well,
618
+ but not the Dataset object itself, to allow its blocks to be dropped from
619
+ memory."""
620
+
621
+ def __init__(
622
+ self,
623
+ *,
624
+ metadata: StatsDict,
625
+ parent: Union[Optional["DatasetStats"], List["DatasetStats"]],
626
+ needs_stats_actor: bool = False,
627
+ stats_uuid: str = None,
628
+ base_name: str = None,
629
+ ):
630
+ """Create dataset stats.
631
+
632
+ Args:
633
+ metadata: Dict of operators used to create this Dataset from the
634
+ previous one. Typically one entry, e.g., {"map": [...]}.
635
+ parent: Reference to parent Dataset's stats, or a list of parents
636
+ if there are multiple.
637
+ needs_stats_actor: Whether this Dataset's stats needs a stats actor for
638
+ stats collection. This is currently only used for Datasets using a
639
+ lazy datasource (i.e. a LazyBlockList).
640
+ stats_uuid: The uuid for the stats, used to fetch the right stats
641
+ from the stats actor.
642
+ base_name: The name of the base operation for a multi-operator operation.
643
+ """
644
+
645
+ self.metadata: StatsDict = metadata
646
+ if parent is not None and not isinstance(parent, list):
647
+ parent = [parent]
648
+ self.parents: List["DatasetStats"] = parent or []
649
+ self.number: int = (
650
+ 0 if not self.parents else max(p.number for p in self.parents) + 1
651
+ )
652
+ self.base_name = base_name
653
+ # TODO(ekl) deprecate and remove the notion of dataset UUID once we move
654
+ # fully to streaming execution.
655
+ self.dataset_uuid: str = "unknown_uuid"
656
+ self.time_total_s: float = 0
657
+ self.needs_stats_actor = needs_stats_actor
658
+ self.stats_uuid = stats_uuid
659
+
660
+ # Streaming executor stats
661
+ self.streaming_exec_schedule_s: Timer = Timer()
662
+
663
+ # Iteration stats, filled out if the user iterates over the dataset.
664
+ self.iter_wait_s: Timer = Timer()
665
+ self.iter_get_s: Timer = Timer()
666
+ self.iter_next_batch_s: Timer = Timer()
667
+ self.iter_format_batch_s: Timer = Timer()
668
+ self.iter_collate_batch_s: Timer = Timer()
669
+ self.iter_finalize_batch_s: Timer = Timer()
670
+ self.iter_total_blocked_s: Timer = Timer()
671
+ self.iter_user_s: Timer = Timer()
672
+ self.iter_initialize_s: Timer = Timer()
673
+ self.iter_total_s: Timer = Timer()
674
+ self.extra_metrics = {}
675
+
676
+ # Block fetch stats during iteration.
677
+ # These are stats about locations of blocks when the iterator is trying to
678
+ # consume them. The iteration performance will be affected depending on
679
+ # whether the block is in the local object store of the node where the
680
+ # iterator is running.
681
+ # This serves as an indicator of block prefetching effectiveness.
682
+ self.iter_blocks_local: int = 0
683
+ self.iter_blocks_remote: int = 0
684
+ self.iter_unknown_location: int = 0
685
+
686
+ # Memory usage stats
687
+ self.global_bytes_spilled: int = 0
688
+ self.global_bytes_restored: int = 0
689
+ self.dataset_bytes_spilled: int = 0
690
+
691
+ # Streaming split coordinator stats (dataset level)
692
+ self.streaming_split_coordinator_s: Timer = Timer()
693
+
694
+ @property
695
+ def stats_actor(self):
696
+ return _get_or_create_stats_actor()
697
+
698
+ def child_builder(
699
+ self, name: str, override_start_time: Optional[float] = None
700
+ ) -> _DatasetStatsBuilder:
701
+ """Start recording stats for an op of the given name (e.g., map)."""
702
+ return _DatasetStatsBuilder(name, self, override_start_time)
703
+
704
+ def to_summary(self) -> "DatasetStatsSummary":
705
+ """Generate a `DatasetStatsSummary` object from the given `DatasetStats`
706
+ object, which can be used to generate a summary string."""
707
+ if self.needs_stats_actor:
708
+ ac = self.stats_actor
709
+ # TODO(chengsu): this is a super hack, clean it up.
710
+ stats_map, self.time_total_s = ray.get(ac.get.remote(self.stats_uuid))
711
+ # Only populate stats when stats from all read tasks are ready at
712
+ # stats actor.
713
+ if len(stats_map.items()) == len(self.metadata["Read"]):
714
+ self.metadata["Read"] = []
715
+ for _, blocks_metadata in sorted(stats_map.items()):
716
+ self.metadata["Read"] += blocks_metadata
717
+
718
+ operators_stats = []
719
+ is_sub_operator = len(self.metadata) > 1
720
+ for name, meta in self.metadata.items():
721
+ operators_stats.append(
722
+ OperatorStatsSummary.from_block_metadata(
723
+ name,
724
+ meta,
725
+ is_sub_operator=is_sub_operator,
726
+ )
727
+ )
728
+
729
+ iter_stats = IterStatsSummary(
730
+ self.iter_wait_s,
731
+ self.iter_get_s,
732
+ self.iter_next_batch_s,
733
+ self.iter_format_batch_s,
734
+ self.iter_collate_batch_s,
735
+ self.iter_finalize_batch_s,
736
+ self.iter_total_blocked_s,
737
+ self.iter_user_s,
738
+ self.iter_initialize_s,
739
+ self.iter_total_s,
740
+ self.streaming_split_coordinator_s,
741
+ self.iter_blocks_local,
742
+ self.iter_blocks_remote,
743
+ self.iter_unknown_location,
744
+ )
745
+ stats_summary_parents = []
746
+ if self.parents is not None:
747
+ stats_summary_parents = [p.to_summary() for p in self.parents]
748
+ streaming_exec_schedule_s = (
749
+ self.streaming_exec_schedule_s.get()
750
+ if self.streaming_exec_schedule_s
751
+ else 0
752
+ )
753
+ return DatasetStatsSummary(
754
+ operators_stats,
755
+ iter_stats,
756
+ stats_summary_parents,
757
+ self.number,
758
+ self.dataset_uuid,
759
+ self.time_total_s,
760
+ self.base_name,
761
+ self.extra_metrics,
762
+ self.global_bytes_spilled,
763
+ self.global_bytes_restored,
764
+ self.dataset_bytes_spilled,
765
+ streaming_exec_schedule_s,
766
+ )
767
+
768
+ def runtime_metrics(self) -> str:
769
+ """Generate a string representing the runtime metrics of a Dataset. This is
770
+ a high level summary of the time spent in Ray Data code broken down by operator.
771
+ It also includes the time spent in the scheduler. Times are shown as the total
772
+ time for each operator and percentages of time are shown as a fraction of the
773
+ total time for the whole dataset."""
774
+ return self.to_summary().runtime_metrics()
775
+
776
+
777
+ @DeveloperAPI
778
+ @dataclass
779
+ class DatasetStatsSummary:
780
+ operators_stats: List["OperatorStatsSummary"]
781
+ iter_stats: "IterStatsSummary"
782
+ parents: List["DatasetStatsSummary"]
783
+ number: int
784
+ dataset_uuid: str
785
+ time_total_s: float
786
+ base_name: str
787
+ extra_metrics: Dict[str, Any]
788
+ global_bytes_spilled: int
789
+ global_bytes_restored: int
790
+ dataset_bytes_spilled: int
791
+ streaming_exec_schedule_s: float
792
+
793
+ def to_string(
794
+ self,
795
+ already_printed: Optional[Set[str]] = None,
796
+ include_parent: bool = True,
797
+ add_global_stats=True,
798
+ ) -> str:
799
+ """Return a human-readable summary of this Dataset's stats.
800
+
801
+ Args:
802
+ already_printed: Set of operator IDs that have already had its stats printed
803
+ out.
804
+ include_parent: If true, also include parent stats summary; otherwise, only
805
+ log stats of the latest operator.
806
+ add_global_stats: If true, includes global stats to this summary.
807
+ Returns:
808
+ String with summary statistics for executing the Dataset.
809
+ """
810
+ if already_printed is None:
811
+ already_printed = set()
812
+
813
+ out = ""
814
+ if self.parents and include_parent:
815
+ for p in self.parents:
816
+ parent_sum = p.to_string(already_printed, add_global_stats=False)
817
+ if parent_sum:
818
+ out += parent_sum
819
+ out += "\n"
820
+ operators_stats_summary = None
821
+ if len(self.operators_stats) == 1:
822
+ operators_stats_summary = self.operators_stats[0]
823
+ operator_name = operators_stats_summary.operator_name
824
+ operator_uuid = self.dataset_uuid + operator_name
825
+ out += "Operator {} {}: ".format(self.number, operator_name)
826
+ if operator_uuid in already_printed:
827
+ out += "[execution cached]\n"
828
+ else:
829
+ already_printed.add(operator_uuid)
830
+ out += str(operators_stats_summary)
831
+ elif len(self.operators_stats) > 1:
832
+ rounded_total = round(self.time_total_s, 2)
833
+ if rounded_total <= 0:
834
+ # Handle -0.0 case.
835
+ rounded_total = 0
836
+ out += "Operator {} {}: executed in {}s\n".format(
837
+ self.number, self.base_name, rounded_total
838
+ )
839
+ for n, operators_stats_summary in enumerate(self.operators_stats):
840
+ operator_name = operators_stats_summary.operator_name
841
+ operator_uuid = self.dataset_uuid + operator_name
842
+ out += "\n"
843
+ out += "\tSuboperator {} {}: ".format(n, operator_name)
844
+ if operator_uuid in already_printed:
845
+ out += "\t[execution cached]\n"
846
+ else:
847
+ already_printed.add(operator_uuid)
848
+ out += str(operators_stats_summary)
849
+ verbose_stats_logs = DataContext.get_current().verbose_stats_logs
850
+ if verbose_stats_logs and self.extra_metrics:
851
+ indent = (
852
+ "\t"
853
+ if operators_stats_summary and operators_stats_summary.is_sub_operator
854
+ else ""
855
+ )
856
+ out += indent
857
+ out += "* Extra metrics: " + str(self.extra_metrics) + "\n"
858
+ out += str(self.iter_stats)
859
+
860
+ if len(self.operators_stats) > 0 and add_global_stats:
861
+ mb_spilled = round(self.global_bytes_spilled / 1e6)
862
+ mb_restored = round(self.global_bytes_restored / 1e6)
863
+ if mb_spilled or mb_restored:
864
+ out += "\nCluster memory:\n"
865
+ out += "* Spilled to disk: {}MB\n".format(mb_spilled)
866
+ out += "* Restored from disk: {}MB\n".format(mb_restored)
867
+
868
+ dataset_mb_spilled = round(self.dataset_bytes_spilled / 1e6)
869
+ if dataset_mb_spilled:
870
+ out += "\nDataset memory:\n"
871
+ out += "* Spilled to disk: {}MB\n".format(dataset_mb_spilled)
872
+
873
+ # For throughput, we compute both an observed Ray Data dataset throughput
874
+ # and an estimated single node dataset throughput.
875
+
876
+ # The observed dataset throughput is computed by dividing the total number
877
+ # of rows produced by the total wall time of the dataset (i.e. from start to
878
+ # finish how long did the dataset take to be processed). With the recursive
879
+ # nature of the DatasetStatsSummary, we use get_total_wall_time to determine
880
+ # the total wall time (this finds the difference between the earliest start
881
+ # and latest end for any block in any operator).
882
+
883
+ # The estimated single node dataset throughput is computed by dividing the
884
+ # total number of rows produced the sum of the wall times across all blocks
885
+ # of all operators. This assumes that on a single node the work done would
886
+ # be equivalent, with no concurrency.
887
+ output_num_rows = self.operators_stats[-1].output_num_rows
888
+ total_num_out_rows = output_num_rows["sum"] if output_num_rows else 0
889
+ wall_time = self.get_total_wall_time()
890
+ total_time_all_blocks = self.get_total_time_all_blocks()
891
+ if total_num_out_rows and wall_time and total_time_all_blocks:
892
+ out += "\n"
893
+ out += "Dataset throughput:\n"
894
+ out += (
895
+ "\t* Ray Data throughput:"
896
+ f" {total_num_out_rows / wall_time} "
897
+ "rows/s\n"
898
+ )
899
+ out += (
900
+ "\t* Estimated single node throughput:"
901
+ f" {total_num_out_rows / total_time_all_blocks} "
902
+ "rows/s\n"
903
+ )
904
+ if verbose_stats_logs and add_global_stats:
905
+ out += "\n" + self.runtime_metrics()
906
+
907
+ return out
908
+
909
+ @staticmethod
910
+ def _collect_dataset_stats_summaries(
911
+ curr: "DatasetStatsSummary",
912
+ ) -> List["DatasetStatsSummary"]:
913
+ summs = []
914
+ # TODO: Do operators ever have multiple parents? Do we need to deduplicate?
915
+ for p in curr.parents:
916
+ if p and p.parents:
917
+ summs.extend(DatasetStatsSummary._collect_dataset_stats_summaries(p))
918
+ return summs + [curr]
919
+
920
+ @staticmethod
921
+ def _find_start_and_end(summ: "DatasetStatsSummary") -> Tuple[float, float]:
922
+ earliest_start = min(ops.earliest_start_time for ops in summ.operators_stats)
923
+ latest_end = max(ops.latest_end_time for ops in summ.operators_stats)
924
+ return earliest_start, latest_end
925
+
926
+ def runtime_metrics(self) -> str:
927
+ total_wall_time = self.get_total_wall_time()
928
+
929
+ def fmt_line(name: str, time: float) -> str:
930
+ return f"* {name}: {fmt(time)} ({time / total_wall_time * 100:.3f}%)\n"
931
+
932
+ summaries = DatasetStatsSummary._collect_dataset_stats_summaries(self)
933
+ out = "Runtime Metrics:\n"
934
+ for summ in summaries:
935
+ if len(summ.operators_stats) > 0:
936
+ earliest_start, latest_end = DatasetStatsSummary._find_start_and_end(
937
+ summ
938
+ )
939
+ op_total_time = latest_end - earliest_start
940
+ out += fmt_line(summ.base_name, op_total_time)
941
+ out += fmt_line("Scheduling", self.streaming_exec_schedule_s)
942
+ out += fmt_line("Total", total_wall_time)
943
+ return out
944
+
945
+ def __repr__(self, level=0) -> str:
946
+ indent = leveled_indent(level)
947
+ operators_stats = "\n".join(
948
+ [ss.__repr__(level + 2) for ss in self.operators_stats]
949
+ )
950
+ parent_stats = "\n".join([ps.__repr__(level + 2) for ps in self.parents])
951
+ extra_metrics = "\n".join(
952
+ f"{leveled_indent(level + 2)}{k}: {v},"
953
+ for k, v in self.extra_metrics.items()
954
+ )
955
+
956
+ # Handle formatting case for empty outputs.
957
+ operators_stats = (
958
+ f"\n{operators_stats},\n{indent} " if operators_stats else ""
959
+ )
960
+ parent_stats = f"\n{parent_stats},\n{indent} " if parent_stats else ""
961
+ extra_metrics = f"\n{extra_metrics}\n{indent} " if extra_metrics else ""
962
+ return (
963
+ f"{indent}DatasetStatsSummary(\n"
964
+ f"{indent} dataset_uuid={self.dataset_uuid},\n"
965
+ f"{indent} base_name={self.base_name},\n"
966
+ f"{indent} number={self.number},\n"
967
+ f"{indent} extra_metrics={{{extra_metrics}}},\n"
968
+ f"{indent} operators_stats=[{operators_stats}],\n"
969
+ f"{indent} iter_stats={self.iter_stats.__repr__(level+1)},\n"
970
+ f"{indent} global_bytes_spilled={self.global_bytes_spilled / 1e6}MB,\n"
971
+ f"{indent} global_bytes_restored={self.global_bytes_restored / 1e6}MB,\n"
972
+ f"{indent} dataset_bytes_spilled={self.dataset_bytes_spilled / 1e6}MB,\n"
973
+ f"{indent} parents=[{parent_stats}],\n"
974
+ f"{indent})"
975
+ )
976
+
977
+ def get_total_wall_time(self) -> float:
978
+ """Calculate the total wall time for the dataset, this is done by finding
979
+ the earliest start time and latest end time for any block in any operator.
980
+ The wall time is the difference of these two times.
981
+ """
982
+ start_ends = [
983
+ DatasetStatsSummary._find_start_and_end(summ)
984
+ for summ in DatasetStatsSummary._collect_dataset_stats_summaries(self)
985
+ if len(summ.operators_stats) > 0
986
+ ]
987
+ if len(start_ends) == 0:
988
+ return 0
989
+ else:
990
+ earliest_start = min(start_end[0] for start_end in start_ends)
991
+ latest_end = max(start_end[1] for start_end in start_ends)
992
+ return latest_end - earliest_start
993
+
994
+ def get_total_time_all_blocks(self) -> float:
995
+ """Calculate the sum of the wall times across all blocks of all operators."""
996
+ summaries = DatasetStatsSummary._collect_dataset_stats_summaries(self)
997
+ return sum(
998
+ (
999
+ sum(
1000
+ ops.wall_time.get("sum", 0) if ops.wall_time else 0
1001
+ for ops in summ.operators_stats
1002
+ )
1003
+ )
1004
+ for summ in summaries
1005
+ )
1006
+
1007
+ def get_total_cpu_time(self) -> float:
1008
+ parent_sum = sum(p.get_total_cpu_time() for p in self.parents)
1009
+ return parent_sum + sum(
1010
+ ss.cpu_time.get("sum", 0) for ss in self.operators_stats
1011
+ )
1012
+
1013
+ def get_max_heap_memory(self) -> float:
1014
+ parent_memory = [p.get_max_heap_memory() for p in self.parents]
1015
+ parent_max = max(parent_memory) if parent_memory else 0
1016
+ if not self.operators_stats:
1017
+ return parent_max
1018
+
1019
+ return max(
1020
+ parent_max,
1021
+ *[ss.memory.get("max", 0) for ss in self.operators_stats],
1022
+ )
1023
+
1024
+
1025
+ @dataclass
1026
+ class OperatorStatsSummary:
1027
+ operator_name: str
1028
+ # Whether the operator associated with this OperatorStatsSummary object
1029
+ # is a suboperator
1030
+ is_sub_operator: bool
1031
+ # This is the total walltime of the entire operator, typically obtained from
1032
+ # `DatasetStats.time_total_s`. An important distinction is that this is the
1033
+ # overall runtime of the operator, pulled from the stats actor, whereas the
1034
+ # computed walltimes in `self.wall_time` are calculated on a operator level.
1035
+ time_total_s: float
1036
+ earliest_start_time: float
1037
+ latest_end_time: float
1038
+ # String summarizing high-level statistics from executing the operator
1039
+ block_execution_summary_str: str
1040
+ # The fields below are dicts with stats aggregated across blocks
1041
+ # processed in this operator. For example:
1042
+ # {"min": ..., "max": ..., "mean": ..., "sum": ...}
1043
+ wall_time: Optional[Dict[str, float]] = None
1044
+ cpu_time: Optional[Dict[str, float]] = None
1045
+ udf_time: Optional[Dict[str, float]] = None
1046
+ # memory: no "sum" stat
1047
+ memory: Optional[Dict[str, float]] = None
1048
+ output_num_rows: Optional[Dict[str, float]] = None
1049
+ output_size_bytes: Optional[Dict[str, float]] = None
1050
+ # node_count: "count" stat instead of "sum"
1051
+ node_count: Optional[Dict[str, float]] = None
1052
+ task_rows: Optional[Dict[str, float]] = None
1053
+
1054
+ @classmethod
1055
+ def from_block_metadata(
1056
+ cls,
1057
+ operator_name: str,
1058
+ block_metas: List[BlockMetadata],
1059
+ is_sub_operator: bool,
1060
+ ) -> "OperatorStatsSummary":
1061
+ """Calculate the stats for a operator from a given list of blocks,
1062
+ and generates a `OperatorStatsSummary` object with the results.
1063
+
1064
+ Args:
1065
+ block_metas: List of `BlockMetadata` to calculate stats of
1066
+ operator_name: Name of operator associated with `blocks`
1067
+ is_sub_operator: Whether this set of blocks belongs to a sub operator.
1068
+ Returns:
1069
+ A `OperatorStatsSummary` object initialized with the calculated statistics
1070
+ """
1071
+ exec_stats = [m.exec_stats for m in block_metas if m.exec_stats is not None]
1072
+ rounded_total = 0
1073
+ time_total_s = 0
1074
+ earliest_start_time, latest_end_time = 0, 0
1075
+
1076
+ if exec_stats:
1077
+ # Calculate the total execution time of operator as
1078
+ # the difference between the latest end time and
1079
+ # the earliest start time of all blocks in the operator.
1080
+ earliest_start_time = min(s.start_time_s for s in exec_stats)
1081
+ latest_end_time = max(s.end_time_s for s in exec_stats)
1082
+ time_total_s = latest_end_time - earliest_start_time
1083
+
1084
+ if is_sub_operator:
1085
+ exec_summary_str = "{} blocks produced\n".format(len(exec_stats))
1086
+ else:
1087
+ if exec_stats:
1088
+ rounded_total = round(time_total_s, 2)
1089
+ if rounded_total <= 0:
1090
+ # Handle -0.0 case.
1091
+ rounded_total = 0
1092
+ exec_summary_str = "{} blocks produced in {}s".format(
1093
+ len(exec_stats), rounded_total
1094
+ )
1095
+ else:
1096
+ exec_summary_str = ""
1097
+ exec_summary_str += "\n"
1098
+
1099
+ task_rows = collections.defaultdict(int)
1100
+ for meta in block_metas:
1101
+ if meta.num_rows is not None and meta.exec_stats is not None:
1102
+ task_rows[meta.exec_stats.task_idx] += meta.num_rows
1103
+ task_rows_stats = None
1104
+ if len(task_rows) > 0:
1105
+ task_rows_stats = {
1106
+ "min": min(task_rows.values()),
1107
+ "max": max(task_rows.values()),
1108
+ "mean": int(np.mean(list(task_rows.values()))),
1109
+ "count": len(task_rows),
1110
+ }
1111
+ exec_summary_str = "{} tasks executed, {}".format(
1112
+ len(task_rows), exec_summary_str
1113
+ )
1114
+
1115
+ wall_time_stats, cpu_stats, memory_stats, udf_stats = None, None, None, None
1116
+ if exec_stats:
1117
+ wall_time_stats = {
1118
+ "min": min([e.wall_time_s for e in exec_stats]),
1119
+ "max": max([e.wall_time_s for e in exec_stats]),
1120
+ "mean": np.mean([e.wall_time_s for e in exec_stats]),
1121
+ "sum": sum([e.wall_time_s for e in exec_stats]),
1122
+ }
1123
+ cpu_stats = {
1124
+ "min": min([e.cpu_time_s for e in exec_stats]),
1125
+ "max": max([e.cpu_time_s for e in exec_stats]),
1126
+ "mean": np.mean([e.cpu_time_s for e in exec_stats]),
1127
+ "sum": sum([e.cpu_time_s for e in exec_stats]),
1128
+ }
1129
+
1130
+ memory_stats_mb = [
1131
+ round(e.max_rss_bytes / (1024 * 1024), 2) for e in exec_stats
1132
+ ]
1133
+ memory_stats = {
1134
+ "min": min(memory_stats_mb),
1135
+ "max": max(memory_stats_mb),
1136
+ "mean": int(np.mean(memory_stats_mb)),
1137
+ }
1138
+
1139
+ udf_stats = {
1140
+ "min": min([e.udf_time_s for e in exec_stats]),
1141
+ "max": max([e.udf_time_s for e in exec_stats]),
1142
+ "mean": np.mean([e.udf_time_s for e in exec_stats]),
1143
+ "sum": sum([e.udf_time_s for e in exec_stats]),
1144
+ }
1145
+
1146
+ output_num_rows_stats = None
1147
+ output_num_rows = [m.num_rows for m in block_metas if m.num_rows is not None]
1148
+ if output_num_rows:
1149
+ output_num_rows_stats = {
1150
+ "min": min(output_num_rows),
1151
+ "max": max(output_num_rows),
1152
+ "mean": int(np.mean(output_num_rows)),
1153
+ "sum": sum(output_num_rows),
1154
+ }
1155
+
1156
+ output_size_bytes_stats = None
1157
+ output_size_bytes = [
1158
+ m.size_bytes for m in block_metas if m.size_bytes is not None
1159
+ ]
1160
+ if output_size_bytes:
1161
+ output_size_bytes_stats = {
1162
+ "min": min(output_size_bytes),
1163
+ "max": max(output_size_bytes),
1164
+ "mean": int(np.mean(output_size_bytes)),
1165
+ "sum": sum(output_size_bytes),
1166
+ }
1167
+
1168
+ node_counts_stats = None
1169
+ if exec_stats:
1170
+ node_tasks = collections.defaultdict(set)
1171
+ for s in exec_stats:
1172
+ node_tasks[s.node_id].add(s.task_idx)
1173
+
1174
+ node_counts = {node: len(tasks) for node, tasks in node_tasks.items()}
1175
+ node_counts_stats = {
1176
+ "min": min(node_counts.values()),
1177
+ "max": max(node_counts.values()),
1178
+ "mean": int(np.mean(list(node_counts.values()))),
1179
+ "count": len(node_counts),
1180
+ }
1181
+
1182
+ return OperatorStatsSummary(
1183
+ operator_name=operator_name,
1184
+ is_sub_operator=is_sub_operator,
1185
+ time_total_s=time_total_s,
1186
+ earliest_start_time=earliest_start_time,
1187
+ latest_end_time=latest_end_time,
1188
+ block_execution_summary_str=exec_summary_str,
1189
+ wall_time=wall_time_stats,
1190
+ cpu_time=cpu_stats,
1191
+ udf_time=udf_stats,
1192
+ memory=memory_stats,
1193
+ output_num_rows=output_num_rows_stats,
1194
+ output_size_bytes=output_size_bytes_stats,
1195
+ node_count=node_counts_stats,
1196
+ task_rows=task_rows_stats,
1197
+ )
1198
+
1199
+ def __str__(self) -> str:
1200
+ """For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from
1201
+ `OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string
1202
+ that summarizes operator execution statistics.
1203
+
1204
+ Returns:
1205
+ String with summary statistics for executing the given operator.
1206
+ """
1207
+ indent = "\t" if self.is_sub_operator else ""
1208
+ out = self.block_execution_summary_str
1209
+
1210
+ wall_time_stats = self.wall_time
1211
+ if wall_time_stats:
1212
+ out += indent
1213
+ out += "* Remote wall time: {} min, {} max, {} mean, {} total\n".format(
1214
+ fmt(wall_time_stats["min"]),
1215
+ fmt(wall_time_stats["max"]),
1216
+ fmt(wall_time_stats["mean"]),
1217
+ fmt(wall_time_stats["sum"]),
1218
+ )
1219
+
1220
+ cpu_stats = self.cpu_time
1221
+ if cpu_stats:
1222
+ out += indent
1223
+ out += "* Remote cpu time: {} min, {} max, {} mean, {} total\n".format(
1224
+ fmt(cpu_stats["min"]),
1225
+ fmt(cpu_stats["max"]),
1226
+ fmt(cpu_stats["mean"]),
1227
+ fmt(cpu_stats["sum"]),
1228
+ )
1229
+
1230
+ udf_stats = self.udf_time
1231
+ if udf_stats:
1232
+ out += indent
1233
+ out += "* UDF time: {} min, {} max, {} mean, {} total\n".format(
1234
+ fmt(udf_stats["min"]),
1235
+ fmt(udf_stats["max"]),
1236
+ fmt(udf_stats["mean"]),
1237
+ fmt(udf_stats["sum"]),
1238
+ )
1239
+
1240
+ memory_stats = self.memory
1241
+ if memory_stats:
1242
+ out += indent
1243
+ out += "* Peak heap memory usage (MiB): {} min, {} max, {} mean\n".format(
1244
+ memory_stats["min"],
1245
+ memory_stats["max"],
1246
+ memory_stats["mean"],
1247
+ )
1248
+
1249
+ output_num_rows_stats = self.output_num_rows
1250
+ if output_num_rows_stats:
1251
+ out += indent
1252
+ out += (
1253
+ "* Output num rows per block: {} min, {} max, {} mean, {} total\n"
1254
+ ).format(
1255
+ output_num_rows_stats["min"],
1256
+ output_num_rows_stats["max"],
1257
+ output_num_rows_stats["mean"],
1258
+ output_num_rows_stats["sum"],
1259
+ )
1260
+
1261
+ output_size_bytes_stats = self.output_size_bytes
1262
+ if output_size_bytes_stats:
1263
+ out += indent
1264
+ out += (
1265
+ "* Output size bytes per block: {} min, {} max, {} mean, {} total\n"
1266
+ ).format(
1267
+ output_size_bytes_stats["min"],
1268
+ output_size_bytes_stats["max"],
1269
+ output_size_bytes_stats["mean"],
1270
+ output_size_bytes_stats["sum"],
1271
+ )
1272
+
1273
+ task_rows = self.task_rows
1274
+ if task_rows:
1275
+ out += indent
1276
+ out += (
1277
+ "* Output rows per task: {} min, {} max, {} mean, {} tasks used\n"
1278
+ ).format(
1279
+ task_rows["min"],
1280
+ task_rows["max"],
1281
+ task_rows["mean"],
1282
+ task_rows["count"],
1283
+ )
1284
+
1285
+ node_count_stats = self.node_count
1286
+ if node_count_stats:
1287
+ out += indent
1288
+ out += "* Tasks per node: {} min, {} max, {} mean; {} nodes used\n".format(
1289
+ node_count_stats["min"],
1290
+ node_count_stats["max"],
1291
+ node_count_stats["mean"],
1292
+ node_count_stats["count"],
1293
+ )
1294
+ if output_num_rows_stats and self.time_total_s and wall_time_stats:
1295
+ # For throughput, we compute both an observed Ray Data operator throughput
1296
+ # and an estimated single node operator throughput.
1297
+
1298
+ # The observed Ray Data operator throughput is computed by dividing the
1299
+ # total number of rows produced by the wall time of the operator,
1300
+ # time_total_s.
1301
+
1302
+ # The estimated single node operator throughput is computed by dividing the
1303
+ # total number of rows produced by the the sum of the wall times across all
1304
+ # blocks of the operator. This assumes that on a single node the work done
1305
+ # would be equivalent, with no concurrency.
1306
+ total_num_out_rows = output_num_rows_stats["sum"]
1307
+ out += indent
1308
+ out += "* Operator throughput:\n"
1309
+ out += (
1310
+ indent + "\t* Ray Data throughput:"
1311
+ f" {total_num_out_rows / self.time_total_s} "
1312
+ "rows/s\n"
1313
+ )
1314
+ out += (
1315
+ indent + "\t* Estimated single node throughput:"
1316
+ f" {total_num_out_rows / wall_time_stats['sum']} "
1317
+ "rows/s\n"
1318
+ )
1319
+ return out
1320
+
1321
+ def __repr__(self, level=0) -> str:
1322
+ """For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from
1323
+ `OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string
1324
+ that summarizes operator execution statistics.
1325
+
1326
+ Returns:
1327
+ String with summary statistics for executing the given operator.
1328
+ """
1329
+ indent = leveled_indent(level)
1330
+ indent += leveled_indent(1) if self.is_sub_operator else ""
1331
+
1332
+ wall_time_stats = {k: fmt(v) for k, v in (self.wall_time or {}).items()}
1333
+ cpu_stats = {k: fmt(v) for k, v in (self.cpu_time or {}).items()}
1334
+ memory_stats = {k: fmt(v) for k, v in (self.memory or {}).items()}
1335
+ output_num_rows_stats = {
1336
+ k: fmt(v) for k, v in (self.output_num_rows or {}).items()
1337
+ }
1338
+ output_size_bytes_stats = {
1339
+ k: fmt(v) for k, v in (self.output_size_bytes or {}).items()
1340
+ }
1341
+ node_conut_stats = {k: fmt(v) for k, v in (self.node_count or {}).items()}
1342
+ out = (
1343
+ f"{indent}OperatorStatsSummary(\n"
1344
+ f"{indent} operator_name='{self.operator_name}',\n"
1345
+ f"{indent} is_suboperator={self.is_sub_operator},\n"
1346
+ f"{indent} time_total_s={fmt(self.time_total_s)},\n"
1347
+ # block_execution_summary_str already ends with \n
1348
+ f"{indent} block_execution_summary_str={self.block_execution_summary_str}"
1349
+ f"{indent} wall_time={wall_time_stats or None},\n"
1350
+ f"{indent} cpu_time={cpu_stats or None},\n"
1351
+ f"{indent} memory={memory_stats or None},\n"
1352
+ f"{indent} output_num_rows={output_num_rows_stats or None},\n"
1353
+ f"{indent} output_size_bytes={output_size_bytes_stats or None},\n"
1354
+ f"{indent} node_count={node_conut_stats or None},\n"
1355
+ f"{indent})"
1356
+ )
1357
+ return out
1358
+
1359
+
1360
+ @dataclass
1361
+ class IterStatsSummary:
1362
+ # Time spent in actor based prefetching, in seconds.
1363
+ wait_time: Timer
1364
+ # Time spent in `ray.get()`, in seconds
1365
+ get_time: Timer
1366
+ # Time spent in batch building, in seconds
1367
+ next_time: Timer
1368
+ # Time spent in `_format_batch_()`, in seconds
1369
+ format_time: Timer
1370
+ # Time spent in collate fn, in seconds
1371
+ collate_time: Timer
1372
+ # Time spent in finalize_fn, in seconds
1373
+ finalize_batch_time: Timer
1374
+ # Total time user thread is blocked by iter_batches
1375
+ block_time: Timer
1376
+ # Time spent in user code, in seconds
1377
+ user_time: Timer
1378
+ initialize_time: Timer
1379
+ # Total time taken by Dataset iterator, in seconds
1380
+ total_time: Timer
1381
+ # Time spent in streaming split coordinator
1382
+ streaming_split_coord_time: Timer
1383
+ # Num of blocks that are in local object store
1384
+ iter_blocks_local: int
1385
+ # Num of blocks that are in remote node and have to fetch locally
1386
+ iter_blocks_remote: int
1387
+ # Num of blocks with unknown locations
1388
+ iter_unknown_location: int
1389
+
1390
+ def __str__(self) -> str:
1391
+ return self.to_string()
1392
+
1393
+ def to_string(self) -> str:
1394
+ out = ""
1395
+ if (
1396
+ self.block_time.get()
1397
+ or self.total_time.get()
1398
+ or self.get_time.get()
1399
+ or self.next_time.get()
1400
+ or self.format_time.get()
1401
+ or self.collate_time.get()
1402
+ or self.finalize_batch_time.get()
1403
+ ):
1404
+ out += "\nDataset iterator time breakdown:\n"
1405
+ if self.total_time.get():
1406
+ out += "* Total time overall: {}\n".format(fmt(self.total_time.get()))
1407
+ if self.initialize_time.get():
1408
+ out += (
1409
+ " * Total time in Ray Data iterator initialization code: "
1410
+ "{}\n".format(fmt(self.initialize_time.get()))
1411
+ )
1412
+ if self.block_time.get():
1413
+ out += (
1414
+ " * Total time user thread is blocked by Ray Data iter_batches: "
1415
+ "{}\n".format(fmt(self.block_time.get()))
1416
+ )
1417
+ if self.user_time.get():
1418
+ out += " * Total execution time for user thread: {}\n".format(
1419
+ fmt(self.user_time.get())
1420
+ )
1421
+ out += (
1422
+ "* Batch iteration time breakdown (summed across prefetch threads):\n"
1423
+ )
1424
+ if self.get_time.get():
1425
+ out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format(
1426
+ fmt(self.get_time.min()),
1427
+ fmt(self.get_time.max()),
1428
+ fmt(self.get_time.avg()),
1429
+ fmt(self.get_time.get()),
1430
+ )
1431
+ if self.next_time.get():
1432
+ batch_creation_str = (
1433
+ " * In batch creation: {} min, {} max, " "{} avg, {} total\n"
1434
+ )
1435
+ out += batch_creation_str.format(
1436
+ fmt(self.next_time.min()),
1437
+ fmt(self.next_time.max()),
1438
+ fmt(self.next_time.avg()),
1439
+ fmt(self.next_time.get()),
1440
+ )
1441
+ if self.format_time.get():
1442
+ format_str = (
1443
+ " * In batch formatting: {} min, {} max, " "{} avg, {} total\n"
1444
+ )
1445
+ out += format_str.format(
1446
+ fmt(self.format_time.min()),
1447
+ fmt(self.format_time.max()),
1448
+ fmt(self.format_time.avg()),
1449
+ fmt(self.format_time.get()),
1450
+ )
1451
+ if self.collate_time.get():
1452
+ out += " * In collate_fn: {} min, {} max, {} avg, {} total\n".format(
1453
+ fmt(self.collate_time.min()),
1454
+ fmt(self.collate_time.max()),
1455
+ fmt(self.collate_time.avg()),
1456
+ fmt(self.collate_time.get()),
1457
+ )
1458
+ if self.finalize_batch_time.get():
1459
+ format_str = (
1460
+ " * In host->device transfer: {} min, {} max, {} avg, {} total\n"
1461
+ )
1462
+ out += format_str.format(
1463
+ fmt(self.finalize_batch_time.min()),
1464
+ fmt(self.finalize_batch_time.max()),
1465
+ fmt(self.finalize_batch_time.avg()),
1466
+ fmt(self.finalize_batch_time.get()),
1467
+ )
1468
+ if DataContext.get_current().enable_get_object_locations_for_metrics:
1469
+ out += "Block locations:\n"
1470
+ out += " * Num blocks local: {}\n".format(self.iter_blocks_local)
1471
+ out += " * Num blocks remote: {}\n".format(self.iter_blocks_remote)
1472
+ out += " * Num blocks unknown location: {}\n".format(
1473
+ self.iter_unknown_location
1474
+ )
1475
+ if self.streaming_split_coord_time.get() != 0:
1476
+ out += "Streaming split coordinator overhead time: "
1477
+ out += f"{fmt(self.streaming_split_coord_time.get())}\n"
1478
+
1479
+ return out
1480
+
1481
+ def __repr__(self, level=0) -> str:
1482
+ indent = leveled_indent(level)
1483
+ return (
1484
+ f"IterStatsSummary(\n"
1485
+ f"{indent} wait_time={fmt(self.wait_time.get()) or None},\n"
1486
+ f"{indent} get_time={fmt(self.get_time.get()) or None},\n"
1487
+ f"{indent} iter_blocks_local={self.iter_blocks_local or None},\n"
1488
+ f"{indent} iter_blocks_remote={self.iter_blocks_remote or None},\n"
1489
+ f"{indent} iter_unknown_location={self.iter_unknown_location or None},\n"
1490
+ f"{indent} next_time={fmt(self.next_time.get()) or None},\n"
1491
+ f"{indent} format_time={fmt(self.format_time.get()) or None},\n"
1492
+ f"{indent} user_time={fmt(self.user_time.get()) or None},\n"
1493
+ f"{indent} total_time={fmt(self.total_time.get()) or None},\n"
1494
+ f"{indent})"
1495
+ )
.venv/lib/python3.11/site-packages/ray/data/_internal/table_block.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Dict,
6
+ Iterator,
7
+ List,
8
+ Mapping,
9
+ Optional,
10
+ TypeVar,
11
+ Union,
12
+ )
13
+
14
+ import numpy as np
15
+
16
+ from ray.air.constants import TENSOR_COLUMN_NAME
17
+ from ray.data._internal.block_builder import BlockBuilder
18
+ from ray.data._internal.numpy_support import is_array_like
19
+ from ray.data._internal.row import TableRow
20
+ from ray.data._internal.size_estimator import SizeEstimator
21
+ from ray.data._internal.util import MiB
22
+ from ray.data.block import Block, BlockAccessor
23
+
24
+ if TYPE_CHECKING:
25
+ from ray.data._internal.planner.exchange.sort_task_spec import SortKey
26
+
27
+
28
+ T = TypeVar("T")
29
+
30
+ # The max size of Python tuples to buffer before compacting them into a
31
+ # table in the BlockBuilder.
32
+ MAX_UNCOMPACTED_SIZE_BYTES = 50 * MiB
33
+
34
+
35
+ class TableBlockBuilder(BlockBuilder):
36
+ def __init__(self, block_type):
37
+ # The set of uncompacted Python values buffered.
38
+ self._columns = collections.defaultdict(list)
39
+ # The column names of uncompacted Python values buffered.
40
+ self._column_names = None
41
+ # The set of compacted tables we have built so far.
42
+ self._tables: List[Any] = []
43
+ # Cursor into tables indicating up to which table we've accumulated table sizes.
44
+ # This is used to defer table size calculation, which can be expensive for e.g.
45
+ # Pandas DataFrames.
46
+ # This cursor points to the first table for which we haven't accumulated a table
47
+ # size.
48
+ self._tables_size_cursor = 0
49
+ # Accumulated table sizes, up to the table in _tables pointed to by
50
+ # _tables_size_cursor.
51
+ self._tables_size_bytes = 0
52
+ # Size estimator for un-compacted table values.
53
+ self._uncompacted_size = SizeEstimator()
54
+ self._num_rows = 0
55
+ self._num_compactions = 0
56
+ self._block_type = block_type
57
+
58
+ def add(self, item: Union[dict, TableRow, np.ndarray]) -> None:
59
+ if isinstance(item, TableRow):
60
+ item = item.as_pydict()
61
+ elif isinstance(item, np.ndarray):
62
+ item = {TENSOR_COLUMN_NAME: item}
63
+ if not isinstance(item, collections.abc.Mapping):
64
+ raise ValueError(
65
+ "Returned elements of an TableBlock must be of type `dict`, "
66
+ "got {} (type {}).".format(item, type(item))
67
+ )
68
+
69
+ item_column_names = item.keys()
70
+ if self._column_names is not None:
71
+ # Check all added rows have same columns.
72
+ if item_column_names != self._column_names:
73
+ raise ValueError(
74
+ "Current row has different columns compared to previous rows. "
75
+ f"Columns of current row: {sorted(item_column_names)}, "
76
+ f"Columns of previous rows: {sorted(self._column_names)}."
77
+ )
78
+ else:
79
+ # Initialize column names with the first added row.
80
+ self._column_names = item_column_names
81
+
82
+ for key, value in item.items():
83
+ if is_array_like(value) and not isinstance(value, np.ndarray):
84
+ value = np.array(value)
85
+ self._columns[key].append(value)
86
+ self._num_rows += 1
87
+ self._compact_if_needed()
88
+ self._uncompacted_size.add(item)
89
+
90
+ def add_block(self, block: Any) -> None:
91
+ if not isinstance(block, self._block_type):
92
+ raise TypeError(
93
+ f"Got a block of type {type(block)}, expected {self._block_type}."
94
+ "If you are mapping a function, ensure it returns an "
95
+ "object with the expected type. Block:\n"
96
+ f"{block}"
97
+ )
98
+ accessor = BlockAccessor.for_block(block)
99
+ self._tables.append(block)
100
+ self._num_rows += accessor.num_rows()
101
+
102
+ @staticmethod
103
+ def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:
104
+ raise NotImplementedError
105
+
106
+ @staticmethod
107
+ def _concat_tables(tables: List[Block]) -> Block:
108
+ raise NotImplementedError
109
+
110
+ @staticmethod
111
+ def _empty_table() -> Any:
112
+ raise NotImplementedError
113
+
114
+ @staticmethod
115
+ def _concat_would_copy() -> bool:
116
+ raise NotImplementedError
117
+
118
+ def will_build_yield_copy(self) -> bool:
119
+ if self._columns:
120
+ # Building a table from a dict of list columns always creates a copy.
121
+ return True
122
+ return self._concat_would_copy() and len(self._tables) > 1
123
+
124
+ def build(self) -> Block:
125
+ if self._columns:
126
+ tables = [self._table_from_pydict(self._columns)]
127
+ else:
128
+ tables = []
129
+
130
+ tables.extend(self._tables)
131
+
132
+ if len(tables) > 0:
133
+ return self._concat_tables(tables)
134
+ else:
135
+ return self._empty_table()
136
+
137
+ def num_rows(self) -> int:
138
+ return self._num_rows
139
+
140
+ def get_estimated_memory_usage(self) -> int:
141
+ if self._num_rows == 0:
142
+ return 0
143
+ for table in self._tables[self._tables_size_cursor :]:
144
+ self._tables_size_bytes += BlockAccessor.for_block(table).size_bytes()
145
+ self._tables_size_cursor = len(self._tables)
146
+ return self._tables_size_bytes + self._uncompacted_size.size_bytes()
147
+
148
+ def _compact_if_needed(self) -> None:
149
+ assert self._columns
150
+ if self._uncompacted_size.size_bytes() < MAX_UNCOMPACTED_SIZE_BYTES:
151
+ return
152
+ block = self._table_from_pydict(self._columns)
153
+ self.add_block(block)
154
+ self._uncompacted_size = SizeEstimator()
155
+ self._columns.clear()
156
+ self._num_compactions += 1
157
+
158
+
159
+ class TableBlockAccessor(BlockAccessor):
160
+ ROW_TYPE: TableRow = TableRow
161
+
162
+ def __init__(self, table: Any):
163
+ self._table = table
164
+
165
+ def _get_row(self, index: int, copy: bool = False) -> Union[TableRow, np.ndarray]:
166
+ base_row = self.slice(index, index + 1, copy=copy)
167
+ row = self.ROW_TYPE(base_row)
168
+ return row
169
+
170
+ @staticmethod
171
+ def _munge_conflict(name, count):
172
+ return f"{name}_{count+1}"
173
+
174
+ @staticmethod
175
+ def _build_tensor_row(row: TableRow) -> np.ndarray:
176
+ raise NotImplementedError
177
+
178
+ def to_default(self) -> Block:
179
+ # Always promote Arrow blocks to pandas for consistency, since
180
+ # we lazily convert pandas->Arrow internally for efficiency.
181
+ default = self.to_pandas()
182
+ return default
183
+
184
+ def column_names(self) -> List[str]:
185
+ raise NotImplementedError
186
+
187
+ def append_column(self, name: str, data: Any) -> Block:
188
+ raise NotImplementedError
189
+
190
+ def to_block(self) -> Block:
191
+ return self._table
192
+
193
+ def iter_rows(
194
+ self, public_row_format: bool
195
+ ) -> Iterator[Union[Mapping, np.ndarray]]:
196
+ outer = self
197
+
198
+ class Iter:
199
+ def __init__(self):
200
+ self._cur = -1
201
+
202
+ def __iter__(self):
203
+ return self
204
+
205
+ def __next__(self):
206
+ self._cur += 1
207
+ if self._cur < outer.num_rows():
208
+ row = outer._get_row(self._cur)
209
+ if public_row_format and isinstance(row, TableRow):
210
+ return row.as_pydict()
211
+ else:
212
+ return row
213
+ raise StopIteration
214
+
215
+ return Iter()
216
+
217
+ def _zip(self, acc: BlockAccessor) -> "Block":
218
+ raise NotImplementedError
219
+
220
+ def zip(self, other: "Block") -> "Block":
221
+ acc = BlockAccessor.for_block(other)
222
+ if not isinstance(acc, type(self)):
223
+ if isinstance(self, TableBlockAccessor) and isinstance(
224
+ acc, TableBlockAccessor
225
+ ):
226
+ # If block types are different, but still both of TableBlock type, try
227
+ # converting both to default block type before zipping.
228
+ self_norm, other_norm = TableBlockAccessor.normalize_block_types(
229
+ [self._table, other],
230
+ )
231
+ return BlockAccessor.for_block(self_norm).zip(other_norm)
232
+ else:
233
+ raise ValueError(
234
+ "Cannot zip {} with block of type {}".format(
235
+ type(self), type(other)
236
+ )
237
+ )
238
+ if acc.num_rows() != self.num_rows():
239
+ raise ValueError(
240
+ "Cannot zip self (length {}) with block of length {}".format(
241
+ self.num_rows(), acc.num_rows()
242
+ )
243
+ )
244
+ return self._zip(acc)
245
+
246
+ @staticmethod
247
+ def _empty_table() -> Any:
248
+ raise NotImplementedError
249
+
250
+ def _sample(self, n_samples: int, sort_key: "SortKey") -> Any:
251
+ raise NotImplementedError
252
+
253
+ def sample(self, n_samples: int, sort_key: "SortKey") -> Any:
254
+ if sort_key is None or callable(sort_key):
255
+ raise NotImplementedError(
256
+ f"Table sort key must be a column name, was: {sort_key}"
257
+ )
258
+ if self.num_rows() == 0:
259
+ # If the pyarrow table is empty we may not have schema
260
+ # so calling table.select() will raise an error.
261
+ return self._empty_table()
262
+ k = min(n_samples, self.num_rows())
263
+ return self._sample(k, sort_key)
264
+
265
+ @classmethod
266
+ def normalize_block_types(
267
+ cls,
268
+ blocks: List[Block],
269
+ normalize_type: Optional[str] = None,
270
+ ) -> List[Block]:
271
+ """Normalize input blocks to the specified `normalize_type`. If the blocks
272
+ are already all of the same type, returns the original blocks.
273
+
274
+ Args:
275
+ blocks: A list of TableBlocks to be normalized.
276
+ normalize_type: The type to normalize the blocks to. If None,
277
+ the default block type (Arrow) is used.
278
+
279
+ Returns:
280
+ A list of blocks of the same type.
281
+ """
282
+ seen_types = set()
283
+ for block in blocks:
284
+ acc = BlockAccessor.for_block(block)
285
+ if not isinstance(acc, TableBlockAccessor):
286
+ raise ValueError(
287
+ "Block type normalization is only supported for TableBlock, "
288
+ f"but received block of type: {type(block)}."
289
+ )
290
+ seen_types.add(type(block))
291
+
292
+ # Return original blocks if they are all of the same type.
293
+ if len(seen_types) <= 1:
294
+ return blocks
295
+
296
+ if normalize_type == "arrow":
297
+ results = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
298
+ elif normalize_type == "pandas":
299
+ results = [BlockAccessor.for_block(block).to_pandas() for block in blocks]
300
+ else:
301
+ results = [BlockAccessor.for_block(block).to_default() for block in blocks]
302
+
303
+ if any(not isinstance(block, type(results[0])) for block in results):
304
+ raise ValueError(
305
+ "Expected all blocks to be of the same type after normalization, but "
306
+ f"got different types: {[type(b) for b in results]}. "
307
+ "Try using blocks of the same type to avoid the issue "
308
+ "with block normalization."
309
+ )
310
+ return results
.venv/lib/python3.11/site-packages/ray/data/_internal/torch_iterable_dataset.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import IterableDataset
2
+
3
+
4
+ class TorchIterableDataset(IterableDataset):
5
+ def __init__(self, generator_func):
6
+ self.generator_func = generator_func
7
+
8
+ def __iter__(self):
9
+ it = self.generator_func()
10
+ yield from it
.venv/lib/python3.11/site-packages/ray/data/_internal/util.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import random
6
+ import sys
7
+ import threading
8
+ import time
9
+ import urllib.parse
10
+ from queue import Empty, Full, Queue
11
+ from types import ModuleType
12
+ from typing import (
13
+ TYPE_CHECKING,
14
+ Any,
15
+ Callable,
16
+ Generator,
17
+ Iterable,
18
+ Iterator,
19
+ List,
20
+ Optional,
21
+ Tuple,
22
+ TypeVar,
23
+ Union,
24
+ )
25
+
26
+ import numpy as np
27
+
28
+ import ray
29
+ from ray._private.utils import _get_pyarrow_version
30
+ from ray.data.context import DEFAULT_READ_OP_MIN_NUM_BLOCKS, WARN_PREFIX, DataContext
31
+
32
+ if TYPE_CHECKING:
33
+ import pandas
34
+ import pyarrow
35
+
36
+ from ray.data._internal.compute import ComputeStrategy
37
+ from ray.data._internal.planner.exchange.sort_task_spec import SortKey
38
+ from ray.data.block import Block, BlockMetadata, UserDefinedFunction
39
+ from ray.data.datasource import Datasource, Reader
40
+ from ray.util.placement_group import PlacementGroup
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ KiB = 1024 # bytes
46
+ MiB = 1024 * KiB
47
+ GiB = 1024 * MiB
48
+
49
+
50
+ SENTINEL = object()
51
+
52
+
53
+ # NOTE: Make sure that these lower and upper bounds stay in sync with version
54
+ # constraints given in python/setup.py.
55
+ # Inclusive minimum pyarrow version.
56
+ MIN_PYARROW_VERSION = "6.0.1"
57
+ RAY_DISABLE_PYARROW_VERSION_CHECK = "RAY_DISABLE_PYARROW_VERSION_CHECK"
58
+ _VERSION_VALIDATED = False
59
+ _LOCAL_SCHEME = "local"
60
+ _EXAMPLE_SCHEME = "example"
61
+
62
+
63
+ LazyModule = Union[None, bool, ModuleType]
64
+ _pyarrow_dataset: LazyModule = None
65
+
66
+
67
+ class _NullSentinel:
68
+ """Sentinel value that sorts greater than any other value."""
69
+
70
+ def __eq__(self, other):
71
+ return isinstance(other, _NullSentinel)
72
+
73
+ def __lt__(self, other):
74
+ return False
75
+
76
+ def __le__(self, other):
77
+ return isinstance(other, _NullSentinel)
78
+
79
+ def __gt__(self, other):
80
+ return True
81
+
82
+ def __ge__(self, other):
83
+ return True
84
+
85
+ def __hash__(self):
86
+ return id(self)
87
+
88
+
89
+ NULL_SENTINEL = _NullSentinel()
90
+
91
+
92
+ def _lazy_import_pyarrow_dataset() -> LazyModule:
93
+ global _pyarrow_dataset
94
+ if _pyarrow_dataset is None:
95
+ try:
96
+ from pyarrow import dataset as _pyarrow_dataset
97
+ except ModuleNotFoundError:
98
+ # If module is not found, set _pyarrow to False so we won't
99
+ # keep trying to import it on every _lazy_import_pyarrow() call.
100
+ _pyarrow_dataset = False
101
+ return _pyarrow_dataset
102
+
103
+
104
+ def _check_pyarrow_version():
105
+ """Check that pyarrow's version is within the supported bounds."""
106
+ global _VERSION_VALIDATED
107
+
108
+ if not _VERSION_VALIDATED:
109
+ if os.environ.get(RAY_DISABLE_PYARROW_VERSION_CHECK, "0") == "1":
110
+ _VERSION_VALIDATED = True
111
+ return
112
+
113
+ version = _get_pyarrow_version()
114
+ if version is not None:
115
+ from packaging.version import parse as parse_version
116
+
117
+ if parse_version(version) < parse_version(MIN_PYARROW_VERSION):
118
+ raise ImportError(
119
+ f"Dataset requires pyarrow >= {MIN_PYARROW_VERSION}, but "
120
+ f"{version} is installed. Reinstall with "
121
+ f'`pip install -U "pyarrow"`. '
122
+ "If you want to disable this pyarrow version check, set the "
123
+ f"environment variable {RAY_DISABLE_PYARROW_VERSION_CHECK}=1."
124
+ )
125
+ else:
126
+ logger.warning(
127
+ "You are using the 'pyarrow' module, but the exact version is unknown "
128
+ "(possibly carried as an internal component by another module). Please "
129
+ f"make sure you are using pyarrow >= {MIN_PYARROW_VERSION} to ensure "
130
+ "compatibility with Ray Dataset. "
131
+ "If you want to disable this pyarrow version check, set the "
132
+ f"environment variable {RAY_DISABLE_PYARROW_VERSION_CHECK}=1."
133
+ )
134
+ _VERSION_VALIDATED = True
135
+
136
+
137
+ def _autodetect_parallelism(
138
+ parallelism: int,
139
+ target_max_block_size: int,
140
+ ctx: DataContext,
141
+ datasource_or_legacy_reader: Optional[Union["Datasource", "Reader"]] = None,
142
+ mem_size: Optional[int] = None,
143
+ placement_group: Optional["PlacementGroup"] = None,
144
+ avail_cpus: Optional[int] = None,
145
+ ) -> Tuple[int, str, Optional[int]]:
146
+ """Returns parallelism to use and the min safe parallelism to avoid OOMs.
147
+
148
+ This detects parallelism using the following heuristics, applied in order:
149
+
150
+ 1) We start with the default value of 200. This can be overridden by
151
+ setting the `read_op_min_num_blocks` attribute of
152
+ :class:`~ray.data.context.DataContext`.
153
+ 2) Min block size. If the parallelism would make blocks smaller than this
154
+ threshold, the parallelism is reduced to avoid the overhead of tiny blocks.
155
+ 3) Max block size. If the parallelism would make blocks larger than this
156
+ threshold, the parallelism is increased to avoid OOMs during processing.
157
+ 4) Available CPUs. If the parallelism cannot make use of all the available
158
+ CPUs in the cluster, the parallelism is increased until it can.
159
+
160
+ Args:
161
+ parallelism: The user-requested parallelism, or -1 for auto-detection.
162
+ target_max_block_size: The target max block size to
163
+ produce. We pass this separately from the
164
+ DatasetContext because it may be set per-op instead of
165
+ per-Dataset.
166
+ ctx: The current Dataset context to use for configs.
167
+ datasource_or_legacy_reader: The datasource or legacy reader, to be used for
168
+ data size estimation.
169
+ mem_size: If passed, then used to compute the parallelism according to
170
+ target_max_block_size.
171
+ placement_group: The placement group that this Dataset
172
+ will execute inside, if any.
173
+ avail_cpus: Override avail cpus detection (for testing only).
174
+
175
+ Returns:
176
+ Tuple of detected parallelism (only if -1 was specified), the reason
177
+ for the detected parallelism (only if -1 was specified), and the estimated
178
+ inmemory size of the dataset.
179
+ """
180
+ min_safe_parallelism = 1
181
+ max_reasonable_parallelism = sys.maxsize
182
+ if mem_size is None and datasource_or_legacy_reader:
183
+ mem_size = datasource_or_legacy_reader.estimate_inmemory_data_size()
184
+ if mem_size is not None and not np.isnan(mem_size):
185
+ min_safe_parallelism = max(1, int(mem_size / target_max_block_size))
186
+ max_reasonable_parallelism = max(1, int(mem_size / ctx.target_min_block_size))
187
+
188
+ reason = ""
189
+ if parallelism < 0:
190
+ if parallelism != -1:
191
+ raise ValueError("`parallelism` must either be -1 or a positive integer.")
192
+
193
+ if (
194
+ ctx.min_parallelism is not None
195
+ and ctx.min_parallelism != DEFAULT_READ_OP_MIN_NUM_BLOCKS
196
+ and ctx.read_op_min_num_blocks == DEFAULT_READ_OP_MIN_NUM_BLOCKS
197
+ ):
198
+ logger.warning(
199
+ "``DataContext.min_parallelism`` is deprecated in Ray 2.10. "
200
+ "Please specify ``DataContext.read_op_min_num_blocks`` instead."
201
+ )
202
+ ctx.read_op_min_num_blocks = ctx.min_parallelism
203
+
204
+ # Start with 2x the number of cores as a baseline, with a min floor.
205
+ if placement_group is None:
206
+ placement_group = ray.util.get_current_placement_group()
207
+ avail_cpus = avail_cpus or _estimate_avail_cpus(placement_group)
208
+ parallelism = max(
209
+ min(ctx.read_op_min_num_blocks, max_reasonable_parallelism),
210
+ min_safe_parallelism,
211
+ avail_cpus * 2,
212
+ )
213
+
214
+ if parallelism == ctx.read_op_min_num_blocks:
215
+ reason = (
216
+ "DataContext.get_current().read_op_min_num_blocks="
217
+ f"{ctx.read_op_min_num_blocks}"
218
+ )
219
+ elif parallelism == max_reasonable_parallelism:
220
+ reason = (
221
+ "output blocks of size at least "
222
+ "DataContext.get_current().target_min_block_size="
223
+ f"{ctx.target_min_block_size / (1024 * 1024)}MiB"
224
+ )
225
+ elif parallelism == min_safe_parallelism:
226
+ reason = (
227
+ "output blocks of size at most "
228
+ "DataContext.get_current().target_max_block_size="
229
+ f"{ctx.target_max_block_size / (1024 * 1024)}MiB"
230
+ )
231
+ else:
232
+ reason = (
233
+ "parallelism at least twice the available number "
234
+ f"of CPUs ({avail_cpus})"
235
+ )
236
+
237
+ logger.debug(
238
+ f"Autodetected parallelism={parallelism} based on "
239
+ f"estimated_available_cpus={avail_cpus} and "
240
+ f"estimated_data_size={mem_size}."
241
+ )
242
+
243
+ return parallelism, reason, mem_size
244
+
245
+
246
+ def _estimate_avail_cpus(cur_pg: Optional["PlacementGroup"]) -> int:
247
+ """Estimates the available CPU parallelism for this Dataset in the cluster.
248
+
249
+ If we aren't in a placement group, this is trivially the number of CPUs in the
250
+ cluster. Otherwise, we try to calculate how large the placement group is relative
251
+ to the size of the cluster.
252
+
253
+ Args:
254
+ cur_pg: The current placement group, if any.
255
+ """
256
+ cluster_cpus = int(ray.cluster_resources().get("CPU", 1))
257
+ cluster_gpus = int(ray.cluster_resources().get("GPU", 0))
258
+
259
+ # If we're in a placement group, we shouldn't assume the entire cluster's
260
+ # resources are available for us to use. Estimate an upper bound on what's
261
+ # reasonable to assume is available for datasets to use.
262
+ if cur_pg:
263
+ pg_cpus = 0
264
+ for bundle in cur_pg.bundle_specs:
265
+ # Calculate the proportion of the cluster this placement group "takes up".
266
+ # Then scale our cluster_cpus proportionally to avoid over-parallelizing
267
+ # if there are many parallel Tune trials using the cluster.
268
+ cpu_fraction = bundle.get("CPU", 0) / max(1, cluster_cpus)
269
+ gpu_fraction = bundle.get("GPU", 0) / max(1, cluster_gpus)
270
+ max_fraction = max(cpu_fraction, gpu_fraction)
271
+ # Over-parallelize by up to a factor of 2, but no more than that. It's
272
+ # preferrable to over-estimate than under-estimate.
273
+ pg_cpus += 2 * int(max_fraction * cluster_cpus)
274
+
275
+ return min(cluster_cpus, pg_cpus)
276
+
277
+ return cluster_cpus
278
+
279
+
280
+ def _estimate_available_parallelism() -> int:
281
+ """Estimates the available CPU parallelism for this Dataset in the cluster.
282
+ If we are currently in a placement group, take that into account."""
283
+ cur_pg = ray.util.get_current_placement_group()
284
+ return _estimate_avail_cpus(cur_pg)
285
+
286
+
287
+ def _warn_on_high_parallelism(requested_parallelism, num_read_tasks):
288
+ available_cpu_slots = ray.available_resources().get("CPU", 1)
289
+ if (
290
+ requested_parallelism
291
+ and num_read_tasks > available_cpu_slots * 4
292
+ and num_read_tasks >= 5000
293
+ ):
294
+ logger.warning(
295
+ f"{WARN_PREFIX} The requested parallelism of {requested_parallelism} "
296
+ "is more than 4x the number of available CPU slots in the cluster of "
297
+ f"{available_cpu_slots}. This can "
298
+ "lead to slowdowns during the data reading phase due to excessive "
299
+ "task creation. Reduce the parallelism to match with the available "
300
+ "CPU slots in the cluster, or set parallelism to -1 for Ray Data "
301
+ "to automatically determine the parallelism. "
302
+ "You can ignore this message if the cluster is expected to autoscale."
303
+ )
304
+
305
+
306
+ def _check_import(obj, *, module: str, package: str) -> None:
307
+ """Check if a required dependency is installed.
308
+
309
+ If `module` can't be imported, this function raises an `ImportError` instructing
310
+ the user to install `package` from PyPI.
311
+
312
+ Args:
313
+ obj: The object that has a dependency.
314
+ module: The name of the module to import.
315
+ package: The name of the package on PyPI.
316
+ """
317
+ try:
318
+ importlib.import_module(module)
319
+ except ImportError:
320
+ raise ImportError(
321
+ f"`{obj.__class__.__name__}` depends on '{package}', but '{package}' "
322
+ f"couldn't be imported. You can install '{package}' by running `pip "
323
+ f"install {package}`."
324
+ )
325
+
326
+
327
+ def _resolve_custom_scheme(path: str) -> str:
328
+ """Returns the resolved path if the given path follows a Ray-specific custom
329
+ scheme. Othewise, returns the path unchanged.
330
+
331
+ The supported custom schemes are: "local", "example".
332
+ """
333
+ parsed_uri = urllib.parse.urlparse(path)
334
+ if parsed_uri.scheme == _LOCAL_SCHEME:
335
+ path = parsed_uri.netloc + parsed_uri.path
336
+ elif parsed_uri.scheme == _EXAMPLE_SCHEME:
337
+ example_data_path = pathlib.Path(__file__).parent.parent / "examples" / "data"
338
+ path = example_data_path / (parsed_uri.netloc + parsed_uri.path)
339
+ path = str(path.resolve())
340
+ return path
341
+
342
+
343
+ def _is_local_scheme(paths: Union[str, List[str]]) -> bool:
344
+ """Returns True if the given paths are in local scheme.
345
+ Note: The paths must be in same scheme, i.e. it's invalid and
346
+ will raise error if paths are mixed with different schemes.
347
+ """
348
+ if isinstance(paths, str):
349
+ paths = [paths]
350
+ if isinstance(paths, pathlib.Path):
351
+ paths = [str(paths)]
352
+ elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths):
353
+ raise ValueError("paths must be a path string or a list of path strings.")
354
+ elif len(paths) == 0:
355
+ raise ValueError("Must provide at least one path.")
356
+ num = sum(urllib.parse.urlparse(path).scheme == _LOCAL_SCHEME for path in paths)
357
+ if num > 0 and num < len(paths):
358
+ raise ValueError(
359
+ "The paths must all be local-scheme or not local-scheme, "
360
+ f"but found mixed {paths}"
361
+ )
362
+ return num == len(paths)
363
+
364
+
365
+ def _truncated_repr(obj: Any) -> str:
366
+ """Utility to return a truncated object representation for error messages."""
367
+ msg = str(obj)
368
+ if len(msg) > 200:
369
+ msg = msg[:200] + "..."
370
+ return msg
371
+
372
+
373
+ def _insert_doc_at_pattern(
374
+ obj,
375
+ *,
376
+ message: str,
377
+ pattern: str,
378
+ insert_after: bool = True,
379
+ directive: Optional[str] = None,
380
+ skip_matches: int = 0,
381
+ ) -> str:
382
+ if "\n" in message:
383
+ raise ValueError(
384
+ "message shouldn't contain any newlines, since this function will insert "
385
+ f"its own linebreaks when text wrapping: {message}"
386
+ )
387
+
388
+ doc = obj.__doc__.strip()
389
+ if not doc:
390
+ doc = ""
391
+
392
+ if pattern == "" and insert_after:
393
+ # Empty pattern + insert_after means that we want to append the message to the
394
+ # end of the docstring.
395
+ head = doc
396
+ tail = ""
397
+ else:
398
+ tail = doc
399
+ i = tail.find(pattern)
400
+ skip_matches_left = skip_matches
401
+ while i != -1:
402
+ if insert_after:
403
+ # Set offset to the first character after the pattern.
404
+ offset = i + len(pattern)
405
+ else:
406
+ # Set offset to the first character in the matched line.
407
+ offset = tail[:i].rfind("\n") + 1
408
+ head = tail[:offset]
409
+ tail = tail[offset:]
410
+ skip_matches_left -= 1
411
+ if skip_matches_left <= 0:
412
+ break
413
+ elif not insert_after:
414
+ # Move past the found pattern, since we're skipping it.
415
+ tail = tail[i - offset + len(pattern) :]
416
+ i = tail.find(pattern)
417
+ else:
418
+ raise ValueError(
419
+ f"Pattern {pattern} not found after {skip_matches} skips in docstring "
420
+ f"{doc}"
421
+ )
422
+ # Get indentation of the to-be-inserted text.
423
+ after_lines = list(filter(bool, tail.splitlines()))
424
+ if len(after_lines) > 0:
425
+ lines = after_lines
426
+ else:
427
+ lines = list(filter(bool, reversed(head.splitlines())))
428
+ # Should always have at least one non-empty line in the docstring.
429
+ assert len(lines) > 0
430
+ indent = " " * (len(lines[0]) - len(lines[0].lstrip()))
431
+ # Handle directive.
432
+ message = message.strip("\n")
433
+ if directive is not None:
434
+ base = f"{indent}.. {directive}::\n"
435
+ message = message.replace("\n", "\n" + indent + " " * 4)
436
+ message = base + indent + " " * 4 + message
437
+ else:
438
+ message = indent + message.replace("\n", "\n" + indent)
439
+ # Add two blank lines before/after message, if necessary.
440
+ if insert_after ^ (pattern == "\n\n"):
441
+ # Only two blank lines before message if:
442
+ # 1. Inserting message after pattern and pattern is not two blank lines.
443
+ # 2. Inserting message before pattern and pattern is two blank lines.
444
+ message = "\n\n" + message
445
+ if (not insert_after) ^ (pattern == "\n\n"):
446
+ # Only two blank lines after message if:
447
+ # 1. Inserting message before pattern and pattern is not two blank lines.
448
+ # 2. Inserting message after pattern and pattern is two blank lines.
449
+ message = message + "\n\n"
450
+
451
+ # Insert message before/after pattern.
452
+ parts = [head, message, tail]
453
+ # Build new docstring.
454
+ obj.__doc__ = "".join(parts)
455
+
456
+
457
+ def _consumption_api(
458
+ if_more_than_read: bool = False,
459
+ datasource_metadata: Optional[str] = None,
460
+ extra_condition: Optional[str] = None,
461
+ delegate: Optional[str] = None,
462
+ pattern="Examples:",
463
+ insert_after=False,
464
+ ):
465
+ """Annotate the function with an indication that it's a consumption API, and that it
466
+ will trigger Dataset execution.
467
+ """
468
+ base = (
469
+ " will trigger execution of the lazy transformations performed on "
470
+ "this dataset."
471
+ )
472
+ if delegate:
473
+ message = delegate + base
474
+ elif not if_more_than_read:
475
+ message = "This operation" + base
476
+ else:
477
+ condition = "If this dataset consists of more than a read, "
478
+ if datasource_metadata is not None:
479
+ condition += (
480
+ f"or if the {datasource_metadata} can't be determined from the "
481
+ "metadata provided by the datasource, "
482
+ )
483
+ if extra_condition is not None:
484
+ condition += extra_condition + ", "
485
+ message = condition + "then this operation" + base
486
+
487
+ def wrap(obj):
488
+ _insert_doc_at_pattern(
489
+ obj,
490
+ message=message,
491
+ pattern=pattern,
492
+ insert_after=insert_after,
493
+ directive="note",
494
+ )
495
+ return obj
496
+
497
+ return wrap
498
+
499
+
500
+ def ConsumptionAPI(*args, **kwargs):
501
+ """Annotate the function with an indication that it's a consumption API, and that it
502
+ will trigger Dataset execution.
503
+ """
504
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
505
+ return _consumption_api()(args[0])
506
+ return _consumption_api(*args, **kwargs)
507
+
508
+
509
+ def _all_to_all_api(*args, **kwargs):
510
+ """Annotate the function with an indication that it's a all to all API, and that it
511
+ is an operation that requires all inputs to be materialized in-memory to execute.
512
+ """
513
+
514
+ def wrap(obj):
515
+ _insert_doc_at_pattern(
516
+ obj,
517
+ message=(
518
+ "This operation requires all inputs to be "
519
+ "materialized in object store for it to execute."
520
+ ),
521
+ pattern="Examples:",
522
+ insert_after=False,
523
+ directive="note",
524
+ )
525
+ return obj
526
+
527
+ return wrap
528
+
529
+
530
+ def AllToAllAPI(*args, **kwargs):
531
+ """Annotate the function with an indication that it's a all to all API, and that it
532
+ is an operation that requires all inputs to be materialized in-memory to execute.
533
+ """
534
+ # This should only be used as a decorator for dataset methods.
535
+ assert len(args) == 1 and len(kwargs) == 0 and callable(args[0])
536
+ return _all_to_all_api()(args[0])
537
+
538
+
539
+ def get_compute_strategy(
540
+ fn: "UserDefinedFunction",
541
+ fn_constructor_args: Optional[Iterable[Any]] = None,
542
+ compute: Optional[Union[str, "ComputeStrategy"]] = None,
543
+ concurrency: Optional[Union[int, Tuple[int, int]]] = None,
544
+ ) -> "ComputeStrategy":
545
+ """Get `ComputeStrategy` based on the function or class, and concurrency
546
+ information.
547
+
548
+ Args:
549
+ fn: The function or generator to apply to a record batch, or a class type
550
+ that can be instantiated to create such a callable.
551
+ fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
552
+ compute: Either "tasks" (default) to use Ray Tasks or an
553
+ :class:`~ray.data.ActorPoolStrategy` to use an autoscaling actor pool.
554
+ concurrency: The number of Ray workers to use concurrently.
555
+
556
+ Returns:
557
+ The `ComputeStrategy` for execution.
558
+ """
559
+ # Lazily import these objects to avoid circular imports.
560
+ from ray.data._internal.compute import ActorPoolStrategy, TaskPoolStrategy
561
+ from ray.data.block import CallableClass
562
+
563
+ if isinstance(fn, CallableClass):
564
+ is_callable_class = True
565
+ else:
566
+ # TODO(chengsu): disallow object that is not a function. For example,
567
+ # An object instance of class often indicates a bug in user code.
568
+ is_callable_class = False
569
+ if fn_constructor_args is not None:
570
+ raise ValueError(
571
+ "``fn_constructor_args`` can only be specified if providing a "
572
+ f"callable class instance for ``fn``, but got: {fn}."
573
+ )
574
+
575
+ if compute is not None:
576
+ # Legacy code path to support `compute` argument.
577
+ logger.warning(
578
+ "The argument ``compute`` is deprecated in Ray 2.9. Please specify "
579
+ "argument ``concurrency`` instead. For more information, see "
580
+ "https://docs.ray.io/en/master/data/transforming-data.html#"
581
+ "stateful-transforms."
582
+ )
583
+ if is_callable_class and (
584
+ compute == "tasks" or isinstance(compute, TaskPoolStrategy)
585
+ ):
586
+ raise ValueError(
587
+ "``compute`` must specify an actor compute strategy when using a "
588
+ f"callable class, but got: {compute}. For example, use "
589
+ "``compute=ray.data.ActorPoolStrategy(size=n)``."
590
+ )
591
+ elif not is_callable_class and (
592
+ compute == "actors" or isinstance(compute, ActorPoolStrategy)
593
+ ):
594
+ raise ValueError(
595
+ f"``compute`` is specified as the actor compute strategy: {compute}, "
596
+ f"but ``fn`` is not a callable class: {fn}. Pass a callable class or "
597
+ "use the default ``compute`` strategy."
598
+ )
599
+ return compute
600
+ elif concurrency is not None:
601
+ if isinstance(concurrency, tuple):
602
+ if (
603
+ len(concurrency) == 2
604
+ and isinstance(concurrency[0], int)
605
+ and isinstance(concurrency[1], int)
606
+ ):
607
+ if is_callable_class:
608
+ return ActorPoolStrategy(
609
+ min_size=concurrency[0], max_size=concurrency[1]
610
+ )
611
+ else:
612
+ raise ValueError(
613
+ "``concurrency`` is set as a tuple of integers, but ``fn`` "
614
+ f"is not a callable class: {fn}. Use ``concurrency=n`` to "
615
+ "control maximum number of workers to use."
616
+ )
617
+ else:
618
+ raise ValueError(
619
+ "``concurrency`` is expected to be set as a tuple of "
620
+ f"integers, but got: {concurrency}."
621
+ )
622
+ elif isinstance(concurrency, int):
623
+ if is_callable_class:
624
+ return ActorPoolStrategy(size=concurrency)
625
+ else:
626
+ return TaskPoolStrategy(size=concurrency)
627
+ else:
628
+ raise ValueError(
629
+ "``concurrency`` is expected to be set as an integer or a "
630
+ f"tuple of integers, but got: {concurrency}."
631
+ )
632
+ else:
633
+ if is_callable_class:
634
+ raise ValueError(
635
+ "``concurrency`` must be specified when using a callable class. "
636
+ "For example, use ``concurrency=n`` for a pool of ``n`` workers."
637
+ )
638
+ else:
639
+ return TaskPoolStrategy()
640
+
641
+
642
+ def capfirst(s: str):
643
+ """Capitalize the first letter of a string
644
+
645
+ Args:
646
+ s: String to capitalize
647
+
648
+ Returns:
649
+ Capitalized string
650
+ """
651
+ return s[0].upper() + s[1:]
652
+
653
+
654
+ def capitalize(s: str):
655
+ """Capitalize a string, removing '_' and keeping camelcase.
656
+
657
+ Args:
658
+ s: String to capitalize
659
+
660
+ Returns:
661
+ Capitalized string with no underscores.
662
+ """
663
+ return "".join(capfirst(x) for x in s.split("_"))
664
+
665
+
666
+ def pandas_df_to_arrow_block(df: "pandas.DataFrame") -> "Block":
667
+ from ray.data.block import BlockAccessor, BlockExecStats
668
+
669
+ block = BlockAccessor.for_block(df).to_arrow()
670
+ stats = BlockExecStats.builder()
671
+ return (
672
+ block,
673
+ BlockAccessor.for_block(block).get_metadata(exec_stats=stats.build()),
674
+ )
675
+
676
+
677
+ def ndarray_to_block(ndarray: np.ndarray, ctx: DataContext) -> "Block":
678
+ from ray.data.block import BlockAccessor, BlockExecStats
679
+
680
+ DataContext._set_current(ctx)
681
+
682
+ stats = BlockExecStats.builder()
683
+ block = BlockAccessor.batch_to_block({"data": ndarray})
684
+ metadata = BlockAccessor.for_block(block).get_metadata(exec_stats=stats.build())
685
+ return block, metadata
686
+
687
+
688
+ def get_table_block_metadata(
689
+ table: Union["pyarrow.Table", "pandas.DataFrame"]
690
+ ) -> "BlockMetadata":
691
+ from ray.data.block import BlockAccessor, BlockExecStats
692
+
693
+ stats = BlockExecStats.builder()
694
+ return BlockAccessor.for_block(table).get_metadata(exec_stats=stats.build())
695
+
696
+
697
+ def unify_block_metadata_schema(
698
+ metadata: List["BlockMetadata"],
699
+ ) -> Optional[Union[type, "pyarrow.lib.Schema"]]:
700
+ """For the input list of BlockMetadata, return a unified schema of the
701
+ corresponding blocks. If the metadata have no valid schema, returns None.
702
+ """
703
+ # Some blocks could be empty, in which case we cannot get their schema.
704
+ # TODO(ekl) validate schema is the same across different blocks.
705
+ from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
706
+
707
+ # First check if there are blocks with computed schemas, then unify
708
+ # valid schemas from all such blocks.
709
+ schemas_to_unify = []
710
+ for m in metadata:
711
+ if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
712
+ schemas_to_unify.append(m.schema)
713
+ if schemas_to_unify:
714
+ # Check valid pyarrow installation before attempting schema unification
715
+ try:
716
+ import pyarrow as pa
717
+ except ImportError:
718
+ pa = None
719
+ # If the result contains PyArrow schemas, unify them
720
+ if pa is not None and all(isinstance(s, pa.Schema) for s in schemas_to_unify):
721
+ return unify_schemas(schemas_to_unify)
722
+ # Otherwise, if the resulting schemas are simple types (e.g. int),
723
+ # return the first schema.
724
+ return schemas_to_unify[0]
725
+ return None
726
+
727
+
728
+ def find_partition_index(
729
+ table: Union["pyarrow.Table", "pandas.DataFrame"],
730
+ desired: Tuple[Union[int, float]],
731
+ sort_key: "SortKey",
732
+ ) -> int:
733
+ """For the given block, find the index where the desired value should be
734
+ added, to maintain sorted order.
735
+
736
+ We do this by iterating over each column, starting with the primary sort key,
737
+ and binary searching for the desired value in the column. Each binary search
738
+ shortens the "range" of indices (represented by ``left`` and ``right``, which
739
+ are indices of rows) where the desired value could be inserted.
740
+
741
+ Args:
742
+ table: The block to search in.
743
+ desired: A single tuple representing the boundary to partition at.
744
+ ``len(desired)`` must be less than or equal to the number of columns
745
+ being sorted.
746
+ sort_key: The sort key to use for sorting, providing the columns to be
747
+ sorted and their directions.
748
+
749
+ Returns:
750
+ The index where the desired value should be inserted to maintain sorted
751
+ order.
752
+ """
753
+ columns = sort_key.get_columns()
754
+ descending = sort_key.get_descending()
755
+
756
+ left, right = 0, len(table)
757
+ for i in range(len(desired)):
758
+ if left == right:
759
+ return right
760
+ col_name = columns[i]
761
+ col_vals = table[col_name].to_numpy()[left:right]
762
+ desired_val = desired[i]
763
+
764
+ # Handle null values - replace them with sentinel values
765
+ if desired_val is None:
766
+ desired_val = NULL_SENTINEL
767
+
768
+ # Replace None/NaN values in col_vals with sentinel
769
+ null_mask = col_vals == None # noqa: E711
770
+ if null_mask.any():
771
+ col_vals = col_vals.copy() # Make a copy to avoid modifying original
772
+ col_vals[null_mask] = NULL_SENTINEL
773
+
774
+ prevleft = left
775
+ if descending[i] is True:
776
+ # ``np.searchsorted`` expects the array to be sorted in ascending
777
+ # order, so we pass ``sorter``, which is an array of integer indices
778
+ # that sort ``col_vals`` into ascending order. The returned index
779
+ # is an index into the ascending order of ``col_vals``, so we need
780
+ # to subtract it from ``len(col_vals)`` to get the index in the
781
+ # original descending order of ``col_vals``.
782
+ left = prevleft + (
783
+ len(col_vals)
784
+ - np.searchsorted(
785
+ col_vals,
786
+ desired_val,
787
+ side="right",
788
+ sorter=np.arange(len(col_vals) - 1, -1, -1),
789
+ )
790
+ )
791
+ right = prevleft + (
792
+ len(col_vals)
793
+ - np.searchsorted(
794
+ col_vals,
795
+ desired_val,
796
+ side="left",
797
+ sorter=np.arange(len(col_vals) - 1, -1, -1),
798
+ )
799
+ )
800
+ else:
801
+ left = prevleft + np.searchsorted(col_vals, desired_val, side="left")
802
+ right = prevleft + np.searchsorted(col_vals, desired_val, side="right")
803
+ return right if descending[0] is True else left
804
+
805
+
806
+ def find_partitions(
807
+ table: Union["pyarrow.Table", "pandas.DataFrame"],
808
+ boundaries: List[Tuple[Union[int, float]]],
809
+ sort_key: "SortKey",
810
+ ):
811
+ partitions = []
812
+
813
+ # For each boundary value, count the number of items that are less
814
+ # than it. Since the block is sorted, these counts partition the items
815
+ # such that boundaries[i] <= x < boundaries[i + 1] for each x in
816
+ # partition[i]. If `descending` is true, `boundaries` would also be
817
+ # in descending order and we only need to count the number of items
818
+ # *greater than* the boundary value instead.
819
+ bounds = [
820
+ find_partition_index(table, boundary, sort_key) for boundary in boundaries
821
+ ]
822
+
823
+ last_idx = 0
824
+ for idx in bounds:
825
+ partitions.append(table[last_idx:idx])
826
+ last_idx = idx
827
+ partitions.append(table[last_idx:])
828
+ return partitions
829
+
830
+
831
+ def get_attribute_from_class_name(class_name: str) -> Any:
832
+ """Get Python attribute from the provided class name.
833
+
834
+ The caller needs to make sure the provided class name includes
835
+ full module name, and can be imported successfully.
836
+ """
837
+ from importlib import import_module
838
+
839
+ paths = class_name.split(".")
840
+ if len(paths) < 2:
841
+ raise ValueError(f"Cannot create object from {class_name}.")
842
+
843
+ module_name = ".".join(paths[:-1])
844
+ attribute_name = paths[-1]
845
+ return getattr(import_module(module_name), attribute_name)
846
+
847
+
848
+ T = TypeVar("T")
849
+ U = TypeVar("U")
850
+
851
+
852
+ class _InterruptibleQueue(Queue):
853
+ """Extension of Python's `queue.Queue` providing ability to get interrupt its
854
+ method callers in other threads"""
855
+
856
+ INTERRUPTION_CHECK_FREQUENCY_SEC = 0.5
857
+
858
+ def __init__(
859
+ self, max_size: int, interrupted_event: Optional[threading.Event] = None
860
+ ):
861
+ super().__init__(maxsize=max_size)
862
+ self._interrupted_event = interrupted_event or threading.Event()
863
+
864
+ def get(self, block=True, timeout=None):
865
+ if not block or timeout is not None:
866
+ return super().get(block, timeout)
867
+
868
+ # In case when the call is blocking and no timeout is specified (ie blocking
869
+ # indefinitely) we apply the following protocol to make it interruptible:
870
+ #
871
+ # 1. `Queue.get` is invoked w/ 500ms timeout
872
+ # 2. `Empty` exception is intercepted (will be raised upon timeout elapsing)
873
+ # 3. If interrupted flag is set `InterruptedError` is raised
874
+ # 4. Otherwise, protocol retried (until interrupted or queue
875
+ # becoming non-empty)
876
+ while True:
877
+ if self._interrupted_event.is_set():
878
+ raise InterruptedError()
879
+
880
+ try:
881
+ return super().get(
882
+ block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
883
+ )
884
+ except Empty:
885
+ pass
886
+
887
+ def put(self, item, block=True, timeout=None):
888
+ if not block or timeout is not None:
889
+ super().put(item, block, timeout)
890
+ return
891
+
892
+ # In case when the call is blocking and no timeout is specified (ie blocking
893
+ # indefinitely) we apply the following protocol to make it interruptible:
894
+ #
895
+ # 1. `Queue.pet` is invoked w/ 500ms timeout
896
+ # 2. `Full` exception is intercepted (will be raised upon timeout elapsing)
897
+ # 3. If interrupted flag is set `InterruptedError` is raised
898
+ # 4. Otherwise, protocol retried (until interrupted or queue
899
+ # becomes non-full)
900
+ while True:
901
+ if self._interrupted_event.is_set():
902
+ raise InterruptedError()
903
+
904
+ try:
905
+ super().put(
906
+ item, block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
907
+ )
908
+ return
909
+ except Full:
910
+ pass
911
+
912
+
913
+ def make_async_gen(
914
+ base_iterator: Iterator[T],
915
+ fn: Callable[[Iterator[T]], Iterator[U]],
916
+ num_workers: int = 1,
917
+ queue_buffer_size: int = 2,
918
+ ) -> Generator[U, None, None]:
919
+
920
+ gen_id = random.randint(0, 2**31 - 1)
921
+
922
+ """Returns a generator (iterator) mapping items from the
923
+ provided iterator applying provided transformation in parallel (using a
924
+ thread-pool).
925
+
926
+ NOTE: Even though the mapping is performed in parallel across N
927
+ threads, this method provides crucial guarantee of preserving the
928
+ ordering of the source iterator, ie that
929
+
930
+ iterator = [A1, A2, ... An]
931
+ mapped iterator = [map(A1), map(A2), ..., map(An)]
932
+
933
+ Preserving ordering is crucial to eliminate non-determinism in producing
934
+ content of the blocks.
935
+
936
+ Args:
937
+ base_iterator: Iterator yielding elements to map
938
+ fn: Transformation to apply to each element
939
+ num_workers: The number of threads to use in the threadpool (defaults to 1)
940
+ buffer_size: Number of objects to be buffered in its input/output
941
+ queues (per queue; defaults to 2). Total number of objects held
942
+ in memory could be calculated as:
943
+
944
+ num_workers * buffer_size * 2 (input and output)
945
+
946
+ Returns:
947
+ An generator (iterator) of the elements corresponding to the source
948
+ elements mapped by provided transformation (while *preserving the ordering*)
949
+ """
950
+
951
+ if num_workers < 1:
952
+ raise ValueError("Size of threadpool must be at least 1.")
953
+
954
+ # To apply transformations to elements in parallel *and* preserve the ordering
955
+ # following invariants are established:
956
+ # - Every worker is handled by standalone thread
957
+ # - Every worker is assigned an input and an output queue
958
+ #
959
+ # And following protocol is implemented:
960
+ # - Filling worker traverses input iterator round-robin'ing elements across
961
+ # the input queues (in order!)
962
+ # - Transforming workers traverse respective input queue in-order: de-queueing
963
+ # element, applying transformation and enqueuing the result into the output
964
+ # queue
965
+ # - Generator (returned from this method) traverses output queues (in the same
966
+ # order as input queues) dequeues 1 mapped element at a time from each output
967
+ # queue and yields it
968
+ #
969
+ # Signal handler used to interrupt workers when terminating
970
+ interrupted_event = threading.Event()
971
+
972
+ input_queues = [
973
+ _InterruptibleQueue(queue_buffer_size, interrupted_event)
974
+ for _ in range(num_workers)
975
+ ]
976
+ output_queues = [
977
+ _InterruptibleQueue(queue_buffer_size, interrupted_event)
978
+ for _ in range(num_workers)
979
+ ]
980
+
981
+ # Filling worker
982
+ def _run_filling_worker():
983
+ try:
984
+ # First, round-robin elements from the iterator into
985
+ # corresponding input queues (one by one)
986
+ for idx, item in enumerate(base_iterator):
987
+ input_queues[idx % num_workers].put(item)
988
+
989
+ # Enqueue sentinel objects to signal end of the line
990
+ for idx in range(num_workers):
991
+ input_queues[idx].put(SENTINEL)
992
+
993
+ except InterruptedError:
994
+ pass
995
+
996
+ except Exception as e:
997
+ logger.warning("Caught exception in filling worker!", exc_info=e)
998
+ # In case of filling worker encountering an exception we have to propagate
999
+ # it back to the (main) iterating thread. To achieve that we're traversing
1000
+ # output queues *backwards* relative to the order of iterator-thread such
1001
+ # that they are more likely to meet w/in a single iteration.
1002
+ for output_queue in reversed(output_queues):
1003
+ output_queue.put(e)
1004
+
1005
+ # Transforming worker
1006
+ def _run_transforming_worker(worker_id: int):
1007
+ input_queue = input_queues[worker_id]
1008
+ output_queue = output_queues[worker_id]
1009
+
1010
+ try:
1011
+ # Create iterator draining the queue, until it receives sentinel
1012
+ #
1013
+ # NOTE: `queue.get` is blocking!
1014
+ input_queue_iter = iter(input_queue.get, SENTINEL)
1015
+
1016
+ mapped_iter = fn(input_queue_iter)
1017
+ for result in mapped_iter:
1018
+ # Enqueue result of the transformation
1019
+ output_queue.put(result)
1020
+
1021
+ # Enqueue sentinel (to signal that transformations are completed)
1022
+ output_queue.put(SENTINEL)
1023
+
1024
+ except InterruptedError:
1025
+ pass
1026
+
1027
+ except Exception as e:
1028
+ logger.warning("Caught exception in transforming worker!", exc_info=e)
1029
+ # NOTE: In this case we simply enqueue the exception rather than
1030
+ # interrupting
1031
+ output_queue.put(e)
1032
+
1033
+ # Start workers threads
1034
+ filling_worker_thread = threading.Thread(
1035
+ target=_run_filling_worker,
1036
+ name=f"map_tp_filling_worker-{gen_id}",
1037
+ daemon=True,
1038
+ )
1039
+ filling_worker_thread.start()
1040
+
1041
+ transforming_worker_threads = [
1042
+ threading.Thread(
1043
+ target=_run_transforming_worker,
1044
+ name=f"map_tp_transforming_worker-{gen_id}-{worker_idx}",
1045
+ args=(worker_idx,),
1046
+ daemon=True,
1047
+ )
1048
+ for worker_idx in range(num_workers)
1049
+ ]
1050
+
1051
+ for t in transforming_worker_threads:
1052
+ t.start()
1053
+
1054
+ # Use main thread to yield output batches
1055
+ try:
1056
+ # Keep track of remaining non-empty output queues
1057
+ remaining_output_queues = output_queues
1058
+
1059
+ while len(remaining_output_queues) > 0:
1060
+ # To provide deterministic ordering of the produced iterator we rely
1061
+ # on the following invariants:
1062
+ #
1063
+ # - Elements from the original iterator are round-robin'd into
1064
+ # input queues (in order)
1065
+ # - Individual workers drain their respective input queues populating
1066
+ # output queues with the results of applying transformation to the
1067
+ # original item (and hence preserving original ordering of the input
1068
+ # queue)
1069
+ # - To yield from the generator output queues are traversed in the same
1070
+ # order and one single element is dequeued (in a blocking way!) at a
1071
+ # time from every individual output queue
1072
+ #
1073
+ non_empty_queues = []
1074
+ empty_queues = []
1075
+
1076
+ # At every iteration only remaining non-empty queues
1077
+ # are traversed (to prevent blocking on exhausted queue)
1078
+ for output_queue in remaining_output_queues:
1079
+ # NOTE: This is blocking!
1080
+ item = output_queue.get()
1081
+
1082
+ if isinstance(item, Exception):
1083
+ raise item
1084
+
1085
+ if item is SENTINEL:
1086
+ empty_queues.append(output_queue)
1087
+ else:
1088
+ non_empty_queues.append(output_queue)
1089
+ yield item
1090
+
1091
+ assert (
1092
+ non_empty_queues + empty_queues == remaining_output_queues
1093
+ ), "Exhausted non-trailing queue!"
1094
+
1095
+ remaining_output_queues = non_empty_queues
1096
+
1097
+ finally:
1098
+ # Set flag to interrupt workers (to make sure no dangling
1099
+ # threads holding the objects are left behind)
1100
+ #
1101
+ # NOTE: Interrupted event is set to interrupt the running threads
1102
+ # that might be blocked otherwise waiting on inputs from respective
1103
+ # queues. However, even though we're interrupting the threads we can't
1104
+ # guarantee that threads will be interrupted in time (as this is
1105
+ # dependent on Python's GC finalizer to close the generator by raising
1106
+ # `GeneratorExit`) and hence we can't join on either filling or
1107
+ # transforming workers.
1108
+ interrupted_event.set()
1109
+
1110
+
1111
+ def call_with_retry(
1112
+ f: Callable[[], Any],
1113
+ description: str,
1114
+ *,
1115
+ match: Optional[List[str]] = None,
1116
+ max_attempts: int = 10,
1117
+ max_backoff_s: int = 32,
1118
+ ) -> Any:
1119
+ """Retry a function with exponential backoff.
1120
+
1121
+ Args:
1122
+ f: The function to retry.
1123
+ match: A list of strings to match in the exception message. If ``None``, any
1124
+ error is retried.
1125
+ description: An imperitive description of the function being retried. For
1126
+ example, "open the file".
1127
+ max_attempts: The maximum number of attempts to retry.
1128
+ max_backoff_s: The maximum number of seconds to backoff.
1129
+ """
1130
+ assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
1131
+
1132
+ for i in range(max_attempts):
1133
+ try:
1134
+ return f()
1135
+ except Exception as e:
1136
+ is_retryable = match is None or any(
1137
+ [pattern in str(e) for pattern in match]
1138
+ )
1139
+ if is_retryable and i + 1 < max_attempts:
1140
+ # Retry with binary expoential backoff with random jitter.
1141
+ backoff = min((2 ** (i + 1)), max_backoff_s) * random.random()
1142
+ logger.debug(
1143
+ f"Retrying {i+1} attempts to {description} after {backoff} seconds."
1144
+ )
1145
+ time.sleep(backoff)
1146
+ else:
1147
+ raise e from None
1148
+
1149
+
1150
+ def iterate_with_retry(
1151
+ iterable_factory: Callable[[], Iterable],
1152
+ description: str,
1153
+ *,
1154
+ match: Optional[List[str]] = None,
1155
+ max_attempts: int = 10,
1156
+ max_backoff_s: int = 32,
1157
+ ) -> Any:
1158
+ """Iterate through an iterable with retries.
1159
+
1160
+ If the iterable raises an exception, this function recreates and re-iterates
1161
+ through the iterable, while skipping the items that have already been yielded.
1162
+
1163
+ Args:
1164
+ iterable_factory: A no-argument function that creates the iterable.
1165
+ match: A list of strings to match in the exception message. If ``None``, any
1166
+ error is retried.
1167
+ description: An imperitive description of the function being retried. For
1168
+ example, "open the file".
1169
+ max_attempts: The maximum number of attempts to retry.
1170
+ max_backoff_s: The maximum number of seconds to backoff.
1171
+ """
1172
+ assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
1173
+
1174
+ num_items_yielded = 0
1175
+ for attempt in range(max_attempts):
1176
+ try:
1177
+ iterable = iterable_factory()
1178
+ for item_index, item in enumerate(iterable):
1179
+ if item_index < num_items_yielded:
1180
+ # Skip items that have already been yielded.
1181
+ continue
1182
+
1183
+ num_items_yielded += 1
1184
+ yield item
1185
+ return
1186
+ except Exception as e:
1187
+ is_retryable = match is None or any(
1188
+ [pattern in str(e) for pattern in match]
1189
+ )
1190
+ if is_retryable and attempt + 1 < max_attempts:
1191
+ # Retry with binary expoential backoff with random jitter.
1192
+ backoff = min((2 ** (attempt + 1)), max_backoff_s) * random.random()
1193
+ logger.debug(
1194
+ f"Retrying {attempt+1} attempts to {description} "
1195
+ f"after {backoff} seconds."
1196
+ )
1197
+ time.sleep(backoff)
1198
+ else:
1199
+ raise e from None
1200
+
1201
+
1202
+ def create_dataset_tag(dataset_name: Optional[str], *args):
1203
+ tag = dataset_name or "dataset"
1204
+ for arg in args:
1205
+ tag += f"_{arg}"
1206
+ return tag
1207
+
1208
+
1209
+ def convert_bytes_to_human_readable_str(num_bytes: int) -> str:
1210
+ if num_bytes >= 1e9:
1211
+ num_bytes_str = f"{round(num_bytes / 1e9)}GB"
1212
+ elif num_bytes >= 1e6:
1213
+ num_bytes_str = f"{round(num_bytes / 1e6)}MB"
1214
+ else:
1215
+ num_bytes_str = f"{round(num_bytes / 1e3)}KB"
1216
+ return num_bytes_str
1217
+
1218
+
1219
+ def _validate_rows_per_file_args(
1220
+ *, num_rows_per_file: Optional[int] = None, min_rows_per_file: Optional[int] = None
1221
+ ) -> Optional[int]:
1222
+ """Helper method to validate and handle rows per file arguments.
1223
+
1224
+ Args:
1225
+ num_rows_per_file: Deprecated parameter for number of rows per file
1226
+ min_rows_per_file: New parameter for minimum rows per file
1227
+
1228
+ Returns:
1229
+ The effective min_rows_per_file value to use
1230
+ """
1231
+ if num_rows_per_file is not None:
1232
+ import warnings
1233
+
1234
+ warnings.warn(
1235
+ "`num_rows_per_file` is deprecated and will be removed in a future release. "
1236
+ "Use `min_rows_per_file` instead.",
1237
+ DeprecationWarning,
1238
+ stacklevel=3,
1239
+ )
1240
+ if min_rows_per_file is not None:
1241
+ raise ValueError(
1242
+ "Cannot specify both `num_rows_per_file` and `min_rows_per_file`. "
1243
+ "Use `min_rows_per_file` as `num_rows_per_file` is deprecated."
1244
+ )
1245
+ return num_rows_per_file
1246
+ return min_rows_per_file
1247
+
1248
+
1249
+ def is_nan(value):
1250
+ try:
1251
+ return isinstance(value, float) and np.isnan(value)
1252
+ except TypeError:
1253
+ return False
1254
+
1255
+
1256
+ def keys_equal(keys1, keys2):
1257
+ if len(keys1) != len(keys2):
1258
+ return False
1259
+ for k1, k2 in zip(keys1, keys2):
1260
+ if not ((is_nan(k1) and is_nan(k2)) or k1 == k2):
1261
+ return False
1262
+ return True
.venv/lib/python3.11/site-packages/ray/data/datasource/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.data._internal.datasource.sql_datasource import Connection
2
+ from ray.data.datasource.datasink import (
3
+ Datasink,
4
+ DummyOutputDatasink,
5
+ WriteResult,
6
+ WriteReturnType,
7
+ )
8
+ from ray.data.datasource.datasource import (
9
+ Datasource,
10
+ RandomIntRowDatasource,
11
+ Reader,
12
+ ReadTask,
13
+ )
14
+ from ray.data.datasource.file_based_datasource import (
15
+ FileBasedDatasource,
16
+ FileShuffleConfig,
17
+ _S3FileSystemWrapper,
18
+ )
19
+ from ray.data.datasource.file_datasink import (
20
+ BlockBasedFileDatasink,
21
+ RowBasedFileDatasink,
22
+ )
23
+ from ray.data.datasource.file_meta_provider import (
24
+ BaseFileMetadataProvider,
25
+ DefaultFileMetadataProvider,
26
+ FastFileMetadataProvider,
27
+ FileMetadataProvider,
28
+ )
29
+ from ray.data.datasource.filename_provider import FilenameProvider
30
+ from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider
31
+ from ray.data.datasource.partitioning import (
32
+ Partitioning,
33
+ PartitionStyle,
34
+ PathPartitionFilter,
35
+ PathPartitionParser,
36
+ )
37
+
38
+ # Note: HuggingFaceDatasource should NOT be imported here, because
39
+ # we want to only import the Hugging Face datasets library when we use
40
+ # ray.data.from_huggingface() or HuggingFaceDatasource() directly.
41
+ __all__ = [
42
+ "BaseFileMetadataProvider",
43
+ "BlockBasedFileDatasink",
44
+ "Connection",
45
+ "Datasink",
46
+ "Datasource",
47
+ "DeltaSharingDatasource",
48
+ "DefaultFileMetadataProvider",
49
+ "DummyOutputDatasink",
50
+ "FastFileMetadataProvider",
51
+ "FileBasedDatasource",
52
+ "FileShuffleConfig",
53
+ "FileMetadataProvider",
54
+ "FilenameProvider",
55
+ "ParquetMetadataProvider",
56
+ "PartitionStyle",
57
+ "PathPartitionFilter",
58
+ "PathPartitionParser",
59
+ "Partitioning",
60
+ "RandomIntRowDatasource",
61
+ "ReadTask",
62
+ "Reader",
63
+ "RowBasedFileDatasink",
64
+ "_S3FileSystemWrapper",
65
+ "WriteResult",
66
+ "WriteReturnType",
67
+ ]
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.86 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasink.cpython-311.pyc ADDED
Binary file (8.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/datasource.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_based_datasource.cpython-311.pyc ADDED
Binary file (26.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_datasink.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/file_meta_provider.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/filename_provider.cpython-311.pyc ADDED
Binary file (6.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/parquet_meta_provider.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/partitioning.cpython-311.pyc ADDED
Binary file (23.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/__pycache__/path_util.cpython-311.pyc ADDED
Binary file (9.25 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/datasource/file_datasink.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import posixpath
3
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
4
+ from urllib.parse import urlparse
5
+
6
+ from ray._private.utils import _add_creatable_buckets_param_if_s3_uri
7
+ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
8
+ from ray.data._internal.execution.interfaces import TaskContext
9
+ from ray.data._internal.util import _is_local_scheme, call_with_retry
10
+ from ray.data.block import Block, BlockAccessor
11
+ from ray.data.context import DataContext
12
+ from ray.data.datasource.datasink import Datasink, WriteResult
13
+ from ray.data.datasource.filename_provider import (
14
+ FilenameProvider,
15
+ _DefaultFilenameProvider,
16
+ )
17
+ from ray.data.datasource.path_util import _resolve_paths_and_filesystem
18
+ from ray.util.annotations import DeveloperAPI
19
+
20
+ if TYPE_CHECKING:
21
+ import pyarrow
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ WRITE_FILE_MAX_ATTEMPTS = 10
27
+ WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32
28
+
29
+
30
+ class _FileDatasink(Datasink[None]):
31
+ def __init__(
32
+ self,
33
+ path: str,
34
+ *,
35
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
36
+ try_create_dir: bool = True,
37
+ open_stream_args: Optional[Dict[str, Any]] = None,
38
+ filename_provider: Optional[FilenameProvider] = None,
39
+ dataset_uuid: Optional[str] = None,
40
+ file_format: Optional[str] = None,
41
+ ):
42
+ """Initialize this datasink.
43
+
44
+ Args:
45
+ path: The folder to write files to.
46
+ filesystem: The filesystem to write files to. If not provided, the
47
+ filesystem is inferred from the path.
48
+ try_create_dir: Whether to create the directory to write files to.
49
+ open_stream_args: Arguments to pass to ``filesystem.open_output_stream``.
50
+ filename_provider: A :class:`ray.data.datasource.FilenameProvider` that
51
+ generates filenames for each row or block.
52
+ dataset_uuid: The UUID of the dataset being written. If specified, it's
53
+ included in the filename.
54
+ file_format: The file extension. If specified, files are written with this
55
+ extension.
56
+ """
57
+ if open_stream_args is None:
58
+ open_stream_args = {}
59
+
60
+ if filename_provider is None:
61
+ filename_provider = _DefaultFilenameProvider(
62
+ dataset_uuid=dataset_uuid, file_format=file_format
63
+ )
64
+
65
+ self.unresolved_path = path
66
+ paths, self.filesystem = _resolve_paths_and_filesystem(path, filesystem)
67
+ assert len(paths) == 1, len(paths)
68
+ self.path = paths[0]
69
+
70
+ self.try_create_dir = try_create_dir
71
+ self.open_stream_args = open_stream_args
72
+ self.filename_provider = filename_provider
73
+ self.dataset_uuid = dataset_uuid
74
+ self.file_format = file_format
75
+
76
+ self.has_created_dir = False
77
+
78
+ def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
79
+ return self.filesystem.open_output_stream(path, **self.open_stream_args)
80
+
81
+ def on_write_start(self) -> None:
82
+ self.has_created_dir = self._create_dir(self.path)
83
+
84
+ def _create_dir(self, dest) -> bool:
85
+ """Create a directory to write files to.
86
+
87
+ If ``try_create_dir`` is ``False``, this method is a no-op.
88
+ """
89
+ from pyarrow.fs import FileType
90
+
91
+ # We should skip creating directories in s3 unless the user specifically
92
+ # overrides this behavior. PyArrow's s3fs implementation for create_dir
93
+ # will attempt to check if the parent directory exists before trying to
94
+ # create the directory (with recursive=True it will try to do this to
95
+ # all of the directories until the root of the bucket). An IAM Policy that
96
+ # restricts access to a subset of prefixes within the bucket might cause
97
+ # the creation of the directory to fail even if the permissions should
98
+ # allow the data can be written to the specified path. For example if a
99
+ # a policy only allows users to write blobs prefixed with s3://bucket/foo
100
+ # a call to create_dir for s3://bucket/foo/bar will fail even though it
101
+ # should not.
102
+ parsed_uri = urlparse(dest)
103
+ is_s3_uri = parsed_uri.scheme == "s3"
104
+ skip_create_dir_for_s3 = (
105
+ is_s3_uri and not DataContext.get_current().s3_try_create_dir
106
+ )
107
+
108
+ if self.try_create_dir and not skip_create_dir_for_s3:
109
+ if self.filesystem.get_file_info(dest).type is FileType.NotFound:
110
+ # Arrow's S3FileSystem doesn't allow creating buckets by default, so we
111
+ # add a query arg enabling bucket creation if an S3 URI is provided.
112
+ tmp = _add_creatable_buckets_param_if_s3_uri(dest)
113
+ self.filesystem.create_dir(tmp, recursive=True)
114
+ return True
115
+
116
+ return False
117
+
118
+ def write(
119
+ self,
120
+ blocks: Iterable[Block],
121
+ ctx: TaskContext,
122
+ ) -> None:
123
+ builder = DelegatingBlockBuilder()
124
+ for block in blocks:
125
+ builder.add_block(block)
126
+ block = builder.build()
127
+ block_accessor = BlockAccessor.for_block(block)
128
+
129
+ if block_accessor.num_rows() == 0:
130
+ logger.warning(f"Skipped writing empty block to {self.path}")
131
+ return
132
+
133
+ self.write_block(block_accessor, 0, ctx)
134
+
135
+ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
136
+ raise NotImplementedError
137
+
138
+ def on_write_complete(self, write_result: WriteResult[None]):
139
+ # If no rows were written, we can delete the directory.
140
+ if self.has_created_dir and write_result.num_rows == 0:
141
+ self.filesystem.delete_dir(self.path)
142
+
143
+ @property
144
+ def supports_distributed_writes(self) -> bool:
145
+ return not _is_local_scheme(self.unresolved_path)
146
+
147
+
148
+ @DeveloperAPI
149
+ class RowBasedFileDatasink(_FileDatasink):
150
+ """A datasink that writes one row to each file.
151
+
152
+ Subclasses must implement ``write_row_to_file`` and call the superclass constructor.
153
+
154
+ Examples:
155
+ .. testcode::
156
+
157
+ import io
158
+ from typing import Any, Dict
159
+
160
+ import pyarrow
161
+ from PIL import Image
162
+
163
+ from ray.data.datasource import RowBasedFileDatasink
164
+
165
+ class ImageDatasink(RowBasedFileDatasink):
166
+ def __init__(self, path: str, *, column: str, file_format: str = "png"):
167
+ super().__init__(path, file_format=file_format)
168
+ self._file_format = file_format
169
+ self._column = column
170
+
171
+ def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
172
+ image = Image.fromarray(row[self._column])
173
+ buffer = io.BytesIO()
174
+ image.save(buffer, format=self._file_format)
175
+ file.write(buffer.getvalue())
176
+ """ # noqa: E501
177
+
178
+ def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
179
+ """Write a row to a file.
180
+
181
+ Args:
182
+ row: The row to write.
183
+ file: The file to write the row to.
184
+ """
185
+ raise NotImplementedError
186
+
187
+ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
188
+ for row_index, row in enumerate(block.iter_rows(public_row_format=False)):
189
+ filename = self.filename_provider.get_filename_for_row(
190
+ row, ctx.task_idx, block_index, row_index
191
+ )
192
+ write_path = posixpath.join(self.path, filename)
193
+
194
+ def write_row_to_path(row, write_path):
195
+ with self.open_output_stream(write_path) as file:
196
+ self.write_row_to_file(row, file)
197
+
198
+ logger.debug(f"Writing {write_path} file.")
199
+ call_with_retry(
200
+ lambda row=row, write_path=write_path: write_row_to_path(
201
+ row, write_path
202
+ ),
203
+ description=f"write '{write_path}'",
204
+ match=DataContext.get_current().retried_io_errors,
205
+ max_attempts=WRITE_FILE_MAX_ATTEMPTS,
206
+ max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
207
+ )
208
+
209
+
210
+ @DeveloperAPI
211
+ class BlockBasedFileDatasink(_FileDatasink):
212
+ """A datasink that writes multiple rows to each file.
213
+
214
+ Subclasses must implement ``write_block_to_file`` and call the superclass
215
+ constructor.
216
+
217
+ Examples:
218
+ .. testcode::
219
+
220
+ class CSVDatasink(BlockBasedFileDatasink):
221
+ def __init__(self, path: str):
222
+ super().__init__(path, file_format="csv")
223
+
224
+ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
225
+ from pyarrow import csv
226
+ csv.write_csv(block.to_arrow(), file)
227
+ """ # noqa: E501
228
+
229
+ def __init__(
230
+ self, path, *, min_rows_per_file: Optional[int] = None, **file_datasink_kwargs
231
+ ):
232
+ super().__init__(path, **file_datasink_kwargs)
233
+
234
+ self._min_rows_per_file = min_rows_per_file
235
+
236
+ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
237
+ """Write a block of data to a file.
238
+
239
+ Args:
240
+ block: The block to write.
241
+ file: The file to write the block to.
242
+ """
243
+ raise NotImplementedError
244
+
245
+ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
246
+ filename = self.filename_provider.get_filename_for_block(
247
+ block, ctx.task_idx, block_index
248
+ )
249
+ write_path = posixpath.join(self.path, filename)
250
+
251
+ def write_block_to_path():
252
+ with self.open_output_stream(write_path) as file:
253
+ self.write_block_to_file(block, file)
254
+
255
+ logger.debug(f"Writing {write_path} file.")
256
+ call_with_retry(
257
+ write_block_to_path,
258
+ description=f"write '{write_path}'",
259
+ match=DataContext.get_current().retried_io_errors,
260
+ max_attempts=WRITE_FILE_MAX_ATTEMPTS,
261
+ max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
262
+ )
263
+
264
+ @property
265
+ def min_rows_per_write(self) -> Optional[int]:
266
+ return self._min_rows_per_file
.venv/lib/python3.11/site-packages/ray/data/datasource/partitioning.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import posixpath
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
5
+
6
+ from ray.util.annotations import DeveloperAPI, PublicAPI
7
+
8
+ if TYPE_CHECKING:
9
+ import pyarrow
10
+
11
+
12
+ PartitionDataType = Type[Union[int, float, str, bool]]
13
+
14
+
15
+ @DeveloperAPI
16
+ class PartitionStyle(str, Enum):
17
+ """Supported dataset partition styles.
18
+
19
+ Inherits from `str` to simplify plain text serialization/deserialization.
20
+
21
+ Examples:
22
+ >>> # Serialize to JSON text.
23
+ >>> json.dumps(PartitionStyle.HIVE) # doctest: +SKIP
24
+ '"hive"'
25
+
26
+ >>> # Deserialize from JSON text.
27
+ >>> PartitionStyle(json.loads('"hive"')) # doctest: +SKIP
28
+ <PartitionStyle.HIVE: 'hive'>
29
+ """
30
+
31
+ HIVE = "hive"
32
+ DIRECTORY = "dir"
33
+
34
+
35
+ @DeveloperAPI
36
+ @dataclass
37
+ class Partitioning:
38
+ """Partition scheme used to describe path-based partitions.
39
+
40
+ Path-based partition formats embed all partition keys and values directly in
41
+ their dataset file paths.
42
+
43
+ For example, to read a dataset with
44
+ `Hive-style partitions <https://athena.guide/articles/hive-style-partitioning>`_:
45
+
46
+ >>> import ray
47
+ >>> from ray.data.datasource.partitioning import Partitioning
48
+ >>> ds = ray.data.read_csv(
49
+ ... "s3://anonymous@ray-example-data/iris.csv",
50
+ ... partitioning=Partitioning("hive"),
51
+ ... )
52
+
53
+ Instead, if your files are arranged in a directory structure such as:
54
+
55
+ .. code::
56
+
57
+ root/dog/dog_0.jpeg
58
+ root/dog/dog_1.jpeg
59
+ ...
60
+
61
+ root/cat/cat_0.jpeg
62
+ root/cat/cat_1.jpeg
63
+ ...
64
+
65
+ Then you can use directory-based partitioning:
66
+
67
+ >>> import ray
68
+ >>> from ray.data.datasource.partitioning import Partitioning
69
+ >>> root = "s3://anonymous@air-example-data/cifar-10/images"
70
+ >>> partitioning = Partitioning("dir", field_names=["class"], base_dir=root)
71
+ >>> ds = ray.data.read_images(root, partitioning=partitioning)
72
+ """
73
+
74
+ #: The partition style - may be either HIVE or DIRECTORY.
75
+ style: PartitionStyle
76
+ #: "/"-delimited base directory that all partitioned paths should
77
+ #: exist under (exclusive). File paths either outside of, or at the first
78
+ #: level of, this directory will be considered unpartitioned. Specify
79
+ #: `None` or an empty string to search for partitions in all file path
80
+ #: directories.
81
+ base_dir: Optional[str] = None
82
+ #: The partition key field names (i.e. column names for tabular
83
+ #: datasets). When non-empty, the order and length of partition key
84
+ #: field names must match the order and length of partition values.
85
+ #: Required when parsing DIRECTORY partitioned paths or generating
86
+ #: HIVE partitioned paths.
87
+ field_names: Optional[List[str]] = None
88
+ #: A dictionary that maps partition key names to their desired data type. If not
89
+ #: provided, the data type defaults to string.
90
+ field_types: Optional[Dict[str, PartitionDataType]] = None
91
+ #: Filesystem that will be used for partition path file I/O.
92
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None
93
+
94
+ def __post_init__(self):
95
+ if self.base_dir is None:
96
+ self.base_dir = ""
97
+
98
+ if self.field_types is None:
99
+ self.field_types = {}
100
+
101
+ self._normalized_base_dir = None
102
+ self._resolved_filesystem = None
103
+
104
+ @property
105
+ def normalized_base_dir(self) -> str:
106
+ """Returns the base directory normalized for compatibility with a filesystem."""
107
+ if self._normalized_base_dir is None:
108
+ self._normalize_base_dir()
109
+ return self._normalized_base_dir
110
+
111
+ @property
112
+ def resolved_filesystem(self) -> "pyarrow.fs.FileSystem":
113
+ """Returns the filesystem resolved for compatibility with a base directory."""
114
+ if self._resolved_filesystem is None:
115
+ self._normalize_base_dir()
116
+ return self._resolved_filesystem
117
+
118
+ def _normalize_base_dir(self):
119
+ """Normalizes the partition base directory for compatibility with the
120
+ given filesystem.
121
+
122
+ This should be called once a filesystem has been resolved to ensure that this
123
+ base directory is correctly discovered at the root of all partitioned file
124
+ paths.
125
+ """
126
+ from ray.data.datasource.path_util import _resolve_paths_and_filesystem
127
+
128
+ paths, self._resolved_filesystem = _resolve_paths_and_filesystem(
129
+ self.base_dir,
130
+ self.filesystem,
131
+ )
132
+ assert (
133
+ len(paths) == 1
134
+ ), f"Expected 1 normalized base directory, but found {len(paths)}"
135
+ normalized_base_dir = paths[0]
136
+ if len(normalized_base_dir) and not normalized_base_dir.endswith("/"):
137
+ normalized_base_dir += "/"
138
+ self._normalized_base_dir = normalized_base_dir
139
+
140
+
141
+ @DeveloperAPI
142
+ class PathPartitionParser:
143
+ """Partition parser for path-based partition formats.
144
+
145
+ Path-based partition formats embed all partition keys and values directly in
146
+ their dataset file paths.
147
+
148
+ Two path partition formats are currently supported - `HIVE` and `DIRECTORY`.
149
+
150
+ For `HIVE` Partitioning, all partition directories under the base directory
151
+ will be discovered based on `{key1}={value1}/{key2}={value2}` naming
152
+ conventions. Key/value pairs do not need to be presented in the same
153
+ order across all paths. Directory names nested under the base directory that
154
+ don't follow this naming condition will be considered unpartitioned. If a
155
+ partition filter is defined, then it will be called with an empty input
156
+ dictionary for each unpartitioned file.
157
+
158
+ For `DIRECTORY` Partitioning, all directories under the base directory will
159
+ be interpreted as partition values of the form `{value1}/{value2}`. An
160
+ accompanying ordered list of partition field names must also be provided,
161
+ where the order and length of all partition values must match the order and
162
+ length of field names. Files stored directly in the base directory will
163
+ be considered unpartitioned. If a partition filter is defined, then it will
164
+ be called with an empty input dictionary for each unpartitioned file. For
165
+ example, if the base directory is `"foo"`, then `"foo.csv"` and `"foo/bar.csv"`
166
+ would be considered unpartitioned files but `"foo/bar/baz.csv"` would be associated
167
+ with partition `"bar"`. If the base directory is undefined, then `"foo.csv"` would
168
+ be unpartitioned, `"foo/bar.csv"` would be associated with partition `"foo"`, and
169
+ "foo/bar/baz.csv" would be associated with partition `("foo", "bar")`.
170
+ """
171
+
172
+ @staticmethod
173
+ def of(
174
+ style: PartitionStyle = PartitionStyle.HIVE,
175
+ base_dir: Optional[str] = None,
176
+ field_names: Optional[List[str]] = None,
177
+ field_types: Optional[Dict[str, PartitionDataType]] = None,
178
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
179
+ ) -> "PathPartitionParser":
180
+ """Creates a path-based partition parser using a flattened argument list.
181
+
182
+ Args:
183
+ style: The partition style - may be either HIVE or DIRECTORY.
184
+ base_dir: "/"-delimited base directory to start searching for partitions
185
+ (exclusive). File paths outside of this directory will be considered
186
+ unpartitioned. Specify `None` or an empty string to search for
187
+ partitions in all file path directories.
188
+ field_names: The partition key names. Required for DIRECTORY partitioning.
189
+ Optional for HIVE partitioning. When non-empty, the order and length of
190
+ partition key field names must match the order and length of partition
191
+ directories discovered. Partition key field names are not required to
192
+ exist in the dataset schema.
193
+ field_types: A dictionary that maps partition key names to their desired
194
+ data type. If not provided, the data type default to string.
195
+ filesystem: Filesystem that will be used for partition path file I/O.
196
+
197
+ Returns:
198
+ The new path-based partition parser.
199
+ """
200
+ scheme = Partitioning(style, base_dir, field_names, field_types, filesystem)
201
+ return PathPartitionParser(scheme)
202
+
203
+ def __init__(self, partitioning: Partitioning):
204
+ """Creates a path-based partition parser.
205
+
206
+ Args:
207
+ partitioning: The path-based partition scheme. The parser starts
208
+ searching for partitions from this scheme's base directory. File paths
209
+ outside the base directory will be considered unpartitioned. If the
210
+ base directory is `None` or an empty string then this will search for
211
+ partitions in all file path directories. Field names are required for
212
+ DIRECTORY partitioning, and optional for HIVE partitioning. When
213
+ non-empty, the order and length of partition key field names must match
214
+ the order and length of partition directories discovered.
215
+ """
216
+ style = partitioning.style
217
+ field_names = partitioning.field_names
218
+ if style == PartitionStyle.DIRECTORY and not field_names:
219
+ raise ValueError(
220
+ "Directory partitioning requires a corresponding list of "
221
+ "partition key field names. Please retry your request with one "
222
+ "or more field names specified."
223
+ )
224
+ parsers = {
225
+ PartitionStyle.HIVE: self._parse_hive_path,
226
+ PartitionStyle.DIRECTORY: self._parse_dir_path,
227
+ }
228
+ self._parser_fn: Callable[[str], Dict[str, str]] = parsers.get(style)
229
+ if self._parser_fn is None:
230
+ raise ValueError(
231
+ f"Unsupported partition style: {style}. "
232
+ f"Supported styles: {parsers.keys()}"
233
+ )
234
+ self._scheme = partitioning
235
+
236
+ def __call__(self, path: str) -> Dict[str, str]:
237
+ """Parses partition keys and values from a single file path.
238
+
239
+ Args:
240
+ path: Input file path to parse.
241
+
242
+ Returns:
243
+ Dictionary mapping directory partition keys to values from the input file
244
+ path. Returns an empty dictionary for unpartitioned files.
245
+ """
246
+ dir_path = self._dir_path_trim_base(path)
247
+ if dir_path is None:
248
+ return {}
249
+ partitions: Dict[str, str] = self._parser_fn(dir_path)
250
+
251
+ for field, data_type in self._scheme.field_types.items():
252
+ partitions[field] = _cast_value(partitions[field], data_type)
253
+
254
+ return partitions
255
+
256
+ @property
257
+ def scheme(self) -> Partitioning:
258
+ """Returns the partitioning for this parser."""
259
+ return self._scheme
260
+
261
+ def _dir_path_trim_base(self, path: str) -> Optional[str]:
262
+ """Trims the normalized base directory and returns the directory path.
263
+
264
+ Returns None if the path does not start with the normalized base directory.
265
+ Simply returns the directory path if the base directory is undefined.
266
+ """
267
+ if not path.startswith(self._scheme.normalized_base_dir):
268
+ return None
269
+ path = path[len(self._scheme.normalized_base_dir) :]
270
+ return posixpath.dirname(path)
271
+
272
+ def _parse_hive_path(self, dir_path: str) -> Dict[str, str]:
273
+ """Hive partition path parser.
274
+
275
+ Returns a dictionary mapping partition keys to values given a hive-style
276
+ partition path of the form "{key1}={value1}/{key2}={value2}/..." or an empty
277
+ dictionary for unpartitioned files.
278
+ """
279
+ dirs = [d for d in dir_path.split("/") if d and (d.count("=") == 1)]
280
+ kv_pairs = [d.split("=") for d in dirs] if dirs else []
281
+ field_names = self._scheme.field_names
282
+ if field_names and kv_pairs:
283
+ if len(kv_pairs) != len(field_names):
284
+ raise ValueError(
285
+ f"Expected {len(field_names)} partition value(s) but found "
286
+ f"{len(kv_pairs)}: {kv_pairs}."
287
+ )
288
+ for i, field_name in enumerate(field_names):
289
+ if kv_pairs[i][0] != field_name:
290
+ raise ValueError(
291
+ f"Expected partition key {field_name} but found "
292
+ f"{kv_pairs[i][0]}"
293
+ )
294
+ return dict(kv_pairs)
295
+
296
+ def _parse_dir_path(self, dir_path: str) -> Dict[str, str]:
297
+ """Directory partition path parser.
298
+
299
+ Returns a dictionary mapping directory partition keys to values from a
300
+ partition path of the form "{value1}/{value2}/..." or an empty dictionary for
301
+ unpartitioned files.
302
+
303
+ Requires a corresponding ordered list of partition key field names to map the
304
+ correct key to each value.
305
+ """
306
+ dirs = [d for d in dir_path.split("/") if d]
307
+ field_names = self._scheme.field_names
308
+
309
+ if dirs and len(dirs) != len(field_names):
310
+ raise ValueError(
311
+ f"Expected {len(field_names)} partition value(s) but found "
312
+ f"{len(dirs)}: {dirs}."
313
+ )
314
+
315
+ if not dirs:
316
+ return {}
317
+ return {
318
+ field: directory
319
+ for field, directory in zip(field_names, dirs)
320
+ if field is not None
321
+ }
322
+
323
+
324
+ @PublicAPI(stability="beta")
325
+ class PathPartitionFilter:
326
+ """Partition filter for path-based partition formats.
327
+
328
+ Used to explicitly keep or reject files based on a custom filter function that
329
+ takes partition keys and values parsed from the file's path as input.
330
+ """
331
+
332
+ @staticmethod
333
+ def of(
334
+ filter_fn: Callable[[Dict[str, str]], bool],
335
+ style: PartitionStyle = PartitionStyle.HIVE,
336
+ base_dir: Optional[str] = None,
337
+ field_names: Optional[List[str]] = None,
338
+ field_types: Optional[Dict[str, PartitionDataType]] = None,
339
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
340
+ ) -> "PathPartitionFilter":
341
+ """Creates a path-based partition filter using a flattened argument list.
342
+
343
+ Args:
344
+ filter_fn: Callback used to filter partitions. Takes a dictionary mapping
345
+ partition keys to values as input. Unpartitioned files are denoted with
346
+ an empty input dictionary. Returns `True` to read a file for that
347
+ partition or `False` to skip it. Partition keys and values are always
348
+ strings read from the filesystem path. For example, this removes all
349
+ unpartitioned files:
350
+
351
+ .. code:: python
352
+
353
+ lambda d: True if d else False
354
+
355
+ This raises an assertion error for any unpartitioned file found:
356
+
357
+ .. code:: python
358
+
359
+ def do_assert(val, msg):
360
+ assert val, msg
361
+
362
+ lambda d: do_assert(d, "Expected all files to be partitioned!")
363
+
364
+ And this only reads files from January, 2022 partitions:
365
+
366
+ .. code:: python
367
+
368
+ lambda d: d["month"] == "January" and d["year"] == "2022"
369
+
370
+ style: The partition style - may be either HIVE or DIRECTORY.
371
+ base_dir: "/"-delimited base directory to start searching for partitions
372
+ (exclusive). File paths outside of this directory will be considered
373
+ unpartitioned. Specify `None` or an empty string to search for
374
+ partitions in all file path directories.
375
+ field_names: The partition key names. Required for DIRECTORY partitioning.
376
+ Optional for HIVE partitioning. When non-empty, the order and length of
377
+ partition key field names must match the order and length of partition
378
+ directories discovered. Partition key field names are not required to
379
+ exist in the dataset schema.
380
+ field_types: A dictionary that maps partition key names to their desired
381
+ data type. If not provided, the data type defaults to string.
382
+ filesystem: Filesystem that will be used for partition path file I/O.
383
+
384
+ Returns:
385
+ The new path-based partition filter.
386
+ """
387
+ scheme = Partitioning(style, base_dir, field_names, field_types, filesystem)
388
+ path_partition_parser = PathPartitionParser(scheme)
389
+ return PathPartitionFilter(path_partition_parser, filter_fn)
390
+
391
+ def __init__(
392
+ self,
393
+ path_partition_parser: PathPartitionParser,
394
+ filter_fn: Callable[[Dict[str, str]], bool],
395
+ ):
396
+ """Creates a new path-based partition filter based on a parser.
397
+
398
+ Args:
399
+ path_partition_parser: The path-based partition parser.
400
+ filter_fn: Callback used to filter partitions. Takes a dictionary mapping
401
+ partition keys to values as input. Unpartitioned files are denoted with
402
+ an empty input dictionary. Returns `True` to read a file for that
403
+ partition or `False` to skip it. Partition keys and values are always
404
+ strings read from the filesystem path. For example, this removes all
405
+ unpartitioned files:
406
+ ``lambda d: True if d else False``
407
+ This raises an assertion error for any unpartitioned file found:
408
+ ``lambda d: assert d, "Expected all files to be partitioned!"``
409
+ And this only reads files from January, 2022 partitions:
410
+ ``lambda d: d["month"] == "January" and d["year"] == "2022"``
411
+ """
412
+ self._parser = path_partition_parser
413
+ self._filter_fn = filter_fn
414
+
415
+ def __call__(self, paths: List[str]) -> List[str]:
416
+ """Returns all paths that pass this partition scheme's partition filter.
417
+
418
+ If no partition filter is set, then returns all input paths. If a base
419
+ directory is set, then only paths under this base directory will be parsed
420
+ for partitions. All paths outside of this base directory will automatically
421
+ be considered unpartitioned, and passed into the filter function as empty
422
+ dictionaries.
423
+
424
+ Also normalizes the partition base directory for compatibility with the
425
+ given filesystem before applying the filter.
426
+
427
+ Args:
428
+ paths: Paths to pass through the partition filter function. All
429
+ paths should be normalized for compatibility with the given
430
+ filesystem.
431
+ Returns:
432
+ List of paths that pass the partition filter, or all paths if no
433
+ partition filter is defined.
434
+ """
435
+ filtered_paths = paths
436
+ if self._filter_fn is not None:
437
+ filtered_paths = [
438
+ path for path in paths if self._filter_fn(self._parser(path))
439
+ ]
440
+ return filtered_paths
441
+
442
+ @property
443
+ def parser(self) -> PathPartitionParser:
444
+ """Returns the path partition parser for this filter."""
445
+ return self._parser
446
+
447
+
448
+ def _cast_value(value: str, data_type: PartitionDataType) -> Any:
449
+ if data_type is int:
450
+ return int(value)
451
+ elif data_type is float:
452
+ return float(value)
453
+ elif data_type is bool:
454
+ return value.lower() == "true"
455
+ else:
456
+ return value
.venv/lib/python3.11/site-packages/ray/data/datasource/path_util.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import sys
3
+ import urllib
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
5
+
6
+ from ray.data._internal.util import _resolve_custom_scheme
7
+
8
+ if TYPE_CHECKING:
9
+ import pyarrow
10
+
11
+
12
+ def _has_file_extension(path: str, extensions: Optional[List[str]]) -> bool:
13
+ """Check if a path has a file extension in the provided list.
14
+
15
+ Examples:
16
+ >>> _has_file_extension("foo.csv", ["csv"])
17
+ True
18
+ >>> _has_file_extension("foo.CSV", ["csv"])
19
+ True
20
+ >>> _has_file_extension("foo.csv", ["json", "jsonl"])
21
+ False
22
+ >>> _has_file_extension("foo.csv", None)
23
+ True
24
+
25
+ Args:
26
+ path: The path to check.
27
+ extensions: A list of extensions to check against. If `None`, any extension is
28
+ considered valid.
29
+ """
30
+ assert extensions is None or isinstance(extensions, list), type(extensions)
31
+
32
+ if extensions is None:
33
+ return True
34
+
35
+ # The user-specified extensions don't contain a leading dot, so we add it here.
36
+ extensions = [f".{ext.lower()}" for ext in extensions]
37
+ return any(path.lower().endswith(ext) for ext in extensions)
38
+
39
+
40
+ def _resolve_paths_and_filesystem(
41
+ paths: Union[str, List[str]],
42
+ filesystem: "pyarrow.fs.FileSystem" = None,
43
+ ) -> Tuple[List[str], "pyarrow.fs.FileSystem"]:
44
+ """
45
+ Resolves and normalizes all provided paths, infers a filesystem from the
46
+ paths and ensures that all paths use the same filesystem.
47
+
48
+ Args:
49
+ paths: A single file/directory path or a list of file/directory paths.
50
+ A list of paths can contain both files and directories.
51
+ filesystem: The filesystem implementation that should be used for
52
+ reading these files. If None, a filesystem will be inferred. If not
53
+ None, the provided filesystem will still be validated against all
54
+ filesystems inferred from the provided paths to ensure
55
+ compatibility.
56
+ """
57
+ import pyarrow as pa
58
+ from pyarrow.fs import (
59
+ FileSystem,
60
+ FSSpecHandler,
61
+ PyFileSystem,
62
+ _resolve_filesystem_and_path,
63
+ )
64
+
65
+ if isinstance(paths, str):
66
+ paths = [paths]
67
+ if isinstance(paths, pathlib.Path):
68
+ paths = [str(paths)]
69
+ elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths):
70
+ raise ValueError(
71
+ "Expected `paths` to be a `str`, `pathlib.Path`, or `list[str]`, but got "
72
+ f"`{paths}`."
73
+ )
74
+ elif len(paths) == 0:
75
+ raise ValueError("Must provide at least one path.")
76
+
77
+ need_unwrap_path_protocol = True
78
+ if filesystem and not isinstance(filesystem, FileSystem):
79
+ err_msg = (
80
+ f"The filesystem passed must either conform to "
81
+ f"pyarrow.fs.FileSystem, or "
82
+ f"fsspec.spec.AbstractFileSystem. The provided "
83
+ f"filesystem was: {filesystem}"
84
+ )
85
+ try:
86
+ import fsspec
87
+ from fsspec.implementations.http import HTTPFileSystem
88
+ except ModuleNotFoundError:
89
+ # If filesystem is not a pyarrow filesystem and fsspec isn't
90
+ # installed, then filesystem is neither a pyarrow filesystem nor
91
+ # an fsspec filesystem, so we raise a TypeError.
92
+ raise TypeError(err_msg) from None
93
+ if not isinstance(filesystem, fsspec.spec.AbstractFileSystem):
94
+ raise TypeError(err_msg) from None
95
+ if isinstance(filesystem, HTTPFileSystem):
96
+ # If filesystem is fsspec HTTPFileSystem, the protocol/scheme of paths
97
+ # should not be unwrapped/removed, because HTTPFileSystem expects full file
98
+ # paths including protocol/scheme. This is different behavior compared to
99
+ # file systems implementation in pyarrow.fs.FileSystem.
100
+ need_unwrap_path_protocol = False
101
+
102
+ filesystem = PyFileSystem(FSSpecHandler(filesystem))
103
+
104
+ resolved_paths = []
105
+ for path in paths:
106
+ path = _resolve_custom_scheme(path)
107
+ try:
108
+ resolved_filesystem, resolved_path = _resolve_filesystem_and_path(
109
+ path, filesystem
110
+ )
111
+ except pa.lib.ArrowInvalid as e:
112
+ if "Cannot parse URI" in str(e):
113
+ resolved_filesystem, resolved_path = _resolve_filesystem_and_path(
114
+ _encode_url(path), filesystem
115
+ )
116
+ resolved_path = _decode_url(resolved_path)
117
+ elif "Unrecognized filesystem type in URI" in str(e):
118
+ scheme = urllib.parse.urlparse(path, allow_fragments=False).scheme
119
+ if scheme in ["http", "https"]:
120
+ # If scheme of path is HTTP and filesystem is not resolved,
121
+ # try to use fsspec HTTPFileSystem. This expects fsspec is
122
+ # installed.
123
+ try:
124
+ from fsspec.implementations.http import HTTPFileSystem
125
+ except ModuleNotFoundError:
126
+ raise ImportError(
127
+ "Please install fsspec to read files from HTTP."
128
+ ) from None
129
+
130
+ resolved_filesystem = PyFileSystem(FSSpecHandler(HTTPFileSystem()))
131
+ resolved_path = path
132
+ need_unwrap_path_protocol = False
133
+ else:
134
+ raise
135
+ else:
136
+ raise
137
+ if filesystem is None:
138
+ filesystem = resolved_filesystem
139
+ elif need_unwrap_path_protocol:
140
+ resolved_path = _unwrap_protocol(resolved_path)
141
+ resolved_path = filesystem.normalize_path(resolved_path)
142
+ resolved_paths.append(resolved_path)
143
+
144
+ return resolved_paths, filesystem
145
+
146
+
147
+ def _unwrap_protocol(path):
148
+ """
149
+ Slice off any protocol prefixes on path.
150
+ """
151
+ if sys.platform == "win32" and _is_local_windows_path(path):
152
+ # Represent as posix path such that downstream functions properly handle it.
153
+ # This is executed when 'file://' is NOT included in the path.
154
+ return pathlib.Path(path).as_posix()
155
+
156
+ parsed = urllib.parse.urlparse(path, allow_fragments=False) # support '#' in path
157
+ query = "?" + parsed.query if parsed.query else "" # support '?' in path
158
+ netloc = parsed.netloc
159
+ if parsed.scheme == "s3" and "@" in parsed.netloc:
160
+ # If the path contains an @, it is assumed to be an anonymous
161
+ # credentialed path, and we need to strip off the credentials.
162
+ netloc = parsed.netloc.split("@")[-1]
163
+
164
+ parsed_path = parsed.path
165
+ # urlparse prepends the path with a '/'. This does not work on Windows
166
+ # so if this is the case strip the leading slash.
167
+ if (
168
+ sys.platform == "win32"
169
+ and not netloc
170
+ and len(parsed_path) >= 3
171
+ and parsed_path[0] == "/" # The problematic leading slash
172
+ and parsed_path[1].isalpha() # Ensure it is a drive letter.
173
+ and parsed_path[2:4] in (":", ":/")
174
+ ):
175
+ parsed_path = parsed_path[1:]
176
+
177
+ return netloc + parsed_path + query
178
+
179
+
180
+ def _is_url(path) -> bool:
181
+ return urllib.parse.urlparse(path).scheme != ""
182
+
183
+
184
+ def _is_local_windows_path(path: str) -> bool:
185
+ """Determines if path is a Windows file-system location."""
186
+ if sys.platform != "win32":
187
+ return False
188
+
189
+ if len(path) >= 1 and path[0] == "\\":
190
+ return True
191
+ if (
192
+ len(path) >= 3
193
+ and path[1] == ":"
194
+ and (path[2] == "/" or path[2] == "\\")
195
+ and path[0].isalpha()
196
+ ):
197
+ return True
198
+ return False
199
+
200
+
201
+ def _encode_url(path):
202
+ return urllib.parse.quote(path, safe="/:")
203
+
204
+
205
+ def _decode_url(path):
206
+ return urllib.parse.unquote(path)
.venv/lib/python3.11/site-packages/ray/data/extensions/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.air.util.tensor_extensions.arrow import (
2
+ ArrowTensorTypeV2,
3
+ get_arrow_extension_tensor_types,
4
+ )
5
+ from ray.data.extensions.object_extension import (
6
+ ArrowPythonObjectArray,
7
+ ArrowPythonObjectScalar,
8
+ ArrowPythonObjectType,
9
+ PythonObjectArray,
10
+ PythonObjectDtype,
11
+ _object_extension_type_allowed,
12
+ )
13
+ from ray.data.extensions.tensor_extension import (
14
+ ArrowConversionError,
15
+ ArrowTensorArray,
16
+ ArrowTensorType,
17
+ ArrowVariableShapedTensorArray,
18
+ ArrowVariableShapedTensorType,
19
+ TensorArray,
20
+ TensorArrayElement,
21
+ TensorDtype,
22
+ column_needs_tensor_extension,
23
+ )
24
+
25
+ __all__ = [
26
+ # Tensor array extension.
27
+ "TensorDtype",
28
+ "TensorArray",
29
+ "TensorArrayElement",
30
+ "ArrowTensorType",
31
+ "ArrowTensorTypeV2",
32
+ "ArrowTensorArray",
33
+ "ArrowVariableShapedTensorType",
34
+ "ArrowVariableShapedTensorArray",
35
+ "column_needs_tensor_extension",
36
+ "ArrowConversionError",
37
+ # Object array extension
38
+ "ArrowPythonObjectArray",
39
+ "ArrowPythonObjectType",
40
+ "ArrowPythonObjectScalar",
41
+ "PythonObjectArray",
42
+ "PythonObjectDtype",
43
+ "_object_extension_type_allowed",
44
+ "get_arrow_extension_tensor_types",
45
+ ]
.venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.23 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/object_extension.cpython-311.pyc ADDED
Binary file (595 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/data/extensions/__pycache__/tensor_extension.cpython-311.pyc ADDED
Binary file (842 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/data/extensions/object_extension.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.air.util.object_extensions.arrow import ( # noqa: F401
2
+ ArrowPythonObjectArray,
3
+ ArrowPythonObjectScalar,
4
+ ArrowPythonObjectType,
5
+ _object_extension_type_allowed,
6
+ )
7
+ from ray.air.util.object_extensions.pandas import ( # noqa: F401
8
+ PythonObjectArray,
9
+ PythonObjectDtype,
10
+ )
.venv/lib/python3.11/site-packages/ray/data/extensions/tensor_extension.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.air.util.tensor_extensions.arrow import ( # noqa: F401
2
+ ArrowConversionError,
3
+ ArrowTensorArray,
4
+ ArrowTensorType,
5
+ ArrowTensorTypeV2,
6
+ ArrowVariableShapedTensorArray,
7
+ ArrowVariableShapedTensorType,
8
+ )
9
+ from ray.air.util.tensor_extensions.pandas import ( # noqa: F401
10
+ TensorArray,
11
+ TensorArrayElement,
12
+ TensorDtype,
13
+ column_needs_tensor_extension,
14
+ )
15
+ from ray.air.util.tensor_extensions.utils import create_ragged_ndarray # noqa: F401
.venv/lib/python3.11/site-packages/ray/data/preprocessors/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.data.preprocessors.chain import Chain
2
+ from ray.data.preprocessors.concatenator import Concatenator
3
+ from ray.data.preprocessors.discretizer import (
4
+ CustomKBinsDiscretizer,
5
+ UniformKBinsDiscretizer,
6
+ )
7
+ from ray.data.preprocessors.encoder import (
8
+ Categorizer,
9
+ LabelEncoder,
10
+ MultiHotEncoder,
11
+ OneHotEncoder,
12
+ OrdinalEncoder,
13
+ )
14
+ from ray.data.preprocessors.hasher import FeatureHasher
15
+ from ray.data.preprocessors.imputer import SimpleImputer
16
+ from ray.data.preprocessors.normalizer import Normalizer
17
+ from ray.data.preprocessors.scaler import (
18
+ MaxAbsScaler,
19
+ MinMaxScaler,
20
+ RobustScaler,
21
+ StandardScaler,
22
+ )
23
+ from ray.data.preprocessors.tokenizer import Tokenizer
24
+ from ray.data.preprocessors.torch import TorchVisionPreprocessor
25
+ from ray.data.preprocessors.transformer import PowerTransformer
26
+ from ray.data.preprocessors.vectorizer import CountVectorizer, HashingVectorizer
27
+
28
+ __all__ = [
29
+ "Categorizer",
30
+ "CountVectorizer",
31
+ "Chain",
32
+ "FeatureHasher",
33
+ "HashingVectorizer",
34
+ "LabelEncoder",
35
+ "MaxAbsScaler",
36
+ "MinMaxScaler",
37
+ "MultiHotEncoder",
38
+ "Normalizer",
39
+ "OneHotEncoder",
40
+ "OrdinalEncoder",
41
+ "PowerTransformer",
42
+ "RobustScaler",
43
+ "SimpleImputer",
44
+ "StandardScaler",
45
+ "Concatenator",
46
+ "Tokenizer",
47
+ "TorchVisionPreprocessor",
48
+ "CustomKBinsDiscretizer",
49
+ "UniformKBinsDiscretizer",
50
+ ]