koichi12 commited on
Commit
c8ebe32
·
verified ·
1 Parent(s): 6f8c8ab

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/ray/air/__init__.py +22 -0
  3. .venv/lib/python3.11/site-packages/ray/air/config.py +766 -0
  4. .venv/lib/python3.11/site-packages/ray/air/constants.py +94 -0
  5. .venv/lib/python3.11/site-packages/ray/air/data_batch_type.py +11 -0
  6. .venv/lib/python3.11/site-packages/ray/air/execution/__init__.py +12 -0
  7. .venv/lib/python3.11/site-packages/ray/air/execution/__pycache__/__init__.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/air/execution/resources/__init__.py +12 -0
  9. .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/fixed.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/placement_group.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/request.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/resource_manager.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/air/execution/resources/fixed.py +147 -0
  14. .venv/lib/python3.11/site-packages/ray/air/execution/resources/placement_group.py +214 -0
  15. .venv/lib/python3.11/site-packages/ray/air/execution/resources/request.py +255 -0
  16. .venv/lib/python3.11/site-packages/ray/air/execution/resources/resource_manager.py +155 -0
  17. .venv/lib/python3.11/site-packages/ray/air/result.py +283 -0
  18. .venv/lib/python3.11/site-packages/ray/air/session.py +1 -0
  19. .venv/lib/python3.11/site-packages/ray/air/util/__init__.py +0 -0
  20. .venv/lib/python3.11/site-packages/ray/air/util/tensor_extensions/pandas.py +1451 -0
  21. .venv/lib/python3.11/site-packages/ray/air/util/torch_dist.py +191 -0
  22. .venv/lib/python3.11/site-packages/ray/air/util/transform_pyarrow.py +39 -0
  23. .venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc +3 -0
  24. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__init__.py +0 -0
  25. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/__init__.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/common.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_noop_latency.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_throughput.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/http_noop_latency.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/microbenchmark.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/proxy_benchmark.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/common.py +276 -0
  33. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_noop_latency.py +34 -0
  34. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_throughput.py +62 -0
  35. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/http_noop_latency.py +32 -0
  36. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/microbenchmark.py +182 -0
  37. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/proxy_benchmark.py +294 -0
  38. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__init__.py +0 -0
  39. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/common.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/serialization_benchmark.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/common.py +29 -0
  43. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/serialization_benchmark.py +163 -0
  44. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__init__.py +0 -0
  45. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/__init__.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/common.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_core_throughput.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_grpc_throughput.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_handle_throughput.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_http_throughput.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -155,3 +155,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
155
  .venv/lib/python3.11/site-packages/ray/data/__pycache__/dataset.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
156
  .venv/lib/python3.11/site-packages/ray/data/__pycache__/read_api.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
157
  .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
155
  .venv/lib/python3.11/site-packages/ray/data/__pycache__/dataset.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
156
  .venv/lib/python3.11/site-packages/ray/data/__pycache__/read_api.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
157
  .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
158
+ .venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/ray/air/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.air.config import (
2
+ CheckpointConfig,
3
+ DatasetConfig,
4
+ FailureConfig,
5
+ RunConfig,
6
+ ScalingConfig,
7
+ )
8
+ from ray.air.data_batch_type import DataBatchType
9
+ from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
10
+ from ray.air.result import Result
11
+
12
+ __all__ = [
13
+ "DataBatchType",
14
+ "RunConfig",
15
+ "Result",
16
+ "ScalingConfig",
17
+ "DatasetConfig",
18
+ "FailureConfig",
19
+ "CheckpointConfig",
20
+ "AcquiredResources",
21
+ "ResourceRequest",
22
+ ]
.venv/lib/python3.11/site-packages/ray/air/config.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import Counter, defaultdict
3
+ from dataclasses import _MISSING_TYPE, dataclass, fields
4
+ from pathlib import Path
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ List,
11
+ Mapping,
12
+ Optional,
13
+ Tuple,
14
+ Union,
15
+ )
16
+
17
+ import pyarrow.fs
18
+
19
+ from ray._private.ray_constants import RESOURCE_CONSTRAINT_PREFIX
20
+ from ray._private.storage import _get_storage_uri
21
+ from ray._private.thirdparty.tabulate.tabulate import tabulate
22
+ from ray.data.preprocessor import Preprocessor
23
+ from ray.util.annotations import Deprecated, PublicAPI
24
+ from ray.widgets import Template, make_table_html_repr
25
+
26
+ if TYPE_CHECKING:
27
+ from ray.train import SyncConfig
28
+ from ray.tune.callback import Callback
29
+ from ray.tune.execution.placement_groups import PlacementGroupFactory
30
+ from ray.tune.experimental.output import AirVerbosity
31
+ from ray.tune.search.sample import Domain
32
+ from ray.tune.stopper import Stopper
33
+ from ray.tune.utils.log import Verbosity
34
+
35
+
36
+ # Dict[str, List] is to support `tune.grid_search`:
37
+ # TODO(sumanthratna/matt): Upstream this to Tune.
38
+ SampleRange = Union["Domain", Dict[str, List]]
39
+
40
+
41
+ MAX = "max"
42
+ MIN = "min"
43
+ _DEPRECATED_VALUE = "DEPRECATED"
44
+
45
+ DATASET_CONFIG_DEPRECATION_MSG = """
46
+ Use `ray.train.DataConfig` instead of DatasetConfig to configure data ingest for training. See https://docs.ray.io/en/releases-2.6.3/ray-air/check-ingest.html#migrating-from-the-legacy-datasetconfig-api for more details.
47
+ """ # noqa: E501
48
+
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) -> str:
54
+ """A utility function to elegantly represent dataclasses.
55
+
56
+ In contrast to the default dataclass `__repr__`, which shows all parameters, this
57
+ function only shows parameters with non-default values.
58
+
59
+ Args:
60
+ obj: The dataclass to represent.
61
+ default_values: An optional dictionary that maps field names to default values.
62
+ Use this parameter to specify default values that are generated dynamically
63
+ (e.g., in `__post_init__` or by a `default_factory`). If a default value
64
+ isn't specified in `default_values`, then the default value is inferred from
65
+ the `dataclass`.
66
+
67
+ Returns:
68
+ A representation of the dataclass.
69
+ """
70
+ if default_values is None:
71
+ default_values = {}
72
+
73
+ non_default_values = {} # Maps field name to value.
74
+
75
+ def equals(value, default_value):
76
+ # We need to special case None because of a bug in pyarrow:
77
+ # https://github.com/apache/arrow/issues/38535
78
+ if value is None and default_value is None:
79
+ return True
80
+ if value is None or default_value is None:
81
+ return False
82
+ return value == default_value
83
+
84
+ for field in fields(obj):
85
+ value = getattr(obj, field.name)
86
+ default_value = default_values.get(field.name, field.default)
87
+ is_required = isinstance(field.default, _MISSING_TYPE)
88
+ if is_required or not equals(value, default_value):
89
+ non_default_values[field.name] = value
90
+
91
+ string = f"{obj.__class__.__name__}("
92
+ string += ", ".join(
93
+ f"{name}={value!r}" for name, value in non_default_values.items()
94
+ )
95
+ string += ")"
96
+
97
+ return string
98
+
99
+
100
+ @dataclass
101
+ @PublicAPI(stability="stable")
102
+ class ScalingConfig:
103
+ """Configuration for scaling training.
104
+
105
+ For more details, see :ref:`train_scaling_config`.
106
+
107
+ Args:
108
+ trainer_resources: Resources to allocate for the training coordinator.
109
+ The training coordinator launches the worker group and executes
110
+ the training function per worker, and this process does NOT require
111
+ GPUs. The coordinator is always scheduled on the same node as the
112
+ rank 0 worker, so one example use case is to set a minimum amount
113
+ of resources (e.g. CPU memory) required by the rank 0 node.
114
+ By default, this assigns 1 CPU to the training coordinator.
115
+ num_workers: The number of workers (Ray actors) to launch.
116
+ Each worker will reserve 1 CPU by default. The number of CPUs
117
+ reserved by each worker can be overridden with the
118
+ ``resources_per_worker`` argument.
119
+ use_gpu: If True, training will be done on GPUs (1 per worker).
120
+ Defaults to False. The number of GPUs reserved by each
121
+ worker can be overridden with the ``resources_per_worker``
122
+ argument.
123
+ resources_per_worker: If specified, the resources
124
+ defined in this Dict is reserved for each worker.
125
+ Define the ``"CPU"`` key (case-sensitive) to
126
+ override the number of CPUs used by each worker.
127
+ This can also be used to request :ref:`custom resources <custom-resources>`.
128
+ placement_strategy: The placement strategy to use for the
129
+ placement group of the Ray actors. See :ref:`Placement Group
130
+ Strategies <pgroup-strategy>` for the possible options.
131
+ accelerator_type: [Experimental] If specified, Ray Train will launch the
132
+ training coordinator and workers on the nodes with the specified type
133
+ of accelerators.
134
+ See :ref:`the available accelerator types <accelerator_types>`.
135
+ Ensure that your cluster has instances with the specified accelerator type
136
+ or is able to autoscale to fulfill the request.
137
+
138
+ Example:
139
+
140
+ .. code-block:: python
141
+
142
+ from ray.train import ScalingConfig
143
+ scaling_config = ScalingConfig(
144
+ # Number of distributed workers.
145
+ num_workers=2,
146
+ # Turn on/off GPU.
147
+ use_gpu=True,
148
+ # Assign extra CPU/GPU/custom resources per worker.
149
+ resources_per_worker={"GPU": 1, "CPU": 1, "memory": 1e9, "custom": 1.0},
150
+ # Try to schedule workers on different nodes.
151
+ placement_strategy="SPREAD",
152
+ )
153
+
154
+ """
155
+
156
+ trainer_resources: Optional[Union[Dict, SampleRange]] = None
157
+ num_workers: Union[int, SampleRange] = 1
158
+ use_gpu: Union[bool, SampleRange] = False
159
+ resources_per_worker: Optional[Union[Dict, SampleRange]] = None
160
+ placement_strategy: Union[str, SampleRange] = "PACK"
161
+ accelerator_type: Optional[str] = None
162
+
163
+ def __post_init__(self):
164
+ if self.resources_per_worker:
165
+ if not self.use_gpu and self.num_gpus_per_worker > 0:
166
+ raise ValueError(
167
+ "`use_gpu` is False but `GPU` was found in "
168
+ "`resources_per_worker`. Either set `use_gpu` to True or "
169
+ "remove `GPU` from `resources_per_worker."
170
+ )
171
+
172
+ if self.use_gpu and self.num_gpus_per_worker == 0:
173
+ raise ValueError(
174
+ "`use_gpu` is True but `GPU` is set to 0 in "
175
+ "`resources_per_worker`. Either set `use_gpu` to False or "
176
+ "request a positive number of `GPU` in "
177
+ "`resources_per_worker."
178
+ )
179
+
180
+ def __repr__(self):
181
+ return _repr_dataclass(self)
182
+
183
+ def _repr_html_(self) -> str:
184
+ return make_table_html_repr(obj=self, title=type(self).__name__)
185
+
186
+ def __eq__(self, o: "ScalingConfig") -> bool:
187
+ if not isinstance(o, type(self)):
188
+ return False
189
+ return self.as_placement_group_factory() == o.as_placement_group_factory()
190
+
191
+ @property
192
+ def _resources_per_worker_not_none(self):
193
+ if self.resources_per_worker is None:
194
+ if self.use_gpu:
195
+ # Note that we don't request any CPUs, which avoids possible
196
+ # scheduling contention. Generally nodes have many more CPUs than
197
+ # GPUs, so not requesting a CPU does not lead to oversubscription.
198
+ resources_per_worker = {"GPU": 1}
199
+ else:
200
+ resources_per_worker = {"CPU": 1}
201
+ else:
202
+ resources_per_worker = {
203
+ k: v for k, v in self.resources_per_worker.items() if v != 0
204
+ }
205
+
206
+ if self.use_gpu:
207
+ resources_per_worker.setdefault("GPU", 1)
208
+
209
+ if self.accelerator_type:
210
+ accelerator = f"{RESOURCE_CONSTRAINT_PREFIX}{self.accelerator_type}"
211
+ resources_per_worker.setdefault(accelerator, 0.001)
212
+ return resources_per_worker
213
+
214
+ @property
215
+ def _trainer_resources_not_none(self):
216
+ if self.trainer_resources is None:
217
+ if self.num_workers:
218
+ # For Google Colab, don't allocate resources to the base Trainer.
219
+ # Colab only has 2 CPUs, and because of this resource scarcity,
220
+ # we have to be careful on where we allocate resources. Since Colab
221
+ # is not distributed, the concern about many parallel Ray Tune trials
222
+ # leading to all Trainers being scheduled on the head node if we set
223
+ # `trainer_resources` to 0 is no longer applicable.
224
+ try:
225
+ import google.colab # noqa: F401
226
+
227
+ trainer_num_cpus = 0
228
+ except ImportError:
229
+ trainer_num_cpus = 1
230
+ else:
231
+ # If there are no additional workers, then always reserve 1 CPU for
232
+ # the Trainer.
233
+ trainer_num_cpus = 1
234
+
235
+ trainer_resources = {"CPU": trainer_num_cpus}
236
+ else:
237
+ trainer_resources = {
238
+ k: v for k, v in self.trainer_resources.items() if v != 0
239
+ }
240
+
241
+ return trainer_resources
242
+
243
+ @property
244
+ def total_resources(self):
245
+ """Map of total resources required for the trainer."""
246
+ total_resource_map = defaultdict(float, self._trainer_resources_not_none)
247
+ for k, value in self._resources_per_worker_not_none.items():
248
+ total_resource_map[k] += value * self.num_workers
249
+ return dict(total_resource_map)
250
+
251
+ @property
252
+ def num_cpus_per_worker(self):
253
+ """The number of CPUs to set per worker."""
254
+ return self._resources_per_worker_not_none.get("CPU", 0)
255
+
256
+ @property
257
+ def num_gpus_per_worker(self):
258
+ """The number of GPUs to set per worker."""
259
+ return self._resources_per_worker_not_none.get("GPU", 0)
260
+
261
+ @property
262
+ def additional_resources_per_worker(self):
263
+ """Resources per worker, not including CPU or GPU resources."""
264
+ return {
265
+ k: v
266
+ for k, v in self._resources_per_worker_not_none.items()
267
+ if k not in ["CPU", "GPU"]
268
+ }
269
+
270
+ def as_placement_group_factory(self) -> "PlacementGroupFactory":
271
+ """Returns a PlacementGroupFactory to specify resources for Tune."""
272
+ from ray.tune.execution.placement_groups import PlacementGroupFactory
273
+
274
+ trainer_bundle = self._trainer_resources_not_none
275
+ worker_bundle = self._resources_per_worker_not_none
276
+
277
+ # Colocate Trainer and rank0 worker by merging their bundles
278
+ # Note: This empty bundle is required so that the Tune actor manager schedules
279
+ # the Trainable onto the combined bundle while taking none of its resources,
280
+ # rather than a non-empty head bundle.
281
+ combined_bundle = dict(Counter(trainer_bundle) + Counter(worker_bundle))
282
+ bundles = [{}, combined_bundle] + [worker_bundle] * (self.num_workers - 1)
283
+ return PlacementGroupFactory(bundles, strategy=self.placement_strategy)
284
+
285
+ @classmethod
286
+ def from_placement_group_factory(
287
+ cls, pgf: "PlacementGroupFactory"
288
+ ) -> "ScalingConfig":
289
+ """Create a ScalingConfig from a Tune's PlacementGroupFactory
290
+
291
+ Note that this is only needed for ResourceChangingScheduler, which
292
+ modifies a trial's PlacementGroupFactory but doesn't propagate
293
+ the changes to ScalingConfig. TrainTrainable needs to reconstruct
294
+ a ScalingConfig from on the trial's PlacementGroupFactory.
295
+ """
296
+
297
+ # pgf.bundles = [{trainer + worker}, {worker}, ..., {worker}]
298
+ num_workers = len(pgf.bundles)
299
+ combined_resources = pgf.bundles[0]
300
+ resources_per_worker = pgf.bundles[-1]
301
+ use_gpu = bool(resources_per_worker.get("GPU", False))
302
+ placement_strategy = pgf.strategy
303
+
304
+ # In `as_placement_group_factory`, we merged the trainer resource into the
305
+ # first worker resources bundle. We need to calculate the resources diff to
306
+ # get the trainer resources.
307
+ # Note: If there's only one worker, we won't be able to calculate the diff.
308
+ # We'll have empty trainer bundle and assign all resources to the worker.
309
+ trainer_resources = dict(
310
+ Counter(combined_resources) - Counter(resources_per_worker)
311
+ )
312
+
313
+ return ScalingConfig(
314
+ trainer_resources=trainer_resources,
315
+ num_workers=num_workers,
316
+ use_gpu=use_gpu,
317
+ resources_per_worker=resources_per_worker,
318
+ placement_strategy=placement_strategy,
319
+ )
320
+
321
+
322
+ @dataclass
323
+ @Deprecated(DATASET_CONFIG_DEPRECATION_MSG)
324
+ class DatasetConfig:
325
+ """Configuration for ingest of a single Dataset.
326
+
327
+ See :ref:`the AIR Dataset configuration guide <data-ingest-torch>` for
328
+ usage examples.
329
+
330
+ This config defines how the Dataset should be read into the DataParallelTrainer.
331
+ It configures the preprocessing, splitting, and ingest strategy per-dataset.
332
+
333
+ DataParallelTrainers declare default DatasetConfigs for each dataset passed in the
334
+ ``datasets`` argument. Users have the opportunity to selectively override these
335
+ configs by passing the ``dataset_config`` argument. Trainers can also define user
336
+ customizable values (e.g., XGBoostTrainer doesn't support streaming ingest).
337
+
338
+ Args:
339
+ fit: Whether to fit preprocessors on this dataset. This can be set on at most
340
+ one dataset at a time. True by default for the "train" dataset only.
341
+ split: Whether the dataset should be split across multiple workers.
342
+ True by default for the "train" dataset only.
343
+ required: Whether to raise an error if the Dataset isn't provided by the user.
344
+ False by default.
345
+ transform: Whether to transform the dataset with the fitted preprocessor.
346
+ This must be enabled at least for the dataset that is fit.
347
+ True by default.
348
+ max_object_store_memory_fraction [Experimental]: The maximum fraction
349
+ of Ray's shared-memory object store to use for the dataset. The
350
+ default value is -1, meaning that the preprocessed dataset should
351
+ be cached, which may cause spilling if its size is larger than the
352
+ object store's capacity. Pipelined ingest (all other values, 0 or
353
+ higher) is experimental. Note that the absolute memory capacity
354
+ used is based on the object store capacity at invocation time; this
355
+ does not currently cover autoscaling cases where the size of the
356
+ cluster may change.
357
+ global_shuffle: Whether to enable global shuffle (per pipeline window
358
+ in streaming mode). Note that this is an expensive all-to-all operation,
359
+ and most likely you want to use local shuffle instead.
360
+ See https://docs.ray.io/en/master/data/faq.html and
361
+ https://docs.ray.io/en/master/ray-air/check-ingest.html.
362
+ False by default.
363
+ randomize_block_order: Whether to randomize the iteration order over blocks.
364
+ The main purpose of this is to prevent data fetching hotspots in the
365
+ cluster when running many parallel workers / trials on the same data.
366
+ We recommend enabling it always. True by default.
367
+ per_epoch_preprocessor [Experimental]: A preprocessor to re-apply on
368
+ each pass of the dataset. The main use case for this is to apply a
369
+ random transform on a training dataset on each epoch. The
370
+ per-epoch preprocessor will be applied *after* all other
371
+ preprocessors and in parallel with the dataset consumer.
372
+ use_stream_api: Deprecated. Use max_object_store_memory_fraction instead.
373
+ stream_window_size: Deprecated. Use max_object_store_memory_fraction instead.
374
+ """
375
+
376
+ # TODO(ekl) could we unify DataParallelTrainer and Trainer so the same data ingest
377
+ # strategy applies to all Trainers?
378
+
379
+ fit: Optional[bool] = None
380
+ split: Optional[bool] = None
381
+ required: Optional[bool] = None
382
+ transform: Optional[bool] = None
383
+ max_object_store_memory_fraction: Optional[float] = None
384
+ global_shuffle: Optional[bool] = None
385
+ randomize_block_order: Optional[bool] = None
386
+ per_epoch_preprocessor: Optional["Preprocessor"] = None
387
+ # Deprecated.
388
+ use_stream_api: Optional[int] = None
389
+ stream_window_size: Optional[int] = None
390
+
391
+ def __post_init__(self):
392
+ raise DeprecationWarning(DATASET_CONFIG_DEPRECATION_MSG)
393
+
394
+
395
+ @dataclass
396
+ @PublicAPI(stability="stable")
397
+ class FailureConfig:
398
+ """Configuration related to failure handling of each training/tuning run.
399
+
400
+ Args:
401
+ max_failures: Tries to recover a run at least this many times.
402
+ Will recover from the latest checkpoint if present.
403
+ Setting to -1 will lead to infinite recovery retries.
404
+ Setting to 0 will disable retries. Defaults to 0.
405
+ fail_fast: Whether to fail upon the first error.
406
+ If fail_fast='raise' provided, the original error during training will be
407
+ immediately raised. fail_fast='raise' can easily leak resources and
408
+ should be used with caution.
409
+ """
410
+
411
+ max_failures: int = 0
412
+ fail_fast: Union[bool, str] = False
413
+
414
+ def __post_init__(self):
415
+ # Same check as in TuneController
416
+ if not (isinstance(self.fail_fast, bool) or self.fail_fast.upper() == "RAISE"):
417
+ raise ValueError(
418
+ "fail_fast must be one of {bool, 'raise'}. " f"Got {self.fail_fast}."
419
+ )
420
+
421
+ # Same check as in tune.run
422
+ if self.fail_fast and self.max_failures != 0:
423
+ raise ValueError(
424
+ f"max_failures must be 0 if fail_fast={repr(self.fail_fast)}."
425
+ )
426
+
427
+ def __repr__(self):
428
+ return _repr_dataclass(self)
429
+
430
+ def _repr_html_(self):
431
+ return Template("scrollableTable.html.j2").render(
432
+ table=tabulate(
433
+ {
434
+ "Setting": ["Max failures", "Fail fast"],
435
+ "Value": [self.max_failures, self.fail_fast],
436
+ },
437
+ tablefmt="html",
438
+ showindex=False,
439
+ headers="keys",
440
+ ),
441
+ max_height="none",
442
+ )
443
+
444
+
445
+ @dataclass
446
+ @PublicAPI(stability="stable")
447
+ class CheckpointConfig:
448
+ """Configurable parameters for defining the checkpointing strategy.
449
+
450
+ Default behavior is to persist all checkpoints to disk. If
451
+ ``num_to_keep`` is set, the default retention policy is to keep the
452
+ checkpoints with maximum timestamp, i.e. the most recent checkpoints.
453
+
454
+ Args:
455
+ num_to_keep: The number of checkpoints to keep
456
+ on disk for this run. If a checkpoint is persisted to disk after
457
+ there are already this many checkpoints, then an existing
458
+ checkpoint will be deleted. If this is ``None`` then checkpoints
459
+ will not be deleted. Must be >= 1.
460
+ checkpoint_score_attribute: The attribute that will be used to
461
+ score checkpoints to determine which checkpoints should be kept
462
+ on disk when there are greater than ``num_to_keep`` checkpoints.
463
+ This attribute must be a key from the checkpoint
464
+ dictionary which has a numerical value. Per default, the last
465
+ checkpoints will be kept.
466
+ checkpoint_score_order: Either "max" or "min".
467
+ If "max", then checkpoints with highest values of
468
+ ``checkpoint_score_attribute`` will be kept.
469
+ If "min", then checkpoints with lowest values of
470
+ ``checkpoint_score_attribute`` will be kept.
471
+ checkpoint_frequency: Number of iterations between checkpoints. If 0
472
+ this will disable checkpointing.
473
+ Please note that most trainers will still save one checkpoint at
474
+ the end of training.
475
+ This attribute is only supported
476
+ by trainers that don't take in custom training loops.
477
+ checkpoint_at_end: If True, will save a checkpoint at the end of training.
478
+ This attribute is only supported by trainers that don't take in
479
+ custom training loops. Defaults to True for trainers that support it
480
+ and False for generic function trainables.
481
+ _checkpoint_keep_all_ranks: This experimental config is deprecated.
482
+ This behavior is now controlled by reporting `checkpoint=None`
483
+ in the workers that shouldn't persist a checkpoint.
484
+ For example, if you only want the rank 0 worker to persist a checkpoint
485
+ (e.g., in standard data parallel training), then you should save and
486
+ report a checkpoint if `ray.train.get_context().get_world_rank() == 0`
487
+ and `None` otherwise.
488
+ _checkpoint_upload_from_workers: This experimental config is deprecated.
489
+ Uploading checkpoint directly from the worker is now the default behavior.
490
+ """
491
+
492
+ num_to_keep: Optional[int] = None
493
+ checkpoint_score_attribute: Optional[str] = None
494
+ checkpoint_score_order: Optional[str] = MAX
495
+ checkpoint_frequency: Optional[int] = 0
496
+ checkpoint_at_end: Optional[bool] = None
497
+ _checkpoint_keep_all_ranks: Optional[bool] = _DEPRECATED_VALUE
498
+ _checkpoint_upload_from_workers: Optional[bool] = _DEPRECATED_VALUE
499
+
500
+ def __post_init__(self):
501
+ if self._checkpoint_keep_all_ranks != _DEPRECATED_VALUE:
502
+ raise DeprecationWarning(
503
+ "The experimental `_checkpoint_keep_all_ranks` config is deprecated. "
504
+ "This behavior is now controlled by reporting `checkpoint=None` "
505
+ "in the workers that shouldn't persist a checkpoint. "
506
+ "For example, if you only want the rank 0 worker to persist a "
507
+ "checkpoint (e.g., in standard data parallel training), "
508
+ "then you should save and report a checkpoint if "
509
+ "`ray.train.get_context().get_world_rank() == 0` "
510
+ "and `None` otherwise."
511
+ )
512
+
513
+ if self._checkpoint_upload_from_workers != _DEPRECATED_VALUE:
514
+ raise DeprecationWarning(
515
+ "The experimental `_checkpoint_upload_from_workers` config is "
516
+ "deprecated. Uploading checkpoint directly from the worker is "
517
+ "now the default behavior."
518
+ )
519
+
520
+ if self.num_to_keep is not None and self.num_to_keep <= 0:
521
+ raise ValueError(
522
+ f"Received invalid num_to_keep: "
523
+ f"{self.num_to_keep}. "
524
+ f"Must be None or an integer >= 1."
525
+ )
526
+ if self.checkpoint_score_order not in (MAX, MIN):
527
+ raise ValueError(
528
+ f"checkpoint_score_order must be either " f'"{MAX}" or "{MIN}".'
529
+ )
530
+
531
+ if self.checkpoint_frequency < 0:
532
+ raise ValueError(
533
+ f"checkpoint_frequency must be >=0, got {self.checkpoint_frequency}"
534
+ )
535
+
536
+ def __repr__(self):
537
+ return _repr_dataclass(self)
538
+
539
+ def _repr_html_(self) -> str:
540
+ if self.num_to_keep is None:
541
+ num_to_keep_repr = "All"
542
+ else:
543
+ num_to_keep_repr = self.num_to_keep
544
+
545
+ if self.checkpoint_score_attribute is None:
546
+ checkpoint_score_attribute_repr = "Most recent"
547
+ else:
548
+ checkpoint_score_attribute_repr = self.checkpoint_score_attribute
549
+
550
+ if self.checkpoint_at_end is None:
551
+ checkpoint_at_end_repr = ""
552
+ else:
553
+ checkpoint_at_end_repr = self.checkpoint_at_end
554
+
555
+ return Template("scrollableTable.html.j2").render(
556
+ table=tabulate(
557
+ {
558
+ "Setting": [
559
+ "Number of checkpoints to keep",
560
+ "Checkpoint score attribute",
561
+ "Checkpoint score order",
562
+ "Checkpoint frequency",
563
+ "Checkpoint at end",
564
+ ],
565
+ "Value": [
566
+ num_to_keep_repr,
567
+ checkpoint_score_attribute_repr,
568
+ self.checkpoint_score_order,
569
+ self.checkpoint_frequency,
570
+ checkpoint_at_end_repr,
571
+ ],
572
+ },
573
+ tablefmt="html",
574
+ showindex=False,
575
+ headers="keys",
576
+ ),
577
+ max_height="none",
578
+ )
579
+
580
+ @property
581
+ def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
582
+ """Same as ``checkpoint_score_attr`` in ``tune.run``.
583
+
584
+ Only used for Legacy API compatibility.
585
+ """
586
+ if self.checkpoint_score_attribute is None:
587
+ return self.checkpoint_score_attribute
588
+ prefix = ""
589
+ if self.checkpoint_score_order == MIN:
590
+ prefix = "min-"
591
+ return f"{prefix}{self.checkpoint_score_attribute}"
592
+
593
+
594
+ @dataclass
595
+ @PublicAPI(stability="stable")
596
+ class RunConfig:
597
+ """Runtime configuration for training and tuning runs.
598
+
599
+ Upon resuming from a training or tuning run checkpoint,
600
+ Ray Train/Tune will automatically apply the RunConfig from
601
+ the previously checkpointed run.
602
+
603
+ Args:
604
+ name: Name of the trial or experiment. If not provided, will be deduced
605
+ from the Trainable.
606
+ storage_path: [Beta] Path where all results and checkpoints are persisted.
607
+ Can be a local directory or a destination on cloud storage.
608
+ For multi-node training/tuning runs, this must be set to a
609
+ shared storage location (e.g., S3, NFS).
610
+ This defaults to the local ``~/ray_results`` directory.
611
+ storage_filesystem: [Beta] A custom filesystem to use for storage.
612
+ If this is provided, `storage_path` should be a path with its
613
+ prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`).
614
+ failure_config: Failure mode configuration.
615
+ checkpoint_config: Checkpointing configuration.
616
+ sync_config: Configuration object for syncing. See train.SyncConfig.
617
+ verbose: 0, 1, or 2. Verbosity mode.
618
+ 0 = silent, 1 = default, 2 = verbose. Defaults to 1.
619
+ If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set,
620
+ uses the old verbosity settings:
621
+ 0 = silent, 1 = only status updates, 2 = status and brief
622
+ results, 3 = status and detailed results.
623
+ stop: Stop conditions to consider. Refer to ray.tune.stopper.Stopper
624
+ for more info. Stoppers should be serializable.
625
+ callbacks: [DeveloperAPI] Callbacks to invoke.
626
+ Refer to ray.tune.callback.Callback for more info.
627
+ Callbacks should be serializable.
628
+ Currently only stateless callbacks are supported for resumed runs.
629
+ (any state of the callback will not be checkpointed by Tune
630
+ and thus will not take effect in resumed runs).
631
+ progress_reporter: [DeveloperAPI] Progress reporter for reporting
632
+ intermediate experiment progress. Defaults to CLIReporter if
633
+ running in command-line, or JupyterNotebookReporter if running in
634
+ a Jupyter notebook.
635
+ log_to_file: [DeveloperAPI] Log stdout and stderr to files in
636
+ trial directories. If this is `False` (default), no files
637
+ are written. If `true`, outputs are written to `trialdir/stdout`
638
+ and `trialdir/stderr`, respectively. If this is a single string,
639
+ this is interpreted as a file relative to the trialdir, to which
640
+ both streams are written. If this is a Sequence (e.g. a Tuple),
641
+ it has to have length 2 and the elements indicate the files to
642
+ which stdout and stderr are written, respectively.
643
+
644
+ """
645
+
646
+ name: Optional[str] = None
647
+ storage_path: Optional[str] = None
648
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
649
+ failure_config: Optional[FailureConfig] = None
650
+ checkpoint_config: Optional[CheckpointConfig] = None
651
+ sync_config: Optional["SyncConfig"] = None
652
+ verbose: Optional[Union[int, "AirVerbosity", "Verbosity"]] = None
653
+ stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None
654
+ callbacks: Optional[List["Callback"]] = None
655
+ progress_reporter: Optional[
656
+ "ray.tune.progress_reporter.ProgressReporter" # noqa: F821
657
+ ] = None
658
+ log_to_file: Union[bool, str, Tuple[str, str]] = False
659
+
660
+ # Deprecated
661
+ local_dir: Optional[str] = None
662
+
663
+ def __post_init__(self):
664
+ from ray.train import SyncConfig
665
+ from ray.train.constants import DEFAULT_STORAGE_PATH
666
+ from ray.tune.experimental.output import AirVerbosity, get_air_verbosity
667
+
668
+ if self.local_dir is not None:
669
+ raise DeprecationWarning(
670
+ "The `RunConfig(local_dir)` argument is deprecated. "
671
+ "You should set the `RunConfig(storage_path)` instead."
672
+ "See the docs: https://docs.ray.io/en/latest/train/user-guides/"
673
+ "persistent-storage.html#setting-the-local-staging-directory"
674
+ )
675
+
676
+ if self.storage_path is None:
677
+ # TODO(justinvyu): [Deprecated] Remove in 2.30
678
+ self.storage_path = DEFAULT_STORAGE_PATH
679
+
680
+ # If no remote path is set, try to get Ray Storage URI
681
+ ray_storage_uri: Optional[str] = _get_storage_uri()
682
+ if ray_storage_uri is not None:
683
+ logger.info(
684
+ "Using configured Ray Storage URI as the `storage_path`: "
685
+ f"{ray_storage_uri}"
686
+ )
687
+ self.storage_path = ray_storage_uri
688
+
689
+ if not self.failure_config:
690
+ self.failure_config = FailureConfig()
691
+
692
+ if not self.sync_config:
693
+ self.sync_config = SyncConfig()
694
+
695
+ if not self.checkpoint_config:
696
+ self.checkpoint_config = CheckpointConfig()
697
+
698
+ if self.verbose is None:
699
+ # Default `verbose` value. For new output engine,
700
+ # this is AirVerbosity.DEFAULT.
701
+ # For old output engine, this is Verbosity.V3_TRIAL_DETAILS
702
+ # Todo (krfricke): Currently uses number to pass test_configs::test_repr
703
+ self.verbose = get_air_verbosity(AirVerbosity.DEFAULT) or 3
704
+
705
+ if isinstance(self.storage_path, Path):
706
+ self.storage_path = self.storage_path.as_posix()
707
+
708
+ def __repr__(self):
709
+ from ray.train import SyncConfig
710
+
711
+ return _repr_dataclass(
712
+ self,
713
+ default_values={
714
+ "failure_config": FailureConfig(),
715
+ "sync_config": SyncConfig(),
716
+ "checkpoint_config": CheckpointConfig(),
717
+ },
718
+ )
719
+
720
+ def _repr_html_(self) -> str:
721
+ reprs = []
722
+ if self.failure_config is not None:
723
+ reprs.append(
724
+ Template("title_data_mini.html.j2").render(
725
+ title="Failure Config", data=self.failure_config._repr_html_()
726
+ )
727
+ )
728
+ if self.sync_config is not None:
729
+ reprs.append(
730
+ Template("title_data_mini.html.j2").render(
731
+ title="Sync Config", data=self.sync_config._repr_html_()
732
+ )
733
+ )
734
+ if self.checkpoint_config is not None:
735
+ reprs.append(
736
+ Template("title_data_mini.html.j2").render(
737
+ title="Checkpoint Config", data=self.checkpoint_config._repr_html_()
738
+ )
739
+ )
740
+
741
+ # Create a divider between each displayed repr
742
+ subconfigs = [Template("divider.html.j2").render()] * (2 * len(reprs) - 1)
743
+ subconfigs[::2] = reprs
744
+
745
+ settings = Template("scrollableTable.html.j2").render(
746
+ table=tabulate(
747
+ {
748
+ "Name": self.name,
749
+ "Local results directory": self.local_dir,
750
+ "Verbosity": self.verbose,
751
+ "Log to file": self.log_to_file,
752
+ }.items(),
753
+ tablefmt="html",
754
+ headers=["Setting", "Value"],
755
+ showindex=False,
756
+ ),
757
+ max_height="300px",
758
+ )
759
+
760
+ return Template("title_data.html.j2").render(
761
+ title="RunConfig",
762
+ data=Template("run_config.html.j2").render(
763
+ subconfigs=subconfigs,
764
+ settings=settings,
765
+ ),
766
+ )
.venv/lib/python3.11/site-packages/ray/air/constants.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Key to denote the preprocessor in the checkpoint dict.
2
+ PREPROCESSOR_KEY = "_preprocessor"
3
+
4
+ # Key to denote the model in the checkpoint dict.
5
+ MODEL_KEY = "model"
6
+
7
+ # Key to denote which dataset is the evaluation dataset.
8
+ # Only used in trainers which do not support multiple
9
+ # evaluation datasets.
10
+ EVALUATION_DATASET_KEY = "evaluation"
11
+
12
+ # Key to denote which dataset is the training dataset.
13
+ # This is the dataset that the preprocessor is fit on.
14
+ TRAIN_DATASET_KEY = "train"
15
+
16
+ # Name to use for the column when representing tensors in table format.
17
+ TENSOR_COLUMN_NAME = "__value__"
18
+
19
+ # The maximum length of strings returned by `__repr__` for AIR objects constructed with
20
+ # default values.
21
+ MAX_REPR_LENGTH = int(80 * 1.5)
22
+
23
+ # Timeout used when putting exceptions raised by runner thread into the queue.
24
+ _ERROR_REPORT_TIMEOUT = 10
25
+
26
+ # Timeout when fetching new results after signaling the training function to continue.
27
+ _RESULT_FETCH_TIMEOUT = 0.2
28
+
29
+ # Timeout for fetching exceptions raised by the training function.
30
+ _ERROR_FETCH_TIMEOUT = 1
31
+
32
+ # The key used to identify whether we have already warned about ray.air.session
33
+ # functions being used outside of the session
34
+ SESSION_MISUSE_LOG_ONCE_KEY = "air_warn_session_misuse"
35
+
36
+ # Name of attribute in Checkpoint storing current Tune ID for restoring
37
+ # training with Ray Train
38
+ CHECKPOINT_ID_ATTR = "_current_checkpoint_id"
39
+
40
+ # Name of the marker dropped by the Trainable. If a worker detects
41
+ # the presence of the marker in the trial dir, it will use lazy
42
+ # checkpointing.
43
+ LAZY_CHECKPOINT_MARKER_FILE = ".lazy_checkpoint_marker"
44
+
45
+
46
+ # The timestamp of when the result is generated.
47
+ # Default to when the result is processed by tune.
48
+ TIMESTAMP = "timestamp"
49
+
50
+ # (Auto-filled) Time in seconds this iteration took to run.
51
+ # This may be overridden to override the system-computed time difference.
52
+ TIME_THIS_ITER_S = "time_this_iter_s"
53
+
54
+ # (Auto-filled) The index of this training iteration.
55
+ TRAINING_ITERATION = "training_iteration"
56
+
57
+ # File that stores parameters of the trial.
58
+ EXPR_PARAM_FILE = "params.json"
59
+
60
+ # Pickle File that stores parameters of the trial.
61
+ EXPR_PARAM_PICKLE_FILE = "params.pkl"
62
+
63
+ # File that stores the progress of the trial.
64
+ EXPR_PROGRESS_FILE = "progress.csv"
65
+
66
+ # File that stores results of the trial.
67
+ EXPR_RESULT_FILE = "result.json"
68
+
69
+ # File that stores the pickled error file
70
+ EXPR_ERROR_PICKLE_FILE = "error.pkl"
71
+
72
+ # File that stores the error file
73
+ EXPR_ERROR_FILE = "error.txt"
74
+
75
+ # File that stores the checkpoint metadata
76
+ CHECKPOINT_TUNE_METADATA_FILE = ".tune_metadata"
77
+
78
+ # ==================================================
79
+ # Environment Variables
80
+ # ==================================================
81
+
82
+ # Integer value which if set will copy files in reported AIR directory
83
+ # checkpoints instead of moving them (if worker is on the same node as Trainable)
84
+ COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV = (
85
+ "TRAIN_COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING"
86
+ )
87
+
88
+ # NOTE: When adding a new environment variable, please track it in this list.
89
+ # TODO(ml-team): Most env var constants should get moved here.
90
+ AIR_ENV_VARS = {
91
+ COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
92
+ "RAY_AIR_FULL_TRACEBACKS",
93
+ "RAY_AIR_NEW_OUTPUT",
94
+ }
.venv/lib/python3.11/site-packages/ray/air/data_batch_type.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Dict, Union
2
+
3
+ if TYPE_CHECKING:
4
+ import numpy
5
+ import pandas # noqa: F401
6
+ import pyarrow
7
+
8
+ # TODO de-dup with ray.data.block.DataBatch
9
+ DataBatchType = Union[
10
+ "numpy.ndarray", "pyarrow.Table" "pandas.DataFrame", Dict[str, "numpy.ndarray"]
11
+ ]
.venv/lib/python3.11/site-packages/ray/air/execution/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.air.execution.resources.fixed import FixedResourceManager
2
+ from ray.air.execution.resources.placement_group import PlacementGroupResourceManager
3
+ from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
4
+ from ray.air.execution.resources.resource_manager import ResourceManager
5
+
6
+ __all__ = [
7
+ "ResourceRequest",
8
+ "AcquiredResources",
9
+ "ResourceManager",
10
+ "FixedResourceManager",
11
+ "PlacementGroupResourceManager",
12
+ ]
.venv/lib/python3.11/site-packages/ray/air/execution/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (679 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.air.execution.resources.fixed import FixedResourceManager
2
+ from ray.air.execution.resources.placement_group import PlacementGroupResourceManager
3
+ from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
4
+ from ray.air.execution.resources.resource_manager import ResourceManager
5
+
6
+ __all__ = [
7
+ "ResourceRequest",
8
+ "AcquiredResources",
9
+ "ResourceManager",
10
+ "FixedResourceManager",
11
+ "PlacementGroupResourceManager",
12
+ ]
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/fixed.cpython-311.pyc ADDED
Binary file (7.71 kB). View file
 
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/placement_group.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/request.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/air/execution/resources/__pycache__/resource_manager.cpython-311.pyc ADDED
Binary file (7.74 kB). View file
 
.venv/lib/python3.11/site-packages/ray/air/execution/resources/fixed.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional
3
+
4
+ import ray
5
+ from ray import LOCAL_MODE, SCRIPT_MODE
6
+ from ray.air.execution.resources.request import (
7
+ AcquiredResources,
8
+ RemoteRayEntity,
9
+ ResourceRequest,
10
+ )
11
+ from ray.air.execution.resources.resource_manager import ResourceManager
12
+ from ray.util.annotations import DeveloperAPI
13
+
14
+ # Avoid numerical errors by multiplying and subtracting with this number.
15
+ # Compare: 0.99 - 0.33 = 0.65999... vs (0.99 * 1000 - 0.33 * 1000) / 1000 = 0.66
16
+ _DIGITS = 100000
17
+
18
+
19
+ @DeveloperAPI
20
+ @dataclass
21
+ class FixedAcquiredResources(AcquiredResources):
22
+ bundles: List[Dict[str, float]]
23
+
24
+ def _annotate_remote_entity(
25
+ self, entity: RemoteRayEntity, bundle: Dict[str, float], bundle_index: int
26
+ ) -> RemoteRayEntity:
27
+ bundle = bundle.copy()
28
+ num_cpus = bundle.pop("CPU", 0)
29
+ num_gpus = bundle.pop("GPU", 0)
30
+ memory = bundle.pop("memory", 0.0)
31
+
32
+ return entity.options(
33
+ num_cpus=num_cpus,
34
+ num_gpus=num_gpus,
35
+ memory=memory,
36
+ resources=bundle,
37
+ )
38
+
39
+
40
+ @DeveloperAPI
41
+ class FixedResourceManager(ResourceManager):
42
+ """Fixed budget based resource manager.
43
+
44
+ This resource manager keeps track of a fixed set of resources. When resources
45
+ are acquired, they are subtracted from the budget. When resources are freed,
46
+ they are added back to the budget.
47
+
48
+ The resource manager still requires resources to be requested before they become
49
+ available. However, because the resource requests are virtual, this will not
50
+ trigger autoscaling.
51
+
52
+ Additionally, resources are not reserved on request, only on acquisition. Thus,
53
+ acquiring a resource can change the availability of other requests. Note that
54
+ this behavior may be changed in future implementations.
55
+
56
+ The fixed resource manager does not support placement strategies. Using
57
+ ``STRICT_SPREAD`` will result in an error. ``STRICT_PACK`` will succeed only
58
+ within a placement group bundle. All other placement group arguments will be
59
+ ignored.
60
+
61
+ Args:
62
+ total_resources: Budget of resources to manage. Defaults to all available
63
+ resources in the current task or all cluster resources (if outside a task).
64
+
65
+ """
66
+
67
+ _resource_cls: AcquiredResources = FixedAcquiredResources
68
+
69
+ def __init__(self, total_resources: Optional[Dict[str, float]] = None):
70
+ rtc = ray.get_runtime_context()
71
+
72
+ if not total_resources:
73
+ if rtc.worker.mode in {None, SCRIPT_MODE, LOCAL_MODE}:
74
+ total_resources = ray.cluster_resources()
75
+ else:
76
+ total_resources = rtc.get_assigned_resources()
77
+
78
+ # If we are in a placement group, all of our resources will be in a bundle
79
+ # and thus fulfill requirements of STRICT_PACK - but only if child tasks
80
+ # are captured by the pg.
81
+ self._allow_strict_pack = (
82
+ ray.util.get_current_placement_group() is not None
83
+ and rtc.should_capture_child_tasks_in_placement_group
84
+ )
85
+
86
+ self._total_resources = total_resources
87
+ self._requested_resources = []
88
+ self._used_resources = []
89
+
90
+ @property
91
+ def _available_resources(self) -> Dict[str, float]:
92
+ available_resources = self._total_resources.copy()
93
+
94
+ for used_resources in self._used_resources:
95
+ all_resources = used_resources.required_resources
96
+ for k, v in all_resources.items():
97
+ available_resources[k] = (
98
+ available_resources[k] * _DIGITS - v * _DIGITS
99
+ ) / _DIGITS
100
+ return available_resources
101
+
102
+ def request_resources(self, resource_request: ResourceRequest):
103
+ if resource_request.strategy == "STRICT_SPREAD" or (
104
+ not self._allow_strict_pack and resource_request.strategy == "STRICT_PACK"
105
+ ):
106
+ raise RuntimeError(
107
+ f"Requested a resource with placement strategy "
108
+ f"{resource_request.strategy}, but this cannot be fulfilled by a "
109
+ f"FixedResourceManager. In a nested setting, please set the inner "
110
+ f"placement strategy to be less restrictive (i.e. no STRICT_ strategy)."
111
+ )
112
+
113
+ self._requested_resources.append(resource_request)
114
+
115
+ def cancel_resource_request(self, resource_request: ResourceRequest):
116
+ self._requested_resources.remove(resource_request)
117
+
118
+ def has_resources_ready(self, resource_request: ResourceRequest) -> bool:
119
+ if resource_request not in self._requested_resources:
120
+ return False
121
+
122
+ available_resources = self._available_resources
123
+ all_resources = resource_request.required_resources
124
+ for k, v in all_resources.items():
125
+ if available_resources.get(k, 0.0) < v:
126
+ return False
127
+ return True
128
+
129
+ def acquire_resources(
130
+ self, resource_request: ResourceRequest
131
+ ) -> Optional[AcquiredResources]:
132
+ if not self.has_resources_ready(resource_request):
133
+ return None
134
+
135
+ self._used_resources.append(resource_request)
136
+ return self._resource_cls(
137
+ bundles=resource_request.bundles, resource_request=resource_request
138
+ )
139
+
140
+ def free_resources(self, acquired_resource: AcquiredResources):
141
+ resources = acquired_resource.resource_request
142
+ self._used_resources.remove(resources)
143
+
144
+ def clear(self):
145
+ # Reset internal state
146
+ self._requested_resources = []
147
+ self._used_resources = []
.venv/lib/python3.11/site-packages/ray/air/execution/resources/placement_group.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Set
5
+
6
+ import ray
7
+ from ray.air.execution.resources.request import (
8
+ AcquiredResources,
9
+ RemoteRayEntity,
10
+ ResourceRequest,
11
+ )
12
+ from ray.air.execution.resources.resource_manager import ResourceManager
13
+ from ray.util.annotations import DeveloperAPI
14
+ from ray.util.placement_group import PlacementGroup, remove_placement_group
15
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
16
+
17
+
18
+ @DeveloperAPI
19
+ @dataclass
20
+ class PlacementGroupAcquiredResources(AcquiredResources):
21
+ placement_group: PlacementGroup
22
+
23
+ def _annotate_remote_entity(
24
+ self, entity: RemoteRayEntity, bundle: Dict[str, float], bundle_index: int
25
+ ) -> RemoteRayEntity:
26
+ bundle = bundle.copy()
27
+ num_cpus = bundle.pop("CPU", 0)
28
+ num_gpus = bundle.pop("GPU", 0)
29
+ memory = bundle.pop("memory", 0.0)
30
+
31
+ return entity.options(
32
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
33
+ placement_group=self.placement_group,
34
+ placement_group_bundle_index=bundle_index,
35
+ placement_group_capture_child_tasks=True,
36
+ ),
37
+ num_cpus=num_cpus,
38
+ num_gpus=num_gpus,
39
+ memory=memory,
40
+ resources=bundle,
41
+ )
42
+
43
+
44
+ @DeveloperAPI
45
+ class PlacementGroupResourceManager(ResourceManager):
46
+ """Resource manager using placement groups as the resource backend.
47
+
48
+ This manager will use placement groups to fulfill resource requests. Requesting
49
+ a resource will schedule the placement group. Acquiring a resource will
50
+ return a ``PlacementGroupAcquiredResources`` that can be used to schedule
51
+ Ray tasks and actors on the placement group. Freeing an acquired resource
52
+ will destroy the associated placement group.
53
+
54
+ Ray core does not emit events when resources are available. Instead, the
55
+ scheduling state has to be periodically updated.
56
+
57
+ Per default, placement group scheduling state is refreshed every time when
58
+ resource state is inquired, but not more often than once every ``update_interval_s``
59
+ seconds. Alternatively, staging futures can be retrieved (and awaited) with
60
+ ``get_resource_futures()`` and state update can be force with ``update_state()``.
61
+
62
+ Args:
63
+ update_interval_s: Minimum interval in seconds between updating scheduling
64
+ state of placement groups.
65
+
66
+ """
67
+
68
+ _resource_cls: AcquiredResources = PlacementGroupAcquiredResources
69
+
70
+ def __init__(self, update_interval_s: float = 0.1):
71
+ # Internally, the placement group lifecycle is like this:
72
+ # - Resources are requested with ``request_resources()``
73
+ # - A placement group is scheduled ("staged")
74
+ # - A ``PlacementGroup.ready()`` future is scheduled ("staging future")
75
+ # - We update the scheduling state when we need to
76
+ # (e.g. when ``has_resources_ready()`` is called)
77
+ # - When staging futures resolve, a placement group is moved from "staging"
78
+ # to "ready"
79
+ # - When a resource request is canceled, we remove a placement group from
80
+ # "staging". If there are not staged placement groups
81
+ # (because they are already "ready"), we remove one from "ready" instead.
82
+ # - When a resource is acquired, the pg is removed from "ready" and moved
83
+ # to "acquired"
84
+ # - When a resource is freed, the pg is removed from "acquired" and destroyed
85
+
86
+ # Mapping of placement group to request
87
+ self._pg_to_request: Dict[PlacementGroup, ResourceRequest] = {}
88
+
89
+ # PGs that are staged but not "ready", yet (i.e. not CREATED)
90
+ self._request_to_staged_pgs: Dict[
91
+ ResourceRequest, Set[PlacementGroup]
92
+ ] = defaultdict(set)
93
+
94
+ # PGs that are CREATED and can be used by tasks and actors
95
+ self._request_to_ready_pgs: Dict[
96
+ ResourceRequest, Set[PlacementGroup]
97
+ ] = defaultdict(set)
98
+
99
+ # Staging futures used to update internal state.
100
+ # We keep a double mapping here for better lookup efficiency.
101
+ self._staging_future_to_pg: Dict[ray.ObjectRef, PlacementGroup] = dict()
102
+ self._pg_to_staging_future: Dict[PlacementGroup, ray.ObjectRef] = dict()
103
+
104
+ # Set of acquired PGs. We keep track of these here to make sure we
105
+ # only free PGs that this manager managed.
106
+ self._acquired_pgs: Set[PlacementGroup] = set()
107
+
108
+ # Minimum time between updates of the internal state
109
+ self.update_interval_s = update_interval_s
110
+ self._last_update = time.monotonic() - self.update_interval_s - 1
111
+
112
+ def get_resource_futures(self) -> List[ray.ObjectRef]:
113
+ return list(self._staging_future_to_pg.keys())
114
+
115
+ def _maybe_update_state(self):
116
+ now = time.monotonic()
117
+ if now > self._last_update + self.update_interval_s:
118
+ self.update_state()
119
+
120
+ def update_state(self):
121
+ ready, not_ready = ray.wait(
122
+ list(self._staging_future_to_pg.keys()),
123
+ num_returns=len(self._staging_future_to_pg),
124
+ timeout=0,
125
+ )
126
+ for future in ready:
127
+ # Remove staging future
128
+ pg = self._staging_future_to_pg.pop(future)
129
+ self._pg_to_staging_future.pop(pg)
130
+ # Fetch resource request
131
+ request = self._pg_to_request[pg]
132
+ # Remove from staging, add to ready
133
+ self._request_to_staged_pgs[request].remove(pg)
134
+ self._request_to_ready_pgs[request].add(pg)
135
+ self._last_update = time.monotonic()
136
+
137
+ def request_resources(self, resource_request: ResourceRequest):
138
+ pg = resource_request.to_placement_group()
139
+ self._pg_to_request[pg] = resource_request
140
+ self._request_to_staged_pgs[resource_request].add(pg)
141
+
142
+ future = pg.ready()
143
+ self._staging_future_to_pg[future] = pg
144
+ self._pg_to_staging_future[pg] = future
145
+
146
+ def cancel_resource_request(self, resource_request: ResourceRequest):
147
+ if self._request_to_staged_pgs[resource_request]:
148
+ pg = self._request_to_staged_pgs[resource_request].pop()
149
+
150
+ # PG was staging
151
+ future = self._pg_to_staging_future.pop(pg)
152
+ self._staging_future_to_pg.pop(future)
153
+
154
+ # Cancel the pg.ready task.
155
+ # Otherwise, it will be pending node assignment forever.
156
+ ray.cancel(future)
157
+ else:
158
+ # PG might be ready
159
+ pg = self._request_to_ready_pgs[resource_request].pop()
160
+ if not pg:
161
+ raise RuntimeError(
162
+ "Cannot cancel resource request: No placement group was "
163
+ f"staged or is ready. Make sure to not cancel more resource "
164
+ f"requests than you've created. Request: {resource_request}"
165
+ )
166
+
167
+ self._pg_to_request.pop(pg)
168
+ ray.util.remove_placement_group(pg)
169
+
170
+ def has_resources_ready(self, resource_request: ResourceRequest) -> bool:
171
+ if not bool(len(self._request_to_ready_pgs[resource_request])):
172
+ # Only update state if needed
173
+ self._maybe_update_state()
174
+
175
+ return bool(len(self._request_to_ready_pgs[resource_request]))
176
+
177
+ def acquire_resources(
178
+ self, resource_request: ResourceRequest
179
+ ) -> Optional[PlacementGroupAcquiredResources]:
180
+ if not self.has_resources_ready(resource_request):
181
+ return None
182
+
183
+ pg = self._request_to_ready_pgs[resource_request].pop()
184
+ self._acquired_pgs.add(pg)
185
+
186
+ return self._resource_cls(placement_group=pg, resource_request=resource_request)
187
+
188
+ def free_resources(self, acquired_resource: PlacementGroupAcquiredResources):
189
+ pg = acquired_resource.placement_group
190
+
191
+ self._acquired_pgs.remove(pg)
192
+ remove_placement_group(pg)
193
+ self._pg_to_request.pop(pg)
194
+
195
+ def clear(self):
196
+ if not ray.is_initialized():
197
+ return
198
+
199
+ for staged_pgs in self._request_to_staged_pgs.values():
200
+ for staged_pg in staged_pgs:
201
+ remove_placement_group(staged_pg)
202
+
203
+ for ready_pgs in self._request_to_ready_pgs.values():
204
+ for ready_pg in ready_pgs:
205
+ remove_placement_group(ready_pg)
206
+
207
+ for acquired_pg in self._acquired_pgs:
208
+ remove_placement_group(acquired_pg)
209
+
210
+ # Reset internal state
211
+ self.__init__(update_interval_s=self.update_interval_s)
212
+
213
+ def __del__(self):
214
+ self.clear()
.venv/lib/python3.11/site-packages/ray/air/execution/resources/request.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import json
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+ from inspect import signature
6
+ from typing import Dict, List, Union
7
+
8
+ import ray
9
+ from ray.util import placement_group
10
+ from ray.util.annotations import DeveloperAPI
11
+
12
+ RemoteRayEntity = Union[ray.remote_function.RemoteFunction, ray.actor.ActorClass]
13
+
14
+
15
+ def _sum_bundles(bundles: List[Dict[str, float]]) -> Dict[str, float]:
16
+ """Sum all resources in a list of resource bundles.
17
+
18
+ Args:
19
+ bundles: List of resource bundles.
20
+
21
+ Returns: Dict containing all resources summed up.
22
+ """
23
+ resources = {}
24
+ for bundle in bundles:
25
+ for k, v in bundle.items():
26
+ resources[k] = resources.get(k, 0) + v
27
+ return resources
28
+
29
+
30
+ @DeveloperAPI
31
+ class ResourceRequest:
32
+ """Request for resources.
33
+
34
+ This class is used to define a resource request. A resource request comprises one
35
+ or more bundles of resources and instructions on the scheduling behavior.
36
+
37
+ The resource request can be submitted to a resource manager, which will
38
+ schedule the resources. Depending on the resource backend, this may instruct
39
+ Ray to scale up (autoscaling).
40
+
41
+ Resource requests are compatible with the most fine-grained low-level resource
42
+ backend, which are Ray placement groups.
43
+
44
+ Args:
45
+ bundles: A list of bundles which represent the resources requirements.
46
+ E.g. ``[{"CPU": 1, "GPU": 1}]``.
47
+ strategy: The scheduling strategy to acquire the bundles.
48
+
49
+ - "PACK": Packs Bundles into as few nodes as possible.
50
+ - "SPREAD": Places Bundles across distinct nodes as even as possible.
51
+ - "STRICT_PACK": Packs Bundles into one node. The group is
52
+ not allowed to span multiple nodes.
53
+ - "STRICT_SPREAD": Packs Bundles across distinct nodes.
54
+ *args: Passed to the call of ``placement_group()``, if applicable.
55
+ **kwargs: Passed to the call of ``placement_group()``, if applicable.
56
+
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ bundles: List[Dict[str, Union[int, float]]],
62
+ strategy: str = "PACK",
63
+ *args,
64
+ **kwargs,
65
+ ):
66
+ if not bundles:
67
+ raise ValueError("Cannot initialize a ResourceRequest with zero bundles.")
68
+
69
+ # Remove empty resource keys
70
+ self._bundles = [
71
+ {k: float(v) for k, v in bundle.items() if v != 0} for bundle in bundles
72
+ ]
73
+
74
+ # Check if the head bundle is empty (no resources defined or all resources
75
+ # are 0 (and thus removed in the previous step)
76
+ if not self._bundles[0]:
77
+ # This is when the head bundle doesn't need resources.
78
+ self._head_bundle_is_empty = True
79
+ self._bundles.pop(0)
80
+
81
+ if not self._bundles:
82
+ raise ValueError(
83
+ "Cannot initialize a ResourceRequest with an empty head "
84
+ "and zero worker bundles."
85
+ )
86
+ else:
87
+ self._head_bundle_is_empty = False
88
+
89
+ self._strategy = strategy
90
+ self._args = args
91
+ self._kwargs = kwargs
92
+
93
+ self._hash = None
94
+ self._bound = None
95
+
96
+ self._bind()
97
+
98
+ @property
99
+ def head_bundle_is_empty(self):
100
+ """Returns True if head bundle is empty while child bundles
101
+ need resources.
102
+
103
+ This is considered an internal API within Tune.
104
+ """
105
+ return self._head_bundle_is_empty
106
+
107
+ @property
108
+ @DeveloperAPI
109
+ def head_cpus(self) -> float:
110
+ """Returns the number of cpus in the head bundle."""
111
+ return 0.0 if self._head_bundle_is_empty else self._bundles[0].get("CPU", 0.0)
112
+
113
+ @property
114
+ @DeveloperAPI
115
+ def bundles(self) -> List[Dict[str, float]]:
116
+ """Returns a deep copy of resource bundles"""
117
+ return deepcopy(self._bundles)
118
+
119
+ @property
120
+ def required_resources(self) -> Dict[str, float]:
121
+ """Returns a dict containing the sums of all resources"""
122
+ return _sum_bundles(self._bundles)
123
+
124
+ @property
125
+ @DeveloperAPI
126
+ def strategy(self) -> str:
127
+ """Returns the placement strategy"""
128
+ return self._strategy
129
+
130
+ def _bind(self):
131
+ """Bind the args and kwargs to the `placement_group()` signature.
132
+
133
+ We bind the args and kwargs, so we can compare equality of two resource
134
+ requests. The main reason for this is that the `placement_group()` API
135
+ can evolve independently from the ResourceRequest API (e.g. adding new
136
+ arguments). Then, `ResourceRequest(bundles, strategy, arg=arg)` should
137
+ be the same as `ResourceRequest(bundles, strategy, arg)`.
138
+ """
139
+ sig = signature(placement_group)
140
+ try:
141
+ self._bound = sig.bind(
142
+ self._bundles, self._strategy, *self._args, **self._kwargs
143
+ )
144
+ except Exception as exc:
145
+ raise RuntimeError(
146
+ "Invalid definition for resource request. Please check "
147
+ "that you passed valid arguments to the ResourceRequest "
148
+ "object."
149
+ ) from exc
150
+
151
+ def to_placement_group(self):
152
+ return placement_group(*self._bound.args, **self._bound.kwargs)
153
+
154
+ def __eq__(self, other: "ResourceRequest"):
155
+ return (
156
+ isinstance(other, ResourceRequest)
157
+ and self._bound == other._bound
158
+ and self.head_bundle_is_empty == other.head_bundle_is_empty
159
+ )
160
+
161
+ def __hash__(self):
162
+ if not self._hash:
163
+ # Cache hash
164
+ self._hash = hash(
165
+ json.dumps(
166
+ {"args": self._bound.args, "kwargs": self._bound.kwargs},
167
+ sort_keys=True,
168
+ indent=0,
169
+ ensure_ascii=True,
170
+ )
171
+ )
172
+ return self._hash
173
+
174
+ def __getstate__(self):
175
+ state = self.__dict__.copy()
176
+ state.pop("_hash", None)
177
+ state.pop("_bound", None)
178
+ return state
179
+
180
+ def __setstate__(self, state):
181
+ self.__dict__.update(state)
182
+ self._hash = None
183
+ self._bound = None
184
+ self._bind()
185
+
186
+ def __repr__(self) -> str:
187
+ return (
188
+ f"<ResourceRequest (_bound={self._bound}, "
189
+ f"head_bundle_is_empty={self.head_bundle_is_empty})>"
190
+ )
191
+
192
+
193
+ @DeveloperAPI
194
+ @dataclass
195
+ class AcquiredResources(abc.ABC):
196
+ """Base class for resources that have been acquired.
197
+
198
+ Acquired resources can be associated to Ray objects, which can then be
199
+ scheduled using these resources.
200
+
201
+ Internally this can point e.g. to a placement group, a placement
202
+ group bundle index, or just raw resources.
203
+
204
+ The main API is the `annotate_remote_entities` method. This will associate
205
+ remote Ray objects (tasks and actors) with the acquired resources by setting
206
+ the Ray remote options to use the acquired resources.
207
+ """
208
+
209
+ resource_request: ResourceRequest
210
+
211
+ def annotate_remote_entities(
212
+ self, entities: List[RemoteRayEntity]
213
+ ) -> List[Union[RemoteRayEntity]]:
214
+ """Return remote ray entities (tasks/actors) to use the acquired resources.
215
+
216
+ The first entity will be associated with the first bundle, the second
217
+ entity will be associated with the second bundle, etc.
218
+
219
+ Args:
220
+ entities: Remote Ray entities to annotate with the acquired resources.
221
+ """
222
+ bundles = self.resource_request.bundles
223
+
224
+ # Also count the empty head bundle as a bundle
225
+ num_bundles = len(bundles) + int(self.resource_request.head_bundle_is_empty)
226
+
227
+ if len(entities) > num_bundles:
228
+ raise RuntimeError(
229
+ f"The number of callables to annotate ({len(entities)}) cannot "
230
+ f"exceed the number of available bundles ({num_bundles})."
231
+ )
232
+
233
+ annotated = []
234
+
235
+ if self.resource_request.head_bundle_is_empty:
236
+ # The empty head bundle is place on the first bundle index with empty
237
+ # resources.
238
+ annotated.append(
239
+ self._annotate_remote_entity(entities[0], {}, bundle_index=0)
240
+ )
241
+
242
+ # Shift the remaining entities
243
+ entities = entities[1:]
244
+
245
+ for i, (entity, bundle) in enumerate(zip(entities, bundles)):
246
+ annotated.append(
247
+ self._annotate_remote_entity(entity, bundle, bundle_index=i)
248
+ )
249
+
250
+ return annotated
251
+
252
+ def _annotate_remote_entity(
253
+ self, entity: RemoteRayEntity, bundle: Dict[str, float], bundle_index: int
254
+ ) -> RemoteRayEntity:
255
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/ray/air/execution/resources/resource_manager.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Optional
3
+
4
+ import ray
5
+ from ray.air.execution.resources.request import AcquiredResources, ResourceRequest
6
+ from ray.util.annotations import DeveloperAPI
7
+
8
+
9
+ @DeveloperAPI
10
+ class ResourceManager(abc.ABC):
11
+ """Resource manager interface.
12
+
13
+ A resource manager can be used to request resources from a Ray cluster and
14
+ allocate them to remote Ray tasks or actors.
15
+
16
+ Resources have to be requested before they can be acquired.
17
+
18
+ Resources managed by the resource manager can be in three states:
19
+
20
+ 1. "Requested": The resources have been requested but are not yet available to
21
+ schedule remote Ray objects. The resource request may trigger autoscaling,
22
+ and can be cancelled if no longer needed.
23
+ 2. "Ready": The requested resources are now available to schedule remote Ray
24
+ objects. They can be acquired and subsequently used remote Ray objects.
25
+ The resource request can still be cancelled if no longer needed.
26
+ 3. "Acquired": The resources have been acquired by a caller to use for scheduling
27
+ remote Ray objects. Note that it is the responsibility of the caller to
28
+ schedule the Ray objects with these resources.
29
+ The associated resource request has been completed and can no longer be
30
+ cancelled. The acquired resources can be freed by the resource manager when
31
+ they are no longer used.
32
+
33
+ The flow is as follows:
34
+
35
+ .. code-block:: python
36
+
37
+ # Create resource manager
38
+ resource_manager = ResourceManager()
39
+
40
+ # Create resource request
41
+ resource_request = ResourceRequest([{"CPU": 4}])
42
+
43
+ # Pass to resource manager
44
+ resource_manager.request_resources(resource_request)
45
+
46
+ # Wait until ready
47
+ while not resource_manager.has_resources_ready(resource_request):
48
+ time.sleep(1)
49
+
50
+ # Once ready, acquire resources
51
+ acquired_resource = resource_manager.acquire_resources(resource_request)
52
+
53
+ # Bind to remote task or actor
54
+ annotated_remote_fn = acquired_resource.annotate_remote_entities(
55
+ [remote_fn])
56
+
57
+ # Run remote function. This will use the acquired resources
58
+ ray.get(annotated_remote_fn.remote())
59
+
60
+ # After using the resources, free
61
+ resource_manager.free_resources(annotated_resources)
62
+
63
+ """
64
+
65
+ def request_resources(self, resource_request: ResourceRequest):
66
+ """Request resources.
67
+
68
+ Depending on the backend, resources can trigger autoscaling. Requested
69
+ resources can be ready or not ready. Once they are "ready", they can
70
+ be acquired and used by remote Ray objects.
71
+
72
+ Resource requests can be cancelled anytime using ``cancel_resource_request()``.
73
+ Once acquired, the resource request is removed. Acquired resources can be
74
+ freed with ``free_resources()``.
75
+ """
76
+ raise NotImplementedError
77
+
78
+ def cancel_resource_request(self, resource_request: ResourceRequest):
79
+ """Cancel resource request.
80
+
81
+ Resource requests can be cancelled anytime before a resource is acquired.
82
+ Acquiring a resource will remove the associated resource request.
83
+ Acquired resources can be freed with ``free_resources()``.
84
+ """
85
+ raise NotImplementedError
86
+
87
+ def has_resources_ready(self, resource_request: ResourceRequest) -> bool:
88
+ """Returns True if resources for the given request are ready to be acquired."""
89
+ raise NotImplementedError
90
+
91
+ def acquire_resources(
92
+ self, resource_request: ResourceRequest
93
+ ) -> Optional[AcquiredResources]:
94
+ """Acquire resources. Returns None if resources are not ready to be acquired.
95
+
96
+ Acquiring resources will remove the associated resource request.
97
+ Acquired resources can be returned with ``free_resources()``.
98
+ """
99
+ raise NotImplementedError
100
+
101
+ def free_resources(self, acquired_resource: AcquiredResources):
102
+ """Free acquired resources from usage and return them to the resource manager.
103
+
104
+ Freeing resources will return the resources to the manager, but there are
105
+ no guarantees about the tasks and actors scheduled on the resources. The caller
106
+ should make sure that any references to tasks or actors scheduled on the
107
+ resources have been removed before calling ``free_resources()``.
108
+ """
109
+ raise NotImplementedError
110
+
111
+ def get_resource_futures(self) -> List[ray.ObjectRef]:
112
+ """Return futures for resources to await.
113
+
114
+ Depending on the backend, we use resource futures to determine availability
115
+ of resources (e.g. placement groups) or resolution of requests.
116
+ In this case, the futures can be awaited externally by the caller.
117
+
118
+ When a resource future resolved, the caller may call ``update_state()``
119
+ to force the resource manager to update its internal state immediately.
120
+ """
121
+ return []
122
+
123
+ def update_state(self):
124
+ """Update internal state of the resource manager.
125
+
126
+ The resource manager may have internal state that needs periodic updating.
127
+ For instance, depending on the backend, resource futures can be awaited
128
+ externally (with ``get_resource_futures()``).
129
+
130
+ If such a future resolved, the caller can instruct the resource
131
+ manager to update its internal state immediately.
132
+ """
133
+ pass
134
+
135
+ def clear(self):
136
+ """Reset internal state and clear all resources.
137
+
138
+ Calling this method will reset the resource manager to its initialization state.
139
+ All resources will be removed.
140
+
141
+ Clearing the state will remove tracked resources from the manager, but there are
142
+ no guarantees about the tasks and actors scheduled on the resources. The caller
143
+ should make sure that any references to tasks or actors scheduled on the
144
+ resources have been removed before calling ``clear()``.
145
+ """
146
+ raise NotImplementedError
147
+
148
+ def __reduce__(self):
149
+ """We disallow serialization.
150
+
151
+ Shared resource managers should live on an actor.
152
+ """
153
+ raise ValueError(
154
+ f"Resource managers cannot be serialized. Resource manager: {str(self)}"
155
+ )
.venv/lib/python3.11/site-packages/ray/air/result.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import pandas as pd
10
+ import pyarrow
11
+
12
+ import ray
13
+ from ray.air.constants import (
14
+ EXPR_ERROR_PICKLE_FILE,
15
+ EXPR_PROGRESS_FILE,
16
+ EXPR_RESULT_FILE,
17
+ )
18
+ from ray.util.annotations import PublicAPI
19
+
20
+ if TYPE_CHECKING:
21
+ from ray.train import Checkpoint
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @PublicAPI(stability="stable")
27
+ @dataclass
28
+ class Result:
29
+ """The final result of a ML training run or a Tune trial.
30
+
31
+ This is the output produced by ``Trainer.fit``.
32
+ ``Tuner.fit`` outputs a :class:`~ray.tune.ResultGrid` that is a collection
33
+ of ``Result`` objects.
34
+
35
+ This API is the recommended way to access the outputs such as:
36
+ - checkpoints (``Result.checkpoint``)
37
+ - the history of reported metrics (``Result.metrics_dataframe``, ``Result.metrics``)
38
+ - errors encountered during a training run (``Result.error``)
39
+
40
+ The constructor is a private API -- use ``Result.from_path`` to create a result
41
+ object from a directory.
42
+
43
+ Attributes:
44
+ metrics: The latest set of reported metrics.
45
+ checkpoint: The latest checkpoint.
46
+ error: The execution error of the Trainable run, if the trial finishes in error.
47
+ path: Path pointing to the result directory on persistent storage. This can
48
+ point to a remote storage location (e.g. S3) or to a local location (path
49
+ on the head node). The path is accessible via the result's associated
50
+ `filesystem`. For instance, for a result stored in S3 at
51
+ ``s3://bucket/location``, ``path`` will have the value ``bucket/location``.
52
+ metrics_dataframe: The full result dataframe of the Trainable.
53
+ The dataframe is indexed by iterations and contains reported
54
+ metrics. Note that the dataframe columns are indexed with the
55
+ *flattened* keys of reported metrics, so the format of this dataframe
56
+ may be slightly different than ``Result.metrics``, which is an unflattened
57
+ dict of the latest set of reported metrics.
58
+ best_checkpoints: A list of tuples of the best checkpoints and
59
+ their associated metrics. The number of
60
+ saved checkpoints is determined by :class:`~ray.train.CheckpointConfig`
61
+ (by default, all checkpoints will be saved).
62
+ """
63
+
64
+ metrics: Optional[Dict[str, Any]]
65
+ checkpoint: Optional["Checkpoint"]
66
+ error: Optional[Exception]
67
+ path: str
68
+ metrics_dataframe: Optional["pd.DataFrame"] = None
69
+ best_checkpoints: Optional[List[Tuple["Checkpoint", Dict[str, Any]]]] = None
70
+ _storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
71
+ _items_to_repr = ["error", "metrics", "path", "filesystem", "checkpoint"]
72
+
73
+ @property
74
+ def config(self) -> Optional[Dict[str, Any]]:
75
+ """The config associated with the result."""
76
+ if not self.metrics:
77
+ return None
78
+ return self.metrics.get("config", None)
79
+
80
+ @property
81
+ def filesystem(self) -> pyarrow.fs.FileSystem:
82
+ """Return the filesystem that can be used to access the result path.
83
+
84
+ Returns:
85
+ pyarrow.fs.FileSystem implementation.
86
+ """
87
+ return self._storage_filesystem or pyarrow.fs.LocalFileSystem()
88
+
89
+ def _repr(self, indent: int = 0) -> str:
90
+ """Construct the representation with specified number of space indent."""
91
+ from ray.tune.experimental.output import BLACKLISTED_KEYS
92
+ from ray.tune.result import AUTO_RESULT_KEYS
93
+
94
+ shown_attributes = {k: getattr(self, k) for k in self._items_to_repr}
95
+ if self.error:
96
+ shown_attributes["error"] = type(self.error).__name__
97
+ else:
98
+ shown_attributes.pop("error")
99
+
100
+ shown_attributes["filesystem"] = shown_attributes["filesystem"].type_name
101
+
102
+ if self.metrics:
103
+ exclude = set(AUTO_RESULT_KEYS)
104
+ exclude.update(BLACKLISTED_KEYS)
105
+ shown_attributes["metrics"] = {
106
+ k: v for k, v in self.metrics.items() if k not in exclude
107
+ }
108
+
109
+ cls_indent = " " * indent
110
+ kws_indent = " " * (indent + 2)
111
+
112
+ kws = [
113
+ f"{kws_indent}{key}={value!r}" for key, value in shown_attributes.items()
114
+ ]
115
+ kws_repr = ",\n".join(kws)
116
+ return "{0}{1}(\n{2}\n{0})".format(cls_indent, type(self).__name__, kws_repr)
117
+
118
+ def __repr__(self) -> str:
119
+ return self._repr(indent=0)
120
+
121
+ @staticmethod
122
+ def _read_file_as_str(
123
+ storage_filesystem: pyarrow.fs.FileSystem,
124
+ storage_path: str,
125
+ ) -> str:
126
+ """Opens a file as an input stream reading all byte content sequentially and
127
+ decoding read bytes as utf-8 string.
128
+
129
+ Args:
130
+ storage_filesystem: The filesystem to use.
131
+ storage_path: The source to open for reading.
132
+ """
133
+
134
+ with storage_filesystem.open_input_stream(storage_path) as f:
135
+ return f.readall().decode()
136
+
137
+ @classmethod
138
+ def from_path(
139
+ cls,
140
+ path: Union[str, os.PathLike],
141
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
142
+ ) -> "Result":
143
+ """Restore a Result object from local or remote trial directory.
144
+
145
+ Args:
146
+ path: A path of a trial directory on local or remote storage
147
+ (ex: s3://bucket/path or /tmp/ray_results).
148
+ storage_filesystem: A custom filesystem to use. If not provided,
149
+ this will be auto-resolved by pyarrow. If provided, the path
150
+ is assumed to be prefix-stripped already, and must be a valid path
151
+ on the filesystem.
152
+
153
+ Returns:
154
+ A :py:class:`Result` object of that trial.
155
+ """
156
+ # TODO(justinvyu): Fix circular dependency.
157
+ from ray.train import Checkpoint
158
+ from ray.train._internal.storage import (
159
+ _exists_at_fs_path,
160
+ _list_at_fs_path,
161
+ get_fs_and_path,
162
+ )
163
+ from ray.train.constants import CHECKPOINT_DIR_NAME
164
+
165
+ fs, fs_path = get_fs_and_path(path, storage_filesystem)
166
+ if not _exists_at_fs_path(fs, fs_path):
167
+ raise RuntimeError(f"Trial folder {fs_path} doesn't exist!")
168
+
169
+ # Restore metrics from result.json
170
+ result_json_file = Path(fs_path, EXPR_RESULT_FILE).as_posix()
171
+ progress_csv_file = Path(fs_path, EXPR_PROGRESS_FILE).as_posix()
172
+ if _exists_at_fs_path(fs, result_json_file):
173
+ lines = cls._read_file_as_str(fs, result_json_file).split("\n")
174
+ json_list = [json.loads(line) for line in lines if line]
175
+ metrics_df = pd.json_normalize(json_list, sep="/")
176
+ latest_metrics = json_list[-1] if json_list else {}
177
+ # Fallback to restore from progress.csv
178
+ elif _exists_at_fs_path(fs, progress_csv_file):
179
+ metrics_df = pd.read_csv(
180
+ io.StringIO(cls._read_file_as_str(fs, progress_csv_file))
181
+ )
182
+ latest_metrics = (
183
+ metrics_df.iloc[-1].to_dict() if not metrics_df.empty else {}
184
+ )
185
+ else:
186
+ raise RuntimeError(
187
+ f"Failed to restore the Result object: Neither {EXPR_RESULT_FILE}"
188
+ f" nor {EXPR_PROGRESS_FILE} exists in the trial folder!"
189
+ )
190
+
191
+ # Restore all checkpoints from the checkpoint folders
192
+ checkpoint_dir_names = sorted(
193
+ _list_at_fs_path(
194
+ fs,
195
+ fs_path,
196
+ file_filter=lambda file_info: file_info.type
197
+ == pyarrow.fs.FileType.Directory
198
+ and file_info.base_name.startswith("checkpoint_"),
199
+ )
200
+ )
201
+
202
+ if checkpoint_dir_names:
203
+ checkpoints = [
204
+ Checkpoint(
205
+ path=Path(fs_path, checkpoint_dir_name).as_posix(), filesystem=fs
206
+ )
207
+ for checkpoint_dir_name in checkpoint_dir_names
208
+ ]
209
+
210
+ metrics = []
211
+ for checkpoint_dir_name in checkpoint_dir_names:
212
+ metrics_corresponding_to_checkpoint = metrics_df[
213
+ metrics_df[CHECKPOINT_DIR_NAME] == checkpoint_dir_name
214
+ ]
215
+ if metrics_corresponding_to_checkpoint.empty:
216
+ logger.warning(
217
+ "Could not find metrics corresponding to "
218
+ f"{checkpoint_dir_name}. These will default to an empty dict."
219
+ )
220
+ metrics.append(
221
+ {}
222
+ if metrics_corresponding_to_checkpoint.empty
223
+ else metrics_corresponding_to_checkpoint.iloc[-1].to_dict()
224
+ )
225
+
226
+ latest_checkpoint = checkpoints[-1]
227
+ # TODO(justinvyu): These are ordered by checkpoint index, since we don't
228
+ # know the metric to order these with.
229
+ best_checkpoints = list(zip(checkpoints, metrics))
230
+ else:
231
+ best_checkpoints = latest_checkpoint = None
232
+
233
+ # Restore the trial error if it exists
234
+ error = None
235
+ error_file_path = Path(fs_path, EXPR_ERROR_PICKLE_FILE).as_posix()
236
+ if _exists_at_fs_path(fs, error_file_path):
237
+ with fs.open_input_stream(error_file_path) as f:
238
+ error = ray.cloudpickle.load(f)
239
+
240
+ return Result(
241
+ metrics=latest_metrics,
242
+ checkpoint=latest_checkpoint,
243
+ path=fs_path,
244
+ _storage_filesystem=fs,
245
+ metrics_dataframe=metrics_df,
246
+ best_checkpoints=best_checkpoints,
247
+ error=error,
248
+ )
249
+
250
+ @PublicAPI(stability="alpha")
251
+ def get_best_checkpoint(self, metric: str, mode: str) -> Optional["Checkpoint"]:
252
+ """Get the best checkpoint from this trial based on a specific metric.
253
+
254
+ Any checkpoints without an associated metric value will be filtered out.
255
+
256
+ Args:
257
+ metric: The key for checkpoints to order on.
258
+ mode: One of ["min", "max"].
259
+
260
+ Returns:
261
+ :class:`Checkpoint <ray.train.Checkpoint>` object, or None if there is
262
+ no valid checkpoint associated with the metric.
263
+ """
264
+ if not self.best_checkpoints:
265
+ raise RuntimeError("No checkpoint exists in the trial directory!")
266
+
267
+ if mode not in ["max", "min"]:
268
+ raise ValueError(
269
+ f'Unsupported mode: {mode}. Please choose from ["min", "max"]!'
270
+ )
271
+
272
+ op = max if mode == "max" else min
273
+ valid_checkpoints = [
274
+ ckpt_info for ckpt_info in self.best_checkpoints if metric in ckpt_info[1]
275
+ ]
276
+
277
+ if not valid_checkpoints:
278
+ raise RuntimeError(
279
+ f"Invalid metric name {metric}! "
280
+ f"You may choose from the following metrics: {self.metrics.keys()}."
281
+ )
282
+
283
+ return op(valid_checkpoints, key=lambda x: x[1][metric])[0]
.venv/lib/python3.11/site-packages/ray/air/session.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from ray.train._internal.session import * # noqa: F401,F403
.venv/lib/python3.11/site-packages/ray/air/util/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/air/util/tensor_extensions/pandas.py ADDED
@@ -0,0 +1,1451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from
2
+ # https://github.com/CODAIT/text-extensions-for-pandas/blob/dc03278689fe1c5f131573658ae19815ba25f33e/text_extensions_for_pandas/array/tensor.py
3
+ # and
4
+ # https://github.com/CODAIT/text-extensions-for-pandas/blob/dc03278689fe1c5f131573658ae19815ba25f33e/text_extensions_for_pandas/array/arrow_conversion.py
5
+
6
+ #
7
+ # Copyright (c) 2020 IBM Corp.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # Modifications:
21
+ # - Added ArrowTensorType.to_pandas_type()
22
+ # - Added ArrowTensorArray.__getitem__()
23
+ # - Added ArrowTensorArray.__iter__()
24
+ # - Added support for column casts to extension types.
25
+ # - Fleshed out docstrings and examples.
26
+ # - Fixed TensorArray.isna() so it returns an appropriate ExtensionArray.
27
+ # - Added different (more vectorized) TensorArray.take() operation.
28
+ # - Added support for more reducers (agg funcs) to TensorArray.
29
+ # - Added support for logical operators to TensorArray(Element).
30
+ # - Added support for heterogeneously-shaped tensors.
31
+ # - Miscellaneous small bug fixes and optimizations.
32
+
33
+ import numbers
34
+ import os
35
+ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
36
+
37
+ import numpy as np
38
+ import pandas as pd
39
+ import pyarrow as pa
40
+ from packaging.version import Version
41
+ from pandas._typing import Dtype
42
+ from pandas.compat import set_function_name
43
+ from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
44
+ from pandas.core.indexers import check_array_indexer, validate_indices
45
+
46
+ from ray.air.util.tensor_extensions.utils import (
47
+ _create_possibly_ragged_ndarray,
48
+ _is_ndarray_variable_shaped_tensor,
49
+ )
50
+ from ray.util.annotations import PublicAPI
51
+
52
+ try:
53
+ from pandas.core.dtypes.generic import ABCIndex
54
+ except ImportError:
55
+ # ABCIndexClass changed to ABCIndex in Pandas 1.3
56
+ from pandas.core.dtypes.generic import ABCIndexClass as ABCIndex
57
+
58
+
59
+ #############################################
60
+ # Begin patching of ExtensionArrayFormatter #
61
+ #############################################
62
+
63
+
64
+ def _format_strings_patched(self) -> List[str]:
65
+ from pandas.core.construction import extract_array
66
+ from pandas.io.formats.format import format_array
67
+
68
+ if not isinstance(self.values, TensorArray):
69
+ return self._format_strings_orig()
70
+
71
+ values = extract_array(self.values, extract_numpy=True)
72
+ array = np.asarray(values)
73
+
74
+ if array.ndim == 1:
75
+ return self._format_strings_orig()
76
+
77
+ def format_array_wrap(array_, formatter_):
78
+ fmt_values = format_array(
79
+ array_,
80
+ formatter_,
81
+ float_format=self.float_format,
82
+ na_rep=self.na_rep,
83
+ digits=self.digits,
84
+ space=self.space,
85
+ justify=self.justify,
86
+ decimal=self.decimal,
87
+ leading_space=self.leading_space,
88
+ quoting=self.quoting,
89
+ )
90
+ return fmt_values
91
+
92
+ flat_formatter = self.formatter
93
+ if flat_formatter is None:
94
+ flat_formatter = values._formatter(boxed=True)
95
+
96
+ # Flatten array, call function, reshape (use ravel_compat in v1.3.0)
97
+ flat_array = array.ravel("K")
98
+ fmt_flat_array = np.asarray(format_array_wrap(flat_array, flat_formatter))
99
+ order = "F" if array.flags.f_contiguous else "C"
100
+ fmt_array = fmt_flat_array.reshape(array.shape, order=order)
101
+
102
+ # Format the array of nested strings, use default formatter
103
+ return format_array_wrap(fmt_array, None)
104
+
105
+
106
+ def _format_strings_patched_v1_0_0(self) -> List[str]:
107
+ from functools import partial
108
+
109
+ from pandas.core.construction import extract_array
110
+ from pandas.io.formats.format import format_array
111
+ from pandas.io.formats.printing import pprint_thing
112
+
113
+ if not isinstance(self.values, TensorArray):
114
+ return self._format_strings_orig()
115
+
116
+ values = extract_array(self.values, extract_numpy=True)
117
+ array = np.asarray(values)
118
+
119
+ if array.ndim == 1:
120
+ return self._format_strings_orig()
121
+
122
+ def format_array_wrap(array_, formatter_):
123
+ fmt_values = format_array(
124
+ array_,
125
+ formatter_,
126
+ float_format=self.float_format,
127
+ na_rep=self.na_rep,
128
+ digits=self.digits,
129
+ space=self.space,
130
+ justify=self.justify,
131
+ decimal=self.decimal,
132
+ leading_space=self.leading_space,
133
+ )
134
+ return fmt_values
135
+
136
+ flat_formatter = self.formatter
137
+ if flat_formatter is None:
138
+ flat_formatter = values._formatter(boxed=True)
139
+
140
+ # Flatten array, call function, reshape (use ravel_compat in v1.3.0)
141
+ flat_array = array.ravel("K")
142
+ fmt_flat_array = np.asarray(format_array_wrap(flat_array, flat_formatter))
143
+ order = "F" if array.flags.f_contiguous else "C"
144
+ fmt_array = fmt_flat_array.reshape(array.shape, order=order)
145
+
146
+ # Slimmed down version of GenericArrayFormatter due to:
147
+ # https://github.com/pandas-dev/pandas/issues/33770
148
+ def format_strings_slim(array_, leading_space):
149
+ formatter = partial(
150
+ pprint_thing,
151
+ escape_chars=("\t", "\r", "\n"),
152
+ )
153
+
154
+ def _format(x):
155
+ return str(formatter(x))
156
+
157
+ fmt_values = []
158
+ for v in array_:
159
+ tpl = "{v}" if leading_space is False else " {v}"
160
+ fmt_values.append(tpl.format(v=_format(v)))
161
+ return fmt_values
162
+
163
+ return format_strings_slim(fmt_array, self.leading_space)
164
+
165
+
166
+ _FORMATTER_ENABLED_ENV_VAR = "TENSOR_COLUMN_EXTENSION_FORMATTER_ENABLED"
167
+
168
+ if os.getenv(_FORMATTER_ENABLED_ENV_VAR, "1") == "1":
169
+ if Version(pd.__version__) < Version("2.2.0"):
170
+ from pandas.io.formats.format import ExtensionArrayFormatter
171
+
172
+ formatter_cls = ExtensionArrayFormatter
173
+ else:
174
+ from pandas.io.formats.format import _ExtensionArrayFormatter
175
+
176
+ formatter_cls = _ExtensionArrayFormatter
177
+ formatter_cls._format_strings_orig = formatter_cls._format_strings
178
+ if Version("1.1.0") <= Version(pd.__version__) < Version("1.3.0"):
179
+ formatter_cls._format_strings = _format_strings_patched
180
+ else:
181
+ formatter_cls._format_strings = _format_strings_patched_v1_0_0
182
+ formatter_cls._patched_by_ray_datasets = True
183
+
184
+ ###########################################
185
+ # End patching of ExtensionArrayFormatter #
186
+ ###########################################
187
+
188
+
189
+ @PublicAPI(stability="beta")
190
+ @pd.api.extensions.register_extension_dtype
191
+ class TensorDtype(pd.api.extensions.ExtensionDtype):
192
+ """
193
+ Pandas extension type for a column of homogeneous-typed tensors.
194
+
195
+ This extension supports tensors in which the elements have different shapes.
196
+ However, each tensor element must be non-ragged, i.e. each tensor element must have
197
+ a well-defined, non-ragged shape.
198
+
199
+ See:
200
+ https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/base.py
201
+ for up-to-date interface documentation and the subclassing contract. The
202
+ docstrings of the below properties and methods were copied from the base
203
+ ExtensionDtype.
204
+
205
+ Examples:
206
+ >>> # Create a DataFrame with a list of ndarrays as a column.
207
+ >>> import pandas as pd
208
+ >>> import numpy as np
209
+ >>> import ray
210
+ >>> df = pd.DataFrame({
211
+ ... "one": [1, 2, 3],
212
+ ... "two": list(np.arange(24).reshape((3, 2, 2, 2)))})
213
+ >>> # Note the opaque np.object dtype for this column.
214
+ >>> df.dtypes # doctest: +SKIP
215
+ one int64
216
+ two object
217
+ dtype: object
218
+ >>> # Cast column to our TensorDtype extension type.
219
+ >>> from ray.data.extensions import TensorDtype
220
+ >>> df["two"] = df["two"].astype(TensorDtype(np.int64, (3, 2, 2, 2)))
221
+ >>> # Note that the column dtype is now TensorDtype instead of
222
+ >>> # np.object.
223
+ >>> df.dtypes # doctest: +SKIP
224
+ one int64
225
+ two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
226
+ dtype: object
227
+ >>> # Pandas is now aware of this tensor column, and we can do the
228
+ >>> # typical DataFrame operations on this column.
229
+ >>> col = 2 * (df["two"] + 10)
230
+ >>> # The ndarrays underlying the tensor column will be manipulated,
231
+ >>> # but the column itself will continue to be a Pandas type.
232
+ >>> type(col) # doctest: +SKIP
233
+ pandas.core.series.Series
234
+ >>> col # doctest: +SKIP
235
+ 0 [[[ 2 4]
236
+ [ 6 8]]
237
+ [[10 12]
238
+ [14 16]]]
239
+ 1 [[[18 20]
240
+ [22 24]]
241
+ [[26 28]
242
+ [30 32]]]
243
+ 2 [[[34 36]
244
+ [38 40]]
245
+ [[42 44]
246
+ [46 48]]]
247
+ Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
248
+ >>> # Once you do an aggregation on that column that returns a single
249
+ >>> # row's value, you get back our TensorArrayElement type.
250
+ >>> tensor = col.mean()
251
+ >>> type(tensor) # doctest: +SKIP
252
+ ray.data.extensions.tensor_extension.TensorArrayElement
253
+ >>> tensor # doctest: +SKIP
254
+ array([[[18., 20.],
255
+ [22., 24.]],
256
+ [[26., 28.],
257
+ [30., 32.]]])
258
+ >>> # This is a light wrapper around a NumPy ndarray, and can easily
259
+ >>> # be converted to an ndarray.
260
+ >>> type(tensor.to_numpy()) # doctest: +SKIP
261
+ numpy.ndarray
262
+ >>> # In addition to doing Pandas operations on the tensor column,
263
+ >>> # you can now put the DataFrame into a Dataset.
264
+ >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
265
+ >>> # Internally, this column is represented the corresponding
266
+ >>> # Arrow tensor extension type.
267
+ >>> ds.schema() # doctest: +SKIP
268
+ one: int64
269
+ two: extension<arrow.py_extension_type<ArrowTensorType>>
270
+ >>> # You can write the dataset to Parquet.
271
+ >>> ds.write_parquet("/some/path") # doctest: +SKIP
272
+ >>> # And you can read it back.
273
+ >>> read_ds = ray.data.read_parquet("/some/path") # doctest: +SKIP
274
+ >>> read_ds.schema() # doctest: +SKIP
275
+ one: int64
276
+ two: extension<arrow.py_extension_type<ArrowTensorType>>
277
+ >>> read_df = ray.get(read_ds.to_pandas_refs())[0] # doctest: +SKIP
278
+ >>> read_df.dtypes # doctest: +SKIP
279
+ one int64
280
+ two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
281
+ dtype: object
282
+ >>> # The tensor extension type is preserved along the
283
+ >>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
284
+ >>> # conversion chain.
285
+ >>> read_df.equals(df) # doctest: +SKIP
286
+ True
287
+ """
288
+
289
+ # NOTE(Clark): This is apparently required to prevent integer indexing
290
+ # errors, but is an undocumented ExtensionDtype attribute. See issue:
291
+ # https://github.com/CODAIT/text-extensions-for-pandas/issues/166
292
+ base = None
293
+
294
+ def __init__(self, shape: Tuple[Optional[int], ...], dtype: np.dtype):
295
+ self._shape = shape
296
+ self._dtype = dtype
297
+
298
+ @property
299
+ def type(self):
300
+ """
301
+ The scalar type for the array, e.g. ``int``
302
+ It's expected ``ExtensionArray[item]`` returns an instance
303
+ of ``ExtensionDtype.type`` for scalar ``item``, assuming
304
+ that value is valid (not NA). NA values do not need to be
305
+ instances of `type`.
306
+ """
307
+ return TensorArrayElement
308
+
309
+ @property
310
+ def element_dtype(self):
311
+ """
312
+ The dtype of the underlying tensor elements.
313
+ """
314
+ return self._dtype
315
+
316
+ @property
317
+ def element_shape(self):
318
+ """
319
+ The shape of the underlying tensor elements. This will be a tuple of Nones if
320
+ the corresponding TensorArray for this TensorDtype holds variable-shaped tensor
321
+ elements.
322
+ """
323
+ return self._shape
324
+
325
+ @property
326
+ def is_variable_shaped(self):
327
+ """
328
+ Whether the corresponding TensorArray for this TensorDtype holds variable-shaped
329
+ tensor elements.
330
+ """
331
+ return all(dim_size is None for dim_size in self.shape)
332
+
333
+ @property
334
+ def name(self) -> str:
335
+ """
336
+ A string identifying the data type.
337
+ Will be used for display in, e.g. ``Series.dtype``
338
+ """
339
+ return f"numpy.ndarray(shape={self._shape}, dtype={self._dtype})"
340
+
341
+ @classmethod
342
+ def construct_from_string(cls, string: str):
343
+ r"""
344
+ Construct this type from a string.
345
+
346
+ This is useful mainly for data types that accept parameters.
347
+ For example, a period dtype accepts a frequency parameter that
348
+ can be set as ``period[H]`` (where H means hourly frequency).
349
+
350
+ By default, in the abstract class, just the name of the type is
351
+ expected. But subclasses can overwrite this method to accept
352
+ parameters.
353
+
354
+ Parameters
355
+ ----------
356
+ string : str
357
+ The name of the type, for example ``category``.
358
+
359
+ Returns
360
+ -------
361
+ ExtensionDtype
362
+ Instance of the dtype.
363
+
364
+ Raises
365
+ ------
366
+ TypeError
367
+ If a class cannot be constructed from this 'string'.
368
+
369
+ Examples
370
+ --------
371
+ For extension dtypes with arguments the following may be an
372
+ adequate implementation.
373
+
374
+ >>> import re
375
+ >>> @classmethod
376
+ ... def construct_from_string(cls, string):
377
+ ... pattern = re.compile(r"^my_type\[(?P<arg_name>.+)\]$")
378
+ ... match = pattern.match(string)
379
+ ... if match:
380
+ ... return cls(**match.groupdict())
381
+ ... else:
382
+ ... raise TypeError(
383
+ ... f"Cannot construct a '{cls.__name__}' from '{string}'"
384
+ ... )
385
+ """
386
+ import ast
387
+ import re
388
+
389
+ if not isinstance(string, str):
390
+ raise TypeError(
391
+ f"'construct_from_string' expects a string, got {type(string)}"
392
+ )
393
+ # Upstream code uses exceptions as part of its normal control flow and
394
+ # will pass this method bogus class names.
395
+ regex = (
396
+ r"^(TensorDtype|numpy.ndarray)"
397
+ r"\(shape=(\((?:(?:\d+|None),?\s?)*\)), dtype=(\w+)\)$"
398
+ )
399
+ m = re.search(regex, string)
400
+ err_msg = (
401
+ f"Cannot construct a '{cls.__name__}' from '{string}'; expected a string "
402
+ "like 'TensorDtype(shape=(1, 2, 3), dtype=int64)'."
403
+ )
404
+ if m is None:
405
+ raise TypeError(err_msg)
406
+ groups = m.groups()
407
+ if len(groups) != 3:
408
+ raise TypeError(err_msg)
409
+ _, shape, dtype = groups
410
+ shape = ast.literal_eval(shape)
411
+ dtype = np.dtype(dtype)
412
+ return cls(shape, dtype)
413
+
414
+ @classmethod
415
+ def construct_array_type(cls):
416
+ """
417
+ Return the array type associated with this dtype.
418
+
419
+ Returns
420
+ -------
421
+ type
422
+ """
423
+ return TensorArray
424
+
425
+ def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]):
426
+ """
427
+ Convert a pyarrow (chunked) array to a TensorArray.
428
+
429
+ This and TensorArray.__arrow_array__ make up the
430
+ Pandas extension type + array <--> Arrow extension type + array
431
+ interoperability protocol. See
432
+ https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
433
+ for more information.
434
+ """
435
+ if isinstance(array, pa.ChunkedArray):
436
+ if array.num_chunks > 1:
437
+ # TODO(Clark): Remove concat and construct from list with
438
+ # shape.
439
+ values = np.concatenate(
440
+ [chunk.to_numpy() for chunk in array.iterchunks()]
441
+ )
442
+ else:
443
+ values = array.chunk(0).to_numpy()
444
+ else:
445
+ values = array.to_numpy()
446
+
447
+ return TensorArray(values)
448
+
449
+ def __str__(self) -> str:
450
+ return self.name
451
+
452
+ def __repr__(self) -> str:
453
+ return str(self)
454
+
455
+ @property
456
+ def _is_boolean(self):
457
+ """
458
+ Whether this extension array should be considered boolean.
459
+
460
+ By default, ExtensionArrays are assumed to be non-numeric.
461
+ Setting this to True will affect the behavior of several places,
462
+ e.g.
463
+
464
+ * is_bool
465
+ * boolean indexing
466
+
467
+ Returns
468
+ -------
469
+ bool
470
+ """
471
+ # This is needed to support returning a TensorArray from .isnan().
472
+ from pandas.core.dtypes.common import is_bool_dtype
473
+
474
+ return is_bool_dtype(self._dtype)
475
+
476
+
477
+ class _TensorOpsMixin(pd.api.extensions.ExtensionScalarOpsMixin):
478
+ """
479
+ Mixin for TensorArray operator support, applying operations on the
480
+ underlying ndarrays.
481
+ """
482
+
483
+ @classmethod
484
+ def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None):
485
+ """
486
+ Add support for binary operators by unwrapping, applying, and
487
+ rewrapping.
488
+ """
489
+
490
+ # NOTE(Clark): This overrides, but coerce_to_dtype, result_dtype might
491
+ # not be needed
492
+
493
+ def _binop(self, other):
494
+ lvalues = self._tensor
495
+
496
+ if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndex)):
497
+ # Rely on Pandas to unbox and dispatch to us.
498
+ return NotImplemented
499
+
500
+ # divmod returns a tuple
501
+ if op_name in ["__divmod__", "__rdivmod__"]:
502
+ # TODO(Clark): Add support for divmod and rdivmod.
503
+ # div, mod = result
504
+ raise NotImplementedError
505
+
506
+ if isinstance(other, (TensorArray, TensorArrayElement)):
507
+ rvalues = other._tensor
508
+ else:
509
+ rvalues = other
510
+
511
+ result = op(lvalues, rvalues)
512
+
513
+ # Force a TensorArray if rvalue is not a scalar.
514
+ if isinstance(self, TensorArrayElement) and (
515
+ not isinstance(other, TensorArrayElement) or not np.isscalar(other)
516
+ ):
517
+ result_wrapped = TensorArray(result)
518
+ else:
519
+ result_wrapped = cls(result)
520
+
521
+ return result_wrapped
522
+
523
+ op_name = f"__{op.__name__}__"
524
+ return set_function_name(_binop, op_name, cls)
525
+
526
+ @classmethod
527
+ def _create_logical_method(cls, op):
528
+ return cls._create_method(op)
529
+
530
+
531
+ class _TensorScalarCastMixin:
532
+ """
533
+ Mixin for casting scalar tensors to a particular numeric type.
534
+ """
535
+
536
+ def _scalarfunc(self, func: Callable[[Any], Any]):
537
+ return func(self._tensor)
538
+
539
+ def __complex__(self):
540
+ return self._scalarfunc(complex)
541
+
542
+ def __float__(self):
543
+ return self._scalarfunc(float)
544
+
545
+ def __int__(self):
546
+ return self._scalarfunc(int)
547
+
548
+ def __hex__(self):
549
+ return self._scalarfunc(hex)
550
+
551
+ def __oct__(self):
552
+ return self._scalarfunc(oct)
553
+
554
+
555
+ @PublicAPI(stability="beta")
556
+ class TensorArrayElement(_TensorOpsMixin, _TensorScalarCastMixin):
557
+ """
558
+ Single element of a TensorArray, wrapping an underlying ndarray.
559
+ """
560
+
561
+ def __init__(self, values: np.ndarray):
562
+ """
563
+ Construct a TensorArrayElement from a NumPy ndarray.
564
+
565
+ Args:
566
+ values: ndarray that underlies this TensorArray element.
567
+ """
568
+ self._tensor = values
569
+
570
+ def __repr__(self):
571
+ return self._tensor.__repr__()
572
+
573
+ def __str__(self):
574
+ return self._tensor.__str__()
575
+
576
+ @property
577
+ def numpy_dtype(self):
578
+ """
579
+ Get the dtype of the tensor.
580
+ :return: The numpy dtype of the backing ndarray
581
+ """
582
+ return self._tensor.dtype
583
+
584
+ @property
585
+ def numpy_ndim(self):
586
+ """
587
+ Get the number of tensor dimensions.
588
+ :return: integer for the number of dimensions
589
+ """
590
+ return self._tensor.ndim
591
+
592
+ @property
593
+ def numpy_shape(self):
594
+ """
595
+ Get the shape of the tensor.
596
+ :return: A tuple of integers for the numpy shape of the backing ndarray
597
+ """
598
+ return self._tensor.shape
599
+
600
+ @property
601
+ def numpy_size(self):
602
+ """
603
+ Get the size of the tensor.
604
+ :return: integer for the number of elements in the tensor
605
+ """
606
+ return self._tensor.size
607
+
608
+ def to_numpy(self):
609
+ """
610
+ Return the values of this element as a NumPy ndarray.
611
+ """
612
+ return np.asarray(self._tensor)
613
+
614
+ def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
615
+ return np.asarray(self._tensor, dtype=dtype, **kwargs)
616
+
617
+
618
+ @PublicAPI(stability="beta")
619
+ class TensorArray(
620
+ pd.api.extensions.ExtensionArray,
621
+ _TensorOpsMixin,
622
+ _TensorScalarCastMixin,
623
+ ):
624
+ """
625
+ Pandas `ExtensionArray` representing a tensor column, i.e. a column
626
+ consisting of ndarrays as elements.
627
+
628
+ This extension supports tensors in which the elements have different shapes.
629
+ However, each tensor element must be non-ragged, i.e. each tensor element must have
630
+ a well-defined, non-ragged shape.
631
+
632
+ Examples:
633
+ >>> # Create a DataFrame with a list of ndarrays as a column.
634
+ >>> import pandas as pd
635
+ >>> import numpy as np
636
+ >>> import ray
637
+ >>> from ray.data.extensions import TensorArray
638
+ >>> df = pd.DataFrame({
639
+ ... "one": [1, 2, 3],
640
+ ... "two": TensorArray(np.arange(24).reshape((3, 2, 2, 2)))})
641
+ >>> # Note that the column dtype is TensorDtype.
642
+ >>> df.dtypes # doctest: +SKIP
643
+ one int64
644
+ two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
645
+ dtype: object
646
+ >>> # Pandas is aware of this tensor column, and we can do the
647
+ >>> # typical DataFrame operations on this column.
648
+ >>> col = 2 * (df["two"] + 10)
649
+ >>> # The ndarrays underlying the tensor column will be manipulated,
650
+ >>> # but the column itself will continue to be a Pandas type.
651
+ >>> type(col) # doctest: +SKIP
652
+ pandas.core.series.Series
653
+ >>> col # doctest: +SKIP
654
+ 0 [[[ 2 4]
655
+ [ 6 8]]
656
+ [[10 12]
657
+ [14 16]]]
658
+ 1 [[[18 20]
659
+ [22 24]]
660
+ [[26 28]
661
+ [30 32]]]
662
+ 2 [[[34 36]
663
+ [38 40]]
664
+ [[42 44]
665
+ [46 48]]]
666
+ Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
667
+ >>> # Once you do an aggregation on that column that returns a single
668
+ >>> # row's value, you get back our TensorArrayElement type.
669
+ >>> tensor = col.mean() # doctest: +SKIP
670
+ >>> type(tensor) # doctest: +SKIP
671
+ ray.data.extensions.tensor_extension.TensorArrayElement
672
+ >>> tensor # doctest: +SKIP
673
+ array([[[18., 20.],
674
+ [22., 24.]],
675
+ [[26., 28.],
676
+ [30., 32.]]])
677
+ >>> # This is a light wrapper around a NumPy ndarray, and can easily
678
+ >>> # be converted to an ndarray.
679
+ >>> type(tensor.to_numpy()) # doctest: +SKIP
680
+ numpy.ndarray
681
+ >>> # In addition to doing Pandas operations on the tensor column,
682
+ >>> # you can now put the DataFrame into a Dataset.
683
+ >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
684
+ >>> # Internally, this column is represented the corresponding
685
+ >>> # Arrow tensor extension type.
686
+ >>> ds.schema() # doctest: +SKIP
687
+ one: int64
688
+ two: extension<arrow.py_extension_type<ArrowTensorType>>
689
+ >>> # You can write the dataset to Parquet.
690
+ >>> ds.write_parquet("/some/path") # doctest: +SKIP
691
+ >>> # And you can read it back.
692
+ >>> read_ds = ray.data.read_parquet("/some/path") # doctest: +SKIP
693
+ >>> read_ds.schema() # doctest: +SKIP
694
+ one: int64
695
+ two: extension<arrow.py_extension_type<ArrowTensorType>>
696
+
697
+ >>> read_df = ray.get(read_ds.to_pandas_refs())[0] # doctest: +SKIP
698
+ >>> read_df.dtypes # doctest: +SKIP
699
+ one int64
700
+ two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
701
+ dtype: object
702
+ >>> # The tensor extension type is preserved along the
703
+ >>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
704
+ >>> # conversion chain.
705
+ >>> read_df.equals(df) # doctest: +SKIP
706
+ True
707
+ """
708
+
709
+ SUPPORTED_REDUCERS = {
710
+ "sum": np.sum,
711
+ "all": np.all,
712
+ "any": np.any,
713
+ "min": np.min,
714
+ "max": np.max,
715
+ "mean": np.mean,
716
+ "median": np.median,
717
+ "prod": np.prod,
718
+ "std": np.std,
719
+ "var": np.var,
720
+ }
721
+
722
+ # See https://github.com/pandas-dev/pandas/blob/master/pandas/core/arrays/base.py
723
+ # for interface documentation and the subclassing contract.
724
+ def __init__(
725
+ self,
726
+ values: Union[
727
+ np.ndarray,
728
+ ABCSeries,
729
+ Sequence[Union[np.ndarray, TensorArrayElement]],
730
+ TensorArrayElement,
731
+ Any,
732
+ ],
733
+ ):
734
+ """
735
+ Args:
736
+ values: A NumPy ndarray or sequence of NumPy ndarrays of equal
737
+ shape.
738
+ """
739
+ # Try to convert some well-known objects to ndarrays before handing off to
740
+ # ndarray handling logic.
741
+ if isinstance(values, ABCSeries):
742
+ values = _create_possibly_ragged_ndarray(values)
743
+ elif isinstance(values, Sequence):
744
+ values = [
745
+ np.asarray(v) if isinstance(v, TensorArrayElement) else v
746
+ for v in values
747
+ ]
748
+ values = _create_possibly_ragged_ndarray(values)
749
+ elif isinstance(values, TensorArrayElement):
750
+ values = np.array([np.asarray(values)], copy=False)
751
+
752
+ if isinstance(values, np.ndarray):
753
+ if values.dtype.type is np.object_:
754
+ if len(values) == 0:
755
+ # Tensor is empty, pass through to create empty TensorArray.
756
+ pass
757
+ elif all(
758
+ isinstance(v, (np.ndarray, TensorArrayElement, Sequence))
759
+ and not isinstance(v, str)
760
+ for v in values
761
+ ):
762
+ values = [np.asarray(v) for v in values]
763
+ # Try to convert ndarrays of ndarrays/TensorArrayElements with an
764
+ # opaque object type to a properly typed ndarray of ndarrays.
765
+ values = _create_possibly_ragged_ndarray(values)
766
+ else:
767
+ raise TypeError(
768
+ "Expected a well-typed ndarray or an object-typed ndarray of "
769
+ "ndarray pointers, but got an object-typed ndarray whose "
770
+ f"subndarrays are of type {type(values[0])}."
771
+ )
772
+ elif isinstance(values, TensorArray):
773
+ raise TypeError("Use the copy() method to create a copy of a TensorArray.")
774
+ else:
775
+ raise TypeError(
776
+ "Expected a numpy.ndarray or sequence of numpy.ndarray, "
777
+ f"but received {values} of type {type(values).__name__} instead."
778
+ )
779
+ assert isinstance(values, np.ndarray)
780
+ self._tensor = values
781
+ self._is_variable_shaped = None
782
+
783
+ @classmethod
784
+ def _from_sequence(
785
+ cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False
786
+ ):
787
+ """
788
+ Construct a new ExtensionArray from a sequence of scalars.
789
+
790
+ Parameters
791
+ ----------
792
+ scalars : Sequence
793
+ Each element will be an instance of the scalar type for this
794
+ array, ``cls.dtype.type`` or be converted into this type in this
795
+ method.
796
+ dtype : dtype, optional
797
+ Construct for this particular dtype. This should be a Dtype
798
+ compatible with the ExtensionArray.
799
+ copy : bool, default False
800
+ If True, copy the underlying data.
801
+
802
+ Returns
803
+ -------
804
+ ExtensionArray
805
+ """
806
+ if copy and isinstance(scalars, np.ndarray):
807
+ scalars = scalars.copy()
808
+ elif isinstance(scalars, TensorArray):
809
+ scalars = scalars._tensor.copy() if copy else scalars._tensor
810
+ return TensorArray(scalars)
811
+
812
+ @classmethod
813
+ def _from_factorized(
814
+ cls, values: np.ndarray, original: pd.api.extensions.ExtensionArray
815
+ ):
816
+ """
817
+ Reconstruct an ExtensionArray after factorization.
818
+
819
+ Parameters
820
+ ----------
821
+ values : ndarray
822
+ An integer ndarray with the factorized values.
823
+ original : ExtensionArray
824
+ The original ExtensionArray that factorize was called on.
825
+
826
+ See Also
827
+ --------
828
+ factorize : Top-level factorize method that dispatches here.
829
+ ExtensionArray.factorize : Encode the extension array as an enumerated
830
+ type.
831
+ """
832
+ raise NotImplementedError
833
+
834
+ def __getitem__(
835
+ self, item: Union[int, slice, np.ndarray]
836
+ ) -> Union["TensorArray", "TensorArrayElement"]:
837
+ """
838
+ Select a subset of self.
839
+
840
+ Parameters
841
+ ----------
842
+ item : int, slice, or ndarray
843
+ * int: The position in 'self' to get.
844
+ * slice: A slice object, where 'start', 'stop', and 'step' are
845
+ integers or None
846
+ * ndarray: A 1-d boolean NumPy ndarray the same length as 'self'
847
+
848
+ Returns
849
+ -------
850
+ item : scalar or ExtensionArray
851
+
852
+ Notes
853
+ -----
854
+ For scalar ``item``, return a scalar value suitable for the array's
855
+ type. This should be an instance of ``self.dtype.type``.
856
+ For slice ``key``, return an instance of ``ExtensionArray``, even
857
+ if the slice is length 0 or 1.
858
+ For a boolean mask, return an instance of ``ExtensionArray``, filtered
859
+ to the values where ``item`` is True.
860
+ """
861
+ # Return scalar if single value is selected, a TensorArrayElement for
862
+ # single array element, or TensorArray for slice.
863
+ if isinstance(item, int):
864
+ value = self._tensor[item]
865
+ if np.isscalar(value):
866
+ return value
867
+ else:
868
+ return TensorArrayElement(value)
869
+ else:
870
+ # BEGIN workaround for Pandas issue #42430
871
+ if isinstance(item, tuple) and len(item) > 1 and item[0] == Ellipsis:
872
+ if len(item) > 2:
873
+ # Hopefully this case is not possible, but can't be sure
874
+ raise ValueError(
875
+ "Workaround Pandas issue #42430 not "
876
+ "implemented for tuple length > 2"
877
+ )
878
+ item = item[1]
879
+ # END workaround for issue #42430
880
+ if isinstance(item, TensorArray):
881
+ item = np.asarray(item)
882
+ item = check_array_indexer(self, item)
883
+ return TensorArray(self._tensor[item])
884
+
885
+ def __len__(self) -> int:
886
+ """
887
+ Length of this array.
888
+
889
+ Returns
890
+ -------
891
+ length : int
892
+ """
893
+ return len(self._tensor)
894
+
895
+ @property
896
+ def dtype(self) -> pd.api.extensions.ExtensionDtype:
897
+ """
898
+ An instance of 'ExtensionDtype'.
899
+ """
900
+ if self.is_variable_shaped:
901
+ # A tensor is only considered variable-shaped if it's non-empty, so no
902
+ # non-empty check is needed here.
903
+ dtype = self._tensor[0].dtype
904
+ shape = (None,) * self._tensor[0].ndim
905
+ else:
906
+ dtype = self.numpy_dtype
907
+ shape = self.numpy_shape[1:]
908
+ return TensorDtype(shape, dtype)
909
+
910
+ @property
911
+ def is_variable_shaped(self):
912
+ """
913
+ Whether this TensorArray holds variable-shaped tensor elements.
914
+ """
915
+ if self._is_variable_shaped is None:
916
+ self._is_variable_shaped = _is_ndarray_variable_shaped_tensor(self._tensor)
917
+ return self._is_variable_shaped
918
+
919
+ @property
920
+ def nbytes(self) -> int:
921
+ """
922
+ The number of bytes needed to store this object in memory.
923
+ """
924
+ return self._tensor.nbytes
925
+
926
+ def isna(self) -> "TensorArray":
927
+ """
928
+ A 1-D array indicating if each value is missing.
929
+
930
+ Returns
931
+ -------
932
+ na_values : Union[np.ndarray, ExtensionArray]
933
+ In most cases, this should return a NumPy ndarray. For
934
+ exceptional cases like ``SparseArray``, where returning
935
+ an ndarray would be expensive, an ExtensionArray may be
936
+ returned.
937
+
938
+ Notes
939
+ -----
940
+ If returning an ExtensionArray, then
941
+
942
+ * ``na_values._is_boolean`` should be True
943
+ * `na_values` should implement :func:`ExtensionArray._reduce`
944
+ * ``na_values.any`` and ``na_values.all`` should be implemented
945
+ """
946
+ if self._tensor.dtype.type is np.object_:
947
+ # Avoid comparing with __eq__ because the elements of the tensor
948
+ # may do something funny with that operation.
949
+ return np.array(
950
+ [self._tensor[i] is None for i in range(len(self))], dtype=bool
951
+ )
952
+ elif self._tensor.dtype.type is np.str_:
953
+ return np.all(self._tensor == "", axis=tuple(range(1, self._tensor.ndim)))
954
+ else:
955
+ return np.all(
956
+ np.isnan(self._tensor), axis=tuple(range(1, self._tensor.ndim))
957
+ )
958
+
959
+ def take(
960
+ self, indices: Sequence[int], allow_fill: bool = False, fill_value: Any = None
961
+ ) -> "TensorArray":
962
+ """
963
+ Take elements from an array.
964
+
965
+ Parameters
966
+ ----------
967
+ indices : sequence of int
968
+ Indices to be taken.
969
+ allow_fill : bool, default False
970
+ How to handle negative values in `indices`.
971
+
972
+ * False: negative values in `indices` indicate positional indices
973
+ from the right (the default). This is similar to
974
+ :func:`numpy.take`.
975
+
976
+ * True: negative values in `indices` indicate
977
+ missing values. These values are set to `fill_value`. Any other
978
+ other negative values raise a ``ValueError``.
979
+
980
+ fill_value : any, optional
981
+ Fill value to use for NA-indices when `allow_fill` is True.
982
+ This may be ``None``, in which case the default NA value for
983
+ the type, ``self.dtype.na_value``, is used.
984
+
985
+ For many ExtensionArrays, there will be two representations of
986
+ `fill_value`: a user-facing "boxed" scalar, and a low-level
987
+ physical NA value. `fill_value` should be the user-facing version,
988
+ and the implementation should handle translating that to the
989
+ physical version for processing the take if necessary.
990
+
991
+ Returns
992
+ -------
993
+ ExtensionArray
994
+
995
+ Raises
996
+ ------
997
+ IndexError
998
+ When the indices are out of bounds for the array.
999
+ ValueError
1000
+ When `indices` contains negative values other than ``-1``
1001
+ and `allow_fill` is True.
1002
+
1003
+ See Also
1004
+ --------
1005
+ numpy.take : Take elements from an array along an axis.
1006
+ api.extensions.take : Take elements from an array.
1007
+
1008
+ Notes
1009
+ -----
1010
+ ExtensionArray.take is called by ``Series.__getitem__``, ``.loc``,
1011
+ ``iloc``, when `indices` is a sequence of values. Additionally,
1012
+ it's called by :meth:`Series.reindex`, or any other method
1013
+ that causes realignment, with a `fill_value`.
1014
+
1015
+ Examples
1016
+ --------
1017
+ Here's an example implementation, which relies on casting the
1018
+ extension array to object dtype. This uses the helper method
1019
+ :func:`pandas.api.extensions.take`.
1020
+
1021
+ .. code-block:: python
1022
+
1023
+ def take(self, indices, allow_fill=False, fill_value=None):
1024
+ from pandas.core.algorithms import take
1025
+
1026
+ # If the ExtensionArray is backed by an ndarray, then
1027
+ # just pass that here instead of coercing to object.
1028
+ data = self.astype(object)
1029
+
1030
+ if allow_fill and fill_value is None:
1031
+ fill_value = self.dtype.na_value
1032
+
1033
+ # fill value should always be translated from the scalar
1034
+ # type for the array, to the physical storage type for
1035
+ # the data, before passing to take.
1036
+
1037
+ result = take(data, indices, fill_value=fill_value,
1038
+ allow_fill=allow_fill)
1039
+ return self._from_sequence(result, dtype=self.dtype)
1040
+ """
1041
+ if allow_fill:
1042
+ # With allow_fill being True, negative values in `indices` indicate
1043
+ # missing values and should be set to `fill_value`.
1044
+ indices = np.asarray(indices, dtype=np.intp)
1045
+ validate_indices(indices, len(self._tensor))
1046
+
1047
+ # Check if there are missing indices to fill, otherwise we can
1048
+ # delegate to NumPy ndarray .take().
1049
+ has_missing = np.any(indices < 0)
1050
+ if has_missing:
1051
+ if fill_value is None:
1052
+ fill_value = np.nan
1053
+
1054
+ # Create an array populated with fill value.
1055
+ values = np.full((len(indices),) + self._tensor.shape[1:], fill_value)
1056
+
1057
+ # Put tensors at the given positive indices into array.
1058
+ is_nonneg = indices >= 0
1059
+ np.put(values, np.where(is_nonneg)[0], self._tensor[indices[is_nonneg]])
1060
+
1061
+ return TensorArray(values)
1062
+
1063
+ # Delegate take to NumPy array.
1064
+ values = self._tensor.take(indices, axis=0)
1065
+
1066
+ return TensorArray(values)
1067
+
1068
+ def copy(self) -> "TensorArray":
1069
+ """
1070
+ Return a copy of the array.
1071
+
1072
+ Returns
1073
+ -------
1074
+ ExtensionArray
1075
+ """
1076
+ # TODO(Clark): Copy cached properties.
1077
+ return TensorArray(self._tensor.copy())
1078
+
1079
+ @classmethod
1080
+ def _concat_same_type(cls, to_concat: Sequence["TensorArray"]) -> "TensorArray":
1081
+ """
1082
+ Concatenate multiple array of this dtype.
1083
+
1084
+ Parameters
1085
+ ----------
1086
+ to_concat : sequence of this type
1087
+
1088
+ Returns
1089
+ -------
1090
+ ExtensionArray
1091
+ """
1092
+ should_flatten = False
1093
+ shape = None
1094
+ for a in to_concat:
1095
+ if shape is None:
1096
+ shape = a.dtype.element_shape
1097
+ if a.is_variable_shaped or a.dtype.element_shape != shape:
1098
+ should_flatten = True
1099
+ break
1100
+ if should_flatten:
1101
+ concated = TensorArray(
1102
+ np.array([e for a in to_concat for e in a._tensor], dtype=object)
1103
+ )
1104
+ else:
1105
+ concated = TensorArray(np.concatenate([a._tensor for a in to_concat]))
1106
+ return concated
1107
+
1108
+ def __setitem__(self, key: Union[int, np.ndarray], value: Any) -> None:
1109
+ """
1110
+ Set one or more values inplace.
1111
+
1112
+ This method is not required to satisfy the pandas extension array
1113
+ interface.
1114
+
1115
+ Parameters
1116
+ ----------
1117
+ key : int, ndarray, or slice
1118
+ When called from, e.g. ``Series.__setitem__``, ``key`` will be
1119
+ one of
1120
+
1121
+ * scalar int
1122
+ * ndarray of integers.
1123
+ * boolean ndarray
1124
+ * slice object
1125
+
1126
+ value : ExtensionDtype.type, Sequence[ExtensionDtype.type], or object
1127
+ value or values to be set of ``key``.
1128
+
1129
+ Returns
1130
+ -------
1131
+ None
1132
+ """
1133
+ key = check_array_indexer(self, key)
1134
+ if isinstance(value, TensorArrayElement) or np.isscalar(value):
1135
+ value = np.asarray(value)
1136
+ if isinstance(value, list):
1137
+ value = [
1138
+ np.asarray(v) if isinstance(v, TensorArrayElement) else v for v in value
1139
+ ]
1140
+ if isinstance(value, ABCSeries) and isinstance(value.dtype, TensorDtype):
1141
+ value = value.values
1142
+ if value is None or isinstance(value, Sequence) and len(value) == 0:
1143
+ self._tensor[key] = np.full_like(self._tensor[key], np.nan)
1144
+ elif isinstance(key, (int, slice, np.ndarray)):
1145
+ self._tensor[key] = value
1146
+ else:
1147
+ raise NotImplementedError(
1148
+ f"__setitem__ with key type '{type(key)}' not implemented"
1149
+ )
1150
+
1151
+ def __contains__(self, item) -> bool:
1152
+ """
1153
+ Return for `item in self`.
1154
+ """
1155
+ if isinstance(item, TensorArrayElement):
1156
+ np_item = np.asarray(item)
1157
+ if np_item.size == 1 and np.isnan(np_item).all():
1158
+ return self.isna().any()
1159
+ return super().__contains__(item)
1160
+
1161
+ def __repr__(self):
1162
+ return self._tensor.__repr__()
1163
+
1164
+ def __str__(self):
1165
+ return self._tensor.__str__()
1166
+
1167
+ def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
1168
+ # TODO(Clark): return self._tensor, np.nan
1169
+ raise NotImplementedError
1170
+
1171
+ def _reduce(self, name: str, skipna: bool = True, **kwargs):
1172
+ """
1173
+ Return a scalar result of performing the reduction operation.
1174
+
1175
+ Parameters
1176
+ ----------
1177
+ name : str
1178
+ Name of the function, supported values are:
1179
+ { any, all, min, max, sum, mean, median, prod,
1180
+ std, var, sem, kurt, skew }.
1181
+ skipna : bool, default True
1182
+ If True, skip NaN values.
1183
+ **kwargs
1184
+ Additional keyword arguments passed to the reduction function.
1185
+ Currently, `ddof` is the only supported kwarg.
1186
+
1187
+ Returns
1188
+ -------
1189
+ scalar
1190
+
1191
+ Raises
1192
+ ------
1193
+ TypeError : subclass does not define reductions
1194
+ """
1195
+ supported_kwargs = ["ddof"]
1196
+ reducer_kwargs = {}
1197
+ for kw in supported_kwargs:
1198
+ try:
1199
+ reducer_kwargs[kw] = kwargs[kw]
1200
+ except KeyError:
1201
+ pass
1202
+ try:
1203
+ return TensorArrayElement(
1204
+ self.SUPPORTED_REDUCERS[name](self._tensor, axis=0, **reducer_kwargs)
1205
+ )
1206
+ except KeyError:
1207
+ raise NotImplementedError(f"'{name}' aggregate not implemented.") from None
1208
+
1209
+ def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
1210
+ return np.asarray(self._tensor, dtype=dtype, **kwargs)
1211
+
1212
+ def __array_ufunc__(self, ufunc: Callable, method: str, *inputs, **kwargs):
1213
+ """
1214
+ Supports NumPy ufuncs without requiring sloppy coercion to an
1215
+ ndarray.
1216
+ """
1217
+ out = kwargs.get("out", ())
1218
+ for x in inputs + out:
1219
+ if not isinstance(x, (TensorArray, np.ndarray, numbers.Number)):
1220
+ return NotImplemented
1221
+
1222
+ # Defer to the implementation of the ufunc on unwrapped values.
1223
+ inputs = tuple(x._tensor if isinstance(x, TensorArray) else x for x in inputs)
1224
+ if out:
1225
+ kwargs["out"] = tuple(
1226
+ x._tensor if isinstance(x, TensorArray) else x for x in out
1227
+ )
1228
+ result = getattr(ufunc, method)(*inputs, **kwargs)
1229
+
1230
+ if type(result) is tuple:
1231
+ # Multiple return values.
1232
+ return tuple(type(self)(x) for x in result)
1233
+ elif method == "at":
1234
+ # No return value.
1235
+ return None
1236
+ else:
1237
+ # One return value.
1238
+ return type(self)(result)
1239
+
1240
+ def to_numpy(
1241
+ self,
1242
+ dtype: np.dtype = None,
1243
+ copy: bool = False,
1244
+ na_value: Any = pd.api.extensions.no_default,
1245
+ ):
1246
+ """
1247
+ Convert to a NumPy ndarray.
1248
+
1249
+ .. versionadded:: 1.0.0
1250
+
1251
+ This is similar to :meth:`numpy.asarray`, but may provide additional
1252
+ control over how the conversion is done.
1253
+
1254
+ Parameters
1255
+ ----------
1256
+ dtype : str or numpy.dtype, optional
1257
+ The dtype to pass to :meth:`numpy.asarray`.
1258
+ copy : bool, default False
1259
+ Whether to ensure that the returned value is a not a view on
1260
+ another array. Note that ``copy=False`` does not *ensure* that
1261
+ ``to_numpy()`` is no-copy. Rather, ``copy=True`` ensure that
1262
+ a copy is made, even if not strictly necessary.
1263
+ na_value : Any, optional
1264
+ The value to use for missing values. The default value depends
1265
+ on `dtype` and the type of the array.
1266
+
1267
+ Returns
1268
+ -------
1269
+ numpy.ndarray
1270
+ """
1271
+ if dtype is not None:
1272
+ dtype = pd.api.types.pandas_dtype(dtype)
1273
+ if copy:
1274
+ values = np.array(self._tensor, dtype=dtype, copy=True)
1275
+ else:
1276
+ values = self._tensor.astype(dtype)
1277
+ elif copy:
1278
+ values = self._tensor.copy()
1279
+ else:
1280
+ values = self._tensor
1281
+ return values
1282
+
1283
+ @property
1284
+ def numpy_dtype(self):
1285
+ """
1286
+ Get the dtype of the tensor.
1287
+ :return: The numpy dtype of the backing ndarray
1288
+ """
1289
+ return self._tensor.dtype
1290
+
1291
+ @property
1292
+ def numpy_ndim(self):
1293
+ """
1294
+ Get the number of tensor dimensions.
1295
+ :return: integer for the number of dimensions
1296
+ """
1297
+ return self._tensor.ndim
1298
+
1299
+ @property
1300
+ def numpy_shape(self):
1301
+ """
1302
+ Get the shape of the tensor.
1303
+ :return: A tuple of integers for the numpy shape of the backing ndarray
1304
+ """
1305
+ return self._tensor.shape
1306
+
1307
+ @property
1308
+ def numpy_size(self):
1309
+ """
1310
+ Get the size of the tensor.
1311
+ :return: integer for the number of elements in the tensor
1312
+ """
1313
+ return self._tensor.size
1314
+
1315
+ def astype(self, dtype, copy=True):
1316
+ """
1317
+ Cast to a NumPy array with 'dtype'.
1318
+
1319
+ Parameters
1320
+ ----------
1321
+ dtype : str or dtype
1322
+ Typecode or data-type to which the array is cast.
1323
+ copy : bool, default True
1324
+ Whether to copy the data, even if not necessary. If False,
1325
+ a copy is made only if the old dtype does not match the
1326
+ new dtype.
1327
+
1328
+ Returns
1329
+ -------
1330
+ array : ndarray
1331
+ NumPy ndarray with 'dtype' for its dtype.
1332
+ """
1333
+ dtype = pd.api.types.pandas_dtype(dtype)
1334
+
1335
+ if isinstance(dtype, TensorDtype):
1336
+ values = TensorArray(self._tensor.copy()) if copy else self
1337
+ elif not (
1338
+ pd.api.types.is_object_dtype(dtype) and pd.api.types.is_string_dtype(dtype)
1339
+ ):
1340
+ values = np.array([str(t) for t in self._tensor])
1341
+ if isinstance(dtype, pd.StringDtype):
1342
+ return dtype.construct_array_type()._from_sequence(values, copy=False)
1343
+ else:
1344
+ return values
1345
+ elif pd.api.types.is_object_dtype(dtype):
1346
+ # Interpret astype(object) as "cast to an array of numpy arrays"
1347
+ values = np.empty(len(self), dtype=object)
1348
+ for i in range(len(self)):
1349
+ values[i] = self._tensor[i]
1350
+ else:
1351
+ values = self._tensor.astype(dtype, copy=copy)
1352
+ return values
1353
+
1354
+ def any(self, axis=None, out=None, keepdims=False):
1355
+ """
1356
+ Test whether any array element along a given axis evaluates to True.
1357
+
1358
+ See numpy.any() documentation for more information
1359
+ https://numpy.org/doc/stable/reference/generated/numpy.any.html#numpy.any
1360
+
1361
+ :param axis: Axis or axes along which a logical OR reduction is
1362
+ performed.
1363
+ :param out: Alternate output array in which to place the result.
1364
+ :param keepdims: If this is set to True, the axes which are reduced are
1365
+ left in the result as dimensions with size one.
1366
+ :return: single boolean unless axis is not None else TensorArray
1367
+ """
1368
+ result = self._tensor.any(axis=axis, out=out, keepdims=keepdims)
1369
+ return result if axis is None else TensorArray(result)
1370
+
1371
+ def all(self, axis=None, out=None, keepdims=False):
1372
+ """
1373
+ Test whether all array elements along a given axis evaluate to True.
1374
+
1375
+ :param axis: Axis or axes along which a logical AND reduction is
1376
+ performed.
1377
+ :param out: Alternate output array in which to place the result.
1378
+ :param keepdims: If this is set to True, the axes which are reduced are
1379
+ left in the result as dimensions with size one.
1380
+ :return: single boolean unless axis is not None else TensorArray
1381
+ """
1382
+ result = self._tensor.all(axis=axis, out=out, keepdims=keepdims)
1383
+ return result if axis is None else TensorArray(result)
1384
+
1385
+ def __arrow_array__(self, type=None):
1386
+ """
1387
+ Convert this TensorArray to an ArrowTensorArray extension array.
1388
+
1389
+ This and TensorDtype.__from_arrow__ make up the
1390
+ Pandas extension type + array <--> Arrow extension type + array
1391
+ interoperability protocol. See
1392
+ https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
1393
+ for more information.
1394
+ """
1395
+ from ray.air.util.tensor_extensions.arrow import (
1396
+ ArrowTensorArray,
1397
+ ArrowVariableShapedTensorArray,
1398
+ )
1399
+
1400
+ if self.is_variable_shaped:
1401
+ return ArrowVariableShapedTensorArray.from_numpy(self._tensor)
1402
+ else:
1403
+ return ArrowTensorArray.from_numpy(self._tensor)
1404
+
1405
+ @property
1406
+ def _is_boolean(self):
1407
+ """
1408
+ Whether this extension array should be considered boolean.
1409
+
1410
+ By default, ExtensionArrays are assumed to be non-numeric.
1411
+ Setting this to True will affect the behavior of several places,
1412
+ e.g.
1413
+
1414
+ * is_bool
1415
+ * boolean indexing
1416
+
1417
+ Returns
1418
+ -------
1419
+ bool
1420
+ """
1421
+ # This is needed to support returning a TensorArray from .isnan().
1422
+ return self.dtype._is_boolean()
1423
+
1424
+
1425
+ # Add operators from the mixin to the TensorArrayElement and TensorArray
1426
+ # classes.
1427
+ TensorArrayElement._add_arithmetic_ops()
1428
+ TensorArrayElement._add_comparison_ops()
1429
+ TensorArrayElement._add_logical_ops()
1430
+ TensorArray._add_arithmetic_ops()
1431
+ TensorArray._add_comparison_ops()
1432
+ TensorArray._add_logical_ops()
1433
+
1434
+
1435
+ @PublicAPI(stability="beta")
1436
+ def column_needs_tensor_extension(s: pd.Series) -> bool:
1437
+ """Return whether the provided pandas Series column needs a tensor extension
1438
+ representation. This tensor extension representation provides more efficient slicing
1439
+ and interop with ML frameworks.
1440
+
1441
+ Args:
1442
+ s: The pandas Series column that may need to be represented using the tensor
1443
+ extension.
1444
+
1445
+ Returns:
1446
+ Whether the provided Series needs a tensor extension representation.
1447
+ """
1448
+ # NOTE: This is an O(1) check.
1449
+ return (
1450
+ s.dtype.type is np.object_ and not s.empty and isinstance(s.iloc[0], np.ndarray)
1451
+ )
.venv/lib/python3.11/site-packages/ray/air/util/torch_dist.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is modeled after ray/python/ray/train/torch/config.py
2
+
3
+ The logics are duplicated right now to allow maximum flexibility for
4
+ setting up PyTorch DDP process groups outside the context of Ray Train.
5
+ Eventually, these use cases should be consolidated.
6
+ """
7
+
8
+ import os
9
+ from abc import ABC
10
+ from collections import defaultdict
11
+ from datetime import timedelta
12
+ from typing import Callable, List, T
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+
17
+ import ray
18
+ from ray.actor import ActorHandle
19
+ from ray.air._internal.torch_utils import get_devices
20
+ from ray.train._internal.utils import get_address_and_port
21
+
22
+
23
+ class TorchDistributedWorker(ABC):
24
+ """Defines the interfaces required by the init_torch_dist_process_group().
25
+
26
+ This is modeled after RayTrainerWorker, which allows arbitrary functions
27
+ to be executed on a remote DDP worker.
28
+ """
29
+
30
+ def execute(self, func: Callable[..., T], *args, **kwargs) -> T:
31
+ """Executes the input function and returns the output.
32
+
33
+ Args:
34
+ func: The function to execute.
35
+ args, kwargs: The arguments to pass into func.
36
+ """
37
+ return func(*args, **kwargs)
38
+
39
+
40
+ def _init_torch_distributed(
41
+ init_method: str,
42
+ backend: str,
43
+ rank: int,
44
+ world_size: int,
45
+ local_rank: int,
46
+ local_world_size: int,
47
+ master_addr: str,
48
+ master_port: str,
49
+ gpu_ids: List[int],
50
+ **init_process_group_kwargs,
51
+ ):
52
+ """Initialize torch distributed backend"""
53
+ if init_method == "env":
54
+ os.environ["MASTER_ADDR"] = str(master_addr)
55
+ os.environ["MASTER_PORT"] = str(master_port)
56
+ url = "env://"
57
+ elif init_method == "tcp":
58
+ url = f"tcp://{master_addr}:{master_port}"
59
+ else:
60
+ raise ValueError(
61
+ f"The provided init_method ("
62
+ f"{init_method}) is not supported. Must "
63
+ f"be either 'env' or 'tcp'."
64
+ )
65
+
66
+ if backend == "nccl":
67
+ # Same as in Ray Train
68
+ os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
69
+ # All workers on a same node should share the same set of
70
+ # visible GPUs. Otherwise they can't talk among themselves.
71
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gid) for gid in gpu_ids)
72
+
73
+ init_process_group_kwargs.update(
74
+ dict(
75
+ backend=backend,
76
+ init_method=url,
77
+ rank=rank,
78
+ world_size=world_size,
79
+ )
80
+ )
81
+ init_process_group_kwargs.setdefault("timeout", timedelta(seconds=1800))
82
+
83
+ dist.init_process_group(**init_process_group_kwargs)
84
+
85
+ os.environ["RANK"] = str(rank)
86
+ os.environ["LOCAL_RANK"] = str(local_rank)
87
+ os.environ["WORLD_SIZE"] = str(world_size)
88
+ os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size)
89
+
90
+
91
+ def _get_node_and_gpu_ids():
92
+ """Returns the node_id and gpu_ids for this worker."""
93
+ node_id = ray.get_runtime_context().get_node_id()
94
+ gpu_ids = ray.get_gpu_ids()
95
+ return node_id, gpu_ids
96
+
97
+
98
+ def init_torch_dist_process_group(
99
+ workers: List[ActorHandle],
100
+ backend: str = "gloo",
101
+ init_method: str = "env",
102
+ **init_process_group_kwargs,
103
+ ) -> List[int]:
104
+ """Initialize a torch distributed process group.
105
+
106
+ Note: this util assumes that the order of the workers passed in
107
+ are their global ranks.
108
+
109
+ Args:
110
+ workers: A list of TorchDistributedWorker actors.
111
+ backend: The torch distributed backend to use,
112
+ possible choices are "gloo" or "nccl".
113
+ init_method: The initialization method to use,
114
+ possible choices are "env" or "tcp".
115
+ init_process_group_kwargs: Additional kwargs to pass to the call to
116
+ :meth:`torch.distributed.init_process_group`.
117
+
118
+ Returns:
119
+ Local ranks on their respective nodes for the list of workers.
120
+ """
121
+ if not dist.is_available():
122
+ raise RuntimeError("Distributed torch is not available.")
123
+
124
+ # Build a map from node_id to workers on that node.
125
+ node_and_gpu_ids = ray.get(
126
+ [w.execute.remote(_get_node_and_gpu_ids) for w in workers]
127
+ )
128
+ # All the workers on a specific node.
129
+ node_to_workers = defaultdict(list)
130
+ # All the gpu ids visible to all the workers on a specific node.
131
+ node_to_gpu_ids = defaultdict(set)
132
+ for i, (node_id, gpu_ids) in enumerate(node_and_gpu_ids):
133
+ node_to_workers[node_id].append(i)
134
+ # Force list.
135
+ if not isinstance(gpu_ids, list):
136
+ gpu_ids = [gpu_ids]
137
+ # It is possible for a worker to have access to multiple GPUs.
138
+ for gpu_id in gpu_ids:
139
+ node_to_gpu_ids[node_id].add(gpu_id)
140
+
141
+ # Assume the first worker is the master.
142
+ master_addr, master_port = ray.get(workers[0].execute.remote(get_address_and_port))
143
+
144
+ setup_futures = []
145
+ world_size = len(workers)
146
+ local_ranks = []
147
+ for rank, worker in enumerate(workers):
148
+ node_id = node_and_gpu_ids[rank][0]
149
+ local_rank = node_to_workers[node_id].index(rank)
150
+ local_world_size = len(node_to_workers[node_id])
151
+ setup_futures.append(
152
+ worker.execute.remote(
153
+ _init_torch_distributed,
154
+ init_method=init_method,
155
+ backend=backend,
156
+ rank=rank,
157
+ world_size=world_size,
158
+ local_rank=local_rank,
159
+ local_world_size=local_world_size,
160
+ master_addr=master_addr,
161
+ master_port=master_port,
162
+ # list(set) will sort the gpu ids, so VISIBLE_CUDA_DEVICES
163
+ # is always sorted.
164
+ gpu_ids=list(node_to_gpu_ids[node_id]),
165
+ **init_process_group_kwargs,
166
+ )
167
+ )
168
+ local_ranks.append(local_rank)
169
+
170
+ # Wait for all workers to join the process group.
171
+ ray.get(setup_futures)
172
+
173
+ return local_ranks
174
+
175
+
176
+ def _shutdown_torch_distributed():
177
+ """Shutdown torch distributed backend"""
178
+ dist.destroy_process_group()
179
+
180
+ if not torch.cuda.is_available():
181
+ return
182
+
183
+ # Clean up cuda memory.
184
+ devices = get_devices()
185
+ for device in devices:
186
+ with torch.cuda.device(device):
187
+ torch.cuda.empty_cache()
188
+
189
+
190
+ def shutdown_torch_dist_process_group(workers: List[ActorHandle]):
191
+ ray.get([w.execute.remote(_shutdown_torch_distributed) for w in workers])
.venv/lib/python3.11/site-packages/ray/air/util/transform_pyarrow.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import pyarrow
3
+ except ImportError:
4
+ pyarrow = None
5
+
6
+
7
+ def _is_column_extension_type(ca: "pyarrow.ChunkedArray") -> bool:
8
+ """Whether the provided Arrow Table column is an extension array, using an Arrow
9
+ extension type.
10
+ """
11
+ return isinstance(ca.type, pyarrow.ExtensionType)
12
+
13
+
14
+ def _concatenate_extension_column(ca: "pyarrow.ChunkedArray") -> "pyarrow.Array":
15
+ """Concatenate chunks of an extension column into a contiguous array.
16
+
17
+ This concatenation is required for creating copies and for .take() to work on
18
+ extension arrays.
19
+ See https://issues.apache.org/jira/browse/ARROW-16503.
20
+ """
21
+ from ray.air.util.tensor_extensions.arrow import (
22
+ ArrowTensorArray,
23
+ get_arrow_extension_tensor_types,
24
+ )
25
+
26
+ if not _is_column_extension_type(ca):
27
+ raise ValueError("Chunked array isn't an extension array: {ca}")
28
+
29
+ tensor_extension_types = get_arrow_extension_tensor_types()
30
+
31
+ if ca.num_chunks == 0:
32
+ # Create empty storage array.
33
+ storage = pyarrow.array([], type=ca.type.storage_type)
34
+ elif isinstance(ca.type, tensor_extension_types):
35
+ return ArrowTensorArray._concat_same_type(ca.chunks)
36
+ else:
37
+ storage = pyarrow.concat_arrays([c.storage for c in ca.chunks])
38
+
39
+ return ca.type.__arrow_ext_class__().from_storage(ca.type, storage)
.venv/lib/python3.11/site-packages/ray/serve/_private/__pycache__/deployment_state.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fed1e2f11cb7a8c80b117de5ff1af60d479c6ccb4594d860948a52f971a6ec45
3
+ size 125314
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (202 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/common.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_noop_latency.cpython-311.pyc ADDED
Binary file (2.04 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/handle_throughput.cpython-311.pyc ADDED
Binary file (2.47 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/http_noop_latency.cpython-311.pyc ADDED
Binary file (2.08 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/microbenchmark.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/__pycache__/proxy_benchmark.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/common.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import random
5
+ import string
6
+ import time
7
+ from functools import partial
8
+ from typing import Any, Callable, Coroutine, List, Optional, Tuple
9
+
10
+ import aiohttp
11
+ import aiohttp.client_exceptions
12
+ import grpc
13
+ import numpy as np
14
+ import pandas as pd
15
+ from starlette.responses import StreamingResponse
16
+ from tqdm import tqdm
17
+
18
+ from ray import serve
19
+ from ray.serve.generated import serve_pb2, serve_pb2_grpc
20
+ from ray.serve.handle import DeploymentHandle
21
+
22
+
23
+ async def run_latency_benchmark(
24
+ f: Callable, num_requests: int, *, num_warmup_requests: int = 100
25
+ ) -> pd.Series:
26
+ if inspect.iscoroutinefunction(f):
27
+ to_call = f
28
+ else:
29
+
30
+ async def to_call():
31
+ f()
32
+
33
+ latencies = []
34
+ for i in tqdm(range(num_requests + num_warmup_requests)):
35
+ start = time.perf_counter()
36
+ await to_call()
37
+ end = time.perf_counter()
38
+
39
+ # Don't include warm-up requests.
40
+ if i >= num_warmup_requests:
41
+ latencies.append(1000 * (end - start))
42
+
43
+ return pd.Series(latencies)
44
+
45
+
46
+ async def run_throughput_benchmark(
47
+ fn: Callable[[], List[float]],
48
+ multiplier: int = 1,
49
+ num_trials: int = 10,
50
+ trial_runtime: float = 1,
51
+ ) -> Tuple[float, float, pd.Series]:
52
+ """Benchmarks throughput of a function.
53
+
54
+ Args:
55
+ fn: The function to benchmark. If this returns anything, it must
56
+ return a list of latencies.
57
+ multiplier: The number of requests or tokens (or whatever unit
58
+ is appropriate for this throughput benchmark) that is
59
+ completed in one call to `fn`.
60
+ num_trials: The number of trials to run.
61
+ trial_runtime: How long each trial should run for. During the
62
+ duration of one trial, `fn` will be repeatedly called.
63
+
64
+ Returns (mean, stddev, latencies).
65
+ """
66
+ # Warmup
67
+ start = time.time()
68
+ while time.time() - start < 0.1:
69
+ await fn()
70
+
71
+ # Benchmark
72
+ stats = []
73
+ latencies = []
74
+ for _ in tqdm(range(num_trials)):
75
+ start = time.perf_counter()
76
+ count = 0
77
+ while time.perf_counter() - start < trial_runtime:
78
+ res = await fn()
79
+ if res:
80
+ latencies.extend(res)
81
+
82
+ count += 1
83
+ end = time.perf_counter()
84
+ stats.append(multiplier * count / (end - start))
85
+
86
+ return round(np.mean(stats), 2), round(np.std(stats), 2), pd.Series(latencies)
87
+
88
+
89
+ async def do_single_http_batch(
90
+ *,
91
+ batch_size: int = 100,
92
+ url: str = "http://localhost:8000",
93
+ stream: bool = False,
94
+ ) -> List[float]:
95
+ """Sends a batch of http requests and returns e2e latencies."""
96
+
97
+ # By default, aiohttp limits the number of client connections to 100.
98
+ # We need to use TCPConnector to configure the limit if batch size
99
+ # is greater than 100.
100
+ connector = aiohttp.TCPConnector(limit=batch_size)
101
+ async with aiohttp.ClientSession(
102
+ connector=connector, raise_for_status=True
103
+ ) as session:
104
+
105
+ async def do_query():
106
+ start = time.perf_counter()
107
+ try:
108
+ if stream:
109
+ async with session.get(url) as r:
110
+ async for chunk, _ in r.content.iter_chunks():
111
+ pass
112
+ else:
113
+ await session.get(url)
114
+ except aiohttp.client_exceptions.ClientConnectionError:
115
+ pass
116
+
117
+ end = time.perf_counter()
118
+ return 1000 * (end - start)
119
+
120
+ return await asyncio.gather(*[do_query() for _ in range(batch_size)])
121
+
122
+
123
+ async def do_single_grpc_batch(
124
+ *, batch_size: int = 100, target: str = "localhost:9000"
125
+ ):
126
+ channel = grpc.aio.insecure_channel(target)
127
+ stub = serve_pb2_grpc.RayServeBenchmarkServiceStub(channel)
128
+ payload = serve_pb2.StringData(data="")
129
+
130
+ async def do_query():
131
+ start = time.perf_counter()
132
+
133
+ await stub.grpc_call(payload)
134
+
135
+ end = time.perf_counter()
136
+ return 1000 * (end - start)
137
+
138
+ return await asyncio.gather(*[do_query() for _ in range(batch_size)])
139
+
140
+
141
+ async def collect_profile_events(coro: Coroutine):
142
+ """Collects profiling events using Viztracer"""
143
+
144
+ from viztracer import VizTracer
145
+
146
+ tracer = VizTracer()
147
+ tracer.start()
148
+
149
+ await coro
150
+
151
+ tracer.stop()
152
+ tracer.save()
153
+
154
+
155
+ def generate_payload(size: int = 100, chars=string.ascii_uppercase + string.digits):
156
+ return "".join(random.choice(chars) for _ in range(size))
157
+
158
+
159
+ class Blackhole:
160
+ def sink(self, o):
161
+ pass
162
+
163
+
164
+ @serve.deployment
165
+ class Noop:
166
+ def __init__(self):
167
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
168
+
169
+ def __call__(self, *args, **kwargs):
170
+ return b""
171
+
172
+
173
+ @serve.deployment
174
+ class Streamer:
175
+ def __init__(self, tokens_per_request: int, inter_token_delay_ms: int = 10):
176
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
177
+ self._tokens_per_request = tokens_per_request
178
+ self._inter_token_delay_s = inter_token_delay_ms / 1000
179
+
180
+ async def stream(self):
181
+ for _ in range(self._tokens_per_request):
182
+ await asyncio.sleep(self._inter_token_delay_s)
183
+ yield b"hi"
184
+
185
+ async def __call__(self):
186
+ return StreamingResponse(self.stream())
187
+
188
+
189
+ @serve.deployment
190
+ class IntermediateRouter:
191
+ def __init__(self, handle: DeploymentHandle):
192
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
193
+ self._handle = handle.options(stream=True)
194
+
195
+ async def stream(self):
196
+ async for token in self._handle.stream.remote():
197
+ yield token
198
+
199
+ def __call__(self):
200
+ return StreamingResponse(self.stream())
201
+
202
+
203
+ @serve.deployment
204
+ class Benchmarker:
205
+ def __init__(
206
+ self,
207
+ handle: DeploymentHandle,
208
+ stream: bool = False,
209
+ ):
210
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
211
+ self._handle = handle.options(stream=stream)
212
+ self._stream = stream
213
+
214
+ async def do_single_request(self, payload: Any = None) -> float:
215
+ """Completes a single unary request. Returns e2e latency in ms."""
216
+ start = time.perf_counter()
217
+
218
+ if payload is None:
219
+ await self._handle.remote()
220
+ else:
221
+ await self._handle.remote(payload)
222
+
223
+ end = time.perf_counter()
224
+ return 1000 * (end - start)
225
+
226
+ async def _do_single_stream(self) -> float:
227
+ """Consumes a single streaming request. Returns e2e latency in ms."""
228
+ start = time.perf_counter()
229
+
230
+ async for r in self._handle.stream.remote():
231
+ pass
232
+
233
+ end = time.perf_counter()
234
+ return 1000 * (end - start)
235
+
236
+ async def _do_single_batch(self, batch_size: int) -> List[float]:
237
+ if self._stream:
238
+ return await asyncio.gather(
239
+ *[self._do_single_stream() for _ in range(batch_size)]
240
+ )
241
+ else:
242
+ return await asyncio.gather(
243
+ *[self.do_single_request() for _ in range(batch_size)]
244
+ )
245
+
246
+ async def run_latency_benchmark(
247
+ self, *, num_requests: int, payload: Any = None
248
+ ) -> pd.Series:
249
+ async def f():
250
+ await self.do_single_request(payload)
251
+
252
+ return await run_latency_benchmark(f, num_requests=num_requests)
253
+
254
+ async def run_throughput_benchmark(
255
+ self,
256
+ *,
257
+ batch_size: int,
258
+ num_trials: int,
259
+ trial_runtime: float,
260
+ tokens_per_request: Optional[float] = None,
261
+ ) -> Tuple[float, float]:
262
+ if self._stream:
263
+ assert tokens_per_request
264
+ multiplier = tokens_per_request * batch_size
265
+ else:
266
+ multiplier = batch_size
267
+
268
+ return await run_throughput_benchmark(
269
+ fn=partial(
270
+ self._do_single_batch,
271
+ batch_size=batch_size,
272
+ ),
273
+ multiplier=multiplier,
274
+ num_trials=num_trials,
275
+ trial_runtime=trial_runtime,
276
+ )
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_noop_latency.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import click
4
+ import pandas as pd
5
+
6
+ from ray import serve
7
+ from ray.serve._private.benchmarks.common import Benchmarker, Noop
8
+ from ray.serve.handle import DeploymentHandle
9
+
10
+
11
+ @click.command(help="Benchmark no-op DeploymentHandle latency.")
12
+ @click.option("--num-replicas", type=int, default=1)
13
+ @click.option("--num-requests", type=int, default=100)
14
+ def main(num_replicas: int, num_requests: int):
15
+ h: DeploymentHandle = serve.run(
16
+ Benchmarker.bind(Noop.options(num_replicas=num_replicas).bind())
17
+ )
18
+
19
+ latencies: pd.Series = h.run_latency_benchmark.remote(
20
+ num_requests,
21
+ ).result()
22
+
23
+ # Let the logs flush to avoid interwoven output.
24
+ time.sleep(1)
25
+
26
+ print(
27
+ "Latency (ms) for noop DeploymentHandle requests "
28
+ f"(num_replicas={num_replicas},num_requests={num_requests}):"
29
+ )
30
+ print(latencies.describe(percentiles=[0.5, 0.9, 0.95, 0.99]))
31
+
32
+
33
+ if __name__ == "__main__":
34
+ main()
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/handle_throughput.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+
3
+ from ray import serve
4
+ from ray.serve._private.benchmarks.common import Benchmarker, Hello
5
+ from ray.serve.handle import DeploymentHandle
6
+
7
+
8
+ @click.command(help="Benchmark deployment handle throughput.")
9
+ @click.option(
10
+ "--batch-size",
11
+ type=int,
12
+ default=100,
13
+ help="Number of requests to send to downstream deployment in each trial.",
14
+ )
15
+ @click.option(
16
+ "--num-replicas",
17
+ type=int,
18
+ default=1,
19
+ help="Number of replicas in the downstream deployment.",
20
+ )
21
+ @click.option(
22
+ "--num-trials",
23
+ type=int,
24
+ default=5,
25
+ help="Number of trials of the benchmark to run.",
26
+ )
27
+ @click.option(
28
+ "--trial-runtime",
29
+ type=int,
30
+ default=1,
31
+ help="Duration to run each trial of the benchmark for (seconds).",
32
+ )
33
+ def main(
34
+ batch_size: int,
35
+ num_replicas: int,
36
+ num_trials: int,
37
+ trial_runtime: float,
38
+ ):
39
+ app = Benchmarker.bind(
40
+ Hello.options(
41
+ num_replicas=num_replicas, ray_actor_options={"num_cpus": 0}
42
+ ).bind(),
43
+ )
44
+ h: DeploymentHandle = serve.run(app)
45
+
46
+ mean, stddev = h.run_throughput_benchmark.remote(
47
+ batch_size=batch_size,
48
+ num_trials=num_trials,
49
+ trial_runtime=trial_runtime,
50
+ ).result()
51
+
52
+ print(
53
+ "DeploymentHandle throughput {}: {} +- {} requests/s".format(
54
+ f"(num_replicas={num_replicas}, batch_size={batch_size})",
55
+ mean,
56
+ stddev,
57
+ )
58
+ )
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/http_noop_latency.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import click
4
+ import pandas as pd
5
+ import requests
6
+
7
+ from ray import serve
8
+ from ray.serve._private.benchmarks.common import Noop, run_latency_benchmark
9
+
10
+
11
+ @click.command(help="Benchmark no-op HTTP latency.")
12
+ @click.option("--num-replicas", type=int, default=1)
13
+ @click.option("--num-requests", type=int, default=100)
14
+ def main(num_replicas: int, num_requests: int):
15
+ serve.run(Noop.options(num_replicas=num_replicas).bind())
16
+
17
+ latencies: pd.Series = asyncio.new_event_loop().run_until_complete(
18
+ run_latency_benchmark(
19
+ lambda: requests.get("http://localhost:8000"),
20
+ num_requests=num_requests,
21
+ )
22
+ )
23
+
24
+ print(
25
+ "Latency (ms) for noop HTTP requests "
26
+ f"(num_replicas={num_replicas},num_requests={num_requests}):"
27
+ )
28
+ print(latencies.describe(percentiles=[0.5, 0.9, 0.95, 0.99]))
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/microbenchmark.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runs several scenarios with varying max batch size, max concurrent queries,
2
+ # number of replicas, and with intermediate serve handles (to simulate ensemble
3
+ # models) either on or off.
4
+
5
+ import asyncio
6
+ import logging
7
+ from pprint import pprint
8
+ from typing import Dict, Union
9
+
10
+ import aiohttp
11
+ from starlette.requests import Request
12
+
13
+ import ray
14
+ from ray import serve
15
+ from ray.serve._private.benchmarks.common import run_throughput_benchmark
16
+ from ray.serve.handle import DeploymentHandle
17
+
18
+ NUM_CLIENTS = 8
19
+ CALLS_PER_BATCH = 100
20
+
21
+
22
+ async def fetch(session, data):
23
+ async with session.get("http://localhost:8000/", data=data) as response:
24
+ response = await response.text()
25
+ assert response == "ok", response
26
+
27
+
28
+ @ray.remote
29
+ class Client:
30
+ def ready(self):
31
+ return "ok"
32
+
33
+ async def do_queries(self, num, data):
34
+ async with aiohttp.ClientSession() as session:
35
+ for _ in range(num):
36
+ await fetch(session, data)
37
+
38
+
39
+ def build_app(
40
+ intermediate_handles: bool,
41
+ num_replicas: int,
42
+ max_batch_size: int,
43
+ max_ongoing_requests: int,
44
+ ):
45
+ @serve.deployment(max_ongoing_requests=1000)
46
+ class Upstream:
47
+ def __init__(self, handle: DeploymentHandle):
48
+ self._handle = handle
49
+
50
+ # Turn off access log.
51
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
52
+
53
+ async def __call__(self, req: Request):
54
+ return await self._handle.remote(await req.body())
55
+
56
+ @serve.deployment(
57
+ num_replicas=num_replicas,
58
+ max_ongoing_requests=max_ongoing_requests,
59
+ )
60
+ class Downstream:
61
+ def __init__(self):
62
+ # Turn off access log.
63
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
64
+
65
+ @serve.batch(max_batch_size=max_batch_size)
66
+ async def batch(self, reqs):
67
+ return [b"ok"] * len(reqs)
68
+
69
+ async def __call__(self, req: Union[bytes, Request]):
70
+ if max_batch_size > 1:
71
+ return await self.batch(req)
72
+ else:
73
+ return b"ok"
74
+
75
+ if intermediate_handles:
76
+ return Upstream.bind(Downstream.bind())
77
+ else:
78
+ return Downstream.bind()
79
+
80
+
81
+ async def trial(
82
+ intermediate_handles: bool,
83
+ num_replicas: int,
84
+ max_batch_size: int,
85
+ max_ongoing_requests: int,
86
+ data_size: str,
87
+ ) -> Dict[str, float]:
88
+ results = {}
89
+
90
+ trial_key_base = (
91
+ f"replica:{num_replicas}/batch_size:{max_batch_size}/"
92
+ f"concurrent_queries:{max_ongoing_requests}/"
93
+ f"data_size:{data_size}/intermediate_handle:{intermediate_handles}"
94
+ )
95
+
96
+ print(
97
+ f"intermediate_handles={intermediate_handles},"
98
+ f"num_replicas={num_replicas},"
99
+ f"max_batch_size={max_batch_size},"
100
+ f"max_ongoing_requests={max_ongoing_requests},"
101
+ f"data_size={data_size}"
102
+ )
103
+
104
+ app = build_app(
105
+ intermediate_handles, num_replicas, max_batch_size, max_ongoing_requests
106
+ )
107
+ serve.run(app)
108
+
109
+ if data_size == "small":
110
+ data = None
111
+ elif data_size == "large":
112
+ data = b"a" * 1024 * 1024
113
+ else:
114
+ raise ValueError("data_size should be 'small' or 'large'.")
115
+
116
+ async with aiohttp.ClientSession() as session:
117
+
118
+ async def single_client():
119
+ for _ in range(CALLS_PER_BATCH):
120
+ await fetch(session, data)
121
+
122
+ single_client_avg_tps, single_client_std_tps = await run_throughput_benchmark(
123
+ single_client,
124
+ multiplier=CALLS_PER_BATCH,
125
+ )
126
+ print(
127
+ "\t{} {} +- {} requests/s".format(
128
+ "single client {} data".format(data_size),
129
+ single_client_avg_tps,
130
+ single_client_std_tps,
131
+ )
132
+ )
133
+ key = f"num_client:1/{trial_key_base}"
134
+ results[key] = single_client_avg_tps
135
+
136
+ clients = [Client.remote() for _ in range(NUM_CLIENTS)]
137
+ ray.get([client.ready.remote() for client in clients])
138
+
139
+ async def many_clients():
140
+ ray.get([a.do_queries.remote(CALLS_PER_BATCH, data) for a in clients])
141
+
142
+ multi_client_avg_tps, _ = await run_throughput_benchmark(
143
+ many_clients,
144
+ multiplier=CALLS_PER_BATCH * len(clients),
145
+ )
146
+
147
+ results[f"num_client:{len(clients)}/{trial_key_base}"] = multi_client_avg_tps
148
+ return results
149
+
150
+
151
+ async def main():
152
+ results = {}
153
+ for intermediate_handles in [False, True]:
154
+ for num_replicas in [1, 8]:
155
+ for max_batch_size, max_ongoing_requests in [
156
+ (1, 1),
157
+ (1, 10000),
158
+ (10000, 10000),
159
+ ]:
160
+ # TODO(edoakes): large data causes broken pipe errors.
161
+ for data_size in ["small"]:
162
+ results.update(
163
+ await trial(
164
+ intermediate_handles,
165
+ num_replicas,
166
+ max_batch_size,
167
+ max_ongoing_requests,
168
+ data_size,
169
+ )
170
+ )
171
+
172
+ print("Results from all conditions:")
173
+ pprint(results)
174
+ return results
175
+
176
+
177
+ if __name__ == "__main__":
178
+ ray.init()
179
+ serve.start()
180
+ loop = asyncio.new_event_loop()
181
+ asyncio.set_event_loop(loop)
182
+ loop.run_until_complete(main())
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/proxy_benchmark.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runs some request ping to compare HTTP and gRPC performances in TPS and latency.
2
+ # Note: this takes around 1 hour to run.
3
+
4
+ import asyncio
5
+ import json
6
+ import logging
7
+ import time
8
+ from random import random
9
+ from typing import Callable, Dict
10
+
11
+ import aiohttp
12
+ import numpy as np
13
+ import pandas as pd
14
+ from grpc import aio
15
+ from starlette.requests import Request
16
+
17
+ import ray
18
+ from ray import serve
19
+ from ray.serve._private.common import RequestProtocol
20
+ from ray.serve.config import gRPCOptions
21
+ from ray.serve.generated import serve_pb2, serve_pb2_grpc
22
+ from ray.serve.handle import DeploymentHandle
23
+
24
+ CALLS_PER_BATCH = 100
25
+ DELTA = 10**-7
26
+
27
+
28
+ async def get_query_tps(name: str, fn: Callable, multiplier: int = CALLS_PER_BATCH):
29
+ """Get query TPS.
30
+
31
+ Run the function for 0.5 seconds 10 times to calculate how many requests can
32
+ be completed. And use those stats to calculate the mean and std of TPS.
33
+ """
34
+ # warmup
35
+ start = time.time()
36
+ while time.time() - start < 0.1:
37
+ await fn()
38
+ # real run
39
+ stats = []
40
+ for _ in range(10):
41
+ count = 0
42
+ start = time.time()
43
+ while time.time() - start < 0.5:
44
+ await fn()
45
+ count += 1
46
+ end = time.time()
47
+ stats.append(multiplier * count / (end - start))
48
+ tps_mean = round(np.mean(stats), 2)
49
+ tps_std = round(np.std(stats), 2)
50
+ print(f"\t{name} {tps_mean} +- {tps_std} requests/s")
51
+
52
+ return tps_mean, tps_std
53
+
54
+
55
+ async def get_query_latencies(name: str, fn: Callable):
56
+ """Get query latencies.
57
+
58
+ Take all the latencies from the function and calculate the mean and std.
59
+ """
60
+ many_client_results = np.asarray(await fn())
61
+ many_client_results.flatten()
62
+ latency_ms_mean = round(np.mean(many_client_results) * 1000, 2)
63
+ latency_ms_std = round(np.std(many_client_results) * 1000, 2)
64
+ print(f"\t{name} {latency_ms_mean} +- {latency_ms_std} ms")
65
+
66
+ return latency_ms_mean, latency_ms_std
67
+
68
+
69
+ async def fetch_http(session, data):
70
+ data_json = {"nums": data}
71
+ response = await session.get("http://localhost:8000/", json=data_json)
72
+ response_text = await response.read()
73
+ float(response_text.decode())
74
+
75
+
76
+ async def fetch_grpc(stub, data):
77
+ result = await stub.grpc_call(serve_pb2.RawData(nums=data))
78
+ result.output
79
+
80
+
81
+ @ray.remote
82
+ class HTTPClient:
83
+ def ready(self):
84
+ return "ok"
85
+
86
+ async def do_queries(self, num, data):
87
+ async with aiohttp.ClientSession() as session:
88
+ for _ in range(num):
89
+ await fetch_http(session, data)
90
+
91
+ async def time_queries(self, num, data):
92
+ stats = []
93
+ async with aiohttp.ClientSession() as session:
94
+ for _ in range(num):
95
+ start = time.time()
96
+ await fetch_http(session, data)
97
+ end = time.time()
98
+ stats.append(end - start)
99
+
100
+ return stats
101
+
102
+
103
+ @ray.remote
104
+ class gRPCClient:
105
+ def __init__(self):
106
+ channel = aio.insecure_channel("localhost:9000")
107
+ self.stub = serve_pb2_grpc.RayServeBenchmarkServiceStub(channel)
108
+
109
+ def ready(self):
110
+ return "ok"
111
+
112
+ async def do_queries(self, num, data):
113
+ for _ in range(num):
114
+ await fetch_grpc(self.stub, data)
115
+
116
+ async def time_queries(self, num, data):
117
+ stats = []
118
+ for _ in range(num):
119
+ start = time.time()
120
+ await fetch_grpc(self.stub, data)
121
+ end = time.time()
122
+ stats.append(end - start)
123
+ return stats
124
+
125
+
126
+ def build_app(
127
+ num_replicas: int,
128
+ max_ongoing_requests: int,
129
+ data_size: int,
130
+ ):
131
+ @serve.deployment(max_ongoing_requests=1000)
132
+ class DataPreprocessing:
133
+ def __init__(self, handle: DeploymentHandle):
134
+ self._handle = handle
135
+
136
+ # Turn off access log.
137
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
138
+
139
+ def normalize(self, raw: np.ndarray) -> np.ndarray:
140
+ return (raw - np.min(raw)) / (np.max(raw) - np.min(raw) + DELTA)
141
+
142
+ async def __call__(self, req: Request):
143
+ """HTTP entrypoint.
144
+
145
+ It parses the request, normalize the data, and send to model for inference.
146
+ """
147
+ body = json.loads(await req.body())
148
+ raw = np.asarray(body["nums"])
149
+ processed = self.normalize(raw)
150
+ return await self._handle.remote(processed)
151
+
152
+ async def grpc_call(self, raq_data):
153
+ """gRPC entrypoint.
154
+
155
+ It parses the request, normalize the data, and send to model for inference.
156
+ """
157
+ raw = np.asarray(raq_data.nums)
158
+ processed = self.normalize(raw)
159
+ output = await self._handle.remote(processed)
160
+ return serve_pb2.ModelOutput(output=output)
161
+
162
+ async def call_with_string(self, raq_data):
163
+ """gRPC entrypoint."""
164
+ return serve_pb2.ModelOutput(output=0)
165
+
166
+ @serve.deployment(
167
+ num_replicas=num_replicas,
168
+ max_ongoing_requests=max_ongoing_requests,
169
+ )
170
+ class ModelInference:
171
+ def __init__(self):
172
+ # Turn off access log.
173
+ logging.getLogger("ray.serve").setLevel(logging.WARNING)
174
+ self._model = np.random.randn(data_size, data_size)
175
+
176
+ async def __call__(self, processed: np.ndarray) -> float:
177
+ # Run a dot product with a random matrix to simulate a model inference.
178
+ model_output = np.dot(processed, self._model)
179
+ return sum(model_output)
180
+
181
+ return DataPreprocessing.bind(ModelInference.bind())
182
+
183
+
184
+ async def trial(
185
+ num_replicas: int,
186
+ max_ongoing_requests: int,
187
+ data_size: int,
188
+ num_clients: int,
189
+ proxy: RequestProtocol,
190
+ ) -> Dict[str, float]:
191
+ # Generate input data as array of random floats.
192
+ data = [random() for _ in range(data_size)]
193
+
194
+ # Build and deploy the app.
195
+ app = build_app(
196
+ num_replicas=num_replicas,
197
+ max_ongoing_requests=max_ongoing_requests,
198
+ data_size=data_size,
199
+ )
200
+ serve.run(app)
201
+
202
+ # Start clients.
203
+ if proxy == RequestProtocol.GRPC:
204
+ clients = [gRPCClient.remote() for _ in range(num_clients)]
205
+ elif proxy == RequestProtocol.HTTP:
206
+ clients = [HTTPClient.remote() for _ in range(num_clients)]
207
+ ray.get([client.ready.remote() for client in clients])
208
+
209
+ async def client_time_queries():
210
+ return ray.get([a.time_queries.remote(CALLS_PER_BATCH, data) for a in clients])
211
+
212
+ async def client_do_queries():
213
+ ray.get([a.do_queries.remote(CALLS_PER_BATCH, data) for a in clients])
214
+
215
+ trial_key_base = (
216
+ f"proxy:{proxy}/"
217
+ f"num_client:{num_clients}/"
218
+ f"replica:{num_replicas}/"
219
+ f"concurrent_queries:{max_ongoing_requests}/"
220
+ f"data_size:{data_size}"
221
+ )
222
+ tps_mean, tps_sdt = await get_query_tps(
223
+ trial_key_base,
224
+ client_do_queries,
225
+ )
226
+ latency_ms_mean, latency_ms_std = await get_query_latencies(
227
+ trial_key_base,
228
+ client_time_queries,
229
+ )
230
+
231
+ results = {
232
+ "proxy": proxy.value,
233
+ "num_client": num_clients,
234
+ "replica": num_replicas,
235
+ "concurrent_queries": max_ongoing_requests,
236
+ "data_size": data_size,
237
+ "tps_mean": tps_mean,
238
+ "tps_sdt": tps_sdt,
239
+ "latency_ms_mean": latency_ms_mean,
240
+ "latency_ms_std": latency_ms_std,
241
+ }
242
+
243
+ return results
244
+
245
+
246
+ async def main():
247
+ start_time = time.time()
248
+ results = []
249
+ for num_replicas in [1, 8]:
250
+ for max_ongoing_requests in [1, 10_000]:
251
+ for data_size in [1, 100, 10_000]:
252
+ for num_clients in [1, 8]:
253
+ for proxy in [RequestProtocol.GRPC, RequestProtocol.HTTP]:
254
+ results.append(
255
+ await trial(
256
+ num_replicas=num_replicas,
257
+ max_ongoing_requests=max_ongoing_requests,
258
+ data_size=data_size,
259
+ num_clients=num_clients,
260
+ proxy=proxy,
261
+ )
262
+ )
263
+
264
+ print(f"Total time: {time.time() - start_time}s")
265
+ print("results", results)
266
+
267
+ df = pd.DataFrame.from_dict(results)
268
+ df = df.sort_values(
269
+ by=["proxy", "num_client", "replica", "concurrent_queries", "data_size"]
270
+ )
271
+ print("Results from all conditions:")
272
+ # Print the results in with tab separated so we can copy into google sheets.
273
+ for i in range(len(df.index)):
274
+ row = list(df.iloc[i])
275
+ print("\t".join(map(str, row)))
276
+
277
+
278
+ if __name__ == "__main__":
279
+ ray.init()
280
+
281
+ grpc_port = 9000
282
+ grpc_servicer_functions = [
283
+ "ray.serve.generated.serve_pb2_grpc."
284
+ "add_RayServeBenchmarkServiceServicer_to_server",
285
+ ]
286
+ serve.start(
287
+ grpc_options=gRPCOptions(
288
+ port=grpc_port,
289
+ grpc_servicer_functions=grpc_servicer_functions,
290
+ )
291
+ )
292
+ loop = asyncio.new_event_loop()
293
+ asyncio.set_event_loop(loop)
294
+ loop.run_until_complete(main())
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (216 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/common.cpython-311.pyc ADDED
Binary file (1.53 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/__pycache__/serialization_benchmark.cpython-311.pyc ADDED
Binary file (7.37 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/common.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+ #
7
+ # NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
8
+ #
9
+ # Payloads in this module are purposefully extracted from benchmark file to force
10
+ # Ray's cloudpickle behavior when it does NOT serialize the class definition itself
11
+ # along with its payload (instead relying on it being imported)
12
+ #
13
+
14
+
15
+ class PayloadPydantic(BaseModel):
16
+ text: Optional[str] = None
17
+ floats: Optional[List[float]] = None
18
+ ints: Optional[List[int]] = None
19
+ ts: Optional[float] = None
20
+ reason: Optional[str] = None
21
+
22
+
23
+ @dataclass
24
+ class PayloadDataclass:
25
+ text: Optional[str] = None
26
+ floats: Optional[List[float]] = None
27
+ ints: Optional[List[int]] = None
28
+ ts: Optional[float] = None
29
+ reason: Optional[str] = None
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/serialization/serialization_benchmark.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import enum
3
+ import pickle
4
+ import time
5
+ from typing import Any, Callable
6
+
7
+ import click
8
+ import msgpack
9
+
10
+ from ray._private.serialization import SerializationContext
11
+ from ray.cloudpickle import cloudpickle_fast
12
+ from ray.serve._private.benchmarks.common import (
13
+ collect_profile_events,
14
+ run_latency_benchmark,
15
+ )
16
+ from ray.serve._private.benchmarks.serialization.common import (
17
+ PayloadDataclass,
18
+ PayloadPydantic,
19
+ )
20
+
21
+
22
+ class PayloadType(enum.Enum):
23
+ PYDANTIC = "pydantic"
24
+ DATACLASS = "dataclass"
25
+
26
+
27
+ class SerializerType(enum.Enum):
28
+ RAY = "ray"
29
+ PICKLE = "pickle"
30
+ CLOUDPICKLE = "cloudpickle"
31
+ MSGPACK = "msgpack"
32
+
33
+
34
+ _PERCENTILES = [0.5, 0.99]
35
+
36
+
37
+ sc = SerializationContext(None)
38
+
39
+
40
+ def _create_model(cls):
41
+ return cls(
42
+ text="Test output",
43
+ floats=[float(f) for f in range(1, 100)],
44
+ ints=list(range(1, 100)),
45
+ ts=time.time(),
46
+ reason="Success!",
47
+ )
48
+
49
+
50
+ def _blackhole(o):
51
+ """Placeholder to be used in the benchmark to make sure runtime
52
+ doesn't optimize out unused results"""
53
+ pass
54
+
55
+
56
+ async def run_serializer_benchmark(
57
+ model, serializer: Callable[[Any], bytes], iterations: int
58
+ ):
59
+ def _serde_loop():
60
+ bs = serializer(model)
61
+ _blackhole(bs)
62
+
63
+ pd = await run_latency_benchmark(_serde_loop, iterations)
64
+
65
+ print("Latencies (ms):\n", pd.describe(percentiles=_PERCENTILES))
66
+
67
+
68
+ @click.command(help="Benchmark serialization latency")
69
+ @click.option(
70
+ "--trials",
71
+ type=int,
72
+ default=1000,
73
+ help="Total number of trials to run in a single benchmark run",
74
+ )
75
+ @click.option(
76
+ "--batch-size",
77
+ type=int,
78
+ default=10,
79
+ help="Controls how many objects are contained in a serialized batch",
80
+ )
81
+ @click.option(
82
+ "--payload-type",
83
+ type=PayloadType,
84
+ help="Target type of the payload to be benchmarked (supported: pydantic, "
85
+ "dataclass)",
86
+ )
87
+ @click.option(
88
+ "--serializer",
89
+ type=SerializerType,
90
+ help="Target type of the serializer to be benchmarked (supported: ray, pickle, "
91
+ "cloudpickle, msgpack)",
92
+ )
93
+ @click.option(
94
+ "--profile-events",
95
+ type=bool,
96
+ default=False,
97
+ )
98
+ def main(
99
+ trials: int,
100
+ batch_size: int,
101
+ payload_type: PayloadType,
102
+ serializer: SerializerType,
103
+ profile_events: bool,
104
+ ):
105
+ if serializer == SerializerType.RAY:
106
+
107
+ def _serialize(obj):
108
+ so = sc.serialize(obj)
109
+ bs = so.to_bytes()
110
+ return bs
111
+
112
+ elif serializer == SerializerType.CLOUDPICKLE:
113
+
114
+ def _serialize(obj):
115
+ bs = cloudpickle_fast.dumps(obj)
116
+ return bs
117
+
118
+ elif serializer == SerializerType.PICKLE:
119
+
120
+ def _serialize(obj):
121
+ bs = pickle.dumps(obj)
122
+ return bs
123
+
124
+ elif serializer == SerializerType.MSGPACK:
125
+
126
+ def _dumps(obj):
127
+ bs = msgpack.dumps(obj.__dict__)
128
+ # print(f"Bytes ({len(bs)}): ", bs)
129
+ return bs
130
+
131
+ def _loads(bs):
132
+ dict = msgpack.loads(bs)
133
+ return PayloadPydantic(**dict)
134
+
135
+ sc._register_cloudpickle_serializer(PayloadPydantic, _dumps, _loads)
136
+
137
+ def _serialize(obj):
138
+ so = sc.serialize(obj)
139
+ bs = so.to_bytes()
140
+ return bs
141
+
142
+ else:
143
+ raise NotImplementedError(serializer)
144
+
145
+ if payload_type == PayloadType.PYDANTIC:
146
+ model = _create_model(PayloadPydantic)
147
+ elif payload_type == PayloadType.DATACLASS:
148
+ model = _create_model(PayloadDataclass)
149
+ else:
150
+ raise NotImplementedError(f"Not supported ({payload_type})")
151
+
152
+ payload = [model.copy(deep=True) for _ in range(batch_size)]
153
+
154
+ routine = run_serializer_benchmark(payload, _serialize, trials)
155
+
156
+ if profile_events:
157
+ routine = collect_profile_events(routine)
158
+
159
+ asyncio.run(routine)
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (212 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/common.cpython-311.pyc ADDED
Binary file (7.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_core_throughput.cpython-311.pyc ADDED
Binary file (4.08 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_grpc_throughput.cpython-311.pyc ADDED
Binary file (9.07 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_handle_throughput.cpython-311.pyc ADDED
Binary file (4.19 kB). View file
 
.venv/lib/python3.11/site-packages/ray/serve/_private/benchmarks/streaming/__pycache__/streaming_http_throughput.cpython-311.pyc ADDED
Binary file (7.7 kB). View file